TAHG 0.0.2
This commit is contained in:
@@ -85,9 +85,7 @@ def get_trainer(config, logger):
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
real = dict(a=batch["a"], b=batch["b"])
|
||||
edge = batch["edge"]
|
||||
additional_info = batch["additional_info"]
|
||||
content_img = torch.cat([edge, additional_info], dim=1)
|
||||
content_img = batch["edge"]
|
||||
fake = dict(
|
||||
a=generator(content_img=content_img, style_img=real["a"], which_decoder="a"),
|
||||
b=generator(content_img=content_img, style_img=real["b"], which_decoder="b"),
|
||||
@@ -101,7 +99,7 @@ def get_trainer(config, logger):
|
||||
loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True)
|
||||
_, t = perceptual_loss(fake[d], real[d])
|
||||
loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t
|
||||
loss_g["edge"] = config.loss.edge.weight * edge_loss(fake["b"], real["a"], gt_is_edge=False)
|
||||
loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], content_img)
|
||||
loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"])
|
||||
sum(loss_g.values()).backward()
|
||||
optimizers["g"].step()
|
||||
|
||||
Reference in New Issue
Block a user