TAFG 0.01
This commit is contained in:
@@ -145,6 +145,7 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
|
||||
loss[tl] = output["loss"][tl]
|
||||
return loss
|
||||
|
||||
pairs_per_iteration = config.data.train.dataloader.batch_size
|
||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
|
||||
if tensorboard_handler is not None:
|
||||
tensorboard_handler.attach(
|
||||
@@ -159,7 +160,8 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
|
||||
test_images = {}
|
||||
for k in output["img"]:
|
||||
image_list = output["img"][k]
|
||||
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list), engine.state.iteration)
|
||||
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list),
|
||||
engine.state.iteration * pairs_per_iteration)
|
||||
test_images[k] = []
|
||||
for i in range(len(image_list)):
|
||||
test_images[k].append([])
|
||||
@@ -182,6 +184,6 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"test/{k}",
|
||||
make_2d_grid([torch.cat(ti) for ti in test_images[k]]),
|
||||
engine.state.iteration
|
||||
engine.state.iteration * pairs_per_iteration
|
||||
)
|
||||
return trainer
|
||||
|
||||
Reference in New Issue
Block a user