move the same content to hander.py
This commit is contained in:
140
engine/UGATIT.py
140
engine/UGATIT.py
@@ -1,20 +1,18 @@
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from math import ceil
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.utils
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||
from ignite.metrics import RunningAverage
|
||||
from ignite.contrib.handlers import ProgressBar
|
||||
from ignite.utils import convert_tensor
|
||||
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
|
||||
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
|
||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf import OmegaConf, read_write
|
||||
|
||||
import data
|
||||
from loss.gan import GANLoss
|
||||
@@ -22,7 +20,7 @@ from model.weight_init import generation_init_weights
|
||||
from model.GAN.residual_generator import GANImageBuffer
|
||||
from model.GAN.UGATIT import RhoClipper
|
||||
from util.image import make_2d_grid
|
||||
from util.handler import setup_common_handlers
|
||||
from util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||
from util.build import build_model, build_optimizer
|
||||
|
||||
|
||||
@@ -49,14 +47,14 @@ def get_trainer(config, logger):
|
||||
|
||||
milestones_values = [
|
||||
(0, config.optimizers.generator.lr),
|
||||
(config.data.train.scheduler.start, config.optimizers.generator.lr),
|
||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
|
||||
|
||||
milestones_values = [
|
||||
(0, config.optimizers.discriminator.lr),
|
||||
(config.data.train.scheduler.start, config.optimizers.discriminator.lr),
|
||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
|
||||
@@ -66,18 +64,18 @@ def get_trainer(config, logger):
|
||||
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
bce_loss = nn.BCEWithLogitsLoss().to(idist.device())
|
||||
mse_loss = lambda x, t: F.mse_loss(x, x.new_ones(x.size()) if t else x.new_zeros(x.size()))
|
||||
bce_loss = lambda x, t: F.binary_cross_entropy_with_logits(x, x.new_ones(x.size()) if t else x.new_zeros(x.size()))
|
||||
|
||||
image_buffers = {
|
||||
k: GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50) for k in
|
||||
discriminators.keys()}
|
||||
def mse_loss(x, target_flag):
|
||||
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||
|
||||
def bce_loss(x, target_flag):
|
||||
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||
|
||||
image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()}
|
||||
rho_clipper = RhoClipper(0, 1)
|
||||
|
||||
def cal_generator_loss(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l,
|
||||
discriminator_g):
|
||||
def criterion_generator(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l,
|
||||
discriminator_g):
|
||||
discriminator_g.requires_grad_(False)
|
||||
discriminator_l.requires_grad_(False)
|
||||
pred_fake_g, cam_gd_pred = discriminator_g(fake)
|
||||
@@ -92,7 +90,7 @@ def get_trainer(config, logger):
|
||||
f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True),
|
||||
}
|
||||
|
||||
def cal_discriminator_loss(name, discriminator, real, fake):
|
||||
def criterion_discriminator(name, discriminator, real, fake):
|
||||
pred_real, cam_real = discriminator(real)
|
||||
pred_fake, cam_fake = discriminator(fake)
|
||||
# TODO: origin do not divide 2, but I think it better to divide 2.
|
||||
@@ -100,9 +98,8 @@ def get_trainer(config, logger):
|
||||
loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False)
|
||||
return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam}
|
||||
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
real_a, real_b = batch["a"], batch["b"]
|
||||
def _step(engine, real):
|
||||
real = convert_tensor(real, idist.device())
|
||||
|
||||
fake = dict()
|
||||
cam_generator_pred = dict()
|
||||
@@ -111,18 +108,18 @@ def get_trainer(config, logger):
|
||||
cam_identity_pred = dict()
|
||||
heatmap = dict()
|
||||
|
||||
fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real_a)
|
||||
fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real_b)
|
||||
fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real["a"])
|
||||
fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real["b"])
|
||||
rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"])
|
||||
rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"])
|
||||
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real_a)
|
||||
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real_b)
|
||||
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"])
|
||||
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"])
|
||||
|
||||
optimizer_g.zero_grad()
|
||||
loss_g = dict()
|
||||
for n in ["a", "b"]:
|
||||
loss_g.update(cal_generator_loss(n, batch[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
|
||||
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
|
||||
loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
|
||||
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
|
||||
sum(loss_g.values()).backward()
|
||||
optimizer_g.step()
|
||||
for generator in generators.values():
|
||||
@@ -135,13 +132,14 @@ def get_trainer(config, logger):
|
||||
for k in discriminators.keys():
|
||||
n = k[-1] # "a" or "b"
|
||||
loss_d.update(
|
||||
cal_discriminator_loss(k, discriminators[k], batch[n], image_buffers[k].query(fake[n].detach())))
|
||||
criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach())))
|
||||
sum(loss_d.values()).backward()
|
||||
optimizer_d.step()
|
||||
|
||||
for h in heatmap:
|
||||
heatmap[h] = heatmap[h].detach()
|
||||
generated_img = {f"fake_{k}": fake[k].detach() for k in fake}
|
||||
generated_img = {f"real_{k}": real[k].detach() for k in real}
|
||||
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
|
||||
generated_img.update({f"id_{k}": identity[k].detach() for k in identity})
|
||||
generated_img.update({f"rec_{k}": rec[k].detach() for k in rec})
|
||||
|
||||
@@ -169,64 +167,41 @@ def get_trainer(config, logger):
|
||||
to_save.update({f"generator_{k}": generators[k] for k in generators})
|
||||
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
||||
|
||||
setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5,
|
||||
filename_prefix=config.name, to_save=to_save,
|
||||
print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED,
|
||||
metrics_to_print=["loss_g", "loss_d"],
|
||||
save_interval_event=Events.ITERATION_COMPLETED(
|
||||
every=config.checkpoints.interval) | Events.COMPLETED)
|
||||
setup_common_handlers(trainer, config, to_save=to_save, metrics_to_print=["loss_g", "loss_d"],
|
||||
clear_cuda_cache=True, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
def terminate(engine):
|
||||
engine.terminate()
|
||||
def output_transform(output):
|
||||
loss = dict()
|
||||
for tl in output["loss"]:
|
||||
if isinstance(output["loss"][tl], dict):
|
||||
for l in output["loss"][tl]:
|
||||
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
||||
else:
|
||||
loss[tl] = output["loss"][tl]
|
||||
return loss
|
||||
|
||||
if idist.get_rank() == 0:
|
||||
# Create a logger
|
||||
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
||||
tb_writer = tb_logger.writer
|
||||
|
||||
# Attach the logger to the trainer to log training loss at each iteration
|
||||
def global_step_transform(*args, **kwargs):
|
||||
return trainer.state.iteration
|
||||
|
||||
def output_transform(output):
|
||||
loss = dict()
|
||||
for tl in output["loss"]:
|
||||
if isinstance(output["loss"][tl], dict):
|
||||
for l in output["loss"][tl]:
|
||||
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
||||
else:
|
||||
loss[tl] = output["loss"][tl]
|
||||
return loss
|
||||
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OutputHandler(
|
||||
tag="loss",
|
||||
metric_names=["loss_g", "loss_d"],
|
||||
global_step_transform=global_step_transform,
|
||||
output_transform=output_transform
|
||||
),
|
||||
event_name=Events.ITERATION_COMPLETED(every=50)
|
||||
)
|
||||
tb_logger.attach(
|
||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform)
|
||||
if tensorboard_handler is not None:
|
||||
tensorboard_handler.attach(
|
||||
trainer,
|
||||
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
|
||||
event_name=Events.ITERATION_STARTED(every=50)
|
||||
event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
|
||||
)
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
|
||||
def show_images(engine):
|
||||
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]["generated"].values()),
|
||||
engine.state.iteration)
|
||||
tb_writer.add_image("train/heatmap", make_2d_grid(engine.state.output["img"]["heatmap"].values()),
|
||||
engine.state.iteration)
|
||||
|
||||
@trainer.on(Events.COMPLETED)
|
||||
@idist.one_rank_only()
|
||||
def _():
|
||||
# We need to close the logger with we are done
|
||||
tb_logger.close()
|
||||
image_a_order = ["real_a", "fake_b", "rec_a", "id_a"]
|
||||
image_b_order = ["real_b", "fake_a", "rec_b", "id_b"]
|
||||
tensorboard_handler.writer.add_image(
|
||||
"train/a",
|
||||
make_2d_grid([engine.state.output["img"]["generated"][o] for o in image_a_order]),
|
||||
engine.state.iteration
|
||||
)
|
||||
tensorboard_handler.writer.add_image(
|
||||
"train/b",
|
||||
make_2d_grid([engine.state.output["img"]["generated"][o] for o in image_b_order]),
|
||||
engine.state.iteration
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
@@ -235,13 +210,16 @@ def run(task, config, logger):
|
||||
assert torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.benchmark = True
|
||||
logger.info(f"start task {task}")
|
||||
with read_write(config):
|
||||
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
|
||||
|
||||
if task == "train":
|
||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||
trainer = get_trainer(config, logger)
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
|
||||
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
Reference in New Issue
Block a user