TAFG good result
This commit is contained in:
@@ -1,20 +1,17 @@
|
||||
from itertools import chain
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events
|
||||
|
||||
from omegaconf import read_write, OmegaConf
|
||||
|
||||
from model.weight_init import generation_init_weights
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
from loss.gan import GANLoss
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel
|
||||
from engine.util.build import build_model
|
||||
from loss.I2I.edge_loss import EdgeLoss
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
from loss.gan import GANLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
|
||||
|
||||
class TAFGEngineKernel(EngineKernel):
|
||||
@@ -24,6 +21,10 @@ class TAFGEngineKernel(EngineKernel):
|
||||
perceptual_loss_cfg.pop("weight")
|
||||
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||
|
||||
style_loss_cfg = OmegaConf.to_container(config.loss.style)
|
||||
style_loss_cfg.pop("weight")
|
||||
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
@@ -32,6 +33,9 @@ class TAFGEngineKernel(EngineKernel):
|
||||
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
||||
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
|
||||
|
||||
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
|
||||
idist.device())
|
||||
|
||||
def _process_batch(self, batch, inference=False):
|
||||
# batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size())
|
||||
return batch
|
||||
@@ -74,7 +78,9 @@ class TAFGEngineKernel(EngineKernel):
|
||||
batch = self._process_batch(batch)
|
||||
loss = dict()
|
||||
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
|
||||
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
|
||||
_, loss_style = self.style_loss(generated["a"], batch["a"])
|
||||
loss["style"] = self.config.loss.style.weight * loss_style
|
||||
loss["perceptual"] = self.config.loss.perceptual.weight * loss_perceptual
|
||||
for phase in "ab":
|
||||
pred_fake = self.discriminators[phase](generated[phase])
|
||||
loss[f"gan_{phase}"] = 0
|
||||
@@ -93,10 +99,7 @@ class TAFGEngineKernel(EngineKernel):
|
||||
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
||||
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
|
||||
loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"])
|
||||
# loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
||||
# self.generators["main"].module.style_encoders["b"](batch["b"]),
|
||||
# self.generators["main"].module.style_encoders["b"](generated["b"])
|
||||
# )
|
||||
loss["edge"] = self.config.loss.edge.weight * self.edge_loss(generated["b"], batch["edge_a"][:, 0:1, :, :])
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
|
||||
Reference in New Issue
Block a user