TAFG
This commit is contained in:
@@ -203,7 +203,7 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
op = Path(origin_path)
|
||||
if self.edge_type.startswith("landmark_"):
|
||||
edge_type = self.edge_type.lstrip("landmark_")
|
||||
use_landmark = True
|
||||
use_landmark = op.parent.name.endswith("A")
|
||||
else:
|
||||
edge_type = self.edge_type
|
||||
use_landmark = False
|
||||
@@ -225,14 +225,11 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||
if self.with_path:
|
||||
output = {"a": self.A[a_idx], "b": self.B[b_idx]}
|
||||
output["edge_a"] = output["a"][1]
|
||||
return output
|
||||
output = dict()
|
||||
output["a"], path_a = self.A[a_idx]
|
||||
output["b"], path_b = self.B[b_idx]
|
||||
output["edge_a"] = self.get_edge(path_a)
|
||||
output = dict(a={}, b={})
|
||||
output["a"]["img"], output["a"]["path"] = self.A[a_idx]
|
||||
output["b"]["img"], output["b"]["path"] = self.B[b_idx]
|
||||
for p in "ab":
|
||||
output[p]["edge"] = self.get_edge(output[p]["path"])
|
||||
return output
|
||||
|
||||
def __len__(self):
|
||||
|
||||
Reference in New Issue
Block a user