23333
This commit is contained in:
@@ -3,3 +3,4 @@ import model.base.normalization
|
||||
import model.image_translation.UGATIT
|
||||
import model.image_translation.CycleGAN
|
||||
import model.image_translation.pix2pixHD
|
||||
import model.image_translation.GauGAN
|
||||
@@ -7,7 +7,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model.base.module import ResidualBlock, Conv2dBlock, LinearBlock
|
||||
|
||||
from model import MODEL
|
||||
|
||||
class StyleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, style_dim, num_conv, end_size=(4, 4), base_channels=64,
|
||||
@@ -122,7 +122,7 @@ class ImprovedSPADEGenerator(nn.Module):
|
||||
def forward(self, seg, style=None):
|
||||
pass
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class SPADEGenerator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
|
||||
padding_mode='reflect', activation_type="LeakyReLU"):
|
||||
@@ -156,11 +156,8 @@ class SPADEGenerator(nn.Module):
|
||||
)
|
||||
))
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
self.output_converter = nn.Sequential(
|
||||
ReverseConv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
||||
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"),
|
||||
nn.Tanh()
|
||||
)
|
||||
self.output_converter = Conv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
||||
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE")
|
||||
|
||||
def forward(self, seg, z=None):
|
||||
if self.use_vae:
|
||||
|
||||
@@ -65,7 +65,8 @@ def generation_init_weights(module, init_type='normal', init_gain=0.02):
|
||||
elif classname.find('BatchNorm2d') != -1:
|
||||
# BatchNorm Layer's weight is not a matrix;
|
||||
# only normal distribution applies.
|
||||
normal_init(m, 1.0, init_gain)
|
||||
if m.weight is not None:
|
||||
normal_init(m, 1.0, init_gain)
|
||||
|
||||
assert isinstance(module, nn.Module)
|
||||
module.apply(init_func)
|
||||
|
||||
Reference in New Issue
Block a user