move common handler setup to util

This commit is contained in:
2020-08-08 07:11:47 +08:00
parent 888a052f05
commit 8abd35467c
4 changed files with 183 additions and 85 deletions

View File

@@ -1,21 +1,86 @@
from pathlib import Path
import torch
from ignite.engine import Engine
from ignite.handlers import Checkpoint
import ignite.distributed as idist
from ignite.engine import Events
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
from ignite.contrib.handlers import BasicTimeProfiler
class Resumer:
def __init__(self, to_load, checkpoint_path):
self.to_load = to_load
if checkpoint_path is not None:
checkpoint_path = Path(checkpoint_path)
if not checkpoint_path.exists():
raise ValueError(f"Checkpoint '{checkpoint_path}' is not found")
self.checkpoint_path = checkpoint_path
def setup_common_handlers(
trainer,
output_dir=None,
stop_on_nan=True,
use_profiler=True,
print_interval_event=None,
metrics_to_print=None,
to_save=None,
resume_from=None,
save_interval_event=None,
**checkpoint_kwargs
):
"""
Helper method to setup trainer with common handlers.
1. TerminateOnNan
2. BasicTimeProfiler
3. Print
4. Checkpoint
:param trainer: trainer engine. Output of trainer's `update_function` should be a dictionary
or sequence or a single tensor.
:param output_dir: output path to indicate where `to_save` objects are stored. Argument is mutually
:param stop_on_nan: if True, :class:`~ignite.handlers.TerminateOnNan` handler is added to the trainer.
:param use_profiler:
:param print_interval_event:
:param metrics_to_print:
:param to_save:
:param resume_from:
:param save_interval_event:
:param checkpoint_kwargs:
:return:
"""
if stop_on_nan:
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
def __call__(self, engine: Engine):
if self.checkpoint_path is not None:
ckp = torch.load(self.checkpoint_path.as_posix(), map_location="cpu")
Checkpoint.load_objects(to_load=self.to_load, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {self.checkpoint_path}")
if use_profiler:
# Create an object of the profiler and attach an engine to it
profiler = BasicTimeProfiler()
profiler.attach(trainer)
@trainer.on(Events.EPOCH_COMPLETED(once=1))
@idist.one_rank_only()
def log_intermediate_results():
profiler.print_results(profiler.get_results())
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
profiler.print_results(profiler.get_results())
# profiler.write_results(f"{output_dir}/time_profiling.csv")
if metrics_to_print is not None:
if print_interval_event is None:
raise ValueError(
"If metrics_to_print argument is provided then print_interval_event arguments should be also defined"
)
@trainer.on(print_interval_event)
def print_interval(engine):
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
for m in metrics_to_print:
print_str += f"{m}={engine.state.metrics[m]:.3f} "
engine.logger.info(print_str)
if to_save is not None:
if resume_from is not None:
@trainer.on(Events.STARTED)
def resume(engine):
checkpoint_path = Path(resume_from)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
if save_interval_event is not None:
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir), **checkpoint_kwargs)
trainer.add_event_handler(save_interval_event, checkpoint_handler)