TAFG good result

This commit is contained in:
2020-09-09 14:46:07 +08:00
parent 87cbcf34d3
commit 7ea9c6d0d5
4 changed files with 76 additions and 55 deletions

View File

@@ -1,20 +1,17 @@
from itertools import chain
from omegaconf import OmegaConf
import ignite.distributed as idist
import torch
import torch.nn as nn
import ignite.distributed as idist
from ignite.engine import Events
from omegaconf import read_write, OmegaConf
from model.weight_init import generation_init_weights
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from loss.I2I.edge_loss import EdgeLoss
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
class TAFGEngineKernel(EngineKernel):
@@ -24,6 +21,10 @@ class TAFGEngineKernel(EngineKernel):
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
style_loss_cfg = OmegaConf.to_container(config.loss.style)
style_loss_cfg.pop("weight")
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
@@ -32,6 +33,9 @@ class TAFGEngineKernel(EngineKernel):
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
idist.device())
def _process_batch(self, batch, inference=False):
# batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size())
return batch
@@ -74,7 +78,9 @@ class TAFGEngineKernel(EngineKernel):
batch = self._process_batch(batch)
loss = dict()
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
_, loss_style = self.style_loss(generated["a"], batch["a"])
loss["style"] = self.config.loss.style.weight * loss_style
loss["perceptual"] = self.config.loss.perceptual.weight * loss_perceptual
for phase in "ab":
pred_fake = self.discriminators[phase](generated[phase])
loss[f"gan_{phase}"] = 0
@@ -93,10 +99,7 @@ class TAFGEngineKernel(EngineKernel):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"])
# loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
# self.generators["main"].module.style_encoders["b"](batch["b"]),
# self.generators["main"].module.style_encoders["b"](generated["b"])
# )
loss["edge"] = self.config.loss.edge.weight * self.edge_loss(generated["b"], batch["edge_a"][:, 0:1, :, :])
return loss
def criterion_discriminators(self, batch, generated) -> dict: