TAFG 0.01

This commit is contained in:
2020-09-03 09:34:38 +08:00
parent 14d4247112
commit 2469bf15fe
6 changed files with 37 additions and 388 deletions

View File

@@ -65,12 +65,14 @@ class TAFGEngineKernel(EngineKernel):
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
loss["perceptual"], _, = self.perceptual_loss(generated["b"], batch["b"]) * self.config.loss.perceptual.weight
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
for phase in "ab":
pred_fake = self.discriminators[phase](generated[phase])
for i, sub_pred_fake in enumerate(pred_fake):
loss[f"gan_{phase}"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss[f"gan_{phase}_sub_{i}"] = self.gan_loss(sub_pred_fake[-1], True)
loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True)
if self.config.loss.fm.weight > 0 and phase == "b":
pred_real = self.discriminators[phase](batch[phase])