base
This commit is contained in:
61
model/GAN/base.py
Normal file
61
model/GAN/base.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import math
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from model.normalization import select_norm_layer
|
||||
from model import MODEL
|
||||
|
||||
|
||||
# based SPADE or pix2pixHD Discriminator
|
||||
@MODEL.register_module("base-PatchDiscriminator")
|
||||
class PatchDiscriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN",
|
||||
need_intermediate_feature=False):
|
||||
super().__init__()
|
||||
self.need_intermediate_feature = need_intermediate_feature
|
||||
|
||||
kernel_size = 4
|
||||
padding = math.ceil((kernel_size - 1.0) / 2)
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
use_bias = norm_type == "IN"
|
||||
padding_mode = "zeros"
|
||||
|
||||
sequence = [nn.Sequential(
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size, stride=2, padding=padding),
|
||||
nn.LeakyReLU(0.2, False)
|
||||
)]
|
||||
multiple_now = 1
|
||||
for i in range(1, num_conv):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** 3)
|
||||
stride = 1 if i == num_conv - 1 else 2
|
||||
sequence.append(nn.Sequential(
|
||||
self.build_conv2d(use_spectral, base_channels * multiple_prev, base_channels * multiple_now,
|
||||
kernel_size, stride, padding, bias=use_bias, padding_mode=padding_mode),
|
||||
norm_layer(base_channels * multiple_now),
|
||||
nn.LeakyReLU(0.2, inplace=False),
|
||||
))
|
||||
multiple_now = min(2 ** num_conv, 8)
|
||||
sequence.append(nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding,
|
||||
padding_mode=padding_mode))
|
||||
self.conv_blocks = nn.ModuleList(sequence)
|
||||
|
||||
@staticmethod
|
||||
def build_conv2d(use_spectral, in_channels: int, out_channels: int, kernel_size, stride, padding,
|
||||
bias=True, padding_mode: str = 'zeros'):
|
||||
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias, padding_mode=padding_mode)
|
||||
if not use_spectral:
|
||||
return conv
|
||||
return nn.utils.spectral_norm(conv)
|
||||
|
||||
def forward(self, x):
|
||||
if self.need_intermediate_feature:
|
||||
intermediate_feature = []
|
||||
for layer in self.conv_blocks:
|
||||
x = layer(x)
|
||||
intermediate_feature.append(x)
|
||||
return tuple(intermediate_feature)
|
||||
else:
|
||||
for layer in self.conv_blocks:
|
||||
x = layer(x)
|
||||
return x
|
||||
25
model/GAN/wrapper.py
Normal file
25
model/GAN/wrapper.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model import MODEL
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class MultiScaleDiscriminator(nn.Module):
|
||||
def __init__(self, num_scale, discriminator_cfg):
|
||||
super().__init__()
|
||||
|
||||
self.discriminator_list = nn.ModuleList([
|
||||
MODEL.build_with(discriminator_cfg) for _ in range(num_scale)
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def down_sample(x):
|
||||
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
|
||||
|
||||
def forward(self, x):
|
||||
results = []
|
||||
for discriminator in self.discriminator_list:
|
||||
results.append(discriminator(x))
|
||||
x = self.down_sample(x)
|
||||
return results
|
||||
@@ -3,3 +3,5 @@ import model.GAN.residual_generator
|
||||
import model.GAN.TAHG
|
||||
import model.GAN.UGATIT
|
||||
import model.fewshot
|
||||
import model.GAN.wrapper
|
||||
import model.GAN.base
|
||||
|
||||
Reference in New Issue
Block a user