working
This commit is contained in:
@@ -51,31 +51,19 @@ class TSITEngineKernel(EngineKernel):
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
with torch.set_grad_enabled(not inference):
|
||||
fake = dict(
|
||||
b=self.generators["main"](content_img=batch["a"], style_img=batch["b"])
|
||||
b=self.generators["main"](content_img=batch["a"])
|
||||
)
|
||||
return fake
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
|
||||
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
|
||||
loss["perceptual"] = self.perceptual_loss(generated["b"], batch["a"]) * self.config.loss.perceptual.weight
|
||||
for phase in "b":
|
||||
pred_fake = self.discriminators[phase](generated[phase])
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for sub_pred_fake in pred_fake:
|
||||
# last output is actual prediction
|
||||
loss[f"gan_{phase}"] += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
|
||||
|
||||
if self.config.loss.fm.weight > 0 and phase == "b":
|
||||
pred_real = self.discriminators[phase](batch[phase])
|
||||
loss_fm = 0
|
||||
num_scale_discriminator = len(pred_fake)
|
||||
for i in range(num_scale_discriminator):
|
||||
# last output is the final prediction, so we exclude it
|
||||
num_intermediate_outputs = len(pred_fake[i]) - 1
|
||||
for j in range(num_intermediate_outputs):
|
||||
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
||||
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
|
||||
Reference in New Issue
Block a user