almost 0.1

This commit is contained in:
2020-09-06 10:34:52 +08:00
parent e3c760d0c5
commit ab545843bf
15 changed files with 308 additions and 680 deletions

101
engine/CyCleGAN.py Normal file
View File

@@ -0,0 +1,101 @@
from itertools import chain
import ignite.distributed as idist
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from loss.gan import GANLoss
from model.GAN.base import GANImageBuffer
from model.weight_init import generation_init_weights
class TAFGEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
def build_models(self) -> (dict, dict):
generators = dict(
a2b=build_model(self.config.model.generator),
b2a=build_model(self.config.model.generator)
)
discriminators = dict(
a=build_model(self.config.model.discriminator),
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["a"])
self.logger.debug(generators["a2b"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
images = dict()
with torch.set_grad_enabled(not inference):
images["a2b"] = self.generators["a2b"](batch["a"])
images["b2a"] = self.generators["b2a"](batch["b"])
images["a2b2a"] = self.generators["b2a"](images["a2b"])
images["b2a2b"] = self.generators["a2b"](images["b2a"])
if self.config.loss.id.weight > 0:
images["a2a"] = self.generators["b2a"](batch["a"])
images["b2b"] = self.generators["a2b"](batch["b"])
return images
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
for phase in ["a2b", "b2a"]:
loss[f"cycle_{phase[0]}"] = self.config.loss.cycle.weight * self.cycle_loss(
generated[f"{phase}2{phase[0]}"], batch[phase[0]])
loss[f"gan_{phase}"] = self.config.loss.gan.weight * self.gan_loss(
self.discriminators[phase[-1]](generated[phase]), True)
if self.config.loss.id.weight > 0:
loss[f"id_{phase[0]}"] = self.config.loss.id.weight * self.id_loss(
generated[f"{phase[0]}2{phase[0]}"], batch[phase[0]])
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
for phase in "ab":
generated_image = self.image_buffers[phase].query(generated["b2a" if phase == "a" else "a2b"].detach())
loss[f"gan_{phase}"] = (self.gan_loss(self.discriminators[phase](generated_image), False,
is_discriminator=True) +
self.gan_loss(self.discriminators[phase](batch[phase]), True,
is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
return dict(
a=[batch["a"].detach(), generated["a2b"].detach(), generated["a2b2a"].detach()],
b=[batch["b"].detach(), generated["b2a"].detach(), generated["b2a2b"].detach()],
)
def run(task, config, _):
kernel = TAFGEngineKernel(config)
run_kernel(task, config, kernel)

View File

@@ -5,6 +5,9 @@ from omegaconf import OmegaConf
import torch
import torch.nn as nn
import ignite.distributed as idist
from ignite.engine import Events
from omegaconf import read_write, OmegaConf
from model.weight_init import generation_init_weights
from loss.I2I.perceptual_loss import PerceptualLoss
@@ -49,7 +52,7 @@ class TAFGEngineKernel(EngineKernel):
return generators, discriminators
def setup_before_d(self):
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
@@ -89,7 +92,7 @@ class TAFGEngineKernel(EngineKernel):
for j in range(num_intermediate_outputs):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
loss["recon"] = self.recon_loss(generated["a"], batch["a"]) * self.config.loss.recon.weight
loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"])
# loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
# self.generators["main"].module.style_encoders["b"](batch["b"]),
# self.generators["main"].module.style_encoders["b"](generated["b"])
@@ -122,6 +125,12 @@ class TAFGEngineKernel(EngineKernel):
generated["b"].detach()]
)
def change_engine(self, config, trainer):
@trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3)))
def change_config(engine):
with read_write(config):
config.loss.perceptual.weight = 5
def run(task, config, _):
kernel = TAFGEngineKernel(config)

View File

@@ -1,5 +1,3 @@
from itertools import chain
from omegaconf import OmegaConf
import torch
@@ -7,10 +5,9 @@ import torch.nn as nn
import torch.nn.functional as F
import ignite.distributed as idist
from model.weight_init import generation_init_weights
from loss.gan import GANLoss
from model.GAN.UGATIT import RhoClipper
from model.GAN.residual_generator import GANImageBuffer
from model.GAN.base import GANImageBuffer
from util.image import attention_colored_map
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
@@ -36,6 +33,7 @@ class UGATITEngineKernel(EngineKernel):
self.rho_clipper = RhoClipper(0, 1)
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
self.train_generator_first = False
def build_models(self) -> (dict, dict):
generators = dict(
@@ -51,12 +49,9 @@ class UGATITEngineKernel(EngineKernel):
self.logger.debug(discriminators["ga"])
self.logger.debug(generators["a2b"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_before_d(self):
def setup_after_g(self):
for generator in self.generators.values():
generator.apply(self.rho_clipper)
for discriminator in self.discriminators.values():
@@ -101,8 +96,7 @@ class UGATITEngineKernel(EngineKernel):
loss = dict()
for phase in "ab":
for level in "gl":
generated_image = self.image_buffers[level + phase].query(
generated["images"]["a2b" if phase == "b" else "b2a"])
generated_image = generated["images"]["b2a" if phase == "a" else "a2b"].detach()
pred_fake, cam_fake_pred = self.discriminators[level + phase](generated_image)
pred_real, cam_real_pred = self.discriminators[level + phase](batch[phase])
loss[f"gan_{phase}_{level}"] = self.gan_loss(pred_real, True, is_discriminator=True) + self.gan_loss(

View File

@@ -1,23 +1,21 @@
from itertools import chain
import logging
from itertools import chain
from pathlib import Path
import ignite.distributed as idist
import torch
import torchvision
import ignite.distributed as idist
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from ignite.engine import Events, Engine
from ignite.metrics import RunningAverage
from ignite.utils import convert_tensor
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from math import ceil
from omegaconf import read_write, OmegaConf
from util.image import make_2d_grid
from engine.util.handler import setup_common_handlers, setup_tensorboard_handler
from engine.util.build import build_optimizer
import data
from engine.util.build import build_optimizer
from engine.util.handler import setup_common_handlers, setup_tensorboard_handler
from util.image import make_2d_grid
def build_lr_schedulers(optimizers, config):
@@ -59,6 +57,7 @@ class EngineKernel(object):
self.config = config
self.logger = logging.getLogger(config.name)
self.generators, self.discriminators = self.build_models()
self.train_generator_first = True
def build_models(self) -> (dict, dict):
raise NotImplemented
@@ -69,7 +68,7 @@ class EngineKernel(object):
to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators})
return to_save
def setup_before_d(self):
def setup_after_g(self):
raise NotImplemented
def setup_before_g(self):
@@ -93,6 +92,9 @@ class EngineKernel(object):
"""
raise NotImplemented
def change_engine(self, config, engine: Engine):
pass
def get_trainer(config, kernel: EngineKernel):
logger = logging.getLogger(config.name)
@@ -106,26 +108,37 @@ def get_trainer(config, kernel: EngineKernel):
lr_schedulers = build_lr_schedulers(optimizers, config)
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
image_per_iteration = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1)
iteration_per_image = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1)
def train_generators(batch, generated):
kernel.setup_before_g()
optimizers["g"].zero_grad()
loss_g = kernel.criterion_generators(batch, generated)
sum(loss_g.values()).backward()
optimizers["g"].step()
kernel.setup_after_g()
return loss_g
def train_discriminators(batch, generated):
optimizers["d"].zero_grad()
loss_d = kernel.criterion_discriminators(batch, generated)
sum(loss_d.values()).backward()
optimizers["d"].step()
return loss_d
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
generated = kernel.forward(batch)
kernel.setup_before_g()
optimizers["g"].zero_grad()
loss_g = kernel.criterion_generators(batch, generated)
sum(loss_g.values()).backward()
optimizers["g"].step()
if kernel.train_generator_first:
loss_g = train_generators(batch, generated)
loss_d = train_discriminators(batch, generated)
else:
loss_d = train_discriminators(batch, generated)
loss_g = train_generators(batch, generated)
kernel.setup_before_d()
optimizers["d"].zero_grad()
loss_d = kernel.criterion_discriminators(batch, generated)
sum(loss_d.values()).backward()
optimizers["d"].step()
if engine.state.iteration % image_per_iteration == 0:
if engine.state.iteration % iteration_per_image == 0:
return {
"loss": dict(g=loss_g, d=loss_d),
"img": kernel.intermediate_images(batch, generated)
@@ -137,6 +150,8 @@ def get_trainer(config, kernel: EngineKernel):
for lr_shd in lr_schedulers.values():
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
kernel.change_engine(config, trainer)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
to_save = dict(trainer=trainer)
@@ -150,7 +165,7 @@ def get_trainer(config, kernel: EngineKernel):
tensorboard_handler = setup_tensorboard_handler(trainer, config, optimizers, step_type="item")
if tensorboard_handler is not None:
basic_image_event = Events.ITERATION_COMPLETED(
every=image_per_iteration)
every=iteration_per_image)
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
@trainer.on(basic_image_event)
@@ -227,7 +242,7 @@ def run_kernel(task, config, kernel):
logger = logging.getLogger(config.name)
with read_write(config):
real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size()
config.max_iteration = config.max_pairs // real_batch_size + 1
config.max_iteration = ceil(config.max_pairs / real_batch_size)
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
@@ -243,7 +258,7 @@ def run_kernel(task, config, kernel):
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset
try:
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
except Exception:
import traceback
print(traceback.format_exc())

View File

@@ -1,268 +0,0 @@
import itertools
from pathlib import Path
import torch
import torch.nn as nn
import torchvision.utils
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from ignite.metrics import RunningAverage
from ignite.contrib.handlers import ProgressBar
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
from omegaconf import OmegaConf
import data
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
from model.GAN.residual_generator import GANImageBuffer
from util.image import make_2d_grid
from util.handler import setup_common_handlers
from util.build import build_model, build_optimizer
def get_trainer(config, logger):
generator_a = build_model(config.model.generator, config.distributed.model)
generator_b = build_model(config.model.generator, config.distributed.model)
discriminator_a = build_model(config.model.discriminator, config.distributed.model)
discriminator_b = build_model(config.model.discriminator, config.distributed.model)
for m in [generator_b, generator_a, discriminator_b, discriminator_a]:
generation_init_weights(m)
logger.info(discriminator_a)
logger.info(generator_a)
optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
config.optimizers.generator)
optimizer_d = build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()),
config.optimizers.discriminator)
milestones_values = [
(0, config.optimizers.generator.lr),
(100, config.optimizers.generator.lr),
(200, config.data.train.scheduler.target_lr)
]
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
milestones_values = [
(0, config.optimizers.discriminator.lr),
(100, config.optimizers.discriminator.lr),
(200, config.data.train.scheduler.target_lr)
]
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real_a, real_b = batch["a"], batch["b"]
fake_b = generator_a(real_a) # G_A(A)
rec_a = generator_b(fake_b) # G_B(G_A(A))
fake_a = generator_b(real_b) # G_B(B)
rec_b = generator_a(fake_a) # G_A(G_B(B))
optimizer_g.zero_grad()
discriminator_a.requires_grad_(False)
discriminator_b.requires_grad_(False)
loss_g = dict(
cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a),
cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b),
gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True),
gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True)
)
if config.loss.id.weight > 0:
loss_g["id_a"] = config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
loss_g["id_b"] = config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
sum(loss_g.values()).backward()
optimizer_g.step()
discriminator_a.requires_grad_(True)
discriminator_b.requires_grad_(True)
optimizer_d.zero_grad()
loss_d_a = dict(
real=gan_loss(discriminator_a(real_b), True, is_discriminator=True),
fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True),
)
loss_d_b = dict(
real=gan_loss(discriminator_b(real_a), True, is_discriminator=True),
fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True),
)
(sum(loss_d_a.values()) * 0.5).backward()
(sum(loss_d_b.values()) * 0.5).backward()
optimizer_d.step()
return {
"loss": {
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
"d_a": {ln: loss_d_a[ln].mean().item() for ln in loss_d_a},
"d_b": {ln: loss_d_b[ln].mean().item() for ln in loss_d_b},
},
"img": [
real_a.detach(),
fake_b.detach(),
rec_a.detach(),
real_b.detach(),
fake_a.detach(),
rec_b.detach()
]
}
trainer = Engine(_step)
trainer.logger = logger
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_g)
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_d)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b")
to_save = dict(
generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a,
discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer,
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
)
setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5,
filename_prefix=config.name, to_save=to_save,
print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED,
metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"],
save_interval_event=Events.ITERATION_COMPLETED(
every=config.checkpoints.interval) | Events.COMPLETED)
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
def terminate(engine):
engine.terminate()
if idist.get_rank() == 0:
# Create a logger
tb_logger = TensorboardLogger(log_dir=config.output_dir)
tb_writer = tb_logger.writer
# Attach the logger to the trainer to log training loss at each iteration
def global_step_transform(*args, **kwargs):
return trainer.state.iteration
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="loss",
metric_names=["loss_g", "loss_d_a", "loss_d_b"],
global_step_transform=global_step_transform,
output_transform=output_transform
),
event_name=Events.ITERATION_COMPLETED(every=50)
)
tb_logger.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=50)
)
@trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
def show_images(engine):
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration)
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
# We need to close the logger with we are done
tb_logger.close()
return trainer
def get_tester(config, logger):
generator_a = build_model(config.model.generator, config.distributed.model)
generator_b = build_model(config.model.generator, config.distributed.model)
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real_a, real_b = batch["a"], batch["b"]
with torch.no_grad():
fake_b = generator_a(real_a) # G_A(A)
rec_a = generator_b(fake_b) # G_B(G_A(A))
fake_a = generator_b(real_b) # G_B(B)
rec_b = generator_a(fake_a) # G_A(G_B(B))
return [
real_a.detach(),
fake_b.detach(),
rec_a.detach(),
real_b.detach(),
fake_a.detach(),
rec_b.detach()
]
tester = Engine(_step)
tester.logger = logger
if idist.get_rank == 0:
ProgressBar(ncols=0).attach(tester)
to_load = dict(generator_a=generator_a, generator_b=generator_b)
setup_common_handlers(tester, use_profiler=False, to_save=to_load, resume_from=config.resume_from)
@tester.on(Events.STARTED)
@idist.one_rank_only()
def mkdir(engine):
img_output_dir = Path(config.output_dir) / "test_images"
if not img_output_dir.exists():
engine.logger.info(f"mkdir {img_output_dir}")
img_output_dir.mkdir()
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
img_tensors = engine.state.output
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
torchvision.utils.save_image([img[i] for img in img_tensors],
Path(config.output_dir) / f"test_images/{engine.state.iteration}_{i}.jpg",
nrow=len(img_tensors))
return tester
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger)
try:
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
except Exception:
import traceback
print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.dataset)
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, logger)
try:
tester.run(test_data_loader, max_epochs=1)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")