TAFG update
This commit is contained in:
@@ -20,6 +20,10 @@ class TAFGEngineKernel(EngineKernel):
|
||||
perceptual_loss_cfg.pop("weight")
|
||||
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||
|
||||
style_loss_cfg = OmegaConf.to_container(config.loss.style)
|
||||
style_loss_cfg.pop("weight")
|
||||
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
@@ -68,14 +72,14 @@ class TAFGEngineKernel(EngineKernel):
|
||||
contents = dict()
|
||||
images = dict()
|
||||
with torch.set_grad_enabled(not inference):
|
||||
contents["a"], styles["a"] = generator.encode(batch["a"]["edge"], batch["a"]["img"], "a", "a")
|
||||
contents["b"], styles["b"] = generator.encode(batch["b"]["edge"], batch["b"]["img"], "b", "b")
|
||||
for ph in "ab":
|
||||
contents[ph], styles[ph] = generator.encode(batch[ph]["edge"], batch[ph]["img"], ph, ph)
|
||||
for ph in ("a2b", "b2a"):
|
||||
images[f"fake_{ph[-1]}"] = generator.decode(contents[ph[0]], styles[ph[-1]], ph[-1])
|
||||
contents["recon_a"], styles["recon_b"] = generator.encode(
|
||||
self.edge_loss.edge_extractor(images["fake_b"]), images["fake_b"], "b", "b")
|
||||
images["a2a"] = generator.decode(contents["a"], styles["a"], "a")
|
||||
images["b2b"] = generator.decode(contents["b"], styles["recon_b"], "b")
|
||||
images[f"{ph}2{ph}"] = generator.decode(contents[ph], styles[ph], ph)
|
||||
images["a2b"] = generator.decode(contents["a"], styles["b"], "b")
|
||||
contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]),
|
||||
images["a2b"], "b", "b")
|
||||
images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b")
|
||||
images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a")
|
||||
return dict(styles=styles, contents=contents, images=images)
|
||||
|
||||
@@ -87,35 +91,38 @@ class TAFGEngineKernel(EngineKernel):
|
||||
loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss(
|
||||
generated["images"][f"{ph}2{ph}"], batch[ph]["img"])
|
||||
|
||||
pred_fake = self.discriminators[ph](generated["images"][f"fake_{ph}"])
|
||||
pred_fake = self.discriminators[ph](generated["images"][f"a2{ph}"])
|
||||
loss[f"gan_{ph}"] = 0
|
||||
for sub_pred_fake in pred_fake:
|
||||
# last output is actual prediction
|
||||
loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
|
||||
loss[f"recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
|
||||
loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
|
||||
generated["contents"]["a"], generated["contents"]["recon_a"]
|
||||
)
|
||||
loss[f"recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
||||
loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
||||
generated["styles"]["b"], generated["styles"]["recon_b"]
|
||||
)
|
||||
|
||||
for ph in ("a2b", "b2a"):
|
||||
if self.config.loss.perceptual.weight > 0:
|
||||
loss[f"perceptual_{ph}"] = self.config.loss.perceptual.weight * self.perceptual_loss(
|
||||
batch[ph[0]]["img"], generated["images"][f"fake_{ph[-1]}"]
|
||||
)
|
||||
if self.config.loss.edge.weight > 0:
|
||||
loss[f"edge_a"] = self.config.loss.edge.weight * self.edge_loss(
|
||||
generated["images"]["fake_b"], batch["a"]["edge"][:, 0:1, :, :]
|
||||
)
|
||||
loss[f"edge_b"] = self.config.loss.edge.weight * self.edge_loss(
|
||||
generated["images"]["fake_a"], batch["b"]["edge"]
|
||||
if self.config.loss.perceptual.weight > 0:
|
||||
loss["perceptual_a"] = self.config.loss.perceptual.weight * self.perceptual_loss(
|
||||
batch["a"]["img"], generated["images"]["a2b"]
|
||||
)
|
||||
|
||||
if self.config.loss.cycle.weight > 0:
|
||||
loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss(
|
||||
batch["a"]["img"], generated["images"]["cycle_a"]
|
||||
for ph in "ab":
|
||||
if self.config.loss.cycle.weight > 0:
|
||||
loss[f"cycle_{ph}"] = self.config.loss.cycle.weight * self.cycle_loss(
|
||||
batch[ph]["img"], generated["images"][f"cycle_{ph}"]
|
||||
)
|
||||
if self.config.loss.style.weight > 0:
|
||||
loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss(
|
||||
batch[ph]["img"], generated["images"][f"a2{ph}"]
|
||||
)
|
||||
|
||||
if self.config.loss.edge.weight > 0:
|
||||
loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss(
|
||||
generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :]
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
@@ -123,7 +130,7 @@ class TAFGEngineKernel(EngineKernel):
|
||||
# batch = self._process_batch(batch)
|
||||
for phase in self.discriminators.keys():
|
||||
pred_real = self.discriminators[phase](batch[phase]["img"])
|
||||
pred_fake = self.discriminators[phase](generated["images"][f"fake_{phase}"].detach())
|
||||
pred_fake = self.discriminators[phase](generated["images"][f"a2{phase}"].detach())
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for i in range(len(pred_fake)):
|
||||
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
||||
@@ -142,13 +149,13 @@ class TAFGEngineKernel(EngineKernel):
|
||||
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
|
||||
batch["a"]["img"].detach(),
|
||||
generated["images"]["a2a"].detach(),
|
||||
generated["images"]["fake_b"].detach(),
|
||||
generated["images"]["a2b"].detach(),
|
||||
generated["images"]["cycle_a"].detach(),
|
||||
],
|
||||
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
|
||||
batch["b"]["img"].detach(),
|
||||
generated["images"]["b2b"].detach(),
|
||||
generated["images"]["fake_a"].detach()]
|
||||
generated["images"]["cycle_b"].detach()]
|
||||
)
|
||||
|
||||
def change_engine(self, config, trainer):
|
||||
|
||||
Reference in New Issue
Block a user