imporved gan loss
This commit is contained in:
@@ -10,7 +10,23 @@ from loss.gan import GANLoss
|
||||
def gan_loss(config):
|
||||
gan_loss_cfg = OmegaConf.to_container(config)
|
||||
gan_loss_cfg.pop("weight")
|
||||
return GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
gl = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
def gan_loss_fn(prediction, target_is_real: bool, is_discriminator=False):
|
||||
if isinstance(prediction, torch.Tensor):
|
||||
# origin
|
||||
return gl(prediction, target_is_real, is_discriminator)
|
||||
elif isinstance(prediction, list) and isinstance(prediction[0], list):
|
||||
# for multi scale discriminator, e.g. MultiScaleDiscriminator
|
||||
loss = 0
|
||||
for p in prediction:
|
||||
loss += gl(p[-1], target_is_real, is_discriminator)
|
||||
return loss
|
||||
elif isinstance(prediction, list) and isinstance(prediction[0], torch.Tensor):
|
||||
# for discriminator set `need_intermediate_feature` true
|
||||
return gl(prediction[-1], target_is_real, is_discriminator)
|
||||
else:
|
||||
raise NotImplementedError("not support discriminator output")
|
||||
return gan_loss_fn
|
||||
|
||||
|
||||
def pixel_loss(level):
|
||||
|
||||
Reference in New Issue
Block a user