rewrite
This commit is contained in:
@@ -1,16 +1,15 @@
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import ignite.distributed as idist
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from loss.gan import GANLoss
|
||||
from model.GAN.UGATIT import RhoClipper
|
||||
from model.GAN.base import GANImageBuffer
|
||||
from util.image import attention_colored_map
|
||||
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
||||
from engine.util.build import build_model
|
||||
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
|
||||
from loss.gan import GANLoss
|
||||
from model.image_translation.UGATIT import RhoClipper
|
||||
from util.image import attention_colored_map
|
||||
|
||||
|
||||
def mse_loss(x, target_flag):
|
||||
@@ -30,9 +29,8 @@ class UGATITEngineKernel(EngineKernel):
|
||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
|
||||
self.mgc_loss = MyLoss()
|
||||
self.rho_clipper = RhoClipper(0, 1)
|
||||
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
||||
self.discriminators.keys()}
|
||||
self.train_generator_first = False
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
@@ -82,6 +80,9 @@ class UGATITEngineKernel(EngineKernel):
|
||||
loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase])
|
||||
loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase],
|
||||
generated["images"][f"{phase}2{phase}"])
|
||||
if self.config.loss.mgc.weight > 0:
|
||||
loss[f"mgc_{phase}"] = self.config.loss.mgc.weight * self.mgc_loss(
|
||||
batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
|
||||
for dk in "lg":
|
||||
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
|
||||
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)
|
||||
|
||||
@@ -64,7 +64,7 @@ class EngineKernel(object):
|
||||
self.engine = engine
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
raise NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def to_save(self):
|
||||
to_save = {}
|
||||
@@ -73,19 +73,19 @@ class EngineKernel(object):
|
||||
return to_save
|
||||
|
||||
def setup_after_g(self):
|
||||
raise NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def setup_before_g(self):
|
||||
raise NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
raise NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
raise NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
raise NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def intermediate_images(self, batch, generated) -> dict:
|
||||
"""
|
||||
@@ -94,7 +94,7 @@ class EngineKernel(object):
|
||||
:param generated: dict of images
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
raise NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def change_engine(self, config, engine: Engine):
|
||||
pass
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
import torch
|
||||
import ignite.distributed as idist
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model import MODEL
|
||||
import torch.optim as optim
|
||||
from util.misc import add_spectral_norm
|
||||
|
||||
|
||||
def build_model(cfg):
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
|
||||
add_spectral_norm_flag = cfg.pop("_add_spectral_norm", False)
|
||||
model = MODEL.build_with(cfg)
|
||||
if bn_to_sync_bn:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
if add_spectral_norm_flag:
|
||||
model.apply(add_spectral_norm)
|
||||
return idist.auto_model(model)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user