TAHG 0.0.3
This commit is contained in:
@@ -87,10 +87,17 @@ class ContentEncoder(nn.Module):
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, out_channels, base_channels=64, num_down_sampling=2, padding_mode='reflect', norm_type="LN"):
|
||||
def __init__(self, out_channels, base_channels=64, num_blocks=4, num_down_sampling=2, padding_mode='reflect',
|
||||
norm_type="LN"):
|
||||
super(Decoder, self).__init__()
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
use_bias = norm_type == "IN"
|
||||
|
||||
res_block_channels = (2 ** 2) * base_channels
|
||||
|
||||
self.resnet = nn.Sequential(
|
||||
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
|
||||
|
||||
# up sampling
|
||||
submodules = []
|
||||
for i in range(num_down_sampling):
|
||||
@@ -109,6 +116,7 @@ class Decoder(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.resnet(x)
|
||||
x = self.decoder(x)
|
||||
x = self.end_conv(x)
|
||||
return x
|
||||
@@ -159,8 +167,8 @@ class Generator(nn.Module):
|
||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
|
||||
])
|
||||
self.decoders = nn.ModuleDict({
|
||||
"a": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode),
|
||||
"b": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode)
|
||||
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode),
|
||||
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode)
|
||||
})
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
|
||||
Reference in New Issue
Block a user