add flag to switch to norm-activ-conv

This commit is contained in:
2020-10-11 19:02:42 +08:00
parent 9c08b4cd09
commit 06b2abd19a
4 changed files with 70 additions and 61 deletions

View File

@@ -52,7 +52,8 @@ class LinearBlock(nn.Module):
class Conv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, bias=None,
activation_type="ReLU", norm_type="NONE", **conv_kwargs):
activation_type="ReLU", norm_type="NONE",
additional_norm_kwargs=None, **conv_kwargs):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
@@ -61,40 +62,13 @@ class Conv2dBlock(nn.Module):
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
self.normalization = _normalization(norm_type, out_channels)
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
self.activation = _activation(activation_type)
def forward(self, x):
return self.activation(self.normalization(self.convolution(x)))
class ResidualBlock(nn.Module):
def __init__(self, num_channels, out_channels=None, padding_mode='reflect',
activation_type="ReLU", norm_type="IN", out_activation_type=None):
super().__init__()
self.norm_type = norm_type
if out_channels is None:
out_channels = num_channels
if out_activation_type is None:
out_activation_type = "NONE"
self.learn_skip_connection = num_channels != out_channels
self.conv1 = Conv2dBlock(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=activation_type)
self.conv2 = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=out_activation_type)
if self.learn_skip_connection:
self.res_conv = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=out_activation_type)
def forward(self, x):
res = x if not self.learn_skip_connection else self.res_conv(x)
return self.conv2(self.conv1(x)) + res
class ReverseConv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int,
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
@@ -107,19 +81,44 @@ class ReverseConv2dBlock(nn.Module):
return self.convolution(self.activation(self.normalization(x)))
class ReverseResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, padding_mode="reflect",
norm_type="IN", additional_norm_kwargs=None, activation_type="ReLU"):
class ResidualBlock(nn.Module):
def __init__(self, in_channels,
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
out_channels=None, out_activation_type=None):
"""
Residual Conv Block
:param in_channels:
:param out_channels:
:param padding_mode:
:param activation_type:
:param norm_type:
:param out_activation_type:
:param pre_activation: full pre-activation mode from https://arxiv.org/pdf/1603.05027v3.pdf, figure 4
"""
super().__init__()
self.norm_type = norm_type
if out_channels is None:
out_channels = in_channels
if out_activation_type is None:
# if not specify `out_activation_type`, using default `out_activation_type`
# `out_activation_type` default mode:
# "NONE" for not full pre-activation
# `norm_type` for full pre-activation
out_activation_type = "NONE" if not pre_activation else norm_type
self.learn_skip_connection = in_channels != out_channels
self.conv1 = ReverseConv2dBlock(in_channels, in_channels, activation_type, norm_type, additional_norm_kwargs,
kernel_size=3, padding=1, padding_mode=padding_mode)
self.conv2 = ReverseConv2dBlock(in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
kernel_size=3, padding=1, padding_mode=padding_mode)
conv_block = ReverseConv2dBlock if pre_activation else Conv2dBlock
self.conv1 = conv_block(in_channels, in_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=activation_type)
self.conv2 = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=out_activation_type)
if self.learn_skip_connection:
self.res_conv = ReverseConv2dBlock(
in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
kernel_size=3, padding=1, padding_mode=padding_mode)
self.res_conv = conv_block(in_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=out_activation_type)
def forward(self, x):
res = x if not self.learn_skip_connection else self.res_conv(x)