TAFG 0.01

This commit is contained in:
2020-09-03 09:34:38 +08:00
parent 14d4247112
commit 2469bf15fe
6 changed files with 37 additions and 388 deletions

View File

@@ -145,6 +145,7 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
loss[tl] = output["loss"][tl]
return loss
pairs_per_iteration = config.data.train.dataloader.batch_size
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
if tensorboard_handler is not None:
tensorboard_handler.attach(
@@ -159,7 +160,8 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
test_images = {}
for k in output["img"]:
image_list = output["img"][k]
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list), engine.state.iteration)
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list),
engine.state.iteration * pairs_per_iteration)
test_images[k] = []
for i in range(len(image_list)):
test_images[k].append([])
@@ -182,6 +184,6 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
tensorboard_handler.writer.add_image(
f"test/{k}",
make_2d_grid([torch.cat(ti) for ti in test_images[k]]),
engine.state.iteration
engine.state.iteration * pairs_per_iteration
)
return trainer