move encoder, decoder to CycleGAN
This commit is contained in:
@@ -2,7 +2,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from model import MODEL
|
||||
from model.base.module import Conv2dBlock, ResidualBlock, LinearBlock
|
||||
from model.base.module import Conv2dBlock, LinearBlock
|
||||
from model.image_translation.CycleGAN import Encoder, Decoder
|
||||
|
||||
|
||||
class RhoClipper(object):
|
||||
@@ -46,27 +47,11 @@ class Generator(nn.Module):
|
||||
|
||||
self.light = light
|
||||
|
||||
sequence = [Conv2dBlock(
|
||||
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=norm_type
|
||||
)]
|
||||
|
||||
n_down_sampling = 2
|
||||
for i in range(n_down_sampling):
|
||||
mult = 2 ** i
|
||||
sequence.append(Conv2dBlock(
|
||||
base_channels * mult, base_channels * mult * 2,
|
||||
kernel_size=3, stride=2, padding=1, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=norm_type
|
||||
))
|
||||
|
||||
self.encoder = Encoder(in_channels, base_channels, n_down_sampling, num_blocks,
|
||||
padding_mode=padding_mode, activation_type=activation_type,
|
||||
down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type)
|
||||
mult = 2 ** n_down_sampling
|
||||
sequence += [
|
||||
ResidualBlock(base_channels * mult, base_channels * mult, padding_mode, activation_type=activation_type,
|
||||
norm_type=norm_type)
|
||||
for _ in range(num_blocks)]
|
||||
self.encoder = nn.Sequential(*sequence)
|
||||
|
||||
self.cam = CAMClassifier(base_channels * mult, activation_type)
|
||||
|
||||
# Gamma, Beta block
|
||||
@@ -85,25 +70,12 @@ class Generator(nn.Module):
|
||||
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
||||
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
||||
|
||||
# Up-Sampling Bottleneck
|
||||
self.up_bottleneck = nn.ModuleList(
|
||||
[ResidualBlock(base_channels * mult, base_channels * mult, padding_mode,
|
||||
activation_type, norm_type="AdaILN") for _ in range(num_blocks)])
|
||||
|
||||
sequence = list()
|
||||
channels = base_channels * mult
|
||||
for i in range(n_down_sampling):
|
||||
sequence.append(nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
Conv2dBlock(channels, channels // 2,
|
||||
kernel_size=3, stride=1, padding=1, bias=False, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type="ILN"),
|
||||
))
|
||||
channels = channels // 2
|
||||
sequence.append(Conv2dBlock(channels, out_channels,
|
||||
kernel_size=7, stride=1, padding=3, padding_mode="reflect",
|
||||
activation_type="Tanh", norm_type="NONE"))
|
||||
self.decoder = nn.Sequential(*sequence)
|
||||
self.decoder = Decoder(
|
||||
base_channels * mult, out_channels, n_down_sampling, num_blocks,
|
||||
activation_type=activation_type, padding_mode=padding_mode,
|
||||
up_conv_kernel_size=3, up_conv_norm_type="ILN",
|
||||
res_norm_type="AdaILN"
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
@@ -119,10 +91,9 @@ class Generator(nn.Module):
|
||||
x_ = self.fc(x.view(x.shape[0], -1))
|
||||
gamma, beta = self.gamma(x_), self.beta(x_)
|
||||
|
||||
for blk in self.up_bottleneck:
|
||||
for blk in self.decoder.residual_blocks:
|
||||
blk.conv1.normalization.set_condition(gamma, beta)
|
||||
blk.conv2.normalization.set_condition(gamma, beta)
|
||||
x = blk(x)
|
||||
return self.decoder(x), cam_logit, heatmap
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user