This commit is contained in:
2020-09-25 18:31:12 +08:00
parent fbea96f6d7
commit acf243cb12
11 changed files with 542 additions and 115 deletions

View File

@@ -3,7 +3,6 @@ import torch.nn as nn
import torch.nn.functional as F
from model import MODEL
from model.normalization import AdaptiveInstanceNorm2d
from model.normalization import select_norm_layer
@@ -62,7 +61,9 @@ class Interpolation(nn.Module):
class FADE(nn.Module):
def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True):
super().__init__()
self.bn = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
# self.norm = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
self.norm = nn.InstanceNorm2d(num_features=in_channels)
self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
padding_mode="zeros")
self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
@@ -71,7 +72,7 @@ class FADE(nn.Module):
def forward(self, x, feature):
alpha = self.alpha_conv(feature)
beta = self.beta_conv(feature)
x = self.bn(x)
x = self.norm(x)
return alpha * x + beta
@@ -122,9 +123,7 @@ class TSITGenerator(nn.Module):
self.use_spectral = use_spectral
self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type)
self.style_input_layer = self.build_input_layer(style_in_channels, base_channels, input_layer_type)
self.content_stream = self.build_stream()
self.style_stream = self.build_stream()
self.generator = self.build_generator()
self.end_conv = nn.Sequential(
conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"),
@@ -138,11 +137,9 @@ class TSITGenerator(nn.Module):
m = self.num_blocks - i
multiple_prev = multiple_now
multiple_now = min(2 ** m, 2 ** 4)
stream_sequence.append(nn.Sequential(
AdaptiveInstanceNorm2d(multiple_prev * self.base_channels),
stream_sequence.append(
FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels,
multiple_now * self.base_channels)
))
multiple_now * self.base_channels))
return nn.ModuleList(stream_sequence)
def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"):
@@ -171,22 +168,16 @@ class TSITGenerator(nn.Module):
))
return nn.ModuleList(stream_sequence)
def forward(self, content_img, style_img):
def forward(self, content_img):
c = self.content_input_layer(content_img)
s = self.style_input_layer(style_img)
content_features = []
style_features = []
for i in range(self.num_blocks):
s = self.style_stream[i](s)
c = self.content_stream[i](c)
content_features.append(c)
style_features.append(s)
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
for i in range(self.num_blocks):
m = - i - 1
layer = self.generator[i]
layer[0].set_style(torch.cat(torch.std_mean(style_features[m], dim=[2, 3]), dim=1))
z = layer[0](z)
z = layer[1](z, content_features[m])
z = layer(z, content_features[m])
return self.end_conv(z)