TAHG 0.0.3

This commit is contained in:
2020-09-01 09:02:04 +08:00
parent 89b54105c7
commit e71e8d95d0
8 changed files with 97 additions and 36 deletions

View File

@@ -184,22 +184,23 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
self.edges_path = Path(edges_path)
assert self.edges_path.exists()
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=False)
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
self.random_pair = random_pair
def get_edge(self, origin_path):
op = Path(origin_path)
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{self.edge_type}.png"
img = Image.open(edge_path).resize(self.size)
return {"edge": F.to_tensor(img)}
return F.to_tensor(img)
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
output = dict()
output["a"], path_a = self.A[a_idx]
output.update(self.get_edge(path_a))
output["b"] = self.B[b_idx]
output["b"], path_b = self.B[b_idx]
output["edge_a"] = self.get_edge(path_a)
output["edge_b"] = self.get_edge(path_b)
return output
def __len__(self):