UGATIT version 0.1
This commit is contained in:
138
engine/UGATIT.py
138
engine/UGATIT.py
@@ -4,7 +4,6 @@ from math import ceil
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
@@ -20,11 +19,28 @@ from loss.gan import GANLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
from model.GAN.residual_generator import GANImageBuffer
|
||||
from model.GAN.UGATIT import RhoClipper
|
||||
from util.image import make_2d_grid, fuse_attention_map
|
||||
from util.image import make_2d_grid, fuse_attention_map, attention_colored_map
|
||||
from util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||
from util.build import build_model, build_optimizer
|
||||
|
||||
|
||||
def build_lr_schedulers(optimizers, config):
|
||||
g_milestones_values = [
|
||||
(0, config.optimizers.generator.lr),
|
||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
d_milestones_values = [
|
||||
(0, config.optimizers.discriminator.lr),
|
||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
return dict(
|
||||
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
|
||||
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
|
||||
)
|
||||
|
||||
|
||||
def get_trainer(config, logger):
|
||||
generators = dict(
|
||||
a2b=build_model(config.model.generator, config.distributed.model),
|
||||
@@ -42,23 +58,14 @@ def get_trainer(config, logger):
|
||||
logger.debug(discriminators["ga"])
|
||||
logger.debug(generators["a2b"])
|
||||
|
||||
optimizer_g = build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator)
|
||||
optimizer_d = build_optimizer(chain(*[m.parameters() for m in discriminators.values()]),
|
||||
config.optimizers.discriminator)
|
||||
optimizers = dict(
|
||||
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
|
||||
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
|
||||
)
|
||||
logger.info(f"build optimizers:\n{optimizers}")
|
||||
|
||||
milestones_values = [
|
||||
(0, config.optimizers.generator.lr),
|
||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
|
||||
|
||||
milestones_values = [
|
||||
(0, config.optimizers.discriminator.lr),
|
||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
|
||||
lr_schedulers = build_lr_schedulers(optimizers, config)
|
||||
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
@@ -116,26 +123,26 @@ def get_trainer(config, logger):
|
||||
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"])
|
||||
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"])
|
||||
|
||||
optimizer_g.zero_grad()
|
||||
optimizers["g"].zero_grad()
|
||||
loss_g = dict()
|
||||
for n in ["a", "b"]:
|
||||
loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
|
||||
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
|
||||
sum(loss_g.values()).backward()
|
||||
optimizer_g.step()
|
||||
optimizers["g"].step()
|
||||
for generator in generators.values():
|
||||
generator.apply(rho_clipper)
|
||||
for discriminator in discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
optimizer_d.zero_grad()
|
||||
optimizers["d"].zero_grad()
|
||||
loss_d = dict()
|
||||
for k in discriminators.keys():
|
||||
n = k[-1] # "a" or "b"
|
||||
loss_d.update(
|
||||
criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach())))
|
||||
sum(loss_d.values()).backward()
|
||||
optimizer_d.step()
|
||||
optimizers["d"].step()
|
||||
|
||||
for h in heatmap:
|
||||
heatmap[h] = heatmap[h].detach()
|
||||
@@ -157,19 +164,19 @@ def get_trainer(config, logger):
|
||||
|
||||
trainer = Engine(_step)
|
||||
trainer.logger = logger
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_g)
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler_d)
|
||||
for lr_shd in lr_schedulers.values():
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
|
||||
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
|
||||
|
||||
to_save = dict(optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer, lr_scheduler_d=lr_scheduler_d,
|
||||
lr_scheduler_g=lr_scheduler_g)
|
||||
to_save = dict(trainer=trainer)
|
||||
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
|
||||
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
|
||||
to_save.update({f"generator_{k}": generators[k] for k in generators})
|
||||
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
||||
|
||||
setup_common_handlers(trainer, config, to_save=to_save, metrics_to_print=["loss_g", "loss_d"],
|
||||
clear_cuda_cache=False, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True,
|
||||
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
|
||||
def output_transform(output):
|
||||
loss = dict()
|
||||
@@ -185,46 +192,36 @@ def get_trainer(config, logger):
|
||||
if tensorboard_handler is not None:
|
||||
tensorboard_handler.attach(
|
||||
trainer,
|
||||
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
|
||||
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
|
||||
event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
|
||||
)
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
|
||||
def show_images(engine):
|
||||
output = engine.state.output
|
||||
image_a_order = ["real_a", "fake_b", "rec_a", "id_a"]
|
||||
image_b_order = ["real_b", "fake_a", "rec_b", "id_b"]
|
||||
|
||||
image_order = dict(
|
||||
a=["real_a", "fake_b", "rec_a", "id_a"],
|
||||
b=["real_b", "fake_a", "rec_b", "id_b"]
|
||||
)
|
||||
output["img"]["generated"]["real_a"] = fuse_attention_map(
|
||||
output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"])
|
||||
output["img"]["generated"]["real_b"] = fuse_attention_map(
|
||||
output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"])
|
||||
|
||||
tensorboard_handler.writer.add_image(
|
||||
"train/a",
|
||||
make_2d_grid([output["img"]["generated"][o] for o in image_a_order]),
|
||||
engine.state.iteration
|
||||
)
|
||||
tensorboard_handler.writer.add_image(
|
||||
"train/b",
|
||||
make_2d_grid([output["img"]["generated"][o] for o in image_b_order]),
|
||||
engine.state.iteration
|
||||
)
|
||||
for k in "ab":
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"train/{k}",
|
||||
make_2d_grid([output["img"]["generated"][o] for o in image_order[k]]),
|
||||
engine.state.iteration
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(config.misc.random_seed)
|
||||
indices = torch.randperm(len(engine.state.test_dataset), generator=g).tolist()[:10]
|
||||
|
||||
empty_grid = torch.zeros(0, config.model.generator.in_channels, config.model.generator.img_size,
|
||||
config.model.generator.img_size)
|
||||
fake = dict(a=empty_grid.clone(), b=empty_grid.clone())
|
||||
rec = dict(a=empty_grid.clone(), b=empty_grid.clone())
|
||||
heatmap = dict(a2b=torch.zeros(0, 1, config.model.generator.img_size,
|
||||
config.model.generator.img_size),
|
||||
b2a=torch.zeros(0, 1, config.model.generator.img_size,
|
||||
config.model.generator.img_size))
|
||||
real = dict(a=empty_grid.clone(), b=empty_grid.clone())
|
||||
test_images = dict(
|
||||
a=[[], [], [], []],
|
||||
b=[[], [], [], []]
|
||||
)
|
||||
for i in indices:
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
|
||||
@@ -234,27 +231,18 @@ def get_trainer(config, logger):
|
||||
rec_a = generators["b2a"](fake_b)[0]
|
||||
rec_b = generators["a2b"](fake_a)[0]
|
||||
|
||||
fake["a"] = torch.cat([fake["a"], fake_a.cpu()])
|
||||
fake["b"] = torch.cat([fake["b"], fake_b.cpu()])
|
||||
real["a"] = torch.cat([real["a"], real_a.cpu()])
|
||||
real["b"] = torch.cat([real["b"], real_b.cpu()])
|
||||
rec["a"] = torch.cat([rec["a"], rec_a.cpu()])
|
||||
rec["b"] = torch.cat([rec["b"], rec_b.cpu()])
|
||||
|
||||
heatmap["a2b"] = torch.cat(
|
||||
[heatmap["a2b"], torch.nn.functional.interpolate(heatmap_a2b, real_a.size()[-2:]).cpu()])
|
||||
heatmap["b2a"] = torch.cat(
|
||||
[heatmap["b2a"], torch.nn.functional.interpolate(heatmap_b2a, real_a.size()[-2:]).cpu()])
|
||||
tensorboard_handler.writer.add_image(
|
||||
"test/a",
|
||||
make_2d_grid([heatmap["a2b"].expand_as(real["a"]), real["a"], fake["b"], rec["a"]]),
|
||||
engine.state.iteration
|
||||
)
|
||||
tensorboard_handler.writer.add_image(
|
||||
"test/b",
|
||||
make_2d_grid([heatmap["b2a"].expand_as(real["a"]), real["b"], fake["a"], rec["b"]]),
|
||||
engine.state.iteration
|
||||
)
|
||||
for idx, im in enumerate(
|
||||
[attention_colored_map(heatmap_a2b, real_a.size()[-2:]), real_a, fake_b, rec_a]):
|
||||
test_images["a"][idx].append(im.cpu())
|
||||
for idx, im in enumerate(
|
||||
[attention_colored_map(heatmap_b2a, real_b.size()[-2:]), real_b, fake_a, rec_b]):
|
||||
test_images["b"][idx].append(im.cpu())
|
||||
for n in "ab":
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"test/{n}",
|
||||
make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
|
||||
engine.state.iteration
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
Reference in New Issue
Block a user