v2
This commit is contained in:
@@ -105,10 +105,12 @@ class MGCLoss(nn.Module):
|
||||
Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ
|
||||
"""
|
||||
|
||||
def __init__(self, beta=0.5, lambda_=0.05, device=idist.device()):
|
||||
def __init__(self, mi_to_loss_way="opposite", beta=0.5, lambda_=0.05, device=idist.device()):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.lambda_ = lambda_
|
||||
assert mi_to_loss_way in ["opposite", "reciprocal"]
|
||||
self.mi_to_loss_way = mi_to_loss_way
|
||||
mu_y, mu_x = torch.meshgrid([torch.arange(-1, 1.25, 0.25), torch.arange(-1, 1.25, 0.25)])
|
||||
self.mu_x = mu_x.flatten().to(device)
|
||||
self.mu_y = mu_y.flatten().to(device)
|
||||
@@ -134,6 +136,8 @@ class MGCLoss(nn.Module):
|
||||
|
||||
def forward(self, fake, real):
|
||||
rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_, self.R)
|
||||
if self.mi_to_loss_way == "reciprocal":
|
||||
return 1/rSMI.mean()
|
||||
return -rSMI.mean()
|
||||
|
||||
|
||||
|
||||
14
loss/gan.py
14
loss/gan.py
@@ -1,5 +1,6 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
|
||||
|
||||
class GANLoss(nn.Module):
|
||||
@@ -10,7 +11,7 @@ class GANLoss(nn.Module):
|
||||
self.fake_label_val = fake_label_val
|
||||
self.loss_type = loss_type
|
||||
|
||||
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||
def single_forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||
"""
|
||||
gan loss forward
|
||||
:param prediction: network prediction
|
||||
@@ -37,3 +38,14 @@ class GANLoss(nn.Module):
|
||||
return loss
|
||||
else:
|
||||
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')
|
||||
|
||||
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||
if isinstance(prediction, torch.Tensor):
|
||||
# origin
|
||||
return self.single_forward(prediction, target_is_real, is_discriminator)
|
||||
elif isinstance(prediction, list):
|
||||
# for multi scale discriminator, e.g. MultiScaleDiscriminator
|
||||
loss = 0
|
||||
for p in prediction:
|
||||
loss += self.single_forward(p[-1], target_is_real, is_discriminator)
|
||||
return loss
|
||||
|
||||
Reference in New Issue
Block a user