TAFG 0.01
This commit is contained in:
@@ -65,12 +65,14 @@ class TAFGEngineKernel(EngineKernel):
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
loss["perceptual"], _, = self.perceptual_loss(generated["b"], batch["b"]) * self.config.loss.perceptual.weight
|
||||
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
|
||||
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
|
||||
for phase in "ab":
|
||||
pred_fake = self.discriminators[phase](generated[phase])
|
||||
for i, sub_pred_fake in enumerate(pred_fake):
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for sub_pred_fake in pred_fake:
|
||||
# last output is actual prediction
|
||||
loss[f"gan_{phase}_sub_{i}"] = self.gan_loss(sub_pred_fake[-1], True)
|
||||
loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True)
|
||||
|
||||
if self.config.loss.fm.weight > 0 and phase == "b":
|
||||
pred_real = self.discriminators[phase](batch[phase])
|
||||
|
||||
Reference in New Issue
Block a user