This commit is contained in:
2020-09-25 18:31:12 +08:00
parent fbea96f6d7
commit acf243cb12
11 changed files with 542 additions and 115 deletions

View File

@@ -53,6 +53,59 @@ class VGG19StyleEncoder(nn.Module):
return x.view(x.size(0), -1)
@MODEL.register_module("TAFG-ResGenerator")
class ResGenerator(nn.Module):
def __init__(self, in_channels, out_channels=3, use_spectral_norm=False, num_res_blocks=8, base_channels=64):
super().__init__()
self.content_encoder = ContentEncoder(in_channels, 2, num_res_blocks=num_res_blocks,
use_spectral_norm=use_spectral_norm)
resnet_channels = 2 ** 2 * base_channels
self.decoder = Decoder(resnet_channels, out_channels, 2,
0, use_spectral_norm, "IN", norm_type="LN", padding_mode="reflect")
def forward(self, x):
return self.decoder(self.content_encoder(x))
@MODEL.register_module("TAFG-SingleGenerator")
class SingleGenerator(nn.Module):
def __init__(self, style_in_channels, content_in_channels, out_channels=3, use_spectral_norm=False,
style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8,
num_res_blocks=8, base_channels=64, padding_mode="reflect"):
super().__init__()
self.num_adain_blocks = num_adain_blocks
if style_encoder_type == "StyleEncoder":
self.style_encoder = StyleEncoder(
style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"
)
elif style_encoder_type == "VGG19StyleEncoder":
self.style_encoder = VGG19StyleEncoder(
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE"
)
else:
raise NotImplemented(f"do not support {style_encoder_type}")
resnet_channels = 2 ** 2 * base_channels
self.style_converter = Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256,
n_blocks=3, norm_type="NONE")
self.content_encoder = ContentEncoder(content_in_channels, 2, num_res_blocks=num_res_blocks,
use_spectral_norm=use_spectral_norm)
self.decoder = Decoder(resnet_channels, out_channels, 2,
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode)
def forward(self, content_img, style_img):
content = self.content_encoder(content_img)
style = self.style_encoder(style_img)
as_param_style = torch.chunk(self.style_converter(style), self.num_adain_blocks * 2, dim=1)
# set style for decoder
for i, blk in enumerate(self.decoder.res_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
return self.decoder(content)
@MODEL.register_module("TAFG-Generator")
class Generator(nn.Module):
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,

View File

@@ -3,7 +3,6 @@ import torch.nn as nn
import torch.nn.functional as F
from model import MODEL
from model.normalization import AdaptiveInstanceNorm2d
from model.normalization import select_norm_layer
@@ -62,7 +61,9 @@ class Interpolation(nn.Module):
class FADE(nn.Module):
def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True):
super().__init__()
self.bn = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
# self.norm = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
self.norm = nn.InstanceNorm2d(num_features=in_channels)
self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
padding_mode="zeros")
self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
@@ -71,7 +72,7 @@ class FADE(nn.Module):
def forward(self, x, feature):
alpha = self.alpha_conv(feature)
beta = self.beta_conv(feature)
x = self.bn(x)
x = self.norm(x)
return alpha * x + beta
@@ -122,9 +123,7 @@ class TSITGenerator(nn.Module):
self.use_spectral = use_spectral
self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type)
self.style_input_layer = self.build_input_layer(style_in_channels, base_channels, input_layer_type)
self.content_stream = self.build_stream()
self.style_stream = self.build_stream()
self.generator = self.build_generator()
self.end_conv = nn.Sequential(
conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"),
@@ -138,11 +137,9 @@ class TSITGenerator(nn.Module):
m = self.num_blocks - i
multiple_prev = multiple_now
multiple_now = min(2 ** m, 2 ** 4)
stream_sequence.append(nn.Sequential(
AdaptiveInstanceNorm2d(multiple_prev * self.base_channels),
stream_sequence.append(
FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels,
multiple_now * self.base_channels)
))
multiple_now * self.base_channels))
return nn.ModuleList(stream_sequence)
def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"):
@@ -171,22 +168,16 @@ class TSITGenerator(nn.Module):
))
return nn.ModuleList(stream_sequence)
def forward(self, content_img, style_img):
def forward(self, content_img):
c = self.content_input_layer(content_img)
s = self.style_input_layer(style_img)
content_features = []
style_features = []
for i in range(self.num_blocks):
s = self.style_stream[i](s)
c = self.content_stream[i](c)
content_features.append(c)
style_features.append(s)
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
for i in range(self.num_blocks):
m = - i - 1
layer = self.generator[i]
layer[0].set_style(torch.cat(torch.std_mean(style_features[m], dim=[2, 3]), dim=1))
z = layer[0](z)
z = layer[1](z, content_features[m])
z = layer(z, content_features[m])
return self.end_conv(z)