base code for pytorch distributed, add cyclegan

This commit is contained in:
2020-08-07 09:48:09 +08:00
commit f7843de45d
32 changed files with 1444 additions and 0 deletions

0
loss/__init__.py Normal file
View File

39
loss/gan.py Normal file
View File

@@ -0,0 +1,39 @@
import torch.nn as nn
import torch.nn.functional as F
class GANLoss(nn.Module):
def __init__(self, loss_type, real_label_val=1.0, fake_label_val=0.0):
super().__init__()
assert loss_type in ["vanilla", "lsgan", "hinge", "wgan"]
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
self.loss_type = loss_type
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
"""
gan loss forward
:param prediction: network prediction
:param target_is_real: whether the target is real or fake
:param is_discriminator: whether the loss for is_discriminator or not. default False
:return: Tensor, GAN loss value
"""
target_val = self.real_label_val if target_is_real else self.fake_label_val
target = prediction.new_ones(prediction.size()) * target_val
if self.loss_type == "vanilla":
return F.binary_cross_entropy_with_logits(prediction, target)
elif self.loss_type == "lsgan":
return F.mse_loss(prediction, target)
elif self.loss_type == "hinge":
if is_discriminator:
prediction = -prediction if target_is_real else prediction
loss = F.relu(1 + prediction).mean()
else:
loss = -prediction.mean()
return loss
elif self.loss_type == "wgan":
loss = -prediction.mean() if target_is_real else prediction.mean()
return loss
else:
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')