This commit is contained in:
2020-10-11 10:02:33 +08:00
parent 6ea13df465
commit 04c6366c07
24 changed files with 483 additions and 968 deletions

View File

@@ -1,16 +1,15 @@
from omegaconf import OmegaConf
import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.nn.functional as F
import ignite.distributed as idist
from omegaconf import OmegaConf
from loss.gan import GANLoss
from model.GAN.UGATIT import RhoClipper
from model.GAN.base import GANImageBuffer
from util.image import attention_colored_map
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
from loss.gan import GANLoss
from model.image_translation.UGATIT import RhoClipper
from util.image import attention_colored_map
def mse_loss(x, target_flag):
@@ -30,9 +29,8 @@ class UGATITEngineKernel(EngineKernel):
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
self.mgc_loss = MyLoss()
self.rho_clipper = RhoClipper(0, 1)
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
self.train_generator_first = False
def build_models(self) -> (dict, dict):
@@ -82,6 +80,9 @@ class UGATITEngineKernel(EngineKernel):
loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase])
loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase],
generated["images"][f"{phase}2{phase}"])
if self.config.loss.mgc.weight > 0:
loss[f"mgc_{phase}"] = self.config.loss.mgc.weight * self.mgc_loss(
batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
for dk in "lg":
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)