add MUNIT

This commit is contained in:
2020-09-14 22:30:05 +08:00
parent f70658eaed
commit 2ff4a91057
7 changed files with 510 additions and 6 deletions

154
model/GAN/MUNIT.py Normal file
View File

@@ -0,0 +1,154 @@
import torch
import torch.nn as nn
from model import MODEL
from model.GAN.base import Conv2dBlock, ResBlock
from model.normalization import select_norm_layer
class StyleEncoder(nn.Module):
def __init__(self, in_channels, out_dim, num_conv, base_channels=64, use_spectral_norm=False,
padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
super(StyleEncoder, self).__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
)]
multiple_now = 1
for i in range(1, num_conv + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 2)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
))
sequence.append(nn.AdaptiveAvgPool2d(1))
# conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code
sequence.append(nn.Conv2d(multiple_now * base_channels, out_dim, kernel_size=1, stride=1, padding=0))
self.model = nn.Sequential(*sequence)
def forward(self, x):
return self.model(x).view(x.size(0), -1)
class ContentEncoder(nn.Module):
def __init__(self, in_channels, num_down_sampling, num_res_blocks, base_channels=64, use_spectral_norm=False,
padding_mode='reflect', activation_type="ReLU", norm_type="IN"):
super().__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
)]
for i in range(num_down_sampling):
sequence.append(Conv2dBlock(
base_channels * (2 ** i), base_channels * (2 ** (i + 1)),
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
))
for _ in range(num_res_blocks):
sequence.append(
ResBlock(base_channels * (2 ** num_down_sampling), use_spectral_norm, padding_mode, norm_type,
activation_type)
)
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
return self.sequence(x)
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, num_up_sampling, num_res_blocks,
use_spectral_norm=False, res_norm_type="AdaIN", norm_type="LN", activation_type="ReLU",
padding_mode='reflect'):
super(Decoder, self).__init__()
self.res_norm_type = res_norm_type
self.res_blocks = nn.ModuleList([
ResBlock(in_channels, use_spectral_norm, padding_mode, res_norm_type, activation_type=activation_type)
for _ in range(num_res_blocks)
])
sequence = list()
channels = in_channels
for i in range(num_up_sampling):
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2,
kernel_size=5, stride=1, padding=2, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
),
))
channels = channels // 2
sequence.append(
Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect",
use_spectral_norm=use_spectral_norm, activation_type="Tanh", norm_type="NONE"))
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
for blk in self.res_blocks:
x = blk(x)
return self.sequence(x)
class Fusion(nn.Module):
def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"):
super().__init__()
norm_layer = select_norm_layer(norm_type)
self.start_fc = nn.Sequential(
nn.Linear(in_features, base_features),
norm_layer(base_features),
nn.ReLU(True),
)
self.fcs = nn.Sequential(*[
nn.Sequential(
nn.Linear(base_features, base_features),
norm_layer(base_features),
nn.ReLU(True),
) for _ in range(n_blocks - 2)
])
self.end_fc = nn.Sequential(
nn.Linear(base_features, out_features),
)
def forward(self, x):
x = self.start_fc(x)
x = self.fcs(x)
return self.end_fc(x)
@MODEL.register_module("MUNIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels, num_sampling, num_style_dim, num_style_conv,
num_content_res_blocks, num_decoder_res_blocks, num_fusion_dim, num_fusion_blocks,
use_spectral_norm=False, activation_type="ReLU", padding_mode='reflect'):
super().__init__()
self.num_decoder_res_blocks = num_decoder_res_blocks
self.content_encoder = ContentEncoder(in_channels, num_sampling, num_content_res_blocks, base_channels,
use_spectral_norm, padding_mode, activation_type, norm_type="IN")
self.style_encoder = StyleEncoder(in_channels, num_style_dim, num_style_conv, base_channels, use_spectral_norm,
padding_mode, activation_type, norm_type="NONE")
content_channels = base_channels * (2 ** 2)
self.decoder = Decoder(content_channels, out_channels, num_sampling,
num_decoder_res_blocks, use_spectral_norm, "AdaIN", norm_type="LN",
activation_type=activation_type, padding_mode=padding_mode)
self.fusion = Fusion(num_style_dim, num_decoder_res_blocks * 2 * content_channels * 2,
base_features=num_fusion_dim, n_blocks=num_fusion_blocks, norm_type="NONE")
def encode(self, x):
return self.content_encoder(x), self.style_encoder(x)
def decode(self, content, style):
as_param_style = torch.chunk(self.fusion(style), self.num_decoder_res_blocks * 2, dim=1)
# set style for decoder
for i, blk in enumerate(self.decoder.res_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
return self.decoder(content)
def forward(self, x):
content, style = self.encode(x)
return self.decode(content, style)

View File

@@ -1,10 +1,11 @@
import math
from functools import partial
import math
import torch
import torch.nn as nn
from model.normalization import select_norm_layer
from model import MODEL
from model.normalization import select_norm_layer
class GANImageBuffer(object):
@@ -137,3 +138,66 @@ class ResidualBlock(nn.Module):
x = self.relu1(self.norm1(self.conv1(x)))
x = self.norm2(self.conv2(x))
return x + res
_DO_NO_THING_FUNC = lambda x: x
def select_activation(t):
if t == "ReLU":
return partial(nn.ReLU, inplace=True)
elif t == "LeakyReLU":
return partial(nn.LeakyReLU, negative_slope=0.2, inplace=True)
elif t == "Tanh":
return partial(nn.Tanh)
elif t == "NONE":
return _DO_NO_THING_FUNC
else:
raise NotImplemented
def _use_bias_checker(norm_type):
return norm_type not in ["IN", "BN", "AdaIN"]
class Conv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, use_spectral_norm=False, activation_type="ReLU",
bias=None, norm_type="NONE", **conv_kwargs):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
conv = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
self.convolution = nn.utils.spectral_norm(conv) if use_spectral_norm else conv
if norm_type != "NONE":
self.normalization = select_norm_layer(norm_type)(out_channels)
if activation_type != "NONE":
self.activation = select_activation(activation_type)()
def forward(self, x):
x = self.convolution(x)
if self.norm_type != "NONE":
x = self.normalization(x)
if self.activation_type != "NONE":
x = self.activation(x)
return x
class ResBlock(nn.Module):
def __init__(self, num_channels, use_spectral_norm=False, padding_mode='reflect',
norm_type="IN", activation_type="relu", use_bias=None):
super().__init__()
self.norm_type = norm_type
if use_bias is None:
# bias will be canceled after channel wise normalization
use_bias = _use_bias_checker(norm_type)
self.conv1 = Conv2dBlock(num_channels, num_channels, use_spectral_norm,
kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias,
norm_type=norm_type, activation_type=activation_type)
self.conv2 = Conv2dBlock(num_channels, num_channels, use_spectral_norm,
kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias,
norm_type=norm_type, activation_type="NONE")
def forward(self, x):
return self.conv2(self.conv1(x)) + x

View File

@@ -4,4 +4,5 @@ import model.GAN.TAFG
import model.GAN.UGATIT
import model.GAN.wrapper
import model.GAN.base
import model.GAN.TSIT
import model.GAN.TSIT
import model.GAN.MUNIT