move encoder, decoder to CycleGAN

This commit is contained in:
2020-10-11 11:09:16 +08:00
parent 04c6366c07
commit 9c08b4cd09
4 changed files with 106 additions and 128 deletions

View File

@@ -2,99 +2,29 @@ import torch
import torch.nn as nn
from model import MODEL
from model.base.module import Conv2dBlock, ResidualBlock, LinearBlock
def _get_down_sampling_sequence(in_channels, base_channels, num_conv, max_down_sampling_multiple=2,
padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
multiple_now = 1
for i in range(1, num_conv + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
))
return sequence, multiple_now * base_channels
from model.base.module import LinearBlock
from model.image_translation.CycleGAN import Encoder, Decoder
class StyleEncoder(nn.Module):
def __init__(self, in_channels, out_dim, num_conv, base_channels=64,
max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
super().__init__()
sequence, last_channels = _get_down_sampling_sequence(
in_channels, base_channels, num_conv,
max_down_sampling_multiple, padding_mode, activation_type, norm_type
self.down_encoder = Encoder(
in_channels, base_channels, num_conv, num_res=0, max_down_sampling_multiple=max_down_sampling_multiple,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type=norm_type, down_conv_kernel_size=4,
)
sequence = list()
sequence.append(nn.AdaptiveAvgPool2d(1))
# conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code
sequence.append(nn.Conv2d(last_channels, out_dim, kernel_size=1, stride=1, padding=0))
sequence.append(nn.Conv2d(self.down_encoder.out_channels, out_dim, kernel_size=1, stride=1, padding=0))
self.sequence = nn.Sequential(*sequence)
def forward(self, image):
return self.sequence(image).view(image.size(0), -1)
class ContentEncoder(nn.Module):
def __init__(self, in_channels, num_down_sampling, num_residual_blocks, base_channels=64,
max_down_sampling_multiple=2,
padding_mode='reflect', activation_type="ReLU", norm_type="IN"):
super().__init__()
sequence, last_channels = _get_down_sampling_sequence(
in_channels, base_channels, num_down_sampling,
max_down_sampling_multiple, padding_mode, activation_type, norm_type
)
sequence += [ResidualBlock(last_channels, last_channels, padding_mode, activation_type, norm_type) for _ in
range(num_residual_blocks)]
self.sequence = nn.Sequential(*sequence)
def forward(self, image):
return self.sequence(image)
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
res_norm_type="AdaIN", norm_type="LN", activation_type="ReLU", padding_mode='reflect'):
super().__init__()
self.residual_blocks = nn.ModuleList([
ResidualBlock(in_channels, in_channels, padding_mode, activation_type, norm_type=res_norm_type)
for _ in range(num_residual_blocks)
])
sequence = list()
channels = in_channels
for i in range(num_up_sampling):
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2,
kernel_size=5, stride=1, padding=2, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type),
))
channels = channels // 2
sequence.append(Conv2dBlock(channels, out_channels,
kernel_size=7, stride=1, padding=3, padding_mode="reflect",
activation_type="Tanh", norm_type="NONE"))
self.up_sequence = nn.Sequential(*sequence)
def forward(self, x, style):
as_param_style = torch.chunk(style, 2 * len(self.residual_blocks), dim=1)
# set style for decoder
for i, blk in enumerate(self.residual_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
x = blk(x)
return self.up_sequence(x)
class MLPFusion(nn.Module):
def __init__(self, in_features, out_features, base_features, n_blocks, activation_type="ReLU", norm_type="NONE"):
super().__init__()
@@ -119,10 +49,13 @@ class Generator(nn.Module):
encoder_num_residual_blocks=4, decoder_num_residual_blocks=4,
padding_mode='reflect', activation_type="ReLU"):
super().__init__()
self.content_encoder = ContentEncoder(
in_channels, num_content_down_sampling, encoder_num_residual_blocks,
base_channels, max_down_sampling_multiple,
padding_mode, activation_type, norm_type="IN")
self.content_encoder = Encoder(
in_channels, base_channels, num_content_down_sampling, encoder_num_residual_blocks,
max_down_sampling_multiple=num_content_down_sampling,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type="IN", down_conv_kernel_size=4,
res_norm_type="IN"
)
self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels,
max_down_sampling_multiple, padding_mode, activation_type,
@@ -134,15 +67,21 @@ class Generator(nn.Module):
num_mlp_base_feature, num_mlp_blocks, activation_type,
norm_type="NONE")
self.decoder = Decoder(content_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks,
res_norm_type="AdaIN", norm_type="LN", activation_type=activation_type,
padding_mode=padding_mode)
self.decoder = Decoder(in_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks,
activation_type=activation_type, padding_mode=padding_mode,
up_conv_kernel_size=5, up_conv_norm_type="LN",
res_norm_type="AdaIN")
def encode(self, x):
return self.content_encoder(x), self.style_encoder(x)
def decode(self, content, style):
self.decoder(content, self.fusion(style))
as_param_style = torch.chunk(self.fusion(style), 2 * len(self.decoder.residual_blocks), dim=1)
# set style for decoder
for i, blk in enumerate(self.decoder.residual_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
self.decoder(content)
def forward(self, x):
content, style = self.encode(x)