This commit is contained in:
2020-10-11 10:02:33 +08:00
parent 6ea13df465
commit 04c6366c07
24 changed files with 483 additions and 968 deletions

View File

@@ -1,16 +1,15 @@
from omegaconf import OmegaConf
import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.nn.functional as F
import ignite.distributed as idist
from omegaconf import OmegaConf
from loss.gan import GANLoss
from model.GAN.UGATIT import RhoClipper
from model.GAN.base import GANImageBuffer
from util.image import attention_colored_map
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
from loss.gan import GANLoss
from model.image_translation.UGATIT import RhoClipper
from util.image import attention_colored_map
def mse_loss(x, target_flag):
@@ -30,9 +29,8 @@ class UGATITEngineKernel(EngineKernel):
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
self.mgc_loss = MyLoss()
self.rho_clipper = RhoClipper(0, 1)
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
self.train_generator_first = False
def build_models(self) -> (dict, dict):
@@ -82,6 +80,9 @@ class UGATITEngineKernel(EngineKernel):
loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase])
loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase],
generated["images"][f"{phase}2{phase}"])
if self.config.loss.mgc.weight > 0:
loss[f"mgc_{phase}"] = self.config.loss.mgc.weight * self.mgc_loss(
batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
for dk in "lg":
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)

View File

@@ -64,7 +64,7 @@ class EngineKernel(object):
self.engine = engine
def build_models(self) -> (dict, dict):
raise NotImplemented
raise NotImplementedError
def to_save(self):
to_save = {}
@@ -73,19 +73,19 @@ class EngineKernel(object):
return to_save
def setup_after_g(self):
raise NotImplemented
raise NotImplementedError
def setup_before_g(self):
raise NotImplemented
raise NotImplementedError
def forward(self, batch, inference=False) -> dict:
raise NotImplemented
raise NotImplementedError
def criterion_generators(self, batch, generated) -> dict:
raise NotImplemented
raise NotImplementedError
def criterion_discriminators(self, batch, generated) -> dict:
raise NotImplemented
raise NotImplementedError
def intermediate_images(self, batch, generated) -> dict:
"""
@@ -94,7 +94,7 @@ class EngineKernel(object):
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
raise NotImplemented
raise NotImplementedError
def change_engine(self, config, engine: Engine):
pass

View File

@@ -1,18 +1,21 @@
import torch
import ignite.distributed as idist
import torch
import torch.optim as optim
from omegaconf import OmegaConf
from model import MODEL
import torch.optim as optim
from util.misc import add_spectral_norm
def build_model(cfg):
cfg = OmegaConf.to_container(cfg)
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
add_spectral_norm_flag = cfg.pop("_add_spectral_norm", False)
model = MODEL.build_with(cfg)
if bn_to_sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if add_spectral_norm_flag:
model.apply(add_spectral_norm)
return idist.auto_model(model)