base code for pytorch distributed, add cyclegan

This commit is contained in:
2020-08-07 09:48:09 +08:00
commit f7843de45d
32 changed files with 1444 additions and 0 deletions

0
util/__init__.py Normal file
View File

66
util/distributed.py Normal file
View File

@@ -0,0 +1,66 @@
import torch
import torch.nn as nn
from ignite.distributed import utils as idist
from ignite.distributed.comp_models import native as idist_native
from ignite.utils import setup_logger
def auto_model(model: nn.Module, **additional_kwargs) -> nn.Module:
"""Helper method to adapt provided model for non-distributed and distributed configurations (supporting
all available backends from :meth:`~ignite.distributed.utils.available_backends()`).
Internally, we perform to following:
- send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device.
- wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1.
- wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available.
Examples:
.. code-block:: python
import ignite.distribted as idist
model = idist.auto_model(model)
In addition with NVidia/Apex, it can be used in the following way:
.. code-block:: python
import ignite.distribted as idist
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model = idist.auto_model(model)
Args:
model (torch.nn.Module): model to adapt.
Returns:
torch.nn.Module
.. _torch DistributedDataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel
.. _torch DataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
"""
logger = setup_logger(__name__ + ".auto_model")
# Put model's parameters to device if its parameters are not on the device
device = idist.device()
if not all([p.device == device for p in model.parameters()]):
model.to(device)
# distributed data parallel model
if idist.get_world_size() > 1:
if idist.backend() == idist_native.NCCL:
lrank = idist.get_local_rank()
logger.info("Apply torch DistributedDataParallel on model, device id: {}".format(lrank))
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[lrank, ], **additional_kwargs)
elif idist.backend() == idist_native.GLOO:
logger.info("Apply torch DistributedDataParallel on model")
model = torch.nn.parallel.DistributedDataParallel(model, **additional_kwargs)
# not distributed but multiple GPUs reachable so data parallel model
elif torch.cuda.device_count() > 1 and "cuda" in idist.device().type:
logger.info("Apply torch DataParallel on model")
model = torch.nn.parallel.DataParallel(model, **additional_kwargs)
return model

21
util/handler.py Normal file
View File

@@ -0,0 +1,21 @@
from pathlib import Path
import torch
from ignite.engine import Engine
from ignite.handlers import Checkpoint
class Resumer:
def __init__(self, to_load, checkpoint_path):
self.to_load = to_load
if checkpoint_path is not None:
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise ValueError(f"Checkpoint '{checkpoint_path}' is not found")
self.checkpoint_path = checkpoint_path
def __call__(self, engine: Engine):
if self.checkpoint_path is not None:
ckp = torch.load(self.checkpoint_path.as_posix(), map_location="cpu")
Checkpoint.load_objects(to_load=self.to_load, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {self.checkpoint_path}")

10
util/image.py Normal file
View File

@@ -0,0 +1,10 @@
import torchvision.utils
def make_2d_grid(tensors, padding=0, normalize=True, range=None, scale_each=False, pad_value=0):
# merge image in a batch in `y` direction first.
grids = [torchvision.utils.make_grid(img_batch, padding=padding, nrow=1, normalize=normalize, range=range,
scale_each=scale_each, pad_value=pad_value)
for img_batch in tensors]
# merge images in `x` direction.
return torchvision.utils.make_grid(grids, padding=0, nrow=len(grids))

163
util/registry.py Normal file
View File

@@ -0,0 +1,163 @@
import inspect
from omegaconf.dictconfig import DictConfig
from omegaconf import OmegaConf
from types import ModuleType
class _Registry:
def __init__(self, name):
self._name = name
def get(self, key):
raise NotImplemented
def keys(self):
raise NotImplemented
def __len__(self):
len(self.keys())
def __contains__(self, key):
return self.get(key) is not None
@property
def name(self):
return self._name
def __repr__(self):
return f"{self.__class__.__name__}(name={self._name}, items={self.keys()})"
def build_with(self, cfg, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
default_args (dict, optional): Default initialization arguments.
Returns:
object: The constructed object.
"""
if isinstance(cfg, DictConfig):
cfg = OmegaConf.to_container(cfg)
if isinstance(cfg, dict):
if '_type' in cfg:
args = cfg.copy()
obj_type = args.pop('_type')
elif len(cfg) == 1:
obj_type, args = list(cfg.items())[0]
else:
raise KeyError(f'the cfg dict must contain the key "_type", but got {cfg}')
elif isinstance(cfg, str):
obj_type = cfg
args = dict()
else:
raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}')
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
if isinstance(obj_type, str):
obj_cls = self.get(obj_type)
if obj_cls is None:
raise KeyError(f'{obj_type} is not in the {self.name} registry')
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError(f'type must be a str or valid type, but got {type(obj_type)}')
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_cls(**args)
class ModuleRegistry(_Registry):
def __init__(self, name, module, predefined_valid_list=None):
super().__init__(name)
assert isinstance(module, ModuleType), f"module must be ModuleType, but got {type(module)}"
self._module = module
if predefined_valid_list is not None:
self._valid_set = set(predefined_valid_list) & set(self._module.__dict__.keys())
else:
self._valid_set = set(self._module.__dict__.keys())
def keys(self):
return tuple(self._valid_set)
def get(self, key):
"""Get the registry record.
Args:
key (str): The class name in string format.
Returns:
class: The corresponding class.
"""
if key not in self._valid_set:
return None
return getattr(self._module, key)
class Registry(_Registry):
"""A registry to map strings to classes.
Args:
name (str): Registry name.
"""
def __init__(self, name):
super().__init__(name)
self._module_dict = dict()
def keys(self):
return tuple(self._module_dict.keys())
def get(self, key):
"""Get the registry record.
Args:
key (str): The class name in string format.
Returns:
class: The corresponding class.
"""
return self._module_dict.get(key, None)
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if module_name is None:
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {self.name}')
self._module_dict[module_name] = module_class
def register_module(self, name=None, force=False, module=None):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name or the specified name, and value is the class itself.
It can be used as a decorator or a normal function.
Args:
name (str | None): The module name to be registered. If not
specified, the class name will be used.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
module (type): Module class to be registered.
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
return module
# raise the error ahead of time
if not (name is None or isinstance(name, str)):
raise TypeError(f'name must be a str, but got {type(name)}')
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(module_class=cls, module_name=name, force=force)
return cls
return _register