TANG 0.0.1

This commit is contained in:
2020-08-30 09:34:23 +08:00
parent 7a85499edf
commit 715a2e64a1
10 changed files with 690 additions and 2 deletions

View File

@@ -1,5 +1,6 @@
import os
import pickle
from pathlib import Path
from collections import defaultdict
import torch
@@ -171,3 +172,33 @@ class GenerationUnpairedDataset(Dataset):
def __repr__(self):
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
@DATASET.register_module()
class GenerationUnpairedDatasetWithEdge(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type):
self.edge_type = edge_type
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=False)
self.random_pair = random_pair
def get_edge(self, origin_path):
op = Path(origin_path)
add = torch.load(op.parent / f"{op.stem}.add")
return {"edge": add["edge"].float().unsqueeze(dim=0),
"additional_info": torch.cat([add["seg"].float(), add["dist"].float()], dim=0)}
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
output = dict()
output["a"], path_a = self.A[a_idx]
output.update(self.get_edge(path_a))
output["b"] = self.B[b_idx]
return output
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"