change
This commit is contained in:
154
engine/util/handler.py
Normal file
154
engine/util/handler.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
|
||||
from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
|
||||
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler, OptimizerParamsHandler
|
||||
|
||||
|
||||
def empty_cuda_cache(_):
|
||||
torch.cuda.empty_cache()
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
|
||||
def step_transform_maker(stype: str, pairs_per_iteration=None):
|
||||
assert stype in ["item", "iteration", "epoch"]
|
||||
if stype == "item":
|
||||
return lambda engine, _: engine.state.iteration * pairs_per_iteration
|
||||
if stype == "iteration":
|
||||
return lambda engine, _: engine.state.iteration
|
||||
if stype == "epoch":
|
||||
return lambda engine, _: engine.state.epoch
|
||||
|
||||
|
||||
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
|
||||
to_save=None, end_event=None, set_epoch_for_dist_sampler=False):
|
||||
"""
|
||||
Helper method to setup trainer with common handlers.
|
||||
1. TerminateOnNan
|
||||
2. BasicTimeProfiler
|
||||
3. Print
|
||||
4. Checkpoint
|
||||
:param trainer:
|
||||
:param config:
|
||||
:param stop_on_nan:
|
||||
:param clear_cuda_cache:
|
||||
:param use_profiler:
|
||||
:param to_save:
|
||||
:param end_event:
|
||||
:param set_epoch_for_dist_sampler:
|
||||
:return:
|
||||
"""
|
||||
if set_epoch_for_dist_sampler:
|
||||
@trainer.on(Events.EPOCH_STARTED)
|
||||
def distrib_set_epoch(engine):
|
||||
if isinstance(trainer.state.dataloader.sampler, DistributedSampler):
|
||||
trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler")
|
||||
trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1)
|
||||
|
||||
trainer.logger.info(f"data loader length: {config.iterations_per_epoch} iterations per epoch")
|
||||
|
||||
@trainer.on(Events.EPOCH_COMPLETED(once=1))
|
||||
def print_info(engine):
|
||||
engine.logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}")
|
||||
|
||||
if stop_on_nan:
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
|
||||
|
||||
if torch.cuda.is_available() and clear_cuda_cache:
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)
|
||||
|
||||
if use_profiler:
|
||||
# Create an object of the profiler and attach an engine to it
|
||||
profiler = BasicTimeProfiler()
|
||||
profiler.attach(trainer)
|
||||
|
||||
@trainer.on(Events.EPOCH_COMPLETED(once=1) | Events.COMPLETED)
|
||||
@idist.one_rank_only()
|
||||
def log_intermediate_results():
|
||||
profiler.print_results(profiler.get_results())
|
||||
|
||||
ProgressBar(ncols=0).attach(trainer, "all")
|
||||
|
||||
if to_save is not None:
|
||||
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False),
|
||||
n_saved=config.handler.checkpoint.n_saved, filename_prefix=config.name)
|
||||
if config.resume_from is not None:
|
||||
@trainer.on(Events.STARTED)
|
||||
def resume(engine):
|
||||
checkpoint_path = Path(config.resume_from)
|
||||
if not checkpoint_path.exists():
|
||||
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
|
||||
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
|
||||
trainer.logger.info(f"load state_dict for {ckp.keys()}")
|
||||
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
|
||||
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
|
||||
trainer.add_event_handler(
|
||||
Events.EPOCH_COMPLETED(every=config.handler.checkpoint.epoch_interval) | Events.COMPLETED,
|
||||
checkpoint_handler
|
||||
)
|
||||
trainer.logger.debug(f"add checkpoint handler to save {to_save.keys()} periodically")
|
||||
if end_event is not None:
|
||||
trainer.logger.debug(f"engine will stop on {end_event}")
|
||||
|
||||
@trainer.on(end_event)
|
||||
def terminate(engine):
|
||||
engine.terminate()
|
||||
|
||||
|
||||
def setup_tensorboard_handler(trainer: Engine, config, optimizers, step_type="item"):
|
||||
if config.handler.tensorboard is None:
|
||||
return None
|
||||
if idist.get_rank() == 0:
|
||||
# Create a logger
|
||||
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
||||
tb_writer = tb_logger.writer
|
||||
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
|
||||
global_step_transform = step_transform_maker(step_type, pairs_per_iteration)
|
||||
|
||||
basic_event = Events.ITERATION_COMPLETED(
|
||||
every=max(config.iterations_per_epoch // config.handler.tensorboard.scalar, 1))
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OutputHandler(
|
||||
tag="metric", metric_names="all",
|
||||
global_step_transform=global_step_transform
|
||||
),
|
||||
event_name=basic_event
|
||||
)
|
||||
|
||||
@trainer.on(basic_event)
|
||||
def log_loss(engine):
|
||||
global_step = global_step_transform(engine, None)
|
||||
output_loss = engine.state.output["loss"]
|
||||
for total_loss in output_loss:
|
||||
if isinstance(output_loss[total_loss], dict):
|
||||
for ln in output_loss[total_loss]:
|
||||
tb_writer.add_scalar(f"train_{total_loss}/{ln}", output_loss[total_loss][ln], global_step)
|
||||
else:
|
||||
tb_writer.add_scalar(f"train/{total_loss}", output_loss[total_loss], global_step)
|
||||
|
||||
if isinstance(optimizers, dict):
|
||||
for name in optimizers:
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OptimizerParamsHandler(optimizers[name], tag=f"optimizer_{name}"),
|
||||
event_name=Events.ITERATION_STARTED
|
||||
)
|
||||
else:
|
||||
tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizers, tag=f"optimizer"),
|
||||
event_name=Events.ITERATION_STARTED)
|
||||
|
||||
@trainer.on(Events.COMPLETED)
|
||||
@idist.one_rank_only()
|
||||
def _():
|
||||
# We need to close the logger with we are done
|
||||
tb_logger.close()
|
||||
|
||||
return tb_logger
|
||||
return None
|
||||
Reference in New Issue
Block a user