TSIT
This commit is contained in:
@@ -16,7 +16,7 @@ class GauGANEngineKernel(EngineKernel):
|
||||
|
||||
self.gan_loss = gan_loss(config.loss.gan)
|
||||
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
|
||||
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "exponential_decline"))
|
||||
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same"))
|
||||
self.perceptual_loss = LossContainer(config.loss.perceptual.weight, perceptual_loss(config.loss.perceptual))
|
||||
|
||||
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
||||
|
||||
@@ -38,7 +38,7 @@ def feature_match_loss(level, weight_policy):
|
||||
|
||||
def fm_loss(generated_features, target_features):
|
||||
num_scale = len(generated_features)
|
||||
loss = 0
|
||||
loss = torch.zeros(1, device=idist.device())
|
||||
for s_i in range(num_scale):
|
||||
for i in range(len(generated_features[s_i]) - 1):
|
||||
weight = 1 if weight_policy == "same" else 2 ** i
|
||||
|
||||
Reference in New Issue
Block a user