TANG 0.0.1
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
@@ -171,3 +172,33 @@ class GenerationUnpairedDataset(Dataset):
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type):
|
||||
self.edge_type = edge_type
|
||||
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
||||
self.B = SingleFolderDataset(root_b, pipeline, with_path=False)
|
||||
self.random_pair = random_pair
|
||||
|
||||
def get_edge(self, origin_path):
|
||||
op = Path(origin_path)
|
||||
add = torch.load(op.parent / f"{op.stem}.add")
|
||||
return {"edge": add["edge"].float().unsqueeze(dim=0),
|
||||
"additional_info": torch.cat([add["seg"].float(), add["dist"].float()], dim=0)}
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||
output = dict()
|
||||
output["a"], path_a = self.A[a_idx]
|
||||
output.update(self.get_edge(path_a))
|
||||
output["b"] = self.B[b_idx]
|
||||
return output
|
||||
|
||||
def __len__(self):
|
||||
return max(len(self.A), len(self.B))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
||||
|
||||
Reference in New Issue
Block a user