This commit is contained in:
2020-10-22 22:42:01 +08:00
parent 0019d4034c
commit 376f5caeb7
11 changed files with 140 additions and 29 deletions

View File

@@ -2,3 +2,4 @@ from model.registry import MODEL, NORMALIZATION
import model.base.normalization
import model.image_translation.UGATIT
import model.image_translation.CycleGAN
import model.image_translation.pix2pixHD

View File

@@ -0,0 +1,29 @@
import torch.nn as nn
import torch.nn.functional as F
from model import MODEL
@MODEL.register_module()
class MultiScaleDiscriminator(nn.Module):
def __init__(self, num_scale, discriminator_cfg, down_sample_method="avg"):
super().__init__()
assert down_sample_method in ["avg", "bilinear"]
self.down_sample_method = down_sample_method
self.discriminator_list = nn.ModuleList([
MODEL.build_with(discriminator_cfg) for _ in range(num_scale)
])
def down_sample(self, x):
if self.down_sample_method == "avg":
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
if self.down_sample_method == "bilinear":
return F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True)
def forward(self, x):
results = []
for discriminator in self.discriminator_list:
results.append(discriminator(x))
x = self.down_sample(x)
return results