add loss container
This commit is contained in:
@@ -100,6 +100,13 @@ class EngineKernel(object):
|
||||
pass
|
||||
|
||||
|
||||
def _remove_no_grad_loss(loss_dict):
|
||||
for k in loss_dict:
|
||||
if not isinstance(loss_dict[k], torch.Tensor):
|
||||
loss_dict.pop(k)
|
||||
return loss_dict
|
||||
|
||||
|
||||
def get_trainer(config, kernel: EngineKernel):
|
||||
logger = logging.getLogger(config.name)
|
||||
generators, discriminators = kernel.generators, kernel.discriminators
|
||||
@@ -147,10 +154,10 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
|
||||
if engine.state.iteration % iteration_per_image == 0:
|
||||
return {
|
||||
"loss": dict(g=loss_g, d=loss_d),
|
||||
"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d)),
|
||||
"img": kernel.intermediate_images(batch, generated)
|
||||
}
|
||||
return {"loss": dict(g=loss_g, d=loss_d)}
|
||||
return {"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d))}
|
||||
|
||||
trainer = Engine(_step)
|
||||
trainer.logger = logger
|
||||
|
||||
Reference in New Issue
Block a user