base code for pytorch distributed, add cyclegan
This commit is contained in:
34
data/transform.py
Normal file
34
data/transform.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets.folder import default_loader
|
||||
|
||||
from .registry import TRANSFORM
|
||||
|
||||
# from https://pytorch.org/docs/stable/_modules/torchvision/transforms/transforms.html
|
||||
_VALID_TORCHVISION_TRANSFORMS = ["ToTensor", "ToPILImage", "Normalize", "Resize",
|
||||
"Scale", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder",
|
||||
"RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop",
|
||||
"RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter",
|
||||
"RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective",
|
||||
"RandomErasing"]
|
||||
|
||||
for vtt in _VALID_TORCHVISION_TRANSFORMS:
|
||||
TRANSFORM.register_module(module=getattr(transforms, vtt))
|
||||
|
||||
|
||||
@TRANSFORM.register_module()
|
||||
class Load:
|
||||
def __init__(self, loader=default_loader):
|
||||
self.loader = loader
|
||||
|
||||
def __call__(self, image_path):
|
||||
return self.loader(image_path)
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + "()"
|
||||
|
||||
|
||||
def transform_pipeline(pipeline_description):
|
||||
if len(pipeline_description) == 0:
|
||||
return lambda x: x
|
||||
transform_list = [TRANSFORM.build_with(pd) for pd in pipeline_description]
|
||||
return transforms.Compose(transform_list)
|
||||
Reference in New Issue
Block a user