rewrite
This commit is contained in:
@@ -1,16 +1,15 @@
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import ignite.distributed as idist
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from loss.gan import GANLoss
|
||||
from model.GAN.UGATIT import RhoClipper
|
||||
from model.GAN.base import GANImageBuffer
|
||||
from util.image import attention_colored_map
|
||||
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
||||
from engine.util.build import build_model
|
||||
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
|
||||
from loss.gan import GANLoss
|
||||
from model.image_translation.UGATIT import RhoClipper
|
||||
from util.image import attention_colored_map
|
||||
|
||||
|
||||
def mse_loss(x, target_flag):
|
||||
@@ -30,9 +29,8 @@ class UGATITEngineKernel(EngineKernel):
|
||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
|
||||
self.mgc_loss = MyLoss()
|
||||
self.rho_clipper = RhoClipper(0, 1)
|
||||
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
||||
self.discriminators.keys()}
|
||||
self.train_generator_first = False
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
@@ -82,6 +80,9 @@ class UGATITEngineKernel(EngineKernel):
|
||||
loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase])
|
||||
loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase],
|
||||
generated["images"][f"{phase}2{phase}"])
|
||||
if self.config.loss.mgc.weight > 0:
|
||||
loss[f"mgc_{phase}"] = self.config.loss.mgc.weight * self.mgc_loss(
|
||||
batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
|
||||
for dk in "lg":
|
||||
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
|
||||
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)
|
||||
|
||||
Reference in New Issue
Block a user