change
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .residual_generator import ResidualBlock
|
||||
from .base import ResidualBlock
|
||||
from model.registry import MODEL
|
||||
from torchvision.models import vgg19
|
||||
from model.normalization import select_norm_layer
|
||||
@@ -148,48 +148,65 @@ class Fusion(nn.Module):
|
||||
return self.end_fc(x)
|
||||
|
||||
|
||||
@MODEL.register_module("TAHG-Generator")
|
||||
class StyleGenerator(nn.Module):
|
||||
def __init__(self, style_in_channels, style_dim=512, num_blocks=8, base_channels=64, padding_mode="reflect"):
|
||||
super().__init__()
|
||||
self.num_blocks = num_blocks
|
||||
self.style_encoder = VGG19StyleEncoder(
|
||||
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE")
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(style_dim, style_dim),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
res_block_channels = 2 ** 2 * base_channels
|
||||
self.fusion = Fusion(style_dim, num_blocks * 2 * res_block_channels * 2, base_features=256, n_blocks=3,
|
||||
norm_type="NONE")
|
||||
|
||||
def forward(self, x):
|
||||
styles = self.fusion(self.fc(self.style_encoder(x)))
|
||||
return styles
|
||||
|
||||
|
||||
@MODEL.register_module("TAFG-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, num_blocks=8,
|
||||
base_channels=64, padding_mode="reflect"):
|
||||
super(Generator, self).__init__()
|
||||
self.num_blocks = num_blocks
|
||||
self.style_encoders = nn.ModuleDict({
|
||||
"a": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
|
||||
padding_mode=padding_mode, norm_type="NONE"),
|
||||
"b": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
|
||||
padding_mode=padding_mode, norm_type="NONE")
|
||||
"a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks,
|
||||
base_channels=base_channels, padding_mode=padding_mode),
|
||||
"b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks,
|
||||
base_channels=base_channels, padding_mode=padding_mode),
|
||||
})
|
||||
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks,
|
||||
padding_mode=padding_mode, norm_type="IN")
|
||||
res_block_channels = 2 ** 2 * base_channels
|
||||
self.adain_res = nn.ModuleList([
|
||||
self.adain_resnet_a = nn.ModuleList([
|
||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
|
||||
])
|
||||
self.adain_resnet_b = nn.ModuleList([
|
||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
|
||||
])
|
||||
self.decoders = nn.ModuleDict({
|
||||
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode),
|
||||
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode)
|
||||
})
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(style_dim, style_dim),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
self.fusion = Fusion(style_dim, num_blocks * 2 * res_block_channels * 2, base_features=256, n_blocks=3,
|
||||
norm_type="NONE")
|
||||
self.decoders = nn.ModuleDict({
|
||||
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode),
|
||||
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode)
|
||||
})
|
||||
|
||||
def forward(self, content_img, style_img, which_decoder: str = "a"):
|
||||
x = self.content_encoder(content_img)
|
||||
styles = self.fusion(self.fc(self.style_encoders[which_decoder](style_img)))
|
||||
styles = self.style_encoders[which_decoder](style_img)
|
||||
styles = torch.chunk(styles, self.num_blocks * 2, dim=1)
|
||||
for i, ar in enumerate(self.adain_res):
|
||||
resnet = self.adain_resnet_a if which_decoder == "a" else self.adain_resnet_b
|
||||
for i, ar in enumerate(resnet):
|
||||
ar.norm1.set_style(styles[2 * i])
|
||||
ar.norm2.set_style(styles[2 * i + 1])
|
||||
x = ar(x)
|
||||
return self.decoders[which_decoder](x)
|
||||
|
||||
|
||||
@MODEL.register_module("TAHG-Discriminator")
|
||||
@MODEL.register_module("TAFG-Discriminator")
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels=3, base_channels=64, num_down_sampling=2, num_blocks=3, norm_type="IN",
|
||||
padding_mode="reflect"):
|
||||
@@ -7,7 +7,7 @@ from model import MODEL
|
||||
|
||||
|
||||
# based SPADE or pix2pixHD Discriminator
|
||||
@MODEL.register_module("base-PatchDiscriminator")
|
||||
@MODEL.register_module("pix2pixHD-PatchDiscriminator")
|
||||
class PatchDiscriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN",
|
||||
need_intermediate_feature=False):
|
||||
@@ -59,3 +59,26 @@ class PatchDiscriminator(nn.Module):
|
||||
for layer in self.conv_blocks:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None):
|
||||
super(ResidualBlock, self).__init__()
|
||||
if use_bias is None:
|
||||
# Only for IN, use bias since it does not have affine parameters.
|
||||
use_bias = norm_type == "IN"
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
bias=use_bias)
|
||||
self.norm1 = norm_layer(num_channels)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
bias=use_bias)
|
||||
self.norm2 = norm_layer(num_channels)
|
||||
|
||||
def forward(self, x):
|
||||
res = x
|
||||
x = self.relu1(self.norm1(self.conv1(x)))
|
||||
x = self.norm2(self.conv2(x))
|
||||
return x + res
|
||||
|
||||
@@ -58,27 +58,29 @@ class GANImageBuffer(object):
|
||||
return return_images
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None):
|
||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False, use_bias=None):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
if use_bias is None:
|
||||
# Only for IN, use bias since it does not have affine parameters.
|
||||
use_bias = norm_type == "IN"
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
bias=use_bias)
|
||||
self.norm1 = norm_layer(num_channels)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
bias=use_bias)
|
||||
self.norm2 = norm_layer(num_channels)
|
||||
models = [nn.Sequential(
|
||||
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
|
||||
norm_layer(num_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)]
|
||||
if use_dropout:
|
||||
models.append(nn.Dropout(0.5))
|
||||
models.append(nn.Sequential(
|
||||
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
|
||||
norm_layer(num_channels),
|
||||
))
|
||||
self.block = nn.Sequential(*models)
|
||||
|
||||
def forward(self, x):
|
||||
res = x
|
||||
x = self.relu1(self.norm1(self.conv1(x)))
|
||||
x = self.norm2(self.conv2(x))
|
||||
return x + res
|
||||
return x + self.block(x)
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from model.registry import MODEL
|
||||
import model.GAN.residual_generator
|
||||
import model.GAN.TAHG
|
||||
import model.GAN.TAFG
|
||||
import model.GAN.UGATIT
|
||||
import model.fewshot
|
||||
import model.GAN.wrapper
|
||||
|
||||
Reference in New Issue
Block a user