imporved gan loss

This commit is contained in:
2020-10-22 23:19:03 +08:00
parent 376f5caeb7
commit f7b7b78669
2 changed files with 18 additions and 14 deletions

View File

@@ -10,7 +10,23 @@ from loss.gan import GANLoss
def gan_loss(config):
gan_loss_cfg = OmegaConf.to_container(config)
gan_loss_cfg.pop("weight")
return GANLoss(**gan_loss_cfg).to(idist.device())
gl = GANLoss(**gan_loss_cfg).to(idist.device())
def gan_loss_fn(prediction, target_is_real: bool, is_discriminator=False):
if isinstance(prediction, torch.Tensor):
# origin
return gl(prediction, target_is_real, is_discriminator)
elif isinstance(prediction, list) and isinstance(prediction[0], list):
# for multi scale discriminator, e.g. MultiScaleDiscriminator
loss = 0
for p in prediction:
loss += gl(p[-1], target_is_real, is_discriminator)
return loss
elif isinstance(prediction, list) and isinstance(prediction[0], torch.Tensor):
# for discriminator set `need_intermediate_feature` true
return gl(prediction[-1], target_is_real, is_discriminator)
else:
raise NotImplementedError("not support discriminator output")
return gan_loss_fn
def pixel_loss(level):