base code for pytorch distributed, add cyclegan
This commit is contained in:
4
data/__init__.py
Normal file
4
data/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
import data.dataset
|
||||
import data.transform
|
||||
from data.registry import DATASET, TRANSFORM
|
||||
|
||||
73
data/dataset.py
Normal file
73
data/dataset.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
|
||||
|
||||
import lmdb
|
||||
|
||||
from .transform import transform_pipeline
|
||||
from .registry import DATASET
|
||||
|
||||
|
||||
class LMDBDataset(Dataset):
|
||||
def __init__(self, lmdb_path, output_transform=None, map_size=2 ** 40, readonly=True, **lmdb_kwargs):
|
||||
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
|
||||
**lmdb_kwargs)
|
||||
self.output_transform = output_transform
|
||||
with self.db.begin(write=False) as txn:
|
||||
self._len = pickle.loads(txn.get(b"__len__"))
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def __getitem__(self, idx):
|
||||
with self.db.begin(write=False) as txn:
|
||||
sample = pickle.loads(txn.get("{}".format(idx).encode()))
|
||||
if self.output_transform is not None:
|
||||
sample = self.output_transform(sample)
|
||||
return sample
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class SingleFolderDataset(Dataset):
|
||||
def __init__(self, root, pipeline):
|
||||
assert os.path.isdir(root)
|
||||
self.root = root
|
||||
samples = []
|
||||
for r, _, fns in sorted(os.walk(self.root, followlinks=True)):
|
||||
for fn in sorted(fns):
|
||||
path = os.path.join(r, fn)
|
||||
if has_file_allowed_extension(path, IMG_EXTENSIONS):
|
||||
samples.append(path)
|
||||
self.samples = samples
|
||||
self.pipeline = transform_pipeline(pipeline)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.pipeline(self.samples[idx])
|
||||
|
||||
def __repr__(self):
|
||||
return f"<SingleFolderDataset root={self.root} len={len(self)}>"
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDataset(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline):
|
||||
self.A = SingleFolderDataset(root_a, pipeline)
|
||||
self.B = SingleFolderDataset(root_b, pipeline)
|
||||
self.random_pair = random_pair
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||
return dict(a=self.A[a_idx], b=self.B[b_idx])
|
||||
|
||||
def __len__(self):
|
||||
return max(len(self.A), len(self.B))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
||||
4
data/registry.py
Normal file
4
data/registry.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from util.registry import Registry
|
||||
|
||||
DATASET = Registry("dataset")
|
||||
TRANSFORM = Registry("transform")
|
||||
34
data/transform.py
Normal file
34
data/transform.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets.folder import default_loader
|
||||
|
||||
from .registry import TRANSFORM
|
||||
|
||||
# from https://pytorch.org/docs/stable/_modules/torchvision/transforms/transforms.html
|
||||
_VALID_TORCHVISION_TRANSFORMS = ["ToTensor", "ToPILImage", "Normalize", "Resize",
|
||||
"Scale", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder",
|
||||
"RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop",
|
||||
"RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter",
|
||||
"RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective",
|
||||
"RandomErasing"]
|
||||
|
||||
for vtt in _VALID_TORCHVISION_TRANSFORMS:
|
||||
TRANSFORM.register_module(module=getattr(transforms, vtt))
|
||||
|
||||
|
||||
@TRANSFORM.register_module()
|
||||
class Load:
|
||||
def __init__(self, loader=default_loader):
|
||||
self.loader = loader
|
||||
|
||||
def __call__(self, image_path):
|
||||
return self.loader(image_path)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + "()"
|
||||
|
||||
|
||||
def transform_pipeline(pipeline_description):
|
||||
if len(pipeline_description) == 0:
|
||||
return lambda x: x
|
||||
transform_list = [TRANSFORM.build_with(pd) for pd in pipeline_description]
|
||||
return transforms.Compose(transform_list)
|
||||
Reference in New Issue
Block a user