almost same as mmedit

This commit is contained in:
2020-08-08 13:17:26 +08:00
parent 7cf235781d
commit a5133e6795
3 changed files with 61 additions and 18 deletions

27
util/build.py Normal file
View File

@@ -0,0 +1,27 @@
import torch
import torch.optim as optim
import ignite.distributed as idist
from omegaconf import OmegaConf
from model import MODEL
from util.distributed import auto_model
def build_model(cfg, distributed_args=None):
cfg = OmegaConf.to_container(cfg)
model_distributed_config = cfg.pop("_distributed", dict())
model = MODEL.build_with(cfg)
if model_distributed_config.get("bn_to_syncbn"):
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args
return auto_model(model, **distributed_args)
def build_optimizer(params, cfg):
assert "_type" in cfg
cfg = OmegaConf.to_container(cfg)
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
return idist.auto_optim(optimizer)