move sn to engine
This commit is contained in:
@@ -1,10 +1,17 @@
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model import MODEL
|
||||
from util.misc import add_spectral_norm
|
||||
|
||||
|
||||
def add_spectral_norm(module):
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
|
||||
return nn.utils.spectral_norm(module)
|
||||
else:
|
||||
return module
|
||||
|
||||
|
||||
def build_model(cfg):
|
||||
|
||||
Reference in New Issue
Block a user