TAHG 0.0.3

This commit is contained in:
2020-09-01 09:02:04 +08:00
parent 89b54105c7
commit e71e8d95d0
8 changed files with 97 additions and 36 deletions

View File

@@ -87,10 +87,17 @@ class ContentEncoder(nn.Module):
class Decoder(nn.Module):
def __init__(self, out_channels, base_channels=64, num_down_sampling=2, padding_mode='reflect', norm_type="LN"):
def __init__(self, out_channels, base_channels=64, num_blocks=4, num_down_sampling=2, padding_mode='reflect',
norm_type="LN"):
super(Decoder, self).__init__()
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
res_block_channels = (2 ** 2) * base_channels
self.resnet = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
# up sampling
submodules = []
for i in range(num_down_sampling):
@@ -109,6 +116,7 @@ class Decoder(nn.Module):
)
def forward(self, x):
x = self.resnet(x)
x = self.decoder(x)
x = self.end_conv(x)
return x
@@ -159,8 +167,8 @@ class Generator(nn.Module):
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
])
self.decoders = nn.ModuleDict({
"a": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode),
"b": Decoder(out_channels, base_channels, norm_type="LN", padding_mode=padding_mode)
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode),
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode)
})
self.fc = nn.Sequential(