working
This commit is contained in:
@@ -53,6 +53,59 @@ class VGG19StyleEncoder(nn.Module):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
@MODEL.register_module("TAFG-ResGenerator")
|
||||
class ResGenerator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=3, use_spectral_norm=False, num_res_blocks=8, base_channels=64):
|
||||
super().__init__()
|
||||
self.content_encoder = ContentEncoder(in_channels, 2, num_res_blocks=num_res_blocks,
|
||||
use_spectral_norm=use_spectral_norm)
|
||||
resnet_channels = 2 ** 2 * base_channels
|
||||
self.decoder = Decoder(resnet_channels, out_channels, 2,
|
||||
0, use_spectral_norm, "IN", norm_type="LN", padding_mode="reflect")
|
||||
|
||||
def forward(self, x):
|
||||
return self.decoder(self.content_encoder(x))
|
||||
|
||||
|
||||
@MODEL.register_module("TAFG-SingleGenerator")
|
||||
class SingleGenerator(nn.Module):
|
||||
def __init__(self, style_in_channels, content_in_channels, out_channels=3, use_spectral_norm=False,
|
||||
style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8,
|
||||
num_res_blocks=8, base_channels=64, padding_mode="reflect"):
|
||||
super().__init__()
|
||||
self.num_adain_blocks = num_adain_blocks
|
||||
if style_encoder_type == "StyleEncoder":
|
||||
self.style_encoder = StyleEncoder(
|
||||
style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
|
||||
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"
|
||||
)
|
||||
elif style_encoder_type == "VGG19StyleEncoder":
|
||||
self.style_encoder = VGG19StyleEncoder(
|
||||
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE"
|
||||
)
|
||||
else:
|
||||
raise NotImplemented(f"do not support {style_encoder_type}")
|
||||
|
||||
resnet_channels = 2 ** 2 * base_channels
|
||||
self.style_converter = Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256,
|
||||
n_blocks=3, norm_type="NONE")
|
||||
self.content_encoder = ContentEncoder(content_in_channels, 2, num_res_blocks=num_res_blocks,
|
||||
use_spectral_norm=use_spectral_norm)
|
||||
|
||||
self.decoder = Decoder(resnet_channels, out_channels, 2,
|
||||
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode)
|
||||
|
||||
def forward(self, content_img, style_img):
|
||||
content = self.content_encoder(content_img)
|
||||
style = self.style_encoder(style_img)
|
||||
as_param_style = torch.chunk(self.style_converter(style), self.num_adain_blocks * 2, dim=1)
|
||||
# set style for decoder
|
||||
for i, blk in enumerate(self.decoder.res_blocks):
|
||||
blk.conv1.normalization.set_style(as_param_style[2 * i])
|
||||
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
|
||||
return self.decoder(content)
|
||||
|
||||
|
||||
@MODEL.register_module("TAFG-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,
|
||||
|
||||
@@ -3,7 +3,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model import MODEL
|
||||
from model.normalization import AdaptiveInstanceNorm2d
|
||||
from model.normalization import select_norm_layer
|
||||
|
||||
|
||||
@@ -62,7 +61,9 @@ class Interpolation(nn.Module):
|
||||
class FADE(nn.Module):
|
||||
def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True):
|
||||
super().__init__()
|
||||
self.bn = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
|
||||
# self.norm = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
|
||||
self.norm = nn.InstanceNorm2d(num_features=in_channels)
|
||||
|
||||
self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
|
||||
padding_mode="zeros")
|
||||
self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
|
||||
@@ -71,7 +72,7 @@ class FADE(nn.Module):
|
||||
def forward(self, x, feature):
|
||||
alpha = self.alpha_conv(feature)
|
||||
beta = self.beta_conv(feature)
|
||||
x = self.bn(x)
|
||||
x = self.norm(x)
|
||||
return alpha * x + beta
|
||||
|
||||
|
||||
@@ -122,9 +123,7 @@ class TSITGenerator(nn.Module):
|
||||
self.use_spectral = use_spectral
|
||||
|
||||
self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type)
|
||||
self.style_input_layer = self.build_input_layer(style_in_channels, base_channels, input_layer_type)
|
||||
self.content_stream = self.build_stream()
|
||||
self.style_stream = self.build_stream()
|
||||
self.generator = self.build_generator()
|
||||
self.end_conv = nn.Sequential(
|
||||
conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"),
|
||||
@@ -138,11 +137,9 @@ class TSITGenerator(nn.Module):
|
||||
m = self.num_blocks - i
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** m, 2 ** 4)
|
||||
stream_sequence.append(nn.Sequential(
|
||||
AdaptiveInstanceNorm2d(multiple_prev * self.base_channels),
|
||||
stream_sequence.append(
|
||||
FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels,
|
||||
multiple_now * self.base_channels)
|
||||
))
|
||||
multiple_now * self.base_channels))
|
||||
return nn.ModuleList(stream_sequence)
|
||||
|
||||
def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"):
|
||||
@@ -171,22 +168,16 @@ class TSITGenerator(nn.Module):
|
||||
))
|
||||
return nn.ModuleList(stream_sequence)
|
||||
|
||||
def forward(self, content_img, style_img):
|
||||
def forward(self, content_img):
|
||||
c = self.content_input_layer(content_img)
|
||||
s = self.style_input_layer(style_img)
|
||||
content_features = []
|
||||
style_features = []
|
||||
for i in range(self.num_blocks):
|
||||
s = self.style_stream[i](s)
|
||||
c = self.content_stream[i](c)
|
||||
content_features.append(c)
|
||||
style_features.append(s)
|
||||
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
|
||||
|
||||
for i in range(self.num_blocks):
|
||||
m = - i - 1
|
||||
layer = self.generator[i]
|
||||
layer[0].set_style(torch.cat(torch.std_mean(style_features[m], dim=[2, 3]), dim=1))
|
||||
z = layer[0](z)
|
||||
z = layer[1](z, content_features[m])
|
||||
z = layer(z, content_features[m])
|
||||
return self.end_conv(z)
|
||||
|
||||
Reference in New Issue
Block a user