TAHG 0.0.3

This commit is contained in:
2020-09-01 09:02:04 +08:00
parent 89b54105c7
commit e71e8d95d0
8 changed files with 97 additions and 36 deletions

View File

@@ -88,16 +88,17 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
engine.terminate()
def setup_tensorboard_handler(trainer: Engine, config, output_transform):
def setup_tensorboard_handler(trainer: Engine, config, output_transform, iter_per_epoch):
if config.interval.tensorboard is None:
return None
if idist.get_rank() == 0:
# Create a logger
tb_logger = TensorboardLogger(log_dir=config.output_dir)
basic_event = Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"),
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
event_name=basic_event)
tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform),
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
event_name=basic_event)
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()