TAFG 0.01

This commit is contained in:
2020-09-03 09:34:38 +08:00
parent 14d4247112
commit 2469bf15fe
6 changed files with 37 additions and 388 deletions

View File

@@ -15,6 +15,7 @@ from tqdm import tqdm
from .transform import transform_pipeline
from .registry import DATASET
from .util import dlib_landmark
def default_transform_way(transform, sample):
@@ -178,20 +179,38 @@ class GenerationUnpairedDataset(Dataset):
@DATASET.register_module()
class GenerationUnpairedDatasetWithEdge(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, size=(256, 256)):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256)):
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
self.edge_type = edge_type
self.size = size
self.edges_path = Path(edges_path)
self.landmarks_path = Path(landmarks_path)
assert self.edges_path.exists()
assert self.landmarks_path.exists()
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
self.random_pair = random_pair
def get_edge(self, origin_path):
op = Path(origin_path)
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{self.edge_type}.png"
img = Image.open(edge_path).resize(self.size)
return F.to_tensor(img)
if self.edge_type.startswith("landmark_"):
edge_type = self.edge_type.lstrip("landmark_")
use_landmark = True
else:
edge_type = self.edge_type
use_landmark = False
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size))
if not use_landmark:
return origin_edge
else:
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.{edge_type}.txt"
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
dist_tensor = torch.from_numpy(dlib_landmark.dist_tensor(key_points))
part_labels = torch.from_numpy(part_labels)
edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
edges = part_edge + edges
return torch.cat([edges, dist_tensor, part_labels], dim=0)
def __getitem__(self, idx):
a_idx = idx % len(self.A)
@@ -200,7 +219,6 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
output["a"], path_a = self.A[a_idx]
output["b"], path_b = self.B[b_idx]
output["edge_a"] = self.get_edge(path_a)
output["edge_b"] = self.get_edge(path_b)
return output
def __len__(self):