rewrite
This commit is contained in:
111
loss/I2I/minimal_geometry_distortion_constraint_loss.py
Normal file
111
loss/I2I/minimal_geometry_distortion_constraint_loss.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def gaussian_radial_basis_function(x, mu, sigma):
|
||||
# (kernel_size) -> (batch_size, kernel_size, c*h*w)
|
||||
mu = mu.view(1, mu.size(0), 1).expand(x.size(0), -1, x.size(1) * x.size(2) * x.size(3))
|
||||
mu = mu.to(x.device)
|
||||
# (batch_size, c, h, w) -> (batch_size, kernel_size, c*h*w)
|
||||
x = x.view(x.size(0), 1, -1).expand(-1, mu.size(1), -1)
|
||||
return torch.exp((x - mu).pow(2) / (2 * sigma ** 2))
|
||||
|
||||
|
||||
class MyLoss(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyLoss, self).__init__()
|
||||
|
||||
def forward(self, fakeI, realI):
|
||||
def batch_ERSMI(I1, I2):
|
||||
batch_size = I1.shape[0]
|
||||
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
|
||||
if I2.shape[1] == 1 and I1.shape[1] != 1:
|
||||
I2 = I2.repeat(1, 3, 1, 1)
|
||||
|
||||
def kernel_F(y, mu_list, sigma):
|
||||
tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size, 1, 1).cuda() # [81, 784]
|
||||
tmp_y = y.view(batch_size, 1, -1).repeat(1, 81, 1)
|
||||
tmp_y = tmp_mu - tmp_y
|
||||
mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
|
||||
return mat_L
|
||||
|
||||
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).cuda()
|
||||
|
||||
x_mu_list = mu.repeat(9).view(-1, 81)
|
||||
y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)
|
||||
|
||||
mat_K = kernel_F(I1, x_mu_list, 1)
|
||||
mat_L = kernel_F(I2, y_mu_list, 1)
|
||||
|
||||
H1 = ((mat_K.matmul(mat_K.transpose(1, 2))).mul(mat_L.matmul(mat_L.transpose(1, 2))) / (
|
||||
img_size ** 2)).cuda()
|
||||
H2 = ((mat_K.mul(mat_L)).matmul((mat_K.mul(mat_L)).transpose(1, 2)) / img_size).cuda()
|
||||
h2 = ((mat_K.sum(2).view(batch_size, -1, 1)).mul(mat_L.sum(2).view(batch_size, -1, 1)) / (
|
||||
img_size ** 2)).cuda()
|
||||
H2 = 0.5 * H1 + 0.5 * H2
|
||||
tmp = H2 + 0.05 * torch.eye(len(H2[0])).cuda()
|
||||
alpha = (tmp.inverse())
|
||||
|
||||
alpha = alpha.matmul(h2)
|
||||
ersmi = (2 * (alpha.transpose(1, 2)).matmul(h2) - ((alpha.transpose(1, 2)).matmul(H2)).matmul(
|
||||
alpha) - 1).squeeze()
|
||||
ersmi = -ersmi.mean()
|
||||
return ersmi
|
||||
|
||||
batch_loss = batch_ERSMI(fakeI, realI)
|
||||
return batch_loss
|
||||
|
||||
|
||||
class MGCLoss(nn.Module):
|
||||
"""
|
||||
Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ
|
||||
"""
|
||||
|
||||
def __init__(self, beta=0.5, lambda_=0.05):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.lambda_ = lambda_
|
||||
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0])
|
||||
self.mu_x = mu.repeat(9)
|
||||
self.mu_y = mu.unsqueeze(0).t().repeat(1, 9).view(-1)
|
||||
|
||||
@staticmethod
|
||||
def batch_rSMI(img1, img2, mu_x, mu_y, beta, lambda_):
|
||||
assert img1.size() == img2.size()
|
||||
|
||||
num_pixel = img1.size(1) * img1.size(2) * img2.size(3)
|
||||
|
||||
mat_k = gaussian_radial_basis_function(img1, mu_x, sigma=1)
|
||||
mat_l = gaussian_radial_basis_function(img2, mu_y, sigma=1)
|
||||
|
||||
mat_k_mul_mat_l = mat_k * mat_l
|
||||
h_hat = (1 - beta) * (mat_k_mul_mat_l.matmul(mat_k_mul_mat_l.transpose(1, 2))) / num_pixel
|
||||
h_hat += beta * (mat_k.matmul(mat_k.transpose(1, 2)) * mat_l.matmul(mat_l.transpose(1, 2))) / (num_pixel ** 2)
|
||||
small_h_hat = mat_k.sum(2, keepdim=True) * mat_l.sum(2, keepdim=True) / (num_pixel ** 2)
|
||||
|
||||
R = torch.eye(h_hat.size(1)).to(img1.device)
|
||||
alpha = (h_hat + lambda_ * R).inverse().matmul(small_h_hat)
|
||||
|
||||
rSMI = (2 * alpha.transpose(1, 2).matmul(small_h_hat)) - alpha.transpose(1, 2).matmul(h_hat).matmul(alpha) - 1
|
||||
return rSMI
|
||||
|
||||
def forward(self, fake, real):
|
||||
rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_)
|
||||
return -rSMI.squeeze().mean()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mg = MGCLoss().to("cuda")
|
||||
|
||||
|
||||
def norm(x):
|
||||
x -= x.min()
|
||||
x /= x.max()
|
||||
return (x - 0.5) * 2
|
||||
|
||||
|
||||
x1 = norm(torch.randn(5, 3, 256, 256))
|
||||
x2 = norm(x1 * 2 + 1)
|
||||
x3 = norm(torch.randn(5, 3, 256, 256))
|
||||
x4 = norm(torch.exp(x3))
|
||||
print(mg(x1, x1), mg(x1, x2), mg(x1, x3), mg(x1, x4))
|
||||
Reference in New Issue
Block a user