move encoder, decoder to CycleGAN

This commit is contained in:
2020-10-11 11:09:16 +08:00
parent 04c6366c07
commit 9c08b4cd09
4 changed files with 106 additions and 128 deletions

View File

@@ -2,7 +2,8 @@ import torch
import torch.nn as nn
from model import MODEL
from model.base.module import Conv2dBlock, ResidualBlock, LinearBlock
from model.base.module import Conv2dBlock, LinearBlock
from model.image_translation.CycleGAN import Encoder, Decoder
class RhoClipper(object):
@@ -46,27 +47,11 @@ class Generator(nn.Module):
self.light = light
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
n_down_sampling = 2
for i in range(n_down_sampling):
mult = 2 ** i
sequence.append(Conv2dBlock(
base_channels * mult, base_channels * mult * 2,
kernel_size=3, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
))
self.encoder = Encoder(in_channels, base_channels, n_down_sampling, num_blocks,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type)
mult = 2 ** n_down_sampling
sequence += [
ResidualBlock(base_channels * mult, base_channels * mult, padding_mode, activation_type=activation_type,
norm_type=norm_type)
for _ in range(num_blocks)]
self.encoder = nn.Sequential(*sequence)
self.cam = CAMClassifier(base_channels * mult, activation_type)
# Gamma, Beta block
@@ -85,25 +70,12 @@ class Generator(nn.Module):
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
# Up-Sampling Bottleneck
self.up_bottleneck = nn.ModuleList(
[ResidualBlock(base_channels * mult, base_channels * mult, padding_mode,
activation_type, norm_type="AdaILN") for _ in range(num_blocks)])
sequence = list()
channels = base_channels * mult
for i in range(n_down_sampling):
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2,
kernel_size=3, stride=1, padding=1, bias=False, padding_mode=padding_mode,
activation_type=activation_type, norm_type="ILN"),
))
channels = channels // 2
sequence.append(Conv2dBlock(channels, out_channels,
kernel_size=7, stride=1, padding=3, padding_mode="reflect",
activation_type="Tanh", norm_type="NONE"))
self.decoder = nn.Sequential(*sequence)
self.decoder = Decoder(
base_channels * mult, out_channels, n_down_sampling, num_blocks,
activation_type=activation_type, padding_mode=padding_mode,
up_conv_kernel_size=3, up_conv_norm_type="ILN",
res_norm_type="AdaILN"
)
def forward(self, x):
x = self.encoder(x)
@@ -119,10 +91,9 @@ class Generator(nn.Module):
x_ = self.fc(x.view(x.shape[0], -1))
gamma, beta = self.gamma(x_), self.beta(x_)
for blk in self.up_bottleneck:
for blk in self.decoder.residual_blocks:
blk.conv1.normalization.set_condition(gamma, beta)
blk.conv2.normalization.set_condition(gamma, beta)
x = blk(x)
return self.decoder(x), cam_logit, heatmap