change a lot
This commit is contained in:
@@ -1,26 +1,23 @@
|
||||
from itertools import chain
|
||||
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel
|
||||
from engine.util.build import build_model
|
||||
from loss.gan import GANLoss
|
||||
from model.GAN.base import GANImageBuffer
|
||||
from engine.util.container import GANImageBuffer, LossContainer
|
||||
from engine.util.loss import pixel_loss, gan_loss
|
||||
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
|
||||
|
||||
class TAFGEngineKernel(EngineKernel):
|
||||
class CycleGANEngineKernel(EngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
|
||||
self.gan_loss = gan_loss(config.loss.gan)
|
||||
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
|
||||
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
|
||||
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss())
|
||||
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
||||
self.discriminators.keys()}
|
||||
|
||||
@@ -56,21 +53,19 @@ class TAFGEngineKernel(EngineKernel):
|
||||
images["b2a"] = self.generators["b2a"](batch["b"])
|
||||
images["a2b2a"] = self.generators["b2a"](images["a2b"])
|
||||
images["b2a2b"] = self.generators["a2b"](images["b2a"])
|
||||
if self.config.loss.id.weight > 0:
|
||||
if self.id_loss.weight > 0:
|
||||
images["a2a"] = self.generators["b2a"](batch["a"])
|
||||
images["b2b"] = self.generators["a2b"](batch["b"])
|
||||
return images
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
for phase in ["a2b", "b2a"]:
|
||||
loss[f"cycle_{phase[0]}"] = self.config.loss.cycle.weight * self.cycle_loss(
|
||||
generated[f"{phase}2{phase[0]}"], batch[phase[0]])
|
||||
loss[f"gan_{phase}"] = self.config.loss.gan.weight * self.gan_loss(
|
||||
self.discriminators[phase[-1]](generated[phase]), True)
|
||||
if self.config.loss.id.weight > 0:
|
||||
loss[f"id_{phase[0]}"] = self.config.loss.id.weight * self.id_loss(
|
||||
generated[f"{phase[0]}2{phase[0]}"], batch[phase[0]])
|
||||
for ph in "ab":
|
||||
loss[f"cycle_{ph}"] = self.cycle_loss(generated["a2b2a" if ph == "a" else "b2a2b"], batch[ph])
|
||||
loss[f"id_{ph}"] = self.id_loss(generated[f"{ph}2{ph}"], batch[ph])
|
||||
loss[f"mgc_{ph}"] = self.mgc_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph])
|
||||
loss[f"gan_{ph}"] = self.config.loss.gan.weight * self.gan_loss(
|
||||
self.discriminators[ph](generated["a2b" if ph == "b" else "b2a"]), True)
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
@@ -97,5 +92,5 @@ class TAFGEngineKernel(EngineKernel):
|
||||
|
||||
|
||||
def run(task, config, _):
|
||||
kernel = TAFGEngineKernel(config)
|
||||
kernel = CycleGANEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
@@ -1,38 +1,31 @@
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
||||
from engine.util.build import build_model
|
||||
from engine.util.container import LossContainer
|
||||
from engine.util.loss import bce_loss, mse_loss, pixel_loss, gan_loss
|
||||
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
|
||||
from loss.gan import GANLoss
|
||||
from model.image_translation.UGATIT import RhoClipper
|
||||
from util.image import attention_colored_map
|
||||
|
||||
|
||||
def pixel_loss(level):
|
||||
return nn.L1Loss() if level == 1 else nn.MSELoss()
|
||||
class RhoClipper(object):
|
||||
def __init__(self, clip_min, clip_max):
|
||||
self.clip_min = clip_min
|
||||
self.clip_max = clip_max
|
||||
assert clip_min < clip_max
|
||||
|
||||
|
||||
def mse_loss(x, target_flag):
|
||||
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||
|
||||
|
||||
def bce_loss(x, target_flag):
|
||||
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||
def __call__(self, module):
|
||||
if hasattr(module, 'rho'):
|
||||
w = module.rho.data
|
||||
w = w.clamp(self.clip_min, self.clip_max)
|
||||
module.rho.data = w
|
||||
|
||||
|
||||
class UGATITEngineKernel(EngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
|
||||
self.gan_loss = gan_loss(config.loss.gan)
|
||||
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
|
||||
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
|
||||
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
|
||||
|
||||
25
engine/util/loss.py
Normal file
25
engine/util/loss.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from loss.gan import GANLoss
|
||||
|
||||
|
||||
def gan_loss(config):
|
||||
gan_loss_cfg = OmegaConf.to_container(config)
|
||||
gan_loss_cfg.pop("weight")
|
||||
return GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
|
||||
|
||||
def pixel_loss(level):
|
||||
return nn.L1Loss() if level == 1 else nn.MSELoss()
|
||||
|
||||
|
||||
def mse_loss(x, target_flag):
|
||||
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||
|
||||
|
||||
def bce_loss(x, target_flag):
|
||||
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||
Reference in New Issue
Block a user