23333
This commit is contained in:
@@ -7,7 +7,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model.base.module import ResidualBlock, Conv2dBlock, LinearBlock
|
||||
|
||||
from model import MODEL
|
||||
|
||||
class StyleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, style_dim, num_conv, end_size=(4, 4), base_channels=64,
|
||||
@@ -122,7 +122,7 @@ class ImprovedSPADEGenerator(nn.Module):
|
||||
def forward(self, seg, style=None):
|
||||
pass
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class SPADEGenerator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
|
||||
padding_mode='reflect', activation_type="LeakyReLU"):
|
||||
@@ -156,11 +156,8 @@ class SPADEGenerator(nn.Module):
|
||||
)
|
||||
))
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
self.output_converter = nn.Sequential(
|
||||
ReverseConv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
||||
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"),
|
||||
nn.Tanh()
|
||||
)
|
||||
self.output_converter = Conv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
||||
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE")
|
||||
|
||||
def forward(self, seg, z=None):
|
||||
if self.use_vae:
|
||||
|
||||
Reference in New Issue
Block a user