This commit is contained in:
2020-09-17 09:34:53 +08:00
parent 2ff4a91057
commit 61e04de8a5
9 changed files with 168 additions and 288 deletions

View File

@@ -203,7 +203,7 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
op = Path(origin_path)
if self.edge_type.startswith("landmark_"):
edge_type = self.edge_type.lstrip("landmark_")
use_landmark = True
use_landmark = op.parent.name.endswith("A")
else:
edge_type = self.edge_type
use_landmark = False
@@ -225,14 +225,11 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
if self.with_path:
output = {"a": self.A[a_idx], "b": self.B[b_idx]}
output["edge_a"] = output["a"][1]
return output
output = dict()
output["a"], path_a = self.A[a_idx]
output["b"], path_b = self.B[b_idx]
output["edge_a"] = self.get_edge(path_a)
output = dict(a={}, b={})
output["a"]["img"], output["a"]["path"] = self.A[a_idx]
output["b"]["img"], output["b"]["path"] = self.B[b_idx]
for p in "ab":
output[p]["edge"] = self.get_edge(output[p]["path"])
return output
def __len__(self):