base code for pytorch distributed, add cyclegan
This commit is contained in:
0
engine/__init__.py
Normal file
0
engine/__init__.py
Normal file
274
engine/cyclegan.py
Normal file
274
engine/cyclegan.py
Normal file
@@ -0,0 +1,274 @@
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torchvision.utils
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||
from ignite.metrics import RunningAverage
|
||||
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
|
||||
from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
|
||||
from ignite.utils import convert_tensor
|
||||
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import data
|
||||
from model import MODEL
|
||||
from loss.gan import GANLoss
|
||||
from util.distributed import auto_model
|
||||
from util.image import make_2d_grid
|
||||
from util.handler import Resumer
|
||||
|
||||
|
||||
def _build_model(cfg, distributed_args=None):
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
model_distributed_config = cfg.pop("_distributed", dict())
|
||||
model = MODEL.build_with(cfg)
|
||||
|
||||
if model_distributed_config.get("bn_to_syncbn"):
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
|
||||
distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args
|
||||
return auto_model(model, **distributed_args)
|
||||
|
||||
|
||||
def _build_optimizer(params, cfg):
|
||||
assert "_type" in cfg
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
|
||||
return idist.auto_optim(optimizer)
|
||||
|
||||
|
||||
def get_trainer(config, logger):
|
||||
generator_a = _build_model(config.model.generator, config.distributed.model)
|
||||
generator_b = _build_model(config.model.generator, config.distributed.model)
|
||||
discriminator_a = _build_model(config.model.discriminator, config.distributed.model)
|
||||
discriminator_b = _build_model(config.model.discriminator, config.distributed.model)
|
||||
logger.debug(discriminator_a)
|
||||
logger.debug(generator_a)
|
||||
|
||||
optimizer_g = _build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
|
||||
config.optimizers.generator)
|
||||
optimizer_d = _build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()),
|
||||
config.optimizers.discriminator)
|
||||
|
||||
milestones_values = [
|
||||
(config.data.train.scheduler.start, config.optimizers.generator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr),
|
||||
]
|
||||
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
|
||||
|
||||
milestones_values = [
|
||||
(config.data.train.scheduler.start, config.optimizers.discriminator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr),
|
||||
]
|
||||
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
cycle_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
|
||||
id_loss = nn.L1Loss() if config.loss.cycle == 1 else nn.MSELoss()
|
||||
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
real_a, real_b = batch["a"], batch["b"]
|
||||
|
||||
optimizer_g.zero_grad()
|
||||
fake_b = generator_a(real_a) # G_A(A)
|
||||
rec_a = generator_b(fake_b) # G_B(G_A(A))
|
||||
fake_a = generator_b(real_b) # G_B(B)
|
||||
rec_b = generator_a(fake_a) # G_A(G_B(B))
|
||||
|
||||
loss_g = dict(
|
||||
id_a=config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
|
||||
id_b=config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
|
||||
cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a),
|
||||
cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b),
|
||||
gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True),
|
||||
gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True)
|
||||
)
|
||||
sum(loss_g.values()).backward()
|
||||
optimizer_g.step()
|
||||
|
||||
optimizer_d.zero_grad()
|
||||
loss_d_a = dict(
|
||||
real=gan_loss(discriminator_a(real_b), True, is_discriminator=True),
|
||||
fake=gan_loss(discriminator_a(fake_b.detach()), False, is_discriminator=True),
|
||||
)
|
||||
loss_d_b = dict(
|
||||
real=gan_loss(discriminator_b(real_a), True, is_discriminator=True),
|
||||
fake=gan_loss(discriminator_b(fake_a.detach()), False, is_discriminator=True),
|
||||
)
|
||||
loss_d = sum(loss_d_a.values()) / 2 + sum(loss_d_b.values()) / 2
|
||||
loss_d.backward()
|
||||
optimizer_d.step()
|
||||
|
||||
return {
|
||||
"loss": {
|
||||
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
|
||||
"d_a": {ln: loss_d_a[ln].mean().item() for ln in loss_d_a},
|
||||
"d_b": {ln: loss_d_b[ln].mean().item() for ln in loss_d_b},
|
||||
},
|
||||
"img": [
|
||||
real_a.detach(),
|
||||
fake_b.detach(),
|
||||
rec_a.detach(),
|
||||
real_b.detach(),
|
||||
fake_a.detach(),
|
||||
rec_b.detach()
|
||||
]
|
||||
}
|
||||
|
||||
trainer = Engine(_step)
|
||||
trainer.logger = logger
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b")
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=10))
|
||||
def print_log(engine):
|
||||
engine.logger.info(f"iter:[{engine.state.iteration}/{config.max_iteration}]"
|
||||
f"loss_g={engine.state.metrics['loss_g']:.3f} "
|
||||
f"loss_d_a={engine.state.metrics['loss_d_a']:.3f} "
|
||||
f"loss_d_b={engine.state.metrics['loss_d_b']:.3f} ")
|
||||
|
||||
to_save = dict(
|
||||
generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a,
|
||||
discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer,
|
||||
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
|
||||
)
|
||||
|
||||
trainer.add_event_handler(Events.STARTED, Resumer(to_save, config.resume_from))
|
||||
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir), n_saved=None)
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=config.checkpoints.interval), checkpoint_handler)
|
||||
|
||||
if idist.get_rank() == 0:
|
||||
# Create a logger
|
||||
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
||||
tb_writer = tb_logger.writer
|
||||
|
||||
# Attach the logger to the trainer to log training loss at each iteration
|
||||
def global_step_transform(*args, **kwargs):
|
||||
return trainer.state.iteration
|
||||
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OutputHandler(
|
||||
tag="loss",
|
||||
metric_names=["loss_g", "loss_d_a", "loss_d_b"],
|
||||
global_step_transform=global_step_transform,
|
||||
),
|
||||
event_name=Events.ITERATION_COMPLETED(every=50)
|
||||
)
|
||||
# Attach the logger to the trainer to log optimizer's parameters, e.g. learning rate at each iteration
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
|
||||
event_name=Events.ITERATION_STARTED(every=50)
|
||||
)
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
|
||||
def show_images(engine):
|
||||
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration)
|
||||
|
||||
# Create an object of the profiler and attach an engine to it
|
||||
profiler = BasicTimeProfiler()
|
||||
profiler.attach(trainer)
|
||||
|
||||
@trainer.on(Events.EPOCH_COMPLETED(once=1))
|
||||
@idist.one_rank_only()
|
||||
def log_intermediate_results():
|
||||
profiler.print_results(profiler.get_results())
|
||||
|
||||
@trainer.on(Events.COMPLETED)
|
||||
@idist.one_rank_only()
|
||||
def _():
|
||||
profiler.write_results(f"{config.output_dir}/time_profiling.csv")
|
||||
# We need to close the logger with we are done
|
||||
tb_logger.close()
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
def get_tester(config, logger):
|
||||
generator_a = _build_model(config.model.generator, config.distributed.model)
|
||||
generator_b = _build_model(config.model.generator, config.distributed.model)
|
||||
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
real_a, real_b = batch["a"], batch["b"]
|
||||
with torch.no_grad():
|
||||
fake_b = generator_a(real_a) # G_A(A)
|
||||
rec_a = generator_b(fake_b) # G_B(G_A(A))
|
||||
fake_a = generator_b(real_b) # G_B(B)
|
||||
rec_b = generator_a(fake_a) # G_A(G_B(B))
|
||||
return [
|
||||
real_a.detach(),
|
||||
fake_b.detach(),
|
||||
rec_a.detach(),
|
||||
real_b.detach(),
|
||||
fake_a.detach(),
|
||||
rec_b.detach()
|
||||
]
|
||||
|
||||
tester = Engine(_step)
|
||||
tester.logger = logger
|
||||
if idist.get_rank == 0:
|
||||
ProgressBar(ncols=0).attach(tester)
|
||||
to_load = dict(generator_a=generator_a, generator_b=generator_b)
|
||||
tester.add_event_handler(Events.STARTED, Resumer(to_load, config.resume_from))
|
||||
|
||||
@tester.on(Events.STARTED)
|
||||
@idist.one_rank_only()
|
||||
def mkdir(engine):
|
||||
img_output_dir = Path(config.output_dir) / "test_images"
|
||||
if not img_output_dir.exists():
|
||||
engine.logger.info(f"mkdir {img_output_dir}")
|
||||
img_output_dir.mkdir()
|
||||
|
||||
@tester.on(Events.ITERATION_COMPLETED)
|
||||
def save_images(engine):
|
||||
img_tensors = engine.state.output
|
||||
batch_size = img_tensors[0].size(0)
|
||||
for i in range(batch_size):
|
||||
torchvision.utils.save_image([img[i] for img in img_tensors],
|
||||
Path(config.output_dir) / f"test_images/{engine.state.iteration}_{i}.jpg",
|
||||
nrow=len(img_tensors))
|
||||
|
||||
return tester
|
||||
|
||||
|
||||
def run(task, config, logger):
|
||||
logger.info(f"start task {task}")
|
||||
if task == "train":
|
||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||
trainer = get_trainer(config, logger)
|
||||
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=1)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
elif task == "test":
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
logger.info(f"test with dataset:\n{test_dataset}")
|
||||
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
||||
tester = get_tester(config, logger)
|
||||
try:
|
||||
tester.run(test_data_loader, max_epochs=1)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
else:
|
||||
return NotImplemented(f"invalid task: {task}")
|
||||
Reference in New Issue
Block a user