TAHG 0.0.2
This commit is contained in:
@@ -146,8 +146,12 @@ class Generator(nn.Module):
|
||||
base_channels=64, padding_mode="reflect"):
|
||||
super(Generator, self).__init__()
|
||||
self.num_blocks = num_blocks
|
||||
self.style_encoder = VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
|
||||
padding_mode=padding_mode, norm_type="NONE")
|
||||
self.style_encoders = nn.ModuleDict({
|
||||
"a": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
|
||||
padding_mode=padding_mode, norm_type="NONE"),
|
||||
"b": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
|
||||
padding_mode=padding_mode, norm_type="NONE")
|
||||
})
|
||||
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks,
|
||||
padding_mode=padding_mode, norm_type="IN")
|
||||
res_block_channels = 2 ** 2 * base_channels
|
||||
@@ -168,7 +172,7 @@ class Generator(nn.Module):
|
||||
|
||||
def forward(self, content_img, style_img, which_decoder: str = "a"):
|
||||
x = self.content_encoder(content_img)
|
||||
styles = self.fusion(self.fc(self.style_encoder(style_img)))
|
||||
styles = self.fusion(self.fc(self.style_encoders[which_decoder](style_img)))
|
||||
styles = torch.chunk(styles, self.num_blocks * 2, dim=1)
|
||||
for i, ar in enumerate(self.adain_res):
|
||||
ar.norm1.set_style(styles[2 * i])
|
||||
|
||||
Reference in New Issue
Block a user