add U-GAT-IT

This commit is contained in:
2020-08-21 16:14:30 +08:00
parent 323bf2f6ab
commit 1a1cb9b00f
18 changed files with 815 additions and 55 deletions

20
main.py
View File

@@ -5,7 +5,8 @@ import torch
import ignite
import ignite.distributed as idist
from ignite.utils import manual_seed, setup_logger
from ignite.utils import manual_seed
from util.misc import setup_logger
import fire
from omegaconf import OmegaConf
@@ -21,14 +22,12 @@ def log_basic_info(logger, config):
def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False):
logger = setup_logger(name=config.name, distributed_rank=local_rank, **config.log.logger)
log_basic_info(logger, config)
if setup_random_seed:
manual_seed(config.misc.random_seed + idist.get_rank())
if setup_output_dir:
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
config.output_dir = str(output_dir)
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
config.output_dir = str(output_dir)
if setup_output_dir and config.resume_from is None:
if output_dir.exists():
# assert not any(output_dir.iterdir()), "output_dir must be empty"
contains = list(output_dir.iterdir())
@@ -37,11 +36,14 @@ def running(local_rank, config, task, backup_config=False, setup_output_dir=Fals
else:
if idist.get_rank() == 0:
output_dir.mkdir(parents=True)
logger.info(f"mkdir -p {output_dir}")
logger.info(f"output path: {config.output_dir}")
print(f"mkdir -p {output_dir}")
if backup_config and idist.get_rank() == 0:
with open(output_dir / "config.yml", "w+") as f:
print(config.pretty(), file=f)
logger = setup_logger(name=config.name, distributed_rank=local_rank, filepath=output_dir / "train.log")
logger.info(f"output path: {config.output_dir}")
log_basic_info(logger, config)
OmegaConf.set_readonly(config, True)