TAFG
This commit is contained in:
@@ -4,16 +4,17 @@ from torchvision.models import vgg19
|
||||
|
||||
from model.normalization import select_norm_layer
|
||||
from model.registry import MODEL
|
||||
from .base import ResidualBlock
|
||||
from .MUNIT import ContentEncoder, Fusion, Decoder
|
||||
from .base import ResBlock
|
||||
|
||||
|
||||
class VGG19StyleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, style_dim=512, padding_mode='reflect', norm_type="NONE",
|
||||
vgg19_layers=(0, 5, 10, 19)):
|
||||
vgg19_layers=(0, 5, 10, 19), fix_vgg19=True):
|
||||
super().__init__()
|
||||
self.vgg19_layers = vgg19_layers
|
||||
self.vgg19 = vgg19(pretrained=True).features[:vgg19_layers[-1] + 1]
|
||||
self.vgg19.requires_grad_(False)
|
||||
self.vgg19.requires_grad_(not fix_vgg19)
|
||||
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
|
||||
@@ -52,203 +53,57 @@ class VGG19StyleEncoder(nn.Module):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
class ContentEncoder(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, num_blocks=8, padding_mode='reflect', norm_type="IN"):
|
||||
super().__init__()
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
|
||||
self.start_conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
|
||||
bias=True),
|
||||
norm_layer(num_features=base_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
# down sampling
|
||||
submodules = []
|
||||
num_down_sampling = 2
|
||||
for i in range(num_down_sampling):
|
||||
multiple = 2 ** i
|
||||
submodules += [
|
||||
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
|
||||
kernel_size=4, stride=2, padding=1, bias=True),
|
||||
norm_layer(num_features=base_channels * multiple * 2),
|
||||
nn.ReLU(inplace=True)
|
||||
]
|
||||
self.encoder = nn.Sequential(*submodules)
|
||||
res_block_channels = num_down_sampling ** 2 * base_channels
|
||||
self.resnet = nn.Sequential(
|
||||
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.start_conv(x)
|
||||
x = self.encoder(x)
|
||||
x = self.resnet(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, out_channels, base_channels=64, num_blocks=4, num_down_sampling=2, padding_mode='reflect',
|
||||
norm_type="LN"):
|
||||
super(Decoder, self).__init__()
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
use_bias = norm_type == "IN"
|
||||
|
||||
res_block_channels = (2 ** 2) * base_channels
|
||||
|
||||
self.resnet = nn.Sequential(
|
||||
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
|
||||
|
||||
# up sampling
|
||||
submodules = []
|
||||
for i in range(num_down_sampling):
|
||||
multiple = 2 ** (num_down_sampling - i)
|
||||
submodules += [
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=5, stride=1,
|
||||
padding=2, padding_mode=padding_mode, bias=use_bias),
|
||||
norm_layer(num_features=base_channels * multiple // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
self.decoder = nn.Sequential(*submodules)
|
||||
self.end_conv = nn.Sequential(
|
||||
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.resnet(x)
|
||||
x = self.decoder(x)
|
||||
x = self.end_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Fusion(nn.Module):
|
||||
def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"):
|
||||
super().__init__()
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
self.start_fc = nn.Sequential(
|
||||
nn.Linear(in_features, base_features),
|
||||
norm_layer(base_features),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
self.fcs = nn.Sequential(*[
|
||||
nn.Sequential(
|
||||
nn.Linear(base_features, base_features),
|
||||
norm_layer(base_features),
|
||||
nn.ReLU(True),
|
||||
) for _ in range(n_blocks - 2)
|
||||
])
|
||||
self.end_fc = nn.Sequential(
|
||||
nn.Linear(base_features, out_features),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.start_fc(x)
|
||||
x = self.fcs(x)
|
||||
return self.end_fc(x)
|
||||
|
||||
|
||||
class StyleGenerator(nn.Module):
|
||||
def __init__(self, style_in_channels, style_dim=512, num_blocks=8, base_channels=64, padding_mode="reflect"):
|
||||
super().__init__()
|
||||
self.num_blocks = num_blocks
|
||||
self.style_encoder = VGG19StyleEncoder(
|
||||
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE")
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(style_dim, style_dim),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
res_block_channels = 2 ** 2 * base_channels
|
||||
self.fusion = Fusion(style_dim, num_blocks * 2 * res_block_channels * 2, base_features=256, n_blocks=3,
|
||||
norm_type="NONE")
|
||||
|
||||
def forward(self, x):
|
||||
styles = self.fusion(self.fc(self.style_encoder(x)))
|
||||
return styles
|
||||
|
||||
|
||||
@MODEL.register_module("TAFG-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512,
|
||||
num_adain_blocks=8, num_res_blocks=4,
|
||||
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,
|
||||
style_dim=512, style_use_fc=True,
|
||||
num_adain_blocks=8, num_res_blocks=8,
|
||||
base_channels=64, padding_mode="reflect"):
|
||||
super(Generator, self).__init__()
|
||||
self.num_adain_blocks=num_adain_blocks
|
||||
self.style_encoders = nn.ModuleDict({
|
||||
"a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks,
|
||||
base_channels=base_channels, padding_mode=padding_mode),
|
||||
"b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks,
|
||||
base_channels=base_channels, padding_mode=padding_mode),
|
||||
})
|
||||
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=8,
|
||||
padding_mode=padding_mode, norm_type="IN")
|
||||
res_block_channels = 2 ** 2 * base_channels
|
||||
|
||||
self.resnet = nn.ModuleDict({
|
||||
"a": nn.Sequential(*[
|
||||
ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks)
|
||||
]),
|
||||
"b": nn.Sequential(*[
|
||||
ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks)
|
||||
])
|
||||
})
|
||||
self.adain_resnet = nn.ModuleDict({
|
||||
"a": nn.ModuleList([
|
||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks)
|
||||
]),
|
||||
"b": nn.ModuleList([
|
||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks)
|
||||
])
|
||||
self.num_adain_blocks = num_adain_blocks
|
||||
self.style_encoders = nn.ModuleDict(dict(
|
||||
a=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
|
||||
norm_type="NONE"),
|
||||
b=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
|
||||
norm_type="NONE", fix_vgg19=False)
|
||||
))
|
||||
resnet_channels = 2 ** 2 * base_channels
|
||||
self.style_converters = nn.ModuleDict(dict(
|
||||
a=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
|
||||
norm_type="NONE"),
|
||||
b=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
|
||||
norm_type="NONE"),
|
||||
))
|
||||
self.content_encoders = nn.ModuleDict({
|
||||
"a": ContentEncoder(content_in_channels, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm),
|
||||
"b": ContentEncoder(1, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm)
|
||||
})
|
||||
|
||||
self.decoders = nn.ModuleDict({
|
||||
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode),
|
||||
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode)
|
||||
})
|
||||
self.content_resnet = nn.Sequential(*[
|
||||
ResBlock(resnet_channels, use_spectral_norm, padding_mode, "IN")
|
||||
for _ in range(num_res_blocks)
|
||||
])
|
||||
self.decoders = nn.ModuleDict(dict(
|
||||
a=Decoder(resnet_channels, out_channels, 2,
|
||||
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode),
|
||||
b=Decoder(resnet_channels, out_channels, 2,
|
||||
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode),
|
||||
))
|
||||
|
||||
def forward(self, content_img, style_img, which_decoder: str = "a"):
|
||||
x = self.content_encoder(content_img)
|
||||
x = self.resnet[which_decoder](x)
|
||||
styles = self.style_encoders[which_decoder](style_img)
|
||||
styles = torch.chunk(styles, self.num_adain_blocks * 2, dim=1)
|
||||
for i, ar in enumerate(self.adain_resnet[which_decoder]):
|
||||
ar.norm1.set_style(styles[2 * i])
|
||||
ar.norm2.set_style(styles[2 * i + 1])
|
||||
x = ar(x)
|
||||
return self.decoders[which_decoder](x)
|
||||
def encode(self, content_img, style_img, which_content, which_style):
|
||||
content = self.content_resnet(self.content_encoders[which_content](content_img))
|
||||
style = self.style_encoders[which_style](style_img)
|
||||
return content, style
|
||||
|
||||
def decode(self, content, style, which):
|
||||
decoder = self.decoders[which]
|
||||
as_param_style = torch.chunk(self.style_converters[which](style), self.num_adain_blocks * 2, dim=1)
|
||||
# set style for decoder
|
||||
for i, blk in enumerate(decoder.res_blocks):
|
||||
blk.conv1.normalization.set_style(as_param_style[2 * i])
|
||||
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
|
||||
return decoder(content)
|
||||
|
||||
@MODEL.register_module("TAFG-Discriminator")
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels=3, base_channels=64, num_down_sampling=2, num_blocks=3, norm_type="IN",
|
||||
padding_mode="reflect"):
|
||||
super(Discriminator, self).__init__()
|
||||
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
use_bias = norm_type == "IN"
|
||||
|
||||
sequence = [nn.Sequential(
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
|
||||
bias=use_bias),
|
||||
norm_layer(num_features=base_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)]
|
||||
# stacked intermediate layers,
|
||||
# gradually increasing the number of filters
|
||||
multiple_now = 1
|
||||
for n in range(1, num_down_sampling + 1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** n, 4)
|
||||
sequence += [
|
||||
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=3,
|
||||
padding=1, stride=2, bias=use_bias),
|
||||
norm_layer(base_channels * multiple_now),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
]
|
||||
for _ in range(num_blocks):
|
||||
sequence.append(ResidualBlock(base_channels * multiple_now, padding_mode, norm_type))
|
||||
self.model = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
def forward(self, content_img, style_img, which_content, which_style):
|
||||
content, style = self.encode(content_img, style_img, which_content, which_style)
|
||||
return self.decode(content, style, which_style)
|
||||
|
||||
@@ -185,7 +185,7 @@ class Conv2dBlock(nn.Module):
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, num_channels, use_spectral_norm=False, padding_mode='reflect',
|
||||
norm_type="IN", activation_type="relu", use_bias=None):
|
||||
norm_type="IN", activation_type="ReLU", use_bias=None):
|
||||
super().__init__()
|
||||
self.norm_type = norm_type
|
||||
if use_bias is None:
|
||||
|
||||
Reference in New Issue
Block a user