TAHG 0.0.3
This commit is contained in:
@@ -184,22 +184,23 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
self.edges_path = Path(edges_path)
|
||||
assert self.edges_path.exists()
|
||||
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
||||
self.B = SingleFolderDataset(root_b, pipeline, with_path=False)
|
||||
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
|
||||
self.random_pair = random_pair
|
||||
|
||||
def get_edge(self, origin_path):
|
||||
op = Path(origin_path)
|
||||
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{self.edge_type}.png"
|
||||
img = Image.open(edge_path).resize(self.size)
|
||||
return {"edge": F.to_tensor(img)}
|
||||
return F.to_tensor(img)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||
output = dict()
|
||||
output["a"], path_a = self.A[a_idx]
|
||||
output.update(self.get_edge(path_a))
|
||||
output["b"] = self.B[b_idx]
|
||||
output["b"], path_b = self.B[b_idx]
|
||||
output["edge_a"] = self.get_edge(path_a)
|
||||
output["edge_b"] = self.get_edge(path_b)
|
||||
return output
|
||||
|
||||
def __len__(self):
|
||||
|
||||
Reference in New Issue
Block a user