add GauGAN

This commit is contained in:
2020-10-11 23:05:38 +08:00
parent 06b2abd19a
commit 6070f08835
3 changed files with 120 additions and 23 deletions

View File

@@ -16,18 +16,19 @@ for abbr, name in _VALID_NORM_AND_ABBREVIATION.items():
@NORMALIZATION.register_module("ADE")
class AdaptiveDenormalization(nn.Module):
def __init__(self, num_features, base_norm_type="BN"):
def __init__(self, num_features, base_norm_type="BN", gamma_bias=0.0):
super().__init__()
self.num_features = num_features
self.base_norm_type = base_norm_type
self.norm = self.base_norm(num_features)
self.gamma = None
self.gamma_bias = gamma_bias
self.beta = None
self.have_set_condition = False
def base_norm(self, num_features):
if self.base_norm_type == "IN":
return nn.InstanceNorm2d(num_features)
return nn.InstanceNorm2d(num_features, affine=False)
elif self.base_norm_type == "BN":
return nn.BatchNorm2d(num_features, affine=False, track_running_stats=True)
@@ -38,13 +39,13 @@ class AdaptiveDenormalization(nn.Module):
def forward(self, x):
assert self.have_set_condition
x = self.norm(x)
x = self.gamma * x + self.beta
x = (self.gamma + self.gamma_bias) * x + self.beta
self.have_set_condition = False
return x
def __repr__(self):
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
f"base_norm_type={self.base_norm_type})"
#
# def __repr__(self):
# return f"{self.__class__.__name__}(num_features={self.num_features}, " \
# f"base_norm_type={self.base_norm_type})"
@NORMALIZATION.register_module("AdaIN")
@@ -61,8 +62,9 @@ class AdaptiveInstanceNorm2d(AdaptiveDenormalization):
@NORMALIZATION.register_module("FADE")
class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
def __init__(self, num_features: int, condition_in_channels, base_norm_type="BN", padding_mode="zeros"):
super().__init__(num_features, base_norm_type)
def __init__(self, num_features: int, condition_in_channels,
base_norm_type="BN", padding_mode="zeros", gamma_bias=0.0):
super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias)
self.beta_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
padding_mode=padding_mode)
self.gamma_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
@@ -77,9 +79,9 @@ class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
@NORMALIZATION.register_module("SPADE")
class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
def __init__(self, num_features: int, condition_in_channels, base_channels=128, base_norm_type="BN",
activation_type="ReLU", padding_mode="zeros"):
super().__init__(num_features, base_norm_type)
self.base_conv_block = Conv2dBlock(condition_in_channels, num_features, activation_type=activation_type,
activation_type="ReLU", padding_mode="zeros", gamma_bias=0.0):
super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias)
self.base_conv_block = Conv2dBlock(condition_in_channels, base_channels, activation_type=activation_type,
kernel_size=3, padding=1, padding_mode=padding_mode, norm_type="NONE")
self.beta_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
self.gamma_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)