rewrite
This commit is contained in:
@@ -93,7 +93,7 @@ class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
|
||||
|
||||
def _instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
|
||||
out = rho * F.instance_norm(x, eps=eps) + (1 - rho) * F.layer_norm(x, x.size()[1:], eps=eps)
|
||||
out = out * gamma + beta
|
||||
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
|
||||
return out
|
||||
|
||||
|
||||
@@ -115,7 +115,7 @@ class ILN(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return _instance_layer_normalization(
|
||||
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
|
||||
x, self.gamma.view(1, -1), self.beta.view(1, -1), self.rho.view(1, -1, 1, 1), self.eps)
|
||||
|
||||
|
||||
@NORMALIZATION.register_module("AdaILN")
|
||||
@@ -136,7 +136,6 @@ class AdaILN(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
assert self.have_set_condition
|
||||
out = _instance_layer_normalization(
|
||||
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
|
||||
out = _instance_layer_normalization(x, self.gamma, self.beta, self.rho.view(1, -1, 1, 1), self.eps)
|
||||
self.have_set_condition = False
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user