TAHG 0.0.3
This commit is contained in:
@@ -88,16 +88,17 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
engine.terminate()
|
||||
|
||||
|
||||
def setup_tensorboard_handler(trainer: Engine, config, output_transform):
|
||||
def setup_tensorboard_handler(trainer: Engine, config, output_transform, iter_per_epoch):
|
||||
if config.interval.tensorboard is None:
|
||||
return None
|
||||
if idist.get_rank() == 0:
|
||||
# Create a logger
|
||||
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
||||
basic_event = Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
|
||||
tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"),
|
||||
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
|
||||
event_name=basic_event)
|
||||
tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform),
|
||||
event_name=Events.ITERATION_COMPLETED(every=config.interval.tensorboard.scalar))
|
||||
event_name=basic_event)
|
||||
|
||||
@trainer.on(Events.COMPLETED)
|
||||
@idist.one_rank_only()
|
||||
|
||||
Reference in New Issue
Block a user