update
This commit is contained in:
@@ -98,13 +98,13 @@ class PerceptualLoss(nn.Module):
|
||||
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
|
||||
use_input_norm=use_input_norm)
|
||||
|
||||
self.criterion = self.set_criterion(criterion)
|
||||
self.percep_criterion, self.style_criterion = self.set_criterion(criterion)
|
||||
|
||||
def set_criterion(self, criterion: str):
|
||||
assert criterion in ["NL1", "NL2", "L1", "L2"]
|
||||
norm = F.instance_norm if criterion.startswith("N") else lambda x: x
|
||||
fn = F.l1_loss if criterion.endswith("L1") else F.mse_loss
|
||||
return lambda x, t: fn(norm(x), norm(t))
|
||||
return lambda x, t: fn(norm(x), norm(t)), lambda x, t: fn(x, t)
|
||||
|
||||
def forward(self, x, gt):
|
||||
"""Forward function.
|
||||
@@ -126,7 +126,7 @@ class PerceptualLoss(nn.Module):
|
||||
if self.perceptual_loss:
|
||||
percep_loss = 0
|
||||
for k in x_features.keys():
|
||||
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
|
||||
percep_loss += self.percep_criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
|
||||
else:
|
||||
percep_loss = None
|
||||
|
||||
@@ -134,7 +134,7 @@ class PerceptualLoss(nn.Module):
|
||||
if self.style_loss:
|
||||
style_loss = 0
|
||||
for k in x_features.keys():
|
||||
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \
|
||||
style_loss += self.style_criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \
|
||||
self.layer_weights[k]
|
||||
else:
|
||||
style_loss = None
|
||||
|
||||
Reference in New Issue
Block a user