This commit is contained in:
2020-09-05 22:00:17 +08:00
parent 39c754374c
commit e3c760d0c5
12 changed files with 122 additions and 43 deletions

View File

@@ -90,10 +90,10 @@ class TAFGEngineKernel(EngineKernel):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
loss["recon"] = self.recon_loss(generated["a"], batch["a"]) * self.config.loss.recon.weight
loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
self.generators["main"].module.style_encoders["b"](batch["b"]),
self.generators["main"].module.style_encoders["b"](generated["b"])
)
# loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
# self.generators["main"].module.style_encoders["b"](batch["b"]),
# self.generators["main"].module.style_encoders["b"](generated["b"])
# )
return loss
def criterion_discriminators(self, batch, generated) -> dict: