move common handler setup to util
This commit is contained in:
@@ -1,21 +1,86 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from ignite.engine import Engine
|
||||
from ignite.handlers import Checkpoint
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events
|
||||
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
|
||||
from ignite.contrib.handlers import BasicTimeProfiler
|
||||
|
||||
|
||||
class Resumer:
|
||||
def __init__(self, to_load, checkpoint_path):
|
||||
self.to_load = to_load
|
||||
if checkpoint_path is not None:
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
if not checkpoint_path.exists():
|
||||
raise ValueError(f"Checkpoint '{checkpoint_path}' is not found")
|
||||
self.checkpoint_path = checkpoint_path
|
||||
def setup_common_handlers(
|
||||
trainer,
|
||||
output_dir=None,
|
||||
stop_on_nan=True,
|
||||
use_profiler=True,
|
||||
print_interval_event=None,
|
||||
metrics_to_print=None,
|
||||
to_save=None,
|
||||
resume_from=None,
|
||||
save_interval_event=None,
|
||||
**checkpoint_kwargs
|
||||
):
|
||||
"""
|
||||
Helper method to setup trainer with common handlers.
|
||||
1. TerminateOnNan
|
||||
2. BasicTimeProfiler
|
||||
3. Print
|
||||
4. Checkpoint
|
||||
:param trainer: trainer engine. Output of trainer's `update_function` should be a dictionary
|
||||
or sequence or a single tensor.
|
||||
:param output_dir: output path to indicate where `to_save` objects are stored. Argument is mutually
|
||||
:param stop_on_nan: if True, :class:`~ignite.handlers.TerminateOnNan` handler is added to the trainer.
|
||||
:param use_profiler:
|
||||
:param print_interval_event:
|
||||
:param metrics_to_print:
|
||||
:param to_save:
|
||||
:param resume_from:
|
||||
:param save_interval_event:
|
||||
:param checkpoint_kwargs:
|
||||
:return:
|
||||
"""
|
||||
if stop_on_nan:
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
|
||||
|
||||
def __call__(self, engine: Engine):
|
||||
if self.checkpoint_path is not None:
|
||||
ckp = torch.load(self.checkpoint_path.as_posix(), map_location="cpu")
|
||||
Checkpoint.load_objects(to_load=self.to_load, checkpoint=ckp)
|
||||
engine.logger.info(f"resume from a checkpoint {self.checkpoint_path}")
|
||||
if use_profiler:
|
||||
# Create an object of the profiler and attach an engine to it
|
||||
profiler = BasicTimeProfiler()
|
||||
profiler.attach(trainer)
|
||||
|
||||
@trainer.on(Events.EPOCH_COMPLETED(once=1))
|
||||
@idist.one_rank_only()
|
||||
def log_intermediate_results():
|
||||
profiler.print_results(profiler.get_results())
|
||||
|
||||
@trainer.on(Events.COMPLETED)
|
||||
@idist.one_rank_only()
|
||||
def _():
|
||||
profiler.print_results(profiler.get_results())
|
||||
# profiler.write_results(f"{output_dir}/time_profiling.csv")
|
||||
|
||||
if metrics_to_print is not None:
|
||||
if print_interval_event is None:
|
||||
raise ValueError(
|
||||
"If metrics_to_print argument is provided then print_interval_event arguments should be also defined"
|
||||
)
|
||||
|
||||
@trainer.on(print_interval_event)
|
||||
def print_interval(engine):
|
||||
print_str = f"epoch:{engine.state.epoch} iter:{engine.state.iteration}\t"
|
||||
for m in metrics_to_print:
|
||||
print_str += f"{m}={engine.state.metrics[m]:.3f} "
|
||||
engine.logger.info(print_str)
|
||||
|
||||
if to_save is not None:
|
||||
if resume_from is not None:
|
||||
@trainer.on(Events.STARTED)
|
||||
def resume(engine):
|
||||
checkpoint_path = Path(resume_from)
|
||||
if not checkpoint_path.exists():
|
||||
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
|
||||
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
|
||||
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
|
||||
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
|
||||
if save_interval_event is not None:
|
||||
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir), **checkpoint_kwargs)
|
||||
trainer.add_event_handler(save_interval_event, checkpoint_handler)
|
||||
|
||||
Reference in New Issue
Block a user