working
This commit is contained in:
@@ -3,7 +3,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model import MODEL
|
||||
from model.normalization import AdaptiveInstanceNorm2d
|
||||
from model.normalization import select_norm_layer
|
||||
|
||||
|
||||
@@ -62,7 +61,9 @@ class Interpolation(nn.Module):
|
||||
class FADE(nn.Module):
|
||||
def __init__(self, use_spectral, features_channels, in_channels, affine=False, track_running_stats=True):
|
||||
super().__init__()
|
||||
self.bn = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
|
||||
# self.norm = nn.BatchNorm2d(num_features=in_channels, affine=affine, track_running_stats=track_running_stats)
|
||||
self.norm = nn.InstanceNorm2d(num_features=in_channels)
|
||||
|
||||
self.alpha_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
|
||||
padding_mode="zeros")
|
||||
self.beta_conv = conv_block(use_spectral, features_channels, in_channels, kernel_size=3, padding=1,
|
||||
@@ -71,7 +72,7 @@ class FADE(nn.Module):
|
||||
def forward(self, x, feature):
|
||||
alpha = self.alpha_conv(feature)
|
||||
beta = self.beta_conv(feature)
|
||||
x = self.bn(x)
|
||||
x = self.norm(x)
|
||||
return alpha * x + beta
|
||||
|
||||
|
||||
@@ -122,9 +123,7 @@ class TSITGenerator(nn.Module):
|
||||
self.use_spectral = use_spectral
|
||||
|
||||
self.content_input_layer = self.build_input_layer(content_in_channels, base_channels, input_layer_type)
|
||||
self.style_input_layer = self.build_input_layer(style_in_channels, base_channels, input_layer_type)
|
||||
self.content_stream = self.build_stream()
|
||||
self.style_stream = self.build_stream()
|
||||
self.generator = self.build_generator()
|
||||
self.end_conv = nn.Sequential(
|
||||
conv_block(use_spectral, base_channels, out_channels, kernel_size=7, padding=3, padding_mode="zeros"),
|
||||
@@ -138,11 +137,9 @@ class TSITGenerator(nn.Module):
|
||||
m = self.num_blocks - i
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** m, 2 ** 4)
|
||||
stream_sequence.append(nn.Sequential(
|
||||
AdaptiveInstanceNorm2d(multiple_prev * self.base_channels),
|
||||
stream_sequence.append(
|
||||
FADEResBlock(self.use_spectral, multiple_prev * self.base_channels, multiple_prev * self.base_channels,
|
||||
multiple_now * self.base_channels)
|
||||
))
|
||||
multiple_now * self.base_channels))
|
||||
return nn.ModuleList(stream_sequence)
|
||||
|
||||
def build_input_layer(self, in_channels, out_channels, input_layer_type="conv7x7"):
|
||||
@@ -171,22 +168,16 @@ class TSITGenerator(nn.Module):
|
||||
))
|
||||
return nn.ModuleList(stream_sequence)
|
||||
|
||||
def forward(self, content_img, style_img):
|
||||
def forward(self, content_img):
|
||||
c = self.content_input_layer(content_img)
|
||||
s = self.style_input_layer(style_img)
|
||||
content_features = []
|
||||
style_features = []
|
||||
for i in range(self.num_blocks):
|
||||
s = self.style_stream[i](s)
|
||||
c = self.content_stream[i](c)
|
||||
content_features.append(c)
|
||||
style_features.append(s)
|
||||
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
|
||||
|
||||
for i in range(self.num_blocks):
|
||||
m = - i - 1
|
||||
layer = self.generator[i]
|
||||
layer[0].set_style(torch.cat(torch.std_mean(style_features[m], dim=[2, 3]), dim=1))
|
||||
z = layer[0](z)
|
||||
z = layer[1](z, content_features[m])
|
||||
z = layer(z, content_features[m])
|
||||
return self.end_conv(z)
|
||||
|
||||
Reference in New Issue
Block a user