base code for pytorch distributed, add cyclegan
This commit is contained in:
0
util/__init__.py
Normal file
0
util/__init__.py
Normal file
66
util/distributed.py
Normal file
66
util/distributed.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ignite.distributed import utils as idist
|
||||
from ignite.distributed.comp_models import native as idist_native
|
||||
from ignite.utils import setup_logger
|
||||
|
||||
|
||||
def auto_model(model: nn.Module, **additional_kwargs) -> nn.Module:
|
||||
"""Helper method to adapt provided model for non-distributed and distributed configurations (supporting
|
||||
all available backends from :meth:`~ignite.distributed.utils.available_backends()`).
|
||||
|
||||
Internally, we perform to following:
|
||||
|
||||
- send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device.
|
||||
- wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1.
|
||||
- wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ignite.distribted as idist
|
||||
|
||||
model = idist.auto_model(model)
|
||||
|
||||
In addition with NVidia/Apex, it can be used in the following way:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ignite.distribted as idist
|
||||
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
|
||||
model = idist.auto_model(model)
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model to adapt.
|
||||
|
||||
Returns:
|
||||
torch.nn.Module
|
||||
|
||||
.. _torch DistributedDataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel
|
||||
.. _torch DataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
|
||||
"""
|
||||
logger = setup_logger(__name__ + ".auto_model")
|
||||
|
||||
# Put model's parameters to device if its parameters are not on the device
|
||||
device = idist.device()
|
||||
if not all([p.device == device for p in model.parameters()]):
|
||||
model.to(device)
|
||||
|
||||
# distributed data parallel model
|
||||
if idist.get_world_size() > 1:
|
||||
if idist.backend() == idist_native.NCCL:
|
||||
lrank = idist.get_local_rank()
|
||||
logger.info("Apply torch DistributedDataParallel on model, device id: {}".format(lrank))
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[lrank, ], **additional_kwargs)
|
||||
elif idist.backend() == idist_native.GLOO:
|
||||
logger.info("Apply torch DistributedDataParallel on model")
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, **additional_kwargs)
|
||||
|
||||
# not distributed but multiple GPUs reachable so data parallel model
|
||||
elif torch.cuda.device_count() > 1 and "cuda" in idist.device().type:
|
||||
logger.info("Apply torch DataParallel on model")
|
||||
model = torch.nn.parallel.DataParallel(model, **additional_kwargs)
|
||||
|
||||
return model
|
||||
21
util/handler.py
Normal file
21
util/handler.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from ignite.engine import Engine
|
||||
from ignite.handlers import Checkpoint
|
||||
|
||||
|
||||
class Resumer:
|
||||
def __init__(self, to_load, checkpoint_path):
|
||||
self.to_load = to_load
|
||||
if checkpoint_path is not None:
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
if not checkpoint_path.exists():
|
||||
raise ValueError(f"Checkpoint '{checkpoint_path}' is not found")
|
||||
self.checkpoint_path = checkpoint_path
|
||||
|
||||
def __call__(self, engine: Engine):
|
||||
if self.checkpoint_path is not None:
|
||||
ckp = torch.load(self.checkpoint_path.as_posix(), map_location="cpu")
|
||||
Checkpoint.load_objects(to_load=self.to_load, checkpoint=ckp)
|
||||
engine.logger.info(f"resume from a checkpoint {self.checkpoint_path}")
|
||||
10
util/image.py
Normal file
10
util/image.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import torchvision.utils
|
||||
|
||||
|
||||
def make_2d_grid(tensors, padding=0, normalize=True, range=None, scale_each=False, pad_value=0):
|
||||
# merge image in a batch in `y` direction first.
|
||||
grids = [torchvision.utils.make_grid(img_batch, padding=padding, nrow=1, normalize=normalize, range=range,
|
||||
scale_each=scale_each, pad_value=pad_value)
|
||||
for img_batch in tensors]
|
||||
# merge images in `x` direction.
|
||||
return torchvision.utils.make_grid(grids, padding=0, nrow=len(grids))
|
||||
163
util/registry.py
Normal file
163
util/registry.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import inspect
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from omegaconf import OmegaConf
|
||||
from types import ModuleType
|
||||
|
||||
|
||||
class _Registry:
|
||||
def __init__(self, name):
|
||||
self._name = name
|
||||
|
||||
def get(self, key):
|
||||
raise NotImplemented
|
||||
|
||||
def keys(self):
|
||||
raise NotImplemented
|
||||
|
||||
def __len__(self):
|
||||
len(self.keys())
|
||||
|
||||
def __contains__(self, key):
|
||||
return self.get(key) is not None
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(name={self._name}, items={self.keys()})"
|
||||
|
||||
def build_with(self, cfg, default_args=None):
|
||||
"""Build a module from config dict.
|
||||
Args:
|
||||
cfg (dict): Config dict. It should at least contain the key "type".
|
||||
default_args (dict, optional): Default initialization arguments.
|
||||
Returns:
|
||||
object: The constructed object.
|
||||
"""
|
||||
if isinstance(cfg, DictConfig):
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
if isinstance(cfg, dict):
|
||||
if '_type' in cfg:
|
||||
args = cfg.copy()
|
||||
obj_type = args.pop('_type')
|
||||
elif len(cfg) == 1:
|
||||
obj_type, args = list(cfg.items())[0]
|
||||
else:
|
||||
raise KeyError(f'the cfg dict must contain the key "_type", but got {cfg}')
|
||||
elif isinstance(cfg, str):
|
||||
obj_type = cfg
|
||||
args = dict()
|
||||
else:
|
||||
raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}')
|
||||
|
||||
if not (isinstance(default_args, dict) or default_args is None):
|
||||
raise TypeError('default_args must be a dict or None, '
|
||||
f'but got {type(default_args)}')
|
||||
|
||||
if isinstance(obj_type, str):
|
||||
obj_cls = self.get(obj_type)
|
||||
if obj_cls is None:
|
||||
raise KeyError(f'{obj_type} is not in the {self.name} registry')
|
||||
elif inspect.isclass(obj_type):
|
||||
obj_cls = obj_type
|
||||
else:
|
||||
raise TypeError(f'type must be a str or valid type, but got {type(obj_type)}')
|
||||
|
||||
if default_args is not None:
|
||||
for name, value in default_args.items():
|
||||
args.setdefault(name, value)
|
||||
return obj_cls(**args)
|
||||
|
||||
|
||||
class ModuleRegistry(_Registry):
|
||||
def __init__(self, name, module, predefined_valid_list=None):
|
||||
super().__init__(name)
|
||||
|
||||
assert isinstance(module, ModuleType), f"module must be ModuleType, but got {type(module)}"
|
||||
self._module = module
|
||||
if predefined_valid_list is not None:
|
||||
self._valid_set = set(predefined_valid_list) & set(self._module.__dict__.keys())
|
||||
else:
|
||||
self._valid_set = set(self._module.__dict__.keys())
|
||||
|
||||
def keys(self):
|
||||
return tuple(self._valid_set)
|
||||
|
||||
def get(self, key):
|
||||
"""Get the registry record.
|
||||
Args:
|
||||
key (str): The class name in string format.
|
||||
Returns:
|
||||
class: The corresponding class.
|
||||
"""
|
||||
if key not in self._valid_set:
|
||||
return None
|
||||
return getattr(self._module, key)
|
||||
|
||||
|
||||
class Registry(_Registry):
|
||||
"""A registry to map strings to classes.
|
||||
Args:
|
||||
name (str): Registry name.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
super().__init__(name)
|
||||
self._module_dict = dict()
|
||||
|
||||
def keys(self):
|
||||
return tuple(self._module_dict.keys())
|
||||
|
||||
def get(self, key):
|
||||
"""Get the registry record.
|
||||
Args:
|
||||
key (str): The class name in string format.
|
||||
Returns:
|
||||
class: The corresponding class.
|
||||
"""
|
||||
return self._module_dict.get(key, None)
|
||||
|
||||
def _register_module(self, module_class, module_name=None, force=False):
|
||||
if not inspect.isclass(module_class):
|
||||
raise TypeError('module must be a class, '
|
||||
f'but got {type(module_class)}')
|
||||
|
||||
if module_name is None:
|
||||
module_name = module_class.__name__
|
||||
if not force and module_name in self._module_dict:
|
||||
raise KeyError(f'{module_name} is already registered '
|
||||
f'in {self.name}')
|
||||
self._module_dict[module_name] = module_class
|
||||
|
||||
def register_module(self, name=None, force=False, module=None):
|
||||
"""Register a module.
|
||||
A record will be added to `self._module_dict`, whose key is the class
|
||||
name or the specified name, and value is the class itself.
|
||||
It can be used as a decorator or a normal function.
|
||||
Args:
|
||||
name (str | None): The module name to be registered. If not
|
||||
specified, the class name will be used.
|
||||
force (bool, optional): Whether to override an existing class with
|
||||
the same name. Default: False.
|
||||
module (type): Module class to be registered.
|
||||
"""
|
||||
if not isinstance(force, bool):
|
||||
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
||||
|
||||
# use it as a normal method: x.register_module(module=SomeClass)
|
||||
if module is not None:
|
||||
self._register_module(
|
||||
module_class=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
# raise the error ahead of time
|
||||
if not (name is None or isinstance(name, str)):
|
||||
raise TypeError(f'name must be a str, but got {type(name)}')
|
||||
|
||||
# use it as a decorator: @x.register_module()
|
||||
def _register(cls):
|
||||
self._register_module(module_class=cls, module_name=name, force=force)
|
||||
return cls
|
||||
|
||||
return _register
|
||||
Reference in New Issue
Block a user