This commit is contained in:
2020-09-05 10:33:35 +08:00
parent 2469bf15fe
commit 39c754374c
21 changed files with 550 additions and 705 deletions

View File

@@ -1,24 +1,22 @@
from itertools import chain
from math import ceil
from omegaconf import read_write, OmegaConf
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import torch.nn.functional as F
import ignite.distributed as idist
import data
from engine.base.i2i import get_trainer, EngineKernel, build_model
from model.weight_init import generation_init_weights
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
class TAFGEngineKernel(EngineKernel):
def __init__(self, config, logger):
super().__init__(config, logger)
def __init__(self, config):
super().__init__(config)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
@@ -29,6 +27,11 @@ class TAFGEngineKernel(EngineKernel):
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
def _process_batch(self, batch, inference=False):
# batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size())
return batch
def build_models(self) -> (dict, dict):
generators = dict(
@@ -56,6 +59,7 @@ class TAFGEngineKernel(EngineKernel):
def forward(self, batch, inference=False) -> dict:
generator = self.generators["main"]
batch = self._process_batch(batch, inference)
with torch.set_grad_enabled(not inference):
fake = dict(
a=generator(content_img=batch["edge_a"], style_img=batch["a"], which_decoder="a"),
@@ -64,6 +68,7 @@ class TAFGEngineKernel(EngineKernel):
return fake
def criterion_generators(self, batch, generated) -> dict:
batch = self._process_batch(batch)
loss = dict()
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
@@ -85,10 +90,15 @@ class TAFGEngineKernel(EngineKernel):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
loss["recon"] = self.recon_loss(generated["a"], batch["a"]) * self.config.loss.recon.weight
loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
self.generators["main"].module.style_encoders["b"](batch["b"]),
self.generators["main"].module.style_encoders["b"](generated["b"])
)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
# batch = self._process_batch(batch)
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase])
pred_fake = self.discriminators[phase](generated[phase].detach())
@@ -105,31 +115,14 @@ class TAFGEngineKernel(EngineKernel):
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
batch = self._process_batch(batch)
edge = batch["edge_a"][:, 0:1, :, :]
return dict(
a=[batch[f"edge_a"].expand(-1, 3, -1, -1).detach(), batch["a"].detach(), generated["a"].detach()],
b=[batch["b"].detach(), generated["b"].detach()]
a=[edge.expand(-1, 3, -1, -1).detach(), batch["a"].detach(), batch["b"].detach(), generated["a"].detach(),
generated["b"].detach()]
)
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
with read_write(config):
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, TAFGEngineKernel(config, logger), len(train_data_loader))
if idist.get_rank() == 0:
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset
try:
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")
def run(task, config, _):
kernel = TAFGEngineKernel(config)
run_kernel(task, config, kernel)

153
engine/U-GAT-IT.py Normal file
View File

@@ -0,0 +1,153 @@
from itertools import chain
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import torch.nn.functional as F
import ignite.distributed as idist
from model.weight_init import generation_init_weights
from loss.gan import GANLoss
from model.GAN.UGATIT import RhoClipper
from model.GAN.residual_generator import GANImageBuffer
from util.image import attention_colored_map
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
class UGATITEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
self.rho_clipper = RhoClipper(0, 1)
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
def build_models(self) -> (dict, dict):
generators = dict(
a2b=build_model(self.config.model.generator),
b2a=build_model(self.config.model.generator)
)
discriminators = dict(
la=build_model(self.config.model.local_discriminator),
lb=build_model(self.config.model.local_discriminator),
ga=build_model(self.config.model.global_discriminator),
gb=build_model(self.config.model.global_discriminator),
)
self.logger.debug(discriminators["ga"])
self.logger.debug(generators["a2b"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_before_d(self):
for generator in self.generators.values():
generator.apply(self.rho_clipper)
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
images = dict()
heatmap = dict()
cam_pred = dict()
with torch.set_grad_enabled(not inference):
images["a2b"], cam_pred["a2b"], heatmap["a2b"] = self.generators["a2b"](batch["a"])
images["b2a"], cam_pred["b2a"], heatmap["b2a"] = self.generators["b2a"](batch["b"])
images["a2b2a"], _, heatmap["a2b2a"] = self.generators["b2a"](images["a2b"])
images["b2a2b"], _, heatmap["b2a2b"] = self.generators["a2b"](images["b2a"])
images["a2a"], cam_pred["a2a"], heatmap["a2a"] = self.generators["b2a"](batch["a"])
images["b2b"], cam_pred["b2b"], heatmap["b2b"] = self.generators["a2b"](batch["b"])
return dict(images=images, heatmap=heatmap, cam_pred=cam_pred)
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
for phase in "ab":
cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"]
loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase])
loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase],
generated["images"][f"{phase}2{phase}"])
for dk in "lg":
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)
loss[f"gan_{phase}_{dk}"] = self.config.loss.gan.weight * self.gan_loss(pred_fake, True)
loss[f"gan_cam_{phase}_{dk}"] = self.config.loss.gan.weight * mse_loss(cam_pred, True)
for t, f in [("a2b", "b2b"), ("b2a", "a2a")]:
loss[f"cam_{t[-1]}"] = self.config.loss.cam.weight * (
bce_loss(generated["cam_pred"][t], True) + bce_loss(generated["cam_pred"][f], False))
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
for phase in "ab":
for level in "gl":
generated_image = self.image_buffers[level + phase].query(
generated["images"]["a2b" if phase == "b" else "b2a"])
pred_fake, cam_fake_pred = self.discriminators[level + phase](generated_image)
pred_real, cam_real_pred = self.discriminators[level + phase](batch[phase])
loss[f"gan_{phase}_{level}"] = self.gan_loss(pred_real, True, is_discriminator=True) + self.gan_loss(
pred_fake, False, is_discriminator=True)
loss[f"cam_{phase}_{level}"] = mse_loss(cam_fake_pred, False) + mse_loss(cam_real_pred, True)
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
attention_a = attention_colored_map(generated["heatmap"]["a2b"].detach(), batch["a"].size()[-2:])
attention_b = attention_colored_map(generated["heatmap"]["b2a"].detach(), batch["b"].size()[-2:])
generated = {img: generated["images"][img].detach() for img in generated["images"]}
return {
"a": [batch["a"], attention_a, generated["a2b"], generated["a2a"], generated["a2b2a"]],
"b": [batch["b"], attention_b, generated["b2a"], generated["b2b"], generated["b2a2b"]],
}
class UGATITTestEngineKernel(TestEngineKernel):
def __init__(self, config):
super().__init__(config)
def build_generators(self) -> dict:
generators = dict(
a2b=build_model(self.config.model.generator),
)
return generators
def to_load(self):
return {f"generator_{k}": self.generators[k] for k in self.generators}
def inference(self, batch):
with torch.no_grad():
fake, _, _ = self.generators["a2b"](batch["a"])
return {"a": fake.detach()}
def run(task, config, _):
if task == "train":
kernel = UGATITEngineKernel(config)
if task == "test":
kernel = UGATITTestEngineKernel(config)
run_kernel(task, config, kernel)

View File

@@ -1,320 +0,0 @@
from itertools import chain
from math import ceil
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.metrics import RunningAverage
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from omegaconf import OmegaConf, read_write
import data
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
from model.GAN.residual_generator import GANImageBuffer
from model.GAN.UGATIT import RhoClipper
from util.image import make_2d_grid, fuse_attention_map, attention_colored_map
from util.handler import setup_common_handlers, setup_tensorboard_handler
from util.build import build_model, build_optimizer
def build_lr_schedulers(optimizers, config):
g_milestones_values = [
(0, config.optimizers.generator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
d_milestones_values = [
(0, config.optimizers.discriminator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
return dict(
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
)
def get_trainer(config, logger):
generators = dict(
a2b=build_model(config.model.generator, config.distributed.model),
b2a=build_model(config.model.generator, config.distributed.model),
)
discriminators = dict(
la=build_model(config.model.local_discriminator, config.distributed.model),
lb=build_model(config.model.local_discriminator, config.distributed.model),
ga=build_model(config.model.global_discriminator, config.distributed.model),
gb=build_model(config.model.global_discriminator, config.distributed.model),
)
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
logger.debug(discriminators["ga"])
logger.debug(generators["a2b"])
optimizers = dict(
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
)
logger.info(f"build optimizers:\n{optimizers}")
lr_schedulers = build_lr_schedulers(optimizers, config)
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()}
rho_clipper = RhoClipper(0, 1)
def criterion_generator(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l,
discriminator_g):
discriminator_g.requires_grad_(False)
discriminator_l.requires_grad_(False)
pred_fake_g, cam_gd_pred = discriminator_g(fake)
pred_fake_l, cam_ld_pred = discriminator_l(fake)
return {
f"cycle_{name}": config.loss.cycle.weight * cycle_loss(real, rec),
f"id_{name}": config.loss.id.weight * id_loss(real, identity),
f"cam_{name}": config.loss.cam.weight * (bce_loss(cam_g_pred, True) + bce_loss(cam_g_id_pred, False)),
f"gan_l_{name}": config.loss.gan.weight * gan_loss(pred_fake_l, True),
f"gan_g_{name}": config.loss.gan.weight * gan_loss(pred_fake_g, True),
f"gan_cam_g_{name}": config.loss.gan.weight * mse_loss(cam_gd_pred, True),
f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True),
}
def criterion_discriminator(name, discriminator, real, fake):
pred_real, cam_real = discriminator(real)
pred_fake, cam_fake = discriminator(fake)
# TODO: origin do not divide 2, but I think it better to divide 2.
loss_gan = gan_loss(pred_real, True, is_discriminator=True) + gan_loss(pred_fake, False, is_discriminator=True)
loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False)
return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam}
def _step(engine, real):
real = convert_tensor(real, idist.device())
fake = dict()
cam_generator_pred = dict()
rec = dict()
identity = dict()
cam_identity_pred = dict()
heatmap = dict()
fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real["a"])
fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real["b"])
rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"])
rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"])
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"])
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"])
optimizers["g"].zero_grad()
loss_g = dict()
for n in ["a", "b"]:
loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
sum(loss_g.values()).backward()
optimizers["g"].step()
for generator in generators.values():
generator.apply(rho_clipper)
for discriminator in discriminators.values():
discriminator.requires_grad_(True)
optimizers["d"].zero_grad()
loss_d = dict()
for k in discriminators.keys():
n = k[-1] # "a" or "b"
loss_d.update(
criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach())))
sum(loss_d.values()).backward()
optimizers["d"].step()
for h in heatmap:
heatmap[h] = heatmap[h].detach()
generated_img = {f"real_{k}": real[k].detach() for k in real}
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
generated_img.update({f"id_{k}": identity[k].detach() for k in identity})
generated_img.update({f"rec_{k}": rec[k].detach() for k in rec})
return {
"loss": {
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
"d": {ln: loss_d[ln].mean().item() for ln in loss_d},
},
"img": {
"heatmap": heatmap,
"generated": generated_img
}
}
trainer = Engine(_step)
trainer.logger = logger
for lr_shd in lr_schedulers.values():
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
to_save = dict(trainer=trainer)
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
to_save.update({f"generator_{k}": generators[k] for k in generators})
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform)
if tensorboard_handler is not None:
tensorboard_handler.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
)
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
def show_images(engine):
output = engine.state.output
image_order = dict(
a=["real_a", "fake_b", "rec_a", "id_a"],
b=["real_b", "fake_a", "rec_b", "id_b"]
)
output["img"]["generated"]["real_a"] = fuse_attention_map(
output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"])
output["img"]["generated"]["real_b"] = fuse_attention_map(
output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"])
for k in "ab":
tensorboard_handler.writer.add_image(
f"train/{k}",
make_2d_grid([output["img"]["generated"][o] for o in image_order[k]]),
engine.state.iteration
)
with torch.no_grad():
g = torch.Generator()
g.manual_seed(config.misc.random_seed)
random_start = torch.randperm(len(engine.state.test_dataset)-11, generator=g).tolist()[0]
test_images = dict(
a=[[], [], [], []],
b=[[], [], [], []]
)
for i in range(random_start, random_start+10):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
real_a, real_b = batch["a"].view(1, *batch["a"].size()), batch["b"].view(1, *batch["a"].size())
fake_b, _, heatmap_a2b = generators["a2b"](real_a)
fake_a, _, heatmap_b2a = generators["b2a"](real_b)
rec_a = generators["b2a"](fake_b)[0]
rec_b = generators["a2b"](fake_a)[0]
for idx, im in enumerate(
[attention_colored_map(heatmap_a2b, real_a.size()[-2:]), real_a, fake_b, rec_a]):
test_images["a"][idx].append(im.cpu())
for idx, im in enumerate(
[attention_colored_map(heatmap_b2a, real_b.size()[-2:]), real_b, fake_a, rec_b]):
test_images["b"][idx].append(im.cpu())
for n in "ab":
tensorboard_handler.writer.add_image(
f"test/{n}",
make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
engine.state.iteration
)
return trainer
def get_tester(config, logger):
generator_a2b = build_model(config.model.generator, config.distributed.model)
def _step(engine, batch):
real_a, path = convert_tensor(batch, idist.device())
with torch.no_grad():
fake_b = generator_a2b(real_a)[0]
return {"path": path, "img": [real_a.detach(), fake_b.detach()]}
tester = Engine(_step)
tester.logger = logger
to_load = dict(generator_a2b=generator_a2b)
setup_common_handlers(tester, config, use_profiler=False, to_save=to_load)
@tester.on(Events.STARTED)
def mkdir(engine):
img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images"
engine.state.img_output_dir = Path(img_output_dir)
if idist.get_rank() == 0 and not engine.state.img_output_dir.exists():
engine.logger.info(f"mkdir {engine.state.img_output_dir}")
engine.state.img_output_dir.mkdir()
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
img_tensors = engine.state.output["img"]
paths = engine.state.output["path"]
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
image_name = Path(paths[i]).name
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
nrow=len(img_tensors))
return tester
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
with read_write(config):
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger)
if idist.get_rank() == 0:
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset
try:
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
except Exception:
import traceback
print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, logger)
try:
tester.run(test_data_loader, max_epochs=1)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")

View File

@@ -1,32 +1,23 @@
from itertools import chain
from math import ceil
from pathlib import Path
import logging
from pathlib import Path
import torch
import torchvision
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.metrics import RunningAverage
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from model import MODEL
from omegaconf import read_write, OmegaConf
from util.image import make_2d_grid
from util.handler import setup_common_handlers, setup_tensorboard_handler
from util.build import build_optimizer
from engine.util.handler import setup_common_handlers, setup_tensorboard_handler
from engine.util.build import build_optimizer
from omegaconf import OmegaConf
def build_model(cfg):
cfg = OmegaConf.to_container(cfg)
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
model = MODEL.build_with(cfg)
if bn_to_sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
return idist.auto_model(model)
import data
def build_lr_schedulers(optimizers, config):
@@ -47,10 +38,26 @@ def build_lr_schedulers(optimizers, config):
)
class EngineKernel(object):
def __init__(self, config, logger):
class TestEngineKernel(object):
def __init__(self, config):
self.config = config
self.logger = logger
self.logger = logging.getLogger(config.name)
self.generators = self.build_generators()
def build_generators(self) -> dict:
raise NotImplemented
def to_load(self):
return {f"generator_{k}": self.generators[k] for k in self.generators}
def inference(self, batch):
raise NotImplemented
class EngineKernel(object):
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(config.name)
self.generators, self.discriminators = self.build_models()
def build_models(self) -> (dict, dict):
@@ -87,39 +94,43 @@ class EngineKernel(object):
raise NotImplemented
def get_trainer(config, ek: EngineKernel, iter_per_epoch):
def get_trainer(config, kernel: EngineKernel):
logger = logging.getLogger(config.name)
generators, discriminators = ek.generators, ek.discriminators
generators, discriminators = kernel.generators, kernel.discriminators
optimizers = dict(
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
)
logger.info("build optimizers", optimizers)
logger.info(f"build optimizers:\n{optimizers}")
lr_schedulers = build_lr_schedulers(optimizers, config)
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
image_per_iteration = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1)
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
generated = ek.forward(batch)
generated = kernel.forward(batch)
ek.setup_before_g()
kernel.setup_before_g()
optimizers["g"].zero_grad()
loss_g = ek.criterion_generators(batch, generated)
loss_g = kernel.criterion_generators(batch, generated)
sum(loss_g.values()).backward()
optimizers["g"].step()
ek.setup_before_d()
kernel.setup_before_d()
optimizers["d"].zero_grad()
loss_d = ek.criterion_discriminators(batch, generated)
loss_d = kernel.criterion_discriminators(batch, generated)
sum(loss_d.values()).backward()
optimizers["d"].step()
return {
"loss": dict(g=loss_g, d=loss_d),
"img": ek.intermediate_images(batch, generated)
}
if engine.state.iteration % image_per_iteration == 0:
return {
"loss": dict(g=loss_g, d=loss_d),
"img": kernel.intermediate_images(batch, generated)
}
return {"loss": dict(g=loss_g, d=loss_d)}
trainer = Engine(_step)
trainer.logger = logger
@@ -131,33 +142,22 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
to_save = dict(trainer=trainer)
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
to_save.update(ek.to_save())
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
to_save.update(kernel.to_save())
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=config.handler.clear_cuda_cache,
set_epoch_for_dist_sampler=config.handler.set_epoch_for_dist_sampler,
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
pairs_per_iteration = config.data.train.dataloader.batch_size
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
tensorboard_handler = setup_tensorboard_handler(trainer, config, optimizers, step_type="item")
if tensorboard_handler is not None:
tensorboard_handler.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
)
basic_image_event = Events.ITERATION_COMPLETED(
every=image_per_iteration)
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
@trainer.on(basic_image_event)
def show_images(engine):
output = engine.state.output
test_images = {}
for k in output["img"]:
image_list = output["img"][k]
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list),
@@ -174,8 +174,8 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch:
batch[k] = batch[k].view(1, *batch[k].size())
generated = ek.forward(batch)
images = ek.intermediate_images(batch, generated)
generated = kernel.forward(batch)
images = kernel.intermediate_images(batch, generated)
for k in test_images:
for j in range(len(images[k])):
@@ -187,3 +187,78 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
engine.state.iteration * pairs_per_iteration
)
return trainer
def get_tester(config, kernel: TestEngineKernel):
logger = logging.getLogger(config.name)
def _step(engine, batch):
real_a, path = convert_tensor(batch, idist.device())
fake = kernel.inference({"a": real_a})["a"]
return {"path": path, "img": [real_a.detach(), fake.detach()]}
tester = Engine(_step)
tester.logger = logger
setup_common_handlers(tester, config, use_profiler=True, to_save=kernel.to_load())
@tester.on(Events.STARTED)
def mkdir(engine):
img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images"
engine.state.img_output_dir = Path(img_output_dir)
if idist.get_rank() == 0 and not engine.state.img_output_dir.exists():
engine.logger.info(f"mkdir {engine.state.img_output_dir}")
engine.state.img_output_dir.mkdir()
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
img_tensors = engine.state.output["img"]
paths = engine.state.output["path"]
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
image_name = Path(paths[i]).name
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
nrow=len(img_tensors))
return tester
def run_kernel(task, config, kernel):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger = logging.getLogger(config.name)
with read_write(config):
real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size()
config.max_iteration = config.max_pairs // real_batch_size + 1
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
dataloader_kwargs = OmegaConf.to_container(config.data.train.dataloader)
dataloader_kwargs["batch_size"] = dataloader_kwargs["batch_size"] * idist.get_world_size()
train_data_loader = idist.auto_dataloader(train_dataset, **dataloader_kwargs)
with read_write(config):
config.iterations_per_epoch = len(train_data_loader)
trainer = get_trainer(config, kernel)
if idist.get_rank() == 0:
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset
try:
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
except Exception:
import traceback
print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, kernel)
try:
tester.run(test_data_loader, max_epochs=1)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")

View File

@@ -1,85 +0,0 @@
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
import ignite.distributed as idist
from ignite.contrib.metrics.gpu_info import GpuInfo
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, global_step_from_engine, OutputHandler, \
WeightsScalarHandler, GradsHistHandler, WeightsHistHandler, GradsScalarHandler
from ignite.engine import create_supervised_evaluator, create_supervised_trainer, Events
from ignite.metrics import Accuracy, Loss, RunningAverage
from ignite.contrib.engines.common import save_best_model_by_val_score
from ignite.contrib.handlers import ProgressBar
from util.build import build_model, build_optimizer
from util.handler import setup_common_handlers
from data.transform import transform_pipeline
from data.dataset import LMDBDataset
def warmup_trainer(config, logger):
model = build_model(config.model, config.distributed.model)
optimizer = build_optimizer(model.parameters(), config.baseline.optimizers)
loss_fn = nn.CrossEntropyLoss()
trainer = create_supervised_trainer(model, optimizer, loss_fn, idist.device(), non_blocking=True,
output_transform=lambda x, y, y_pred, loss: (loss.item(), y_pred, y))
trainer.logger = logger
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss")
Accuracy(output_transform=lambda x: (x[1], x[2])).attach(trainer, "acc")
ProgressBar(ncols=0).attach(trainer)
if idist.get_rank() == 0:
GpuInfo().attach(trainer, name='gpu')
tb_logger = TensorboardLogger(log_dir=config.output_dir)
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="train",
metric_names='all',
global_step_transform=global_step_from_engine(trainer),
),
event_name=Events.EPOCH_COMPLETED
)
tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model),
event_name=Events.EPOCH_COMPLETED(every=10))
tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25))
tb_logger.attach(trainer, log_handler=GradsScalarHandler(model),
event_name=Events.EPOCH_COMPLETED(every=10))
tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25))
@trainer.on(Events.COMPLETED)
def _():
tb_logger.close()
to_save = dict(model=model, optimizer=optimizer, trainer=trainer)
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.EPOCH_COMPLETED, to_save=to_save,
save_interval_event=Events.EPOCH_COMPLETED(every=25), n_saved=5,
metrics_to_print=["loss", "acc"])
return trainer
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
if task == "warmup":
train_dataset = LMDBDataset(config.baseline.data.dataset.train.lmdb_path,
pipeline=config.baseline.data.dataset.train.pipeline)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader)
trainer = warmup_trainer(config, logger)
try:
trainer.run(train_data_loader, max_epochs=400)
except Exception:
import traceback
print(traceback.format_exc())
elif task == "protonet-wo":
pass
elif task == "protonet-w":
pass
else:
return ValueError(f"invalid task: {task}")

View File

@@ -1,9 +0,0 @@
from data.dataset import EpisodicDataset, LMDBDataset
def prototypical_trainer(config, logger):
pass
def prototypical_dataloader(config):
pass

0
engine/util/__init__.py Normal file
View File

23
engine/util/build.py Normal file
View File

@@ -0,0 +1,23 @@
import torch
import ignite.distributed as idist
from omegaconf import OmegaConf
from model import MODEL
import torch.optim as optim
def build_model(cfg):
cfg = OmegaConf.to_container(cfg)
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
model = MODEL.build_with(cfg)
if bn_to_sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
return idist.auto_model(model)
def build_optimizer(params, cfg):
assert "_type" in cfg
cfg = OmegaConf.to_container(cfg)
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
return idist.auto_optim(optimizer)

154
engine/util/handler.py Normal file
View File

@@ -0,0 +1,154 @@
from pathlib import Path
import torch
from torch.utils.data.distributed import DistributedSampler
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler, OptimizerParamsHandler
def empty_cuda_cache(_):
torch.cuda.empty_cache()
import gc
gc.collect()
def step_transform_maker(stype: str, pairs_per_iteration=None):
assert stype in ["item", "iteration", "epoch"]
if stype == "item":
return lambda engine, _: engine.state.iteration * pairs_per_iteration
if stype == "iteration":
return lambda engine, _: engine.state.iteration
if stype == "epoch":
return lambda engine, _: engine.state.epoch
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
to_save=None, end_event=None, set_epoch_for_dist_sampler=False):
"""
Helper method to setup trainer with common handlers.
1. TerminateOnNan
2. BasicTimeProfiler
3. Print
4. Checkpoint
:param trainer:
:param config:
:param stop_on_nan:
:param clear_cuda_cache:
:param use_profiler:
:param to_save:
:param end_event:
:param set_epoch_for_dist_sampler:
:return:
"""
if set_epoch_for_dist_sampler:
@trainer.on(Events.EPOCH_STARTED)
def distrib_set_epoch(engine):
if isinstance(trainer.state.dataloader.sampler, DistributedSampler):
trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler")
trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1)
trainer.logger.info(f"data loader length: {config.iterations_per_epoch} iterations per epoch")
@trainer.on(Events.EPOCH_COMPLETED(once=1))
def print_info(engine):
engine.logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}")
if stop_on_nan:
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
if torch.cuda.is_available() and clear_cuda_cache:
trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)
if use_profiler:
# Create an object of the profiler and attach an engine to it
profiler = BasicTimeProfiler()
profiler.attach(trainer)
@trainer.on(Events.EPOCH_COMPLETED(once=1) | Events.COMPLETED)
@idist.one_rank_only()
def log_intermediate_results():
profiler.print_results(profiler.get_results())
ProgressBar(ncols=0).attach(trainer, "all")
if to_save is not None:
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False),
n_saved=config.handler.checkpoint.n_saved, filename_prefix=config.name)
if config.resume_from is not None:
@trainer.on(Events.STARTED)
def resume(engine):
checkpoint_path = Path(config.resume_from)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
trainer.logger.info(f"load state_dict for {ckp.keys()}")
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
trainer.add_event_handler(
Events.EPOCH_COMPLETED(every=config.handler.checkpoint.epoch_interval) | Events.COMPLETED,
checkpoint_handler
)
trainer.logger.debug(f"add checkpoint handler to save {to_save.keys()} periodically")
if end_event is not None:
trainer.logger.debug(f"engine will stop on {end_event}")
@trainer.on(end_event)
def terminate(engine):
engine.terminate()
def setup_tensorboard_handler(trainer: Engine, config, optimizers, step_type="item"):
if config.handler.tensorboard is None:
return None
if idist.get_rank() == 0:
# Create a logger
tb_logger = TensorboardLogger(log_dir=config.output_dir)
tb_writer = tb_logger.writer
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
global_step_transform = step_transform_maker(step_type, pairs_per_iteration)
basic_event = Events.ITERATION_COMPLETED(
every=max(config.iterations_per_epoch // config.handler.tensorboard.scalar, 1))
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="metric", metric_names="all",
global_step_transform=global_step_transform
),
event_name=basic_event
)
@trainer.on(basic_event)
def log_loss(engine):
global_step = global_step_transform(engine, None)
output_loss = engine.state.output["loss"]
for total_loss in output_loss:
if isinstance(output_loss[total_loss], dict):
for ln in output_loss[total_loss]:
tb_writer.add_scalar(f"train_{total_loss}/{ln}", output_loss[total_loss][ln], global_step)
else:
tb_writer.add_scalar(f"train/{total_loss}", output_loss[total_loss], global_step)
if isinstance(optimizers, dict):
for name in optimizers:
tb_logger.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizers[name], tag=f"optimizer_{name}"),
event_name=Events.ITERATION_STARTED
)
else:
tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizers, tag=f"optimizer"),
event_name=Events.ITERATION_STARTED)
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
# We need to close the logger with we are done
tb_logger.close()
return tb_logger
return None