TAHG 0.0.2

This commit is contained in:
2020-08-30 14:44:40 +08:00
parent 715a2e64a1
commit 89b54105c7
8 changed files with 172 additions and 17 deletions

View File

@@ -85,9 +85,7 @@ def get_trainer(config, logger):
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real = dict(a=batch["a"], b=batch["b"])
edge = batch["edge"]
additional_info = batch["additional_info"]
content_img = torch.cat([edge, additional_info], dim=1)
content_img = batch["edge"]
fake = dict(
a=generator(content_img=content_img, style_img=real["a"], which_decoder="a"),
b=generator(content_img=content_img, style_img=real["b"], which_decoder="b"),
@@ -101,7 +99,7 @@ def get_trainer(config, logger):
loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True)
_, t = perceptual_loss(fake[d], real[d])
loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t
loss_g["edge"] = config.loss.edge.weight * edge_loss(fake["b"], real["a"], gt_is_edge=False)
loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], content_img)
loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"])
sum(loss_g.values()).backward()
optimizers["g"].step()