v2
This commit is contained in:
@@ -23,3 +23,19 @@ def mse_loss(x, target_flag):
|
||||
|
||||
def bce_loss(x, target_flag):
|
||||
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||
|
||||
|
||||
def feature_match_loss(level, weight_policy):
|
||||
compare_loss = pixel_loss(level)
|
||||
assert weight_policy in ["same", "exponential_decline"]
|
||||
|
||||
def fm_loss(generated_features, target_features):
|
||||
num_scale = len(generated_features)
|
||||
loss = 0
|
||||
for s_i in range(num_scale):
|
||||
for i in range(len(generated_features[s_i]) - 1):
|
||||
weight = 1 if weight_policy == "same" else 2 ** i
|
||||
loss += weight * compare_loss(generated_features[s_i][i], target_features[s_i][i].detach()) / num_scale
|
||||
return loss
|
||||
|
||||
return fm_loss
|
||||
|
||||
Reference in New Issue
Block a user