change
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models.vgg as vgg
|
||||
|
||||
|
||||
@@ -97,12 +98,13 @@ class PerceptualLoss(nn.Module):
|
||||
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
|
||||
use_input_norm=use_input_norm)
|
||||
|
||||
if criterion == 'L1':
|
||||
self.criterion = torch.nn.L1Loss()
|
||||
elif criterion == "L2":
|
||||
self.criterion = torch.nn.MSELoss()
|
||||
else:
|
||||
raise NotImplementedError(f'{criterion} criterion has not been supported in this version.')
|
||||
self.criterion = self.set_criterion(criterion)
|
||||
|
||||
def set_criterion(self, criterion: str):
|
||||
assert criterion in ["NL1", "NL2", "L1", "L2"]
|
||||
norm = F.instance_norm if criterion.startswith("N") else lambda x: x
|
||||
fn = F.l1_loss if criterion.endswith("L1") else F.mse_loss
|
||||
return lambda x, t: fn(norm(x), norm(t))
|
||||
|
||||
def forward(self, x, gt):
|
||||
"""Forward function.
|
||||
@@ -124,8 +126,7 @@ class PerceptualLoss(nn.Module):
|
||||
if self.perceptual_loss:
|
||||
percep_loss = 0
|
||||
for k in x_features.keys():
|
||||
percep_loss += self.criterion(
|
||||
x_features[k], gt_features[k]) * self.layer_weights[k]
|
||||
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
|
||||
else:
|
||||
percep_loss = None
|
||||
|
||||
@@ -133,9 +134,8 @@ class PerceptualLoss(nn.Module):
|
||||
if self.style_loss:
|
||||
style_loss = 0
|
||||
for k in x_features.keys():
|
||||
style_loss += self.criterion(
|
||||
self._gram_mat(x_features[k]),
|
||||
self._gram_mat(gt_features[k])) * self.layer_weights[k]
|
||||
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \
|
||||
self.layer_weights[k]
|
||||
else:
|
||||
style_loss = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user