update
This commit is contained in:
@@ -90,10 +90,10 @@ class TAFGEngineKernel(EngineKernel):
|
||||
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
||||
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
|
||||
loss["recon"] = self.recon_loss(generated["a"], batch["a"]) * self.config.loss.recon.weight
|
||||
loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
||||
self.generators["main"].module.style_encoders["b"](batch["b"]),
|
||||
self.generators["main"].module.style_encoders["b"](generated["b"])
|
||||
)
|
||||
# loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
||||
# self.generators["main"].module.style_encoders["b"](batch["b"]),
|
||||
# self.generators["main"].module.style_encoders["b"](generated["b"])
|
||||
# )
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
|
||||
Reference in New Issue
Block a user