update tester
This commit is contained in:
@@ -1,20 +1,19 @@
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.transforms import functional as F
|
||||
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader
|
||||
from pathlib import Path
|
||||
|
||||
import lmdb
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
|
||||
from torchvision.transforms import functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
from .transform import transform_pipeline
|
||||
from .registry import DATASET
|
||||
from .transform import transform_pipeline
|
||||
from .util import dlib_landmark
|
||||
|
||||
|
||||
@@ -160,9 +159,9 @@ class SingleFolderDataset(Dataset):
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDataset(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline):
|
||||
self.A = SingleFolderDataset(root_a, pipeline)
|
||||
self.B = SingleFolderDataset(root_b, pipeline)
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False):
|
||||
self.A = SingleFolderDataset(root_a, pipeline, with_path)
|
||||
self.B = SingleFolderDataset(root_b, pipeline, with_path)
|
||||
self.random_pair = random_pair
|
||||
|
||||
def __getitem__(self, idx):
|
||||
@@ -186,7 +185,8 @@ def normalize_tensor(tensor):
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256)):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256),
|
||||
with_path=False):
|
||||
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
|
||||
self.edge_type = edge_type
|
||||
self.size = size
|
||||
@@ -197,6 +197,7 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
||||
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
|
||||
self.random_pair = random_pair
|
||||
self.with_path = with_path
|
||||
|
||||
def get_edge(self, origin_path):
|
||||
op = Path(origin_path)
|
||||
@@ -224,6 +225,10 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||
if self.with_path:
|
||||
output = {"a": self.A[a_idx], "b": self.B[b_idx]}
|
||||
output["edge_a"] = output["a"][1]
|
||||
return output
|
||||
output = dict()
|
||||
output["a"], path_a = self.A[a_idx]
|
||||
output["b"], path_b = self.B[b_idx]
|
||||
|
||||
Reference in New Issue
Block a user