TAFG update

This commit is contained in:
2020-09-18 12:03:44 +08:00
parent 61e04de8a5
commit b01016edb5
6 changed files with 91 additions and 59 deletions

View File

@@ -8,7 +8,7 @@ from model.normalization import select_norm_layer
class StyleEncoder(nn.Module):
def __init__(self, in_channels, out_dim, num_conv, base_channels=64, use_spectral_norm=False,
padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
max_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
super(StyleEncoder, self).__init__()
sequence = [Conv2dBlock(
@@ -19,7 +19,7 @@ class StyleEncoder(nn.Module):
multiple_now = 1
for i in range(1, num_conv + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 2)
multiple_now = min(2 ** i, 2 ** max_multiple)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
@@ -50,12 +50,8 @@ class ContentEncoder(nn.Module):
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
))
for _ in range(num_res_blocks):
sequence.append(
ResBlock(base_channels * (2 ** num_down_sampling), use_spectral_norm, padding_mode, norm_type,
activation_type)
)
sequence += [ResBlock(base_channels * (2 ** num_down_sampling), use_spectral_norm, padding_mode, norm_type,
activation_type) for _ in range(num_res_blocks)]
self.sequence = nn.Sequential(*sequence)
def forward(self, x):

View File

@@ -4,7 +4,7 @@ from torchvision.models import vgg19
from model.normalization import select_norm_layer
from model.registry import MODEL
from .MUNIT import ContentEncoder, Fusion, Decoder
from .MUNIT import ContentEncoder, Fusion, Decoder, StyleEncoder
from .base import ResBlock
@@ -56,17 +56,26 @@ class VGG19StyleEncoder(nn.Module):
@MODEL.register_module("TAFG-Generator")
class Generator(nn.Module):
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,
style_dim=512, style_use_fc=True,
num_adain_blocks=8, num_res_blocks=8,
base_channels=64, padding_mode="reflect"):
style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8,
num_res_blocks=8, base_channels=64, padding_mode="reflect"):
super(Generator, self).__init__()
self.num_adain_blocks = num_adain_blocks
self.style_encoders = nn.ModuleDict(dict(
a=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
norm_type="NONE"),
b=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
norm_type="NONE", fix_vgg19=False)
))
if style_encoder_type == "StyleEncoder":
self.style_encoders = nn.ModuleDict(dict(
a=StyleEncoder(style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"),
b=StyleEncoder(style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"),
))
elif style_encoder_type == "VGG19StyleEncoder":
self.style_encoders = nn.ModuleDict(dict(
a=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
norm_type="NONE"),
b=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
norm_type="NONE", fix_vgg19=False)
))
else:
raise NotImplemented(f"do not support {style_encoder_type}")
resnet_channels = 2 ** 2 * base_channels
self.style_converters = nn.ModuleDict(dict(
a=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,