TAFG 0.01
This commit is contained in:
@@ -15,6 +15,7 @@ from tqdm import tqdm
|
||||
|
||||
from .transform import transform_pipeline
|
||||
from .registry import DATASET
|
||||
from .util import dlib_landmark
|
||||
|
||||
|
||||
def default_transform_way(transform, sample):
|
||||
@@ -178,20 +179,38 @@ class GenerationUnpairedDataset(Dataset):
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, size=(256, 256)):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256)):
|
||||
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
|
||||
self.edge_type = edge_type
|
||||
self.size = size
|
||||
self.edges_path = Path(edges_path)
|
||||
self.landmarks_path = Path(landmarks_path)
|
||||
assert self.edges_path.exists()
|
||||
assert self.landmarks_path.exists()
|
||||
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
||||
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
|
||||
self.random_pair = random_pair
|
||||
|
||||
def get_edge(self, origin_path):
|
||||
op = Path(origin_path)
|
||||
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{self.edge_type}.png"
|
||||
img = Image.open(edge_path).resize(self.size)
|
||||
return F.to_tensor(img)
|
||||
if self.edge_type.startswith("landmark_"):
|
||||
edge_type = self.edge_type.lstrip("landmark_")
|
||||
use_landmark = True
|
||||
else:
|
||||
edge_type = self.edge_type
|
||||
use_landmark = False
|
||||
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
|
||||
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size))
|
||||
if not use_landmark:
|
||||
return origin_edge
|
||||
else:
|
||||
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.{edge_type}.txt"
|
||||
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
|
||||
dist_tensor = torch.from_numpy(dlib_landmark.dist_tensor(key_points))
|
||||
part_labels = torch.from_numpy(part_labels)
|
||||
edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
|
||||
edges = part_edge + edges
|
||||
return torch.cat([edges, dist_tensor, part_labels], dim=0)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
@@ -200,7 +219,6 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
output["a"], path_a = self.A[a_idx]
|
||||
output["b"], path_b = self.B[b_idx]
|
||||
output["edge_a"] = self.get_edge(path_a)
|
||||
output["edge_b"] = self.get_edge(path_b)
|
||||
return output
|
||||
|
||||
def __len__(self):
|
||||
|
||||
Reference in New Issue
Block a user