add loss container

This commit is contained in:
2020-10-11 23:09:04 +08:00
parent 6070f08835
commit 436bca88b4
3 changed files with 26 additions and 11 deletions

View File

@@ -100,6 +100,13 @@ class EngineKernel(object):
pass
def _remove_no_grad_loss(loss_dict):
for k in loss_dict:
if not isinstance(loss_dict[k], torch.Tensor):
loss_dict.pop(k)
return loss_dict
def get_trainer(config, kernel: EngineKernel):
logger = logging.getLogger(config.name)
generators, discriminators = kernel.generators, kernel.discriminators
@@ -147,10 +154,10 @@ def get_trainer(config, kernel: EngineKernel):
if engine.state.iteration % iteration_per_image == 0:
return {
"loss": dict(g=loss_g, d=loss_d),
"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d)),
"img": kernel.intermediate_images(batch, generated)
}
return {"loss": dict(g=loss_g, d=loss_d)}
return {"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d))}
trainer = Engine(_step)
trainer.logger = logger