almost same as mmedit
This commit is contained in:
27
util/build.py
Normal file
27
util/build.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import ignite.distributed as idist
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model import MODEL
|
||||
from util.distributed import auto_model
|
||||
|
||||
|
||||
def build_model(cfg, distributed_args=None):
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
model_distributed_config = cfg.pop("_distributed", dict())
|
||||
model = MODEL.build_with(cfg)
|
||||
|
||||
if model_distributed_config.get("bn_to_syncbn"):
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
|
||||
distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args
|
||||
return auto_model(model, **distributed_args)
|
||||
|
||||
|
||||
def build_optimizer(params, cfg):
|
||||
assert "_type" in cfg
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
|
||||
return idist.auto_optim(optimizer)
|
||||
Reference in New Issue
Block a user