This commit is contained in:
2020-10-11 10:02:33 +08:00
parent 6ea13df465
commit 04c6366c07
24 changed files with 483 additions and 968 deletions

View File

@@ -30,7 +30,24 @@ def _activation(activation):
elif activation == "Tanh":
return nn.Tanh()
else:
raise NotImplemented(activation)
raise NotImplementedError(f"{activation} not valid")
class LinearBlock(nn.Module):
def __init__(self, in_features: int, out_features: int, bias=None, activation_type="ReLU", norm_type="NONE"):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
bias = _use_bias_checker(norm_type) if bias is None else bias
self.linear = nn.Linear(in_features, out_features, bias)
self.normalization = _normalization(norm_type, out_features)
self.activation = _activation(activation_type)
def forward(self, x):
return self.activation(self.normalization(self.linear(x)))
class Conv2dBlock(nn.Module):

View File

@@ -93,7 +93,7 @@ class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
def _instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
out = rho * F.instance_norm(x, eps=eps) + (1 - rho) * F.layer_norm(x, x.size()[1:], eps=eps)
out = out * gamma + beta
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
return out
@@ -115,7 +115,7 @@ class ILN(nn.Module):
def forward(self, x):
return _instance_layer_normalization(
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
x, self.gamma.view(1, -1), self.beta.view(1, -1), self.rho.view(1, -1, 1, 1), self.eps)
@NORMALIZATION.register_module("AdaILN")
@@ -136,7 +136,6 @@ class AdaILN(nn.Module):
def forward(self, x):
assert self.have_set_condition
out = _instance_layer_normalization(
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
out = _instance_layer_normalization(x, self.gamma, self.beta, self.rho.view(1, -1, 1, 1), self.eps)
self.have_set_condition = False
return out