This commit is contained in:
2020-10-23 16:14:37 +08:00
parent f7b7b78669
commit 0bec02bf6d
7 changed files with 287 additions and 26 deletions

View File

@@ -1,4 +1,5 @@
import torch.nn as nn
import torch
import torch.nn.functional as F
@@ -10,7 +11,7 @@ class GANLoss(nn.Module):
self.fake_label_val = fake_label_val
self.loss_type = loss_type
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
def single_forward(self, prediction, target_is_real: bool, is_discriminator=False):
"""
gan loss forward
:param prediction: network prediction
@@ -37,3 +38,20 @@ class GANLoss(nn.Module):
return loss
else:
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
if isinstance(prediction, torch.Tensor):
# origin
return self.single_forward(prediction, target_is_real, is_discriminator)
elif isinstance(prediction, list):
# for multi scale discriminator, e.g. MultiScaleDiscriminator
loss = 0
for p in prediction:
loss += self.single_forward(p[-1], target_is_real, is_discriminator)
return loss
elif isinstance(prediction, tuple):
# for single discriminator set `need_intermediate_feature` true
return self.single_forward(prediction[-1], target_is_real, is_discriminator)
else:
raise NotImplementedError(f"not support discriminator output: {prediction}")