TAFG good result
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .base import ResidualBlock
|
||||
from model.registry import MODEL
|
||||
from torchvision.models import vgg19
|
||||
|
||||
from model.normalization import select_norm_layer
|
||||
from model.registry import MODEL
|
||||
from .base import ResidualBlock
|
||||
|
||||
|
||||
class VGG19StyleEncoder(nn.Module):
|
||||
@@ -169,25 +170,37 @@ class StyleGenerator(nn.Module):
|
||||
|
||||
@MODEL.register_module("TAFG-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, num_blocks=8,
|
||||
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512,
|
||||
num_adain_blocks=8, num_res_blocks=4,
|
||||
base_channels=64, padding_mode="reflect"):
|
||||
super(Generator, self).__init__()
|
||||
self.num_blocks = num_blocks
|
||||
self.num_adain_blocks=num_adain_blocks
|
||||
self.style_encoders = nn.ModuleDict({
|
||||
"a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks,
|
||||
"a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks,
|
||||
base_channels=base_channels, padding_mode=padding_mode),
|
||||
"b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks,
|
||||
"b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks,
|
||||
base_channels=base_channels, padding_mode=padding_mode),
|
||||
})
|
||||
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks,
|
||||
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=8,
|
||||
padding_mode=padding_mode, norm_type="IN")
|
||||
res_block_channels = 2 ** 2 * base_channels
|
||||
self.adain_resnet_a = nn.ModuleList([
|
||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
|
||||
])
|
||||
self.adain_resnet_b = nn.ModuleList([
|
||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
|
||||
])
|
||||
|
||||
self.resnet = nn.ModuleDict({
|
||||
"a": nn.Sequential(*[
|
||||
ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks)
|
||||
]),
|
||||
"b": nn.Sequential(*[
|
||||
ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks)
|
||||
])
|
||||
})
|
||||
self.adain_resnet = nn.ModuleDict({
|
||||
"a": nn.ModuleList([
|
||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks)
|
||||
]),
|
||||
"b": nn.ModuleList([
|
||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks)
|
||||
])
|
||||
})
|
||||
|
||||
self.decoders = nn.ModuleDict({
|
||||
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode),
|
||||
@@ -196,10 +209,10 @@ class Generator(nn.Module):
|
||||
|
||||
def forward(self, content_img, style_img, which_decoder: str = "a"):
|
||||
x = self.content_encoder(content_img)
|
||||
x = self.resnet[which_decoder](x)
|
||||
styles = self.style_encoders[which_decoder](style_img)
|
||||
styles = torch.chunk(styles, self.num_blocks * 2, dim=1)
|
||||
resnet = self.adain_resnet_a if which_decoder == "a" else self.adain_resnet_b
|
||||
for i, ar in enumerate(resnet):
|
||||
styles = torch.chunk(styles, self.num_adain_blocks * 2, dim=1)
|
||||
for i, ar in enumerate(self.adain_resnet[which_decoder]):
|
||||
ar.norm1.set_style(styles[2 * i])
|
||||
ar.norm2.set_style(styles[2 * i + 1])
|
||||
x = ar(x)
|
||||
|
||||
Reference in New Issue
Block a user