This commit is contained in:
2020-09-05 22:00:17 +08:00
parent 39c754374c
commit e3c760d0c5
12 changed files with 122 additions and 43 deletions

View File

@@ -98,13 +98,13 @@ class PerceptualLoss(nn.Module):
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
use_input_norm=use_input_norm)
self.criterion = self.set_criterion(criterion)
self.percep_criterion, self.style_criterion = self.set_criterion(criterion)
def set_criterion(self, criterion: str):
assert criterion in ["NL1", "NL2", "L1", "L2"]
norm = F.instance_norm if criterion.startswith("N") else lambda x: x
fn = F.l1_loss if criterion.endswith("L1") else F.mse_loss
return lambda x, t: fn(norm(x), norm(t))
return lambda x, t: fn(norm(x), norm(t)), lambda x, t: fn(x, t)
def forward(self, x, gt):
"""Forward function.
@@ -126,7 +126,7 @@ class PerceptualLoss(nn.Module):
if self.perceptual_loss:
percep_loss = 0
for k in x_features.keys():
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
percep_loss += self.percep_criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
else:
percep_loss = None
@@ -134,7 +134,7 @@ class PerceptualLoss(nn.Module):
if self.style_loss:
style_loss = 0
for k in x_features.keys():
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \
style_loss += self.style_criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \
self.layer_weights[k]
else:
style_loss = None