TAFG update
This commit is contained in:
@@ -8,7 +8,7 @@ from model.normalization import select_norm_layer
|
||||
|
||||
class StyleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, out_dim, num_conv, base_channels=64, use_spectral_norm=False,
|
||||
padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
|
||||
max_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
|
||||
super(StyleEncoder, self).__init__()
|
||||
|
||||
sequence = [Conv2dBlock(
|
||||
@@ -19,7 +19,7 @@ class StyleEncoder(nn.Module):
|
||||
multiple_now = 1
|
||||
for i in range(1, num_conv + 1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** 2)
|
||||
multiple_now = min(2 ** i, 2 ** max_multiple)
|
||||
sequence.append(Conv2dBlock(
|
||||
multiple_prev * base_channels, multiple_now * base_channels,
|
||||
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
|
||||
@@ -50,12 +50,8 @@ class ContentEncoder(nn.Module):
|
||||
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
|
||||
))
|
||||
|
||||
for _ in range(num_res_blocks):
|
||||
sequence.append(
|
||||
ResBlock(base_channels * (2 ** num_down_sampling), use_spectral_norm, padding_mode, norm_type,
|
||||
activation_type)
|
||||
)
|
||||
|
||||
sequence += [ResBlock(base_channels * (2 ** num_down_sampling), use_spectral_norm, padding_mode, norm_type,
|
||||
activation_type) for _ in range(num_res_blocks)]
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
@@ -4,7 +4,7 @@ from torchvision.models import vgg19
|
||||
|
||||
from model.normalization import select_norm_layer
|
||||
from model.registry import MODEL
|
||||
from .MUNIT import ContentEncoder, Fusion, Decoder
|
||||
from .MUNIT import ContentEncoder, Fusion, Decoder, StyleEncoder
|
||||
from .base import ResBlock
|
||||
|
||||
|
||||
@@ -56,17 +56,26 @@ class VGG19StyleEncoder(nn.Module):
|
||||
@MODEL.register_module("TAFG-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,
|
||||
style_dim=512, style_use_fc=True,
|
||||
num_adain_blocks=8, num_res_blocks=8,
|
||||
base_channels=64, padding_mode="reflect"):
|
||||
style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8,
|
||||
num_res_blocks=8, base_channels=64, padding_mode="reflect"):
|
||||
super(Generator, self).__init__()
|
||||
self.num_adain_blocks = num_adain_blocks
|
||||
self.style_encoders = nn.ModuleDict(dict(
|
||||
a=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
|
||||
norm_type="NONE"),
|
||||
b=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
|
||||
norm_type="NONE", fix_vgg19=False)
|
||||
))
|
||||
if style_encoder_type == "StyleEncoder":
|
||||
self.style_encoders = nn.ModuleDict(dict(
|
||||
a=StyleEncoder(style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
|
||||
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"),
|
||||
b=StyleEncoder(style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
|
||||
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"),
|
||||
))
|
||||
elif style_encoder_type == "VGG19StyleEncoder":
|
||||
self.style_encoders = nn.ModuleDict(dict(
|
||||
a=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
|
||||
norm_type="NONE"),
|
||||
b=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
|
||||
norm_type="NONE", fix_vgg19=False)
|
||||
))
|
||||
else:
|
||||
raise NotImplemented(f"do not support {style_encoder_type}")
|
||||
resnet_channels = 2 ** 2 * base_channels
|
||||
self.style_converters = nn.ModuleDict(dict(
|
||||
a=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
|
||||
|
||||
Reference in New Issue
Block a user