This commit is contained in:
2020-10-23 16:14:37 +08:00
parent f7b7b78669
commit 0bec02bf6d
7 changed files with 287 additions and 26 deletions

View File

@@ -4,29 +4,20 @@ import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
def gan_loss(config):
gan_loss_cfg = OmegaConf.to_container(config)
gan_loss_cfg.pop("weight")
gl = GANLoss(**gan_loss_cfg).to(idist.device())
def gan_loss_fn(prediction, target_is_real: bool, is_discriminator=False):
if isinstance(prediction, torch.Tensor):
# origin
return gl(prediction, target_is_real, is_discriminator)
elif isinstance(prediction, list) and isinstance(prediction[0], list):
# for multi scale discriminator, e.g. MultiScaleDiscriminator
loss = 0
for p in prediction:
loss += gl(p[-1], target_is_real, is_discriminator)
return loss
elif isinstance(prediction, list) and isinstance(prediction[0], torch.Tensor):
# for discriminator set `need_intermediate_feature` true
return gl(prediction[-1], target_is_real, is_discriminator)
else:
raise NotImplementedError("not support discriminator output")
return gan_loss_fn
return GANLoss(**gan_loss_cfg).to(idist.device())
def perceptual_loss(config):
perceptual_loss_cfg = OmegaConf.to_container(config)
perceptual_loss_cfg.pop("weight")
return PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
def pixel_loss(level):