TANG 0.0.1
This commit is contained in:
@@ -142,7 +142,7 @@ class Fusion(nn.Module):
|
||||
|
||||
@MODEL.register_module("TAHG-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, style_in_channels, content_in_channels, out_channels, style_dim=512, num_blocks=8,
|
||||
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, num_blocks=8,
|
||||
base_channels=64, padding_mode="reflect"):
|
||||
super(Generator, self).__init__()
|
||||
self.num_blocks = num_blocks
|
||||
@@ -175,3 +175,38 @@ class Generator(nn.Module):
|
||||
ar.norm2.set_style(styles[2 * i + 1])
|
||||
x = ar(x)
|
||||
return self.decoders[which_decoder](x)
|
||||
|
||||
|
||||
@MODEL.register_module("TAHG-Discriminator")
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels=3, base_channels=64, num_down_sampling=2, num_blocks=3, norm_type="IN",
|
||||
padding_mode="reflect"):
|
||||
super(Discriminator, self).__init__()
|
||||
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
use_bias = norm_type == "IN"
|
||||
|
||||
sequence = [nn.Sequential(
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
|
||||
bias=use_bias),
|
||||
norm_layer(num_features=base_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)]
|
||||
# stacked intermediate layers,
|
||||
# gradually increasing the number of filters
|
||||
multiple_now = 1
|
||||
for n in range(1, num_down_sampling + 1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** n, 4)
|
||||
sequence += [
|
||||
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=3,
|
||||
padding=1, stride=2, bias=use_bias),
|
||||
norm_layer(base_channels * multiple_now),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
]
|
||||
for _ in range(num_blocks):
|
||||
sequence.append(ResidualBlock(base_channels * multiple_now, padding_mode, norm_type))
|
||||
self.model = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from model.registry import MODEL
|
||||
import model.GAN.residual_generator
|
||||
import model.GAN.TAHG
|
||||
import model.GAN.UGATIT
|
||||
import model.fewshot
|
||||
|
||||
@@ -37,7 +37,6 @@ class LayerNorm2d(nn.Module):
|
||||
def forward(self, x):
|
||||
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
|
||||
x = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
|
||||
print(x.size())
|
||||
if self.affine:
|
||||
return self.channel_gamma * x + self.channel_beta
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user