This commit is contained in:
2020-10-25 20:46:34 +08:00
parent 0bec02bf6d
commit 8998c30c23
8 changed files with 174 additions and 55 deletions

View File

@@ -16,7 +16,7 @@ class GauGANEngineKernel(EngineKernel):
self.gan_loss = gan_loss(config.loss.gan)
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "exponential_decline"))
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same"))
self.perceptual_loss = LossContainer(config.loss.perceptual.weight, perceptual_loss(config.loss.perceptual))
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in

View File

@@ -38,7 +38,7 @@ def feature_match_loss(level, weight_policy):
def fm_loss(generated_features, target_features):
num_scale = len(generated_features)
loss = 0
loss = torch.zeros(1, device=idist.device())
for s_i in range(num_scale):
for i in range(len(generated_features[s_i]) - 1):
weight = 1 if weight_policy == "same" else 2 ** i