add GauGAN
This commit is contained in:
@@ -16,18 +16,19 @@ for abbr, name in _VALID_NORM_AND_ABBREVIATION.items():
|
||||
|
||||
@NORMALIZATION.register_module("ADE")
|
||||
class AdaptiveDenormalization(nn.Module):
|
||||
def __init__(self, num_features, base_norm_type="BN"):
|
||||
def __init__(self, num_features, base_norm_type="BN", gamma_bias=0.0):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.base_norm_type = base_norm_type
|
||||
self.norm = self.base_norm(num_features)
|
||||
self.gamma = None
|
||||
self.gamma_bias = gamma_bias
|
||||
self.beta = None
|
||||
self.have_set_condition = False
|
||||
|
||||
def base_norm(self, num_features):
|
||||
if self.base_norm_type == "IN":
|
||||
return nn.InstanceNorm2d(num_features)
|
||||
return nn.InstanceNorm2d(num_features, affine=False)
|
||||
elif self.base_norm_type == "BN":
|
||||
return nn.BatchNorm2d(num_features, affine=False, track_running_stats=True)
|
||||
|
||||
@@ -38,13 +39,13 @@ class AdaptiveDenormalization(nn.Module):
|
||||
def forward(self, x):
|
||||
assert self.have_set_condition
|
||||
x = self.norm(x)
|
||||
x = self.gamma * x + self.beta
|
||||
x = (self.gamma + self.gamma_bias) * x + self.beta
|
||||
self.have_set_condition = False
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
|
||||
f"base_norm_type={self.base_norm_type})"
|
||||
#
|
||||
# def __repr__(self):
|
||||
# return f"{self.__class__.__name__}(num_features={self.num_features}, " \
|
||||
# f"base_norm_type={self.base_norm_type})"
|
||||
|
||||
|
||||
@NORMALIZATION.register_module("AdaIN")
|
||||
@@ -61,8 +62,9 @@ class AdaptiveInstanceNorm2d(AdaptiveDenormalization):
|
||||
|
||||
@NORMALIZATION.register_module("FADE")
|
||||
class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
|
||||
def __init__(self, num_features: int, condition_in_channels, base_norm_type="BN", padding_mode="zeros"):
|
||||
super().__init__(num_features, base_norm_type)
|
||||
def __init__(self, num_features: int, condition_in_channels,
|
||||
base_norm_type="BN", padding_mode="zeros", gamma_bias=0.0):
|
||||
super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias)
|
||||
self.beta_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
|
||||
padding_mode=padding_mode)
|
||||
self.gamma_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
|
||||
@@ -77,9 +79,9 @@ class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
|
||||
@NORMALIZATION.register_module("SPADE")
|
||||
class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
|
||||
def __init__(self, num_features: int, condition_in_channels, base_channels=128, base_norm_type="BN",
|
||||
activation_type="ReLU", padding_mode="zeros"):
|
||||
super().__init__(num_features, base_norm_type)
|
||||
self.base_conv_block = Conv2dBlock(condition_in_channels, num_features, activation_type=activation_type,
|
||||
activation_type="ReLU", padding_mode="zeros", gamma_bias=0.0):
|
||||
super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias)
|
||||
self.base_conv_block = Conv2dBlock(condition_in_channels, base_channels, activation_type=activation_type,
|
||||
kernel_size=3, padding=1, padding_mode=padding_mode, norm_type="NONE")
|
||||
self.beta_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
self.gamma_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
|
||||
Reference in New Issue
Block a user