add U-GAT-IT

This commit is contained in:
2020-08-21 16:14:30 +08:00
parent 323bf2f6ab
commit 1a1cb9b00f
18 changed files with 815 additions and 55 deletions

View File

@@ -39,6 +39,7 @@ def setup_common_handlers(
:param checkpoint_kwargs:
:return:
"""
@trainer.on(Events.STARTED)
@idist.one_rank_only()
def print_dataloader_size(engine):
@@ -79,6 +80,8 @@ def setup_common_handlers(
engine.logger.info(print_str)
if to_save is not None:
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir, require_empty=False),
**checkpoint_kwargs)
if resume_from is not None:
@trainer.on(Events.STARTED)
def resume(engine):
@@ -89,5 +92,4 @@ def setup_common_handlers(
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
if save_interval_event is not None:
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=output_dir), **checkpoint_kwargs)
trainer.add_event_handler(save_interval_event, checkpoint_handler)

85
util/misc.py Normal file
View File

@@ -0,0 +1,85 @@
import logging
from typing import Optional
def setup_logger(
name: Optional[str] = None,
level: int = logging.INFO,
logger_format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s",
filepath: Optional[str] = None,
file_level: int = logging.DEBUG,
distributed_rank: Optional[int] = None,
) -> logging.Logger:
"""Setups logger: name, level, format etc.
Args:
name (str, optional): new name for the logger. If None, the standard logger is used.
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG
logger_format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`
filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
file_level (int): Optional logging level for logging file.
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
If None, distributed_rank is initialized to the rank of process.
Returns:
logging.Logger
For example, to improve logs readability when training with a trainer and evaluator:
.. code-block:: python
from ignite.utils import setup_logger
trainer = ...
evaluator = ...
trainer.logger = setup_logger("trainer")
evaluator.logger = setup_logger("evaluator")
trainer.run(data, max_epochs=10)
# Logs will look like
# 2020-01-21 12:46:07,356 trainer INFO: Engine run starting with max_epochs=5.
# 2020-01-21 12:46:07,358 trainer INFO: Epoch[1] Complete. Time taken: 00:5:23
# 2020-01-21 12:46:07,358 evaluator INFO: Engine run starting with max_epochs=1.
# 2020-01-21 12:46:07,358 evaluator INFO: Epoch[1] Complete. Time taken: 00:01:02
# ...
"""
logger = logging.getLogger(name)
# don't propagate to ancestors
# the problem here is to attach handlers to loggers
# should we provide a default configuration less open ?
if name is not None:
logger.propagate = False
# Remove previous handlers
if logger.hasHandlers():
for h in list(logger.handlers):
logger.removeHandler(h)
formatter = logging.Formatter(logger_format)
if distributed_rank is None:
import ignite.distributed as idist
distributed_rank = idist.get_rank()
if distributed_rank > 0:
logger.addHandler(logging.NullHandler())
else:
logger.setLevel(level)
ch = logging.StreamHandler()
ch.setLevel(level)
ch.setFormatter(formatter)
logger.addHandler(ch)
if filepath is not None:
fh = logging.FileHandler(filepath)
fh.setLevel(file_level)
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger

View File

@@ -67,7 +67,11 @@ class _Registry:
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_cls(**args)
try:
obj = obj_cls(**args)
except TypeError as e:
raise TypeError(f"invalid argument in {args} when try to build {obj_cls}\n") from e
return obj
class ModuleRegistry(_Registry):