rewrite
This commit is contained in:
@@ -30,7 +30,24 @@ def _activation(activation):
|
||||
elif activation == "Tanh":
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise NotImplemented(activation)
|
||||
raise NotImplementedError(f"{activation} not valid")
|
||||
|
||||
|
||||
class LinearBlock(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, bias=None, activation_type="ReLU", norm_type="NONE"):
|
||||
super().__init__()
|
||||
|
||||
self.norm_type = norm_type
|
||||
self.activation_type = activation_type
|
||||
|
||||
bias = _use_bias_checker(norm_type) if bias is None else bias
|
||||
self.linear = nn.Linear(in_features, out_features, bias)
|
||||
|
||||
self.normalization = _normalization(norm_type, out_features)
|
||||
self.activation = _activation(activation_type)
|
||||
|
||||
def forward(self, x):
|
||||
return self.activation(self.normalization(self.linear(x)))
|
||||
|
||||
|
||||
class Conv2dBlock(nn.Module):
|
||||
|
||||
@@ -93,7 +93,7 @@ class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
|
||||
|
||||
def _instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
|
||||
out = rho * F.instance_norm(x, eps=eps) + (1 - rho) * F.layer_norm(x, x.size()[1:], eps=eps)
|
||||
out = out * gamma + beta
|
||||
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
|
||||
return out
|
||||
|
||||
|
||||
@@ -115,7 +115,7 @@ class ILN(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return _instance_layer_normalization(
|
||||
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
|
||||
x, self.gamma.view(1, -1), self.beta.view(1, -1), self.rho.view(1, -1, 1, 1), self.eps)
|
||||
|
||||
|
||||
@NORMALIZATION.register_module("AdaILN")
|
||||
@@ -136,7 +136,6 @@ class AdaILN(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
assert self.have_set_condition
|
||||
out = _instance_layer_normalization(
|
||||
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
|
||||
out = _instance_layer_normalization(x, self.gamma, self.beta, self.rho.view(1, -1, 1, 1), self.eps)
|
||||
self.have_set_condition = False
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user