base code for pytorch distributed, add cyclegan
This commit is contained in:
66
util/distributed.py
Normal file
66
util/distributed.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ignite.distributed import utils as idist
|
||||
from ignite.distributed.comp_models import native as idist_native
|
||||
from ignite.utils import setup_logger
|
||||
|
||||
|
||||
def auto_model(model: nn.Module, **additional_kwargs) -> nn.Module:
|
||||
"""Helper method to adapt provided model for non-distributed and distributed configurations (supporting
|
||||
all available backends from :meth:`~ignite.distributed.utils.available_backends()`).
|
||||
|
||||
Internally, we perform to following:
|
||||
|
||||
- send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device.
|
||||
- wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1.
|
||||
- wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ignite.distribted as idist
|
||||
|
||||
model = idist.auto_model(model)
|
||||
|
||||
In addition with NVidia/Apex, it can be used in the following way:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ignite.distribted as idist
|
||||
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
|
||||
model = idist.auto_model(model)
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model to adapt.
|
||||
|
||||
Returns:
|
||||
torch.nn.Module
|
||||
|
||||
.. _torch DistributedDataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel
|
||||
.. _torch DataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
|
||||
"""
|
||||
logger = setup_logger(__name__ + ".auto_model")
|
||||
|
||||
# Put model's parameters to device if its parameters are not on the device
|
||||
device = idist.device()
|
||||
if not all([p.device == device for p in model.parameters()]):
|
||||
model.to(device)
|
||||
|
||||
# distributed data parallel model
|
||||
if idist.get_world_size() > 1:
|
||||
if idist.backend() == idist_native.NCCL:
|
||||
lrank = idist.get_local_rank()
|
||||
logger.info("Apply torch DistributedDataParallel on model, device id: {}".format(lrank))
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[lrank, ], **additional_kwargs)
|
||||
elif idist.backend() == idist_native.GLOO:
|
||||
logger.info("Apply torch DistributedDataParallel on model")
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, **additional_kwargs)
|
||||
|
||||
# not distributed but multiple GPUs reachable so data parallel model
|
||||
elif torch.cuda.device_count() > 1 and "cuda" in idist.device().type:
|
||||
logger.info("Apply torch DataParallel on model")
|
||||
model = torch.nn.parallel.DataParallel(model, **additional_kwargs)
|
||||
|
||||
return model
|
||||
Reference in New Issue
Block a user