This commit is contained in:
2020-09-05 10:33:35 +08:00
parent 2469bf15fe
commit 39c754374c
21 changed files with 550 additions and 705 deletions

View File

@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models.vgg as vgg
@@ -97,12 +98,13 @@ class PerceptualLoss(nn.Module):
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
use_input_norm=use_input_norm)
if criterion == 'L1':
self.criterion = torch.nn.L1Loss()
elif criterion == "L2":
self.criterion = torch.nn.MSELoss()
else:
raise NotImplementedError(f'{criterion} criterion has not been supported in this version.')
self.criterion = self.set_criterion(criterion)
def set_criterion(self, criterion: str):
assert criterion in ["NL1", "NL2", "L1", "L2"]
norm = F.instance_norm if criterion.startswith("N") else lambda x: x
fn = F.l1_loss if criterion.endswith("L1") else F.mse_loss
return lambda x, t: fn(norm(x), norm(t))
def forward(self, x, gt):
"""Forward function.
@@ -124,8 +126,7 @@ class PerceptualLoss(nn.Module):
if self.perceptual_loss:
percep_loss = 0
for k in x_features.keys():
percep_loss += self.criterion(
x_features[k], gt_features[k]) * self.layer_weights[k]
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
else:
percep_loss = None
@@ -133,9 +134,8 @@ class PerceptualLoss(nn.Module):
if self.style_loss:
style_loss = 0
for k in x_features.keys():
style_loss += self.criterion(
self._gram_mat(x_features[k]),
self._gram_mat(gt_features[k])) * self.layer_weights[k]
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \
self.layer_weights[k]
else:
style_loss = None