rewrite
This commit is contained in:
@@ -8,7 +8,7 @@ import torch.nn as nn
|
||||
|
||||
|
||||
def add_spectral_norm(module):
|
||||
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
|
||||
return nn.utils.spectral_norm(module)
|
||||
else:
|
||||
return module
|
||||
|
||||
Reference in New Issue
Block a user