improve run.sh
This commit is contained in:
@@ -3,13 +3,13 @@ from pathlib import Path
|
||||
import torch
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
|
||||
from ignite.contrib.handlers import BasicTimeProfiler
|
||||
|
||||
|
||||
def setup_common_handlers(
|
||||
trainer,
|
||||
trainer: Engine,
|
||||
output_dir=None,
|
||||
stop_on_nan=True,
|
||||
use_profiler=True,
|
||||
@@ -39,6 +39,11 @@ def setup_common_handlers(
|
||||
:param checkpoint_kwargs:
|
||||
:return:
|
||||
"""
|
||||
@trainer.on(Events.STARTED)
|
||||
@idist.one_rank_only()
|
||||
def print_dataloader_size(engine):
|
||||
engine.logger.info(f"data loader length: {len(engine.state.dataloader)}")
|
||||
|
||||
if stop_on_nan:
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
|
||||
|
||||
@@ -68,6 +73,8 @@ def setup_common_handlers(
|
||||
def print_interval(engine):
|
||||
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
|
||||
for m in metrics_to_print:
|
||||
if m not in engine.state.metrics:
|
||||
continue
|
||||
print_str += f"{m}={engine.state.metrics[m]:.3f} "
|
||||
engine.logger.info(print_str)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user