rewrite
This commit is contained in:
@@ -1,18 +1,21 @@
|
||||
import torch
|
||||
import ignite.distributed as idist
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model import MODEL
|
||||
import torch.optim as optim
|
||||
from util.misc import add_spectral_norm
|
||||
|
||||
|
||||
def build_model(cfg):
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
|
||||
add_spectral_norm_flag = cfg.pop("_add_spectral_norm", False)
|
||||
model = MODEL.build_with(cfg)
|
||||
if bn_to_sync_bn:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
if add_spectral_norm_flag:
|
||||
model.apply(add_spectral_norm)
|
||||
return idist.auto_model(model)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user