move common handler setup to util

This commit is contained in:
2020-08-08 07:11:47 +08:00
parent 888a052f05
commit 8abd35467c
4 changed files with 183 additions and 85 deletions

View File

@@ -3,59 +3,41 @@ from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from ignite.metrics import RunningAverage
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
from ignite.contrib.handlers import ProgressBar
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
from omegaconf import OmegaConf
import data
from model import MODEL
from loss.gan import GANLoss
from util.distributed import auto_model
from model.weight_init import generation_init_weights
from model.residual_generator import GANImageBuffer
from util.image import make_2d_grid
from util.handler import Resumer
def _build_model(cfg, distributed_args=None):
cfg = OmegaConf.to_container(cfg)
model_distributed_config = cfg.pop("_distributed", dict())
model = MODEL.build_with(cfg)
if model_distributed_config.get("bn_to_syncbn"):
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args
return auto_model(model, **distributed_args)
def _build_optimizer(params, cfg):
assert "_type" in cfg
cfg = OmegaConf.to_container(cfg)
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
return idist.auto_optim(optimizer)
from util.handler import setup_common_handlers
from util.build import build_model, build_optimizer
def get_trainer(config, logger):
generator_a = _build_model(config.model.generator, config.distributed.model)
generator_b = _build_model(config.model.generator, config.distributed.model)
discriminator_a = _build_model(config.model.discriminator, config.distributed.model)
discriminator_b = _build_model(config.model.discriminator, config.distributed.model)
generator_a = build_model(config.model.generator, config.distributed.model)
generator_b = build_model(config.model.generator, config.distributed.model)
discriminator_a = build_model(config.model.discriminator, config.distributed.model)
discriminator_b = build_model(config.model.discriminator, config.distributed.model)
for m in [generator_b, generator_a, discriminator_b, discriminator_a]:
generation_init_weights(m)
logger.debug(discriminator_a)
logger.debug(generator_a)
optimizer_g = _build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
config.optimizers.generator)
optimizer_d = _build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()),
config.optimizers.discriminator)
optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
config.optimizers.generator)
optimizer_d = build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()),
config.optimizers.discriminator)
milestones_values = [
(config.data.train.scheduler.start, config.optimizers.generator.lr),
@@ -75,16 +57,21 @@ def get_trainer(config, logger):
cycle_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
id_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real_a, real_b = batch["a"], batch["b"]
optimizer_g.zero_grad()
fake_b = generator_a(real_a) # G_A(A)
rec_a = generator_b(fake_b) # G_B(G_A(A))
fake_a = generator_b(real_b) # G_B(B)
rec_b = generator_a(fake_a) # G_A(G_B(B))
optimizer_g.zero_grad()
discriminator_a.requires_grad_(False)
discriminator_b.requires_grad_(False)
loss_g = dict(
id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
@@ -96,17 +83,19 @@ def get_trainer(config, logger):
sum(loss_g.values()).backward()
optimizer_g.step()
discriminator_a.requires_grad_(True)
discriminator_b.requires_grad_(True)
optimizer_d.zero_grad()
loss_d_a = dict(
real=gan_loss(discriminator_a(real_b), True, is_discriminator=True),
fake=gan_loss(discriminator_a(fake_b.detach()), False, is_discriminator=True),
fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True),
)
(sum(loss_d_a.values()) * 0.5).backward()
loss_d_b = dict(
real=gan_loss(discriminator_b(real_a), True, is_discriminator=True),
fake=gan_loss(discriminator_b(fake_a.detach()), False, is_discriminator=True),
fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True),
)
loss_d = sum(loss_d_a.values()) / 2 + sum(loss_d_b.values()) / 2
loss_d.backward()
(sum(loss_d_b.values()) * 0.5).backward()
optimizer_d.step()
return {
@@ -129,27 +118,25 @@ def get_trainer(config, logger):
trainer.logger = logger
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b")
@trainer.on(Events.ITERATION_COMPLETED(every=10))
def print_log(engine):
engine.logger.info(f"iter:[{engine.state.iteration}/{config.max_iteration}]"
f"loss_g={engine.state.metrics['loss_g']:.3f} "
f"loss_d_a={engine.state.metrics['loss_d_a']:.3f} "
f"loss_d_b={engine.state.metrics['loss_d_b']:.3f} ")
to_save = dict(
generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a,
discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer,
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
)
trainer.add_event_handler(Events.STARTED, Resumer(to_save, config.resume_from))
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir), n_saved=None)
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.checkpoints.interval), checkpoint_handler)
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.ITERATION_COMPLETED(every=10),
metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"], to_save=to_save,
resume_from=config.resume_from, n_saved=5, filename_prefix=config.name,
save_interval_event=Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
def terminate(engine):
engine.terminate()
if idist.get_rank() == 0:
# Create a logger
@@ -169,7 +156,6 @@ def get_trainer(config, logger):
),
event_name=Events.ITERATION_COMPLETED(every=50)
)
# Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration
tb_logger.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
@@ -180,28 +166,18 @@ def get_trainer(config, logger):
def show_images(engine):
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration)
# Create an object of the profiler and attach an engine to it
profiler = BasicTimeProfiler()
profiler.attach(trainer)
@trainer.on(Events.EPOCH_COMPLETED(once=1))
@idist.one_rank_only()
def log_intermediate_results():
profiler.print_results(profiler.get_results())
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
profiler.write_results(f"{config.output_dir}/time_profiling.csv")
# We need to close the logger with we are done
tb_logger.close()
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
# We need to close the logger with we are done
tb_logger.close()
return trainer
def get_tester(config, logger):
generator_a = _build_model(config.model.generator, config.distributed.model)
generator_b = _build_model(config.model.generator, config.distributed.model)
generator_a = build_model(config.model.generator, config.distributed.model)
generator_b = build_model(config.model.generator, config.distributed.model)
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
@@ -225,7 +201,7 @@ def get_tester(config, logger):
if idist.get_rank == 0:
ProgressBar(ncols=0).attach(tester)
to_load = dict(generator_a=generator_a, generator_b=generator_b)
tester.add_event_handler(Events.STARTED, Resumer(to_load, config.resume_from))
setup_common_handlers(tester, use_profiler=False, to_save=to_load, resume_from=config.resume_from)
@tester.on(Events.STARTED)
@idist.one_rank_only()
@@ -248,15 +224,16 @@ def get_tester(config, logger):
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger)
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
try:
trainer.run(train_data_loader, max_epochs=1)
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
except Exception:
import traceback
print(traceback.format_exc())