add GauGAN

This commit is contained in:
2020-10-11 23:05:38 +08:00
parent 06b2abd19a
commit 6070f08835
3 changed files with 120 additions and 23 deletions

View File

@@ -20,13 +20,13 @@ def _normalization(norm, num_features, additional_kwargs=None):
return NORMALIZATION.build_with(kwargs)
def _activation(activation):
def _activation(activation, inplace=True):
if activation == "NONE":
return _DO_NO_THING_FUNC
elif activation == "ReLU":
return nn.ReLU(inplace=True)
return nn.ReLU(inplace=inplace)
elif activation == "LeakyReLU":
return nn.LeakyReLU(negative_slope=0.2, inplace=True)
return nn.LeakyReLU(negative_slope=0.2, inplace=inplace)
elif activation == "Tanh":
return nn.Tanh()
else:
@@ -74,7 +74,7 @@ class ReverseConv2dBlock(nn.Module):
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
super().__init__()
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
self.activation = _activation(activation_type)
self.activation = _activation(activation_type, inplace=False)
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
def forward(self, x):
@@ -84,7 +84,7 @@ class ReverseConv2dBlock(nn.Module):
class ResidualBlock(nn.Module):
def __init__(self, in_channels,
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
out_channels=None, out_activation_type=None):
out_channels=None, out_activation_type=None, additional_norm_kwargs=None):
"""
Residual Conv Block
:param in_channels:
@@ -110,15 +110,15 @@ class ResidualBlock(nn.Module):
self.learn_skip_connection = in_channels != out_channels
conv_block = ReverseConv2dBlock if pre_activation else Conv2dBlock
conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type,
additional_norm_kwargs=additional_norm_kwargs,
padding_mode=padding_mode)
self.conv1 = conv_block(in_channels, in_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=activation_type)
self.conv2 = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=out_activation_type)
self.conv1 = conv_block(in_channels, in_channels, **conv_param)
self.conv2 = conv_block(in_channels, out_channels, **conv_param)
if self.learn_skip_connection:
self.res_conv = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=out_activation_type)
self.res_conv = conv_block(in_channels, out_channels, **conv_param)
def forward(self, x):
res = x if not self.learn_skip_connection else self.res_conv(x)