add U-GAT-IT
This commit is contained in:
@@ -39,6 +39,7 @@ def setup_common_handlers(
|
||||
:param checkpoint_kwargs:
|
||||
:return:
|
||||
"""
|
||||
|
||||
@trainer.on(Events.STARTED)
|
||||
@idist.one_rank_only()
|
||||
def print_dataloader_size(engine):
|
||||
@@ -79,6 +80,8 @@ def setup_common_handlers(
|
||||
engine.logger.info(print_str)
|
||||
|
||||
if to_save is not None:
|
||||
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir, require_empty=False),
|
||||
**checkpoint_kwargs)
|
||||
if resume_from is not None:
|
||||
@trainer.on(Events.STARTED)
|
||||
def resume(engine):
|
||||
@@ -89,5 +92,4 @@ def setup_common_handlers(
|
||||
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
|
||||
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
|
||||
if save_interval_event is not None:
|
||||
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir), **checkpoint_kwargs)
|
||||
trainer.add_event_handler(save_interval_event, checkpoint_handler)
|
||||
|
||||
85
util/misc.py
Normal file
85
util/misc.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def setup_logger(
|
||||
name: Optional[str] = None,
|
||||
level: int = logging.INFO,
|
||||
logger_format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s",
|
||||
filepath: Optional[str] = None,
|
||||
file_level: int = logging.DEBUG,
|
||||
distributed_rank: Optional[int] = None,
|
||||
) -> logging.Logger:
|
||||
"""Setups logger: name, level, format etc.
|
||||
|
||||
Args:
|
||||
name (str, optional): new name for the logger. If None, the standard logger is used.
|
||||
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG
|
||||
logger_format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`
|
||||
filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
|
||||
file_level (int): Optional logging level for logging file.
|
||||
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
|
||||
If None, distributed_rank is initialized to the rank of process.
|
||||
|
||||
Returns:
|
||||
logging.Logger
|
||||
|
||||
For example, to improve logs readability when training with a trainer and evaluator:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ignite.utils import setup_logger
|
||||
|
||||
trainer = ...
|
||||
evaluator = ...
|
||||
|
||||
trainer.logger = setup_logger("trainer")
|
||||
evaluator.logger = setup_logger("evaluator")
|
||||
|
||||
trainer.run(data, max_epochs=10)
|
||||
|
||||
# Logs will look like
|
||||
# 2020-01-21 12:46:07,356 trainer INFO: Engine run starting with max_epochs=5.
|
||||
# 2020-01-21 12:46:07,358 trainer INFO: Epoch[1] Complete. Time taken: 00:5:23
|
||||
# 2020-01-21 12:46:07,358 evaluator INFO: Engine run starting with max_epochs=1.
|
||||
# 2020-01-21 12:46:07,358 evaluator INFO: Epoch[1] Complete. Time taken: 00:01:02
|
||||
# ...
|
||||
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
|
||||
# don't propagate to ancestors
|
||||
# the problem here is to attach handlers to loggers
|
||||
# should we provide a default configuration less open ?
|
||||
if name is not None:
|
||||
logger.propagate = False
|
||||
|
||||
# Remove previous handlers
|
||||
if logger.hasHandlers():
|
||||
for h in list(logger.handlers):
|
||||
logger.removeHandler(h)
|
||||
|
||||
formatter = logging.Formatter(logger_format)
|
||||
|
||||
if distributed_rank is None:
|
||||
import ignite.distributed as idist
|
||||
|
||||
distributed_rank = idist.get_rank()
|
||||
|
||||
if distributed_rank > 0:
|
||||
logger.addHandler(logging.NullHandler())
|
||||
else:
|
||||
logger.setLevel(level)
|
||||
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(level)
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
|
||||
if filepath is not None:
|
||||
fh = logging.FileHandler(filepath)
|
||||
fh.setLevel(file_level)
|
||||
fh.setFormatter(formatter)
|
||||
logger.addHandler(fh)
|
||||
|
||||
return logger
|
||||
@@ -67,7 +67,11 @@ class _Registry:
|
||||
if default_args is not None:
|
||||
for name, value in default_args.items():
|
||||
args.setdefault(name, value)
|
||||
return obj_cls(**args)
|
||||
try:
|
||||
obj = obj_cls(**args)
|
||||
except TypeError as e:
|
||||
raise TypeError(f"invalid argument in {args} when try to build {obj_cls}\n") from e
|
||||
return obj
|
||||
|
||||
|
||||
class ModuleRegistry(_Registry):
|
||||
|
||||
Reference in New Issue
Block a user