add GauGAN
This commit is contained in:
@@ -20,13 +20,13 @@ def _normalization(norm, num_features, additional_kwargs=None):
|
||||
return NORMALIZATION.build_with(kwargs)
|
||||
|
||||
|
||||
def _activation(activation):
|
||||
def _activation(activation, inplace=True):
|
||||
if activation == "NONE":
|
||||
return _DO_NO_THING_FUNC
|
||||
elif activation == "ReLU":
|
||||
return nn.ReLU(inplace=True)
|
||||
return nn.ReLU(inplace=inplace)
|
||||
elif activation == "LeakyReLU":
|
||||
return nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
return nn.LeakyReLU(negative_slope=0.2, inplace=inplace)
|
||||
elif activation == "Tanh":
|
||||
return nn.Tanh()
|
||||
else:
|
||||
@@ -74,7 +74,7 @@ class ReverseConv2dBlock(nn.Module):
|
||||
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
|
||||
super().__init__()
|
||||
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
|
||||
self.activation = _activation(activation_type)
|
||||
self.activation = _activation(activation_type, inplace=False)
|
||||
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -84,7 +84,7 @@ class ReverseConv2dBlock(nn.Module):
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_channels,
|
||||
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
|
||||
out_channels=None, out_activation_type=None):
|
||||
out_channels=None, out_activation_type=None, additional_norm_kwargs=None):
|
||||
"""
|
||||
Residual Conv Block
|
||||
:param in_channels:
|
||||
@@ -110,15 +110,15 @@ class ResidualBlock(nn.Module):
|
||||
self.learn_skip_connection = in_channels != out_channels
|
||||
|
||||
conv_block = ReverseConv2dBlock if pre_activation else Conv2dBlock
|
||||
conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type,
|
||||
additional_norm_kwargs=additional_norm_kwargs,
|
||||
padding_mode=padding_mode)
|
||||
|
||||
self.conv1 = conv_block(in_channels, in_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
norm_type=norm_type, activation_type=activation_type)
|
||||
self.conv2 = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
norm_type=norm_type, activation_type=out_activation_type)
|
||||
self.conv1 = conv_block(in_channels, in_channels, **conv_param)
|
||||
self.conv2 = conv_block(in_channels, out_channels, **conv_param)
|
||||
|
||||
if self.learn_skip_connection:
|
||||
self.res_conv = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
norm_type=norm_type, activation_type=out_activation_type)
|
||||
self.res_conv = conv_block(in_channels, out_channels, **conv_param)
|
||||
|
||||
def forward(self, x):
|
||||
res = x if not self.learn_skip_connection else self.res_conv(x)
|
||||
|
||||
Reference in New Issue
Block a user