base code for pytorch distributed, add cyclegan
This commit is contained in:
0
loss/__init__.py
Normal file
0
loss/__init__.py
Normal file
39
loss/gan.py
Normal file
39
loss/gan.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class GANLoss(nn.Module):
|
||||
def __init__(self, loss_type, real_label_val=1.0, fake_label_val=0.0):
|
||||
super().__init__()
|
||||
assert loss_type in ["vanilla", "lsgan", "hinge", "wgan"]
|
||||
self.real_label_val = real_label_val
|
||||
self.fake_label_val = fake_label_val
|
||||
self.loss_type = loss_type
|
||||
|
||||
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||
"""
|
||||
gan loss forward
|
||||
:param prediction: network prediction
|
||||
:param target_is_real: whether the target is real or fake
|
||||
:param is_discriminator: whether the loss for is_discriminator or not. default False
|
||||
:return: Tensor, GAN loss value
|
||||
"""
|
||||
target_val = self.real_label_val if target_is_real else self.fake_label_val
|
||||
target = prediction.new_ones(prediction.size()) * target_val
|
||||
|
||||
if self.loss_type == "vanilla":
|
||||
return F.binary_cross_entropy_with_logits(prediction, target)
|
||||
elif self.loss_type == "lsgan":
|
||||
return F.mse_loss(prediction, target)
|
||||
elif self.loss_type == "hinge":
|
||||
if is_discriminator:
|
||||
prediction = -prediction if target_is_real else prediction
|
||||
loss = F.relu(1 + prediction).mean()
|
||||
else:
|
||||
loss = -prediction.mean()
|
||||
return loss
|
||||
elif self.loss_type == "wgan":
|
||||
loss = -prediction.mean() if target_is_real else prediction.mean()
|
||||
return loss
|
||||
else:
|
||||
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')
|
||||
Reference in New Issue
Block a user