From 39c754374c707a339017dd4734ae603dff972646 Mon Sep 17 00:00:00 2001 From: budui Date: Sat, 5 Sep 2020 10:33:35 +0800 Subject: [PATCH] change --- .idea/deployment.xml | 9 +- configs/synthesizers/TAFG.yml | 53 ++--- configs/synthesizers/UGATIT.yml | 45 ++--- data/dataset.py | 23 ++- engine/TAFG.py | 57 +++--- engine/U-GAT-IT.py | 153 +++++++++++++++ engine/UGATIT.py | 320 ------------------------------- engine/base/i2i.py | 181 ++++++++++++----- engine/crossdomain.py | 85 -------- engine/fewshot.py | 9 - engine/util/__init__.py | 0 {util => engine/util}/build.py | 14 +- {util => engine/util}/handler.py | 70 +++++-- loss/I2I/perceptual_loss.py | 22 +-- main.py | 8 +- model/GAN/{TAHG.py => TAFG.py} | 57 ++++-- model/GAN/base.py | 25 ++- model/GAN/residual_generator.py | 28 +-- model/__init__.py | 2 +- util/distributed.py | 66 ------- util/image.py | 28 ++- 21 files changed, 550 insertions(+), 705 deletions(-) create mode 100644 engine/U-GAT-IT.py delete mode 100644 engine/UGATIT.py delete mode 100644 engine/crossdomain.py delete mode 100644 engine/fewshot.py create mode 100644 engine/util/__init__.py rename {util => engine/util}/build.py (57%) rename {util => engine/util}/handler.py (57%) rename model/GAN/{TAHG.py => TAFG.py} (84%) delete mode 100644 util/distributed.py diff --git a/.idea/deployment.xml b/.idea/deployment.xml index f335efb..4c3f8f7 100644 --- a/.idea/deployment.xml +++ b/.idea/deployment.xml @@ -1,6 +1,6 @@ - + @@ -16,6 +16,13 @@ + + + + + + + diff --git a/configs/synthesizers/TAFG.yml b/configs/synthesizers/TAFG.yml index b042ee4..2624240 100644 --- a/configs/synthesizers/TAFG.yml +++ b/configs/synthesizers/TAFG.yml @@ -3,31 +3,32 @@ engine: TAFG result_dir: ./result max_pairs: 1000000 +handler: + clear_cuda_cache: True + set_epoch_for_dist_sampler: True + checkpoint: + epoch_interval: 1 # checkpoint once per `epoch_interval` epoch + n_saved: 2 + tensorboard: + scalar: 100 # log scalar `scalar` times per epoch + image: 2 # log image `image` times per epoch + + misc: random_seed: 324 -checkpoint: - epoch_interval: 1 # one checkpoint every 1 epoch - n_saved: 2 - -interval: - print_per_iteration: 10 # print once per 10 iteration - tensorboard: - scalar: 100 - image: 2 - model: generator: - _type: TAHG-Generator + _type: TAFG-Generator _bn_to_sync_bn: False style_in_channels: 3 - content_in_channels: 1 - num_blocks: 4 + content_in_channels: 24 + num_blocks: 8 discriminator: _type: MultiScaleDiscriminator num_scale: 2 discriminator_cfg: - _type: base-PatchDiscriminator + _type: pix2pixHD in_channels: 3 base_channels: 64 use_spectral: True @@ -46,7 +47,7 @@ loss: "11": 0.125 "20": 0.25 "29": 1 - criterion: 'L1' + criterion: 'NL1' style_loss: False perceptual_loss: True weight: 5 @@ -63,10 +64,13 @@ loss: weight: 0 fm: level: 1 - weight: 10 + weight: 1 recon: level: 1 - weight: 5 + weight: 10 + style_recon: + level: 1 + weight: 10 optimizers: generator: @@ -87,7 +91,7 @@ data: target_lr: 0 buffer_size: 50 dataloader: - batch_size: 256 + batch_size: 24 shuffle: True num_workers: 2 pin_memory: True @@ -98,13 +102,13 @@ data: root_b: "/data/i2i/VoxCeleb2Anime/trainB" edges_path: "/data/i2i/VoxCeleb2Anime/edges" landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks" - edge_type: "landmark_canny" - size: [128, 128] + edge_type: "landmark_hed" + size: [ 128, 128 ] random_pair: True pipeline: - Load - Resize: - size: [128, 128] + size: [ 128, 128 ] - ToTensor - Normalize: mean: [ 0.5, 0.5, 0.5 ] @@ -121,13 +125,14 @@ data: root_a: "/data/i2i/VoxCeleb2Anime/testA" root_b: "/data/i2i/VoxCeleb2Anime/testB" edges_path: "/data/i2i/VoxCeleb2Anime/edges" - edge_type: "hed" + landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks" + edge_type: "landmark_hed" random_pair: False - size: [128, 128] + size: [ 128, 128 ] pipeline: - Load - Resize: - size: [128, 128] + size: [ 128, 128 ] - ToTensor - Normalize: mean: [ 0.5, 0.5, 0.5 ] diff --git a/configs/synthesizers/UGATIT.yml b/configs/synthesizers/UGATIT.yml index 96f1981..c04e250 100644 --- a/configs/synthesizers/UGATIT.yml +++ b/configs/synthesizers/UGATIT.yml @@ -1,24 +1,20 @@ name: selfie2anime -engine: UGATIT +engine: U-GAT-IT result_dir: ./result max_pairs: 1000000 -distributed: - model: - # broadcast_buffers: False - misc: random_seed: 324 -checkpoint: - epoch_interval: 1 # one checkpoint every 1 epoch - n_saved: 2 - -interval: - print_per_iteration: 10 # print once per 10 iteration +handler: + clear_cuda_cache: True + set_epoch_for_dist_sampler: True + checkpoint: + epoch_interval: 1 # checkpoint once per `epoch_interval` epoch + n_saved: 2 tensorboard: - scalar: 10 - image: 500 + scalar: 100 # log scalar `scalar` times per epoch + image: 2 # log image `image` times per epoch model: generator: @@ -59,12 +55,12 @@ optimizers: generator: _type: Adam lr: 0.0001 - betas: [0.5, 0.999] + betas: [ 0.5, 0.999 ] weight_decay: 0.0001 discriminator: _type: Adam lr: 1e-4 - betas: [0.5, 0.999] + betas: [ 0.5, 0.999 ] weight_decay: 0.0001 data: @@ -74,7 +70,7 @@ data: target_lr: 0 buffer_size: 50 dataloader: - batch_size: 4 + batch_size: 24 shuffle: True num_workers: 2 pin_memory: True @@ -87,14 +83,14 @@ data: pipeline: - Load - Resize: - size: [286, 286] + size: [ 286, 286 ] - RandomCrop: - size: [256, 256] + size: [ 256, 256 ] - RandomHorizontalFlip - ToTensor - Normalize: - mean: [0.5, 0.5, 0.5] - std: [0.5, 0.5, 0.5] + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] test: dataloader: batch_size: 8 @@ -110,11 +106,11 @@ data: pipeline: - Load - Resize: - size: [256, 256] + size: [ 256, 256 ] - ToTensor - Normalize: - mean: [0.5, 0.5, 0.5] - std: [0.5, 0.5, 0.5] + mean: [ 0.5, 0.5, 0.5 ] + std: [ 0.5, 0.5, 0.5 ] video_dataset: _type: SingleFolderDataset root: "/data/i2i/VoxCeleb2Anime/test_video_frames/" @@ -124,6 +120,3 @@ data: - Resize: size: [ 256, 256 ] - ToTensor - - Normalize: - mean: [ 0.5, 0.5, 0.5 ] - std: [ 0.5, 0.5, 0.5 ] diff --git a/data/dataset.py b/data/dataset.py index 6e62cdd..a26db32 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -177,6 +177,13 @@ class GenerationUnpairedDataset(Dataset): return f"\nPipeline:\n{self.A.pipeline}" +def normalize_tensor(tensor): + tensor = tensor.float() + tensor -= tensor.min() + tensor /= tensor.max() + return tensor + + @DATASET.register_module() class GenerationUnpairedDatasetWithEdge(Dataset): def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256)): @@ -200,17 +207,19 @@ class GenerationUnpairedDatasetWithEdge(Dataset): edge_type = self.edge_type use_landmark = False edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png" - origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size)) + origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR)) if not use_landmark: return origin_edge else: - landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.{edge_type}.txt" + landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt" key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size) - dist_tensor = torch.from_numpy(dlib_landmark.dist_tensor(key_points)) - part_labels = torch.from_numpy(part_labels) - edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face - edges = part_edge + edges - return torch.cat([edges, dist_tensor, part_labels], dim=0) + + dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size))) + part_labels = normalize_tensor(torch.from_numpy(part_labels)) + part_edge = torch.from_numpy(part_edge).unsqueeze(0).float() + # edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face + # edges = part_edge + edges + return torch.cat([origin_edge, part_edge, dist_tensor, part_labels]) def __getitem__(self, idx): a_idx = idx % len(self.A) diff --git a/engine/TAFG.py b/engine/TAFG.py index 72d3d89..cd02502 100644 --- a/engine/TAFG.py +++ b/engine/TAFG.py @@ -1,24 +1,22 @@ from itertools import chain -from math import ceil -from omegaconf import read_write, OmegaConf +from omegaconf import OmegaConf import torch import torch.nn as nn -import torch.nn.functional as F import ignite.distributed as idist -import data -from engine.base.i2i import get_trainer, EngineKernel, build_model from model.weight_init import generation_init_weights - from loss.I2I.perceptual_loss import PerceptualLoss from loss.gan import GANLoss +from engine.base.i2i import EngineKernel, run_kernel +from engine.util.build import build_model + class TAFGEngineKernel(EngineKernel): - def __init__(self, config, logger): - super().__init__(config, logger) + def __init__(self, config): + super().__init__(config) perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual) perceptual_loss_cfg.pop("weight") self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) @@ -29,6 +27,11 @@ class TAFGEngineKernel(EngineKernel): self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss() self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss() + self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss() + + def _process_batch(self, batch, inference=False): + # batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size()) + return batch def build_models(self) -> (dict, dict): generators = dict( @@ -56,6 +59,7 @@ class TAFGEngineKernel(EngineKernel): def forward(self, batch, inference=False) -> dict: generator = self.generators["main"] + batch = self._process_batch(batch, inference) with torch.set_grad_enabled(not inference): fake = dict( a=generator(content_img=batch["edge_a"], style_img=batch["a"], which_decoder="a"), @@ -64,6 +68,7 @@ class TAFGEngineKernel(EngineKernel): return fake def criterion_generators(self, batch, generated) -> dict: + batch = self._process_batch(batch) loss = dict() loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"]) loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight @@ -85,10 +90,15 @@ class TAFGEngineKernel(EngineKernel): loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm loss["recon"] = self.recon_loss(generated["a"], batch["a"]) * self.config.loss.recon.weight + loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss( + self.generators["main"].module.style_encoders["b"](batch["b"]), + self.generators["main"].module.style_encoders["b"](generated["b"]) + ) return loss def criterion_discriminators(self, batch, generated) -> dict: loss = dict() + # batch = self._process_batch(batch) for phase in self.discriminators.keys(): pred_real = self.discriminators[phase](batch[phase]) pred_fake = self.discriminators[phase](generated[phase].detach()) @@ -105,31 +115,14 @@ class TAFGEngineKernel(EngineKernel): :param generated: dict of images :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} """ + batch = self._process_batch(batch) + edge = batch["edge_a"][:, 0:1, :, :] return dict( - a=[batch[f"edge_a"].expand(-1, 3, -1, -1).detach(), batch["a"].detach(), generated["a"].detach()], - b=[batch["b"].detach(), generated["b"].detach()] + a=[edge.expand(-1, 3, -1, -1).detach(), batch["a"].detach(), batch["b"].detach(), generated["a"].detach(), + generated["b"].detach()] ) -def run(task, config, logger): - assert torch.backends.cudnn.enabled - torch.backends.cudnn.benchmark = True - logger.info(f"start task {task}") - with read_write(config): - config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size) - - if task == "train": - train_dataset = data.DATASET.build_with(config.data.train.dataset) - logger.info(f"train with dataset:\n{train_dataset}") - train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader) - trainer = get_trainer(config, TAFGEngineKernel(config, logger), len(train_data_loader)) - if idist.get_rank() == 0: - test_dataset = data.DATASET.build_with(config.data.test.dataset) - trainer.state.test_dataset = test_dataset - try: - trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader))) - except Exception: - import traceback - print(traceback.format_exc()) - else: - return NotImplemented(f"invalid task: {task}") +def run(task, config, _): + kernel = TAFGEngineKernel(config) + run_kernel(task, config, kernel) diff --git a/engine/U-GAT-IT.py b/engine/U-GAT-IT.py new file mode 100644 index 0000000..1a089ec --- /dev/null +++ b/engine/U-GAT-IT.py @@ -0,0 +1,153 @@ +from itertools import chain + +from omegaconf import OmegaConf + +import torch +import torch.nn as nn +import torch.nn.functional as F +import ignite.distributed as idist + +from model.weight_init import generation_init_weights +from loss.gan import GANLoss +from model.GAN.UGATIT import RhoClipper +from model.GAN.residual_generator import GANImageBuffer +from util.image import attention_colored_map +from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel +from engine.util.build import build_model + + +def mse_loss(x, target_flag): + return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) + + +def bce_loss(x, target_flag): + return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) + + +class UGATITEngineKernel(EngineKernel): + def __init__(self, config): + super().__init__(config) + + gan_loss_cfg = OmegaConf.to_container(config.loss.gan) + gan_loss_cfg.pop("weight") + self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) + self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() + self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss() + self.rho_clipper = RhoClipper(0, 1) + self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in + self.discriminators.keys()} + + def build_models(self) -> (dict, dict): + generators = dict( + a2b=build_model(self.config.model.generator), + b2a=build_model(self.config.model.generator) + ) + discriminators = dict( + la=build_model(self.config.model.local_discriminator), + lb=build_model(self.config.model.local_discriminator), + ga=build_model(self.config.model.global_discriminator), + gb=build_model(self.config.model.global_discriminator), + ) + self.logger.debug(discriminators["ga"]) + self.logger.debug(generators["a2b"]) + + for m in chain(generators.values(), discriminators.values()): + generation_init_weights(m) + + return generators, discriminators + + def setup_before_d(self): + for generator in self.generators.values(): + generator.apply(self.rho_clipper) + for discriminator in self.discriminators.values(): + discriminator.requires_grad_(True) + + def setup_before_g(self): + for discriminator in self.discriminators.values(): + discriminator.requires_grad_(False) + + def forward(self, batch, inference=False) -> dict: + images = dict() + heatmap = dict() + cam_pred = dict() + + with torch.set_grad_enabled(not inference): + images["a2b"], cam_pred["a2b"], heatmap["a2b"] = self.generators["a2b"](batch["a"]) + images["b2a"], cam_pred["b2a"], heatmap["b2a"] = self.generators["b2a"](batch["b"]) + images["a2b2a"], _, heatmap["a2b2a"] = self.generators["b2a"](images["a2b"]) + images["b2a2b"], _, heatmap["b2a2b"] = self.generators["a2b"](images["b2a"]) + images["a2a"], cam_pred["a2a"], heatmap["a2a"] = self.generators["b2a"](batch["a"]) + images["b2b"], cam_pred["b2b"], heatmap["b2b"] = self.generators["a2b"](batch["b"]) + return dict(images=images, heatmap=heatmap, cam_pred=cam_pred) + + def criterion_generators(self, batch, generated) -> dict: + loss = dict() + for phase in "ab": + cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"] + loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase]) + loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase], + generated["images"][f"{phase}2{phase}"]) + for dk in "lg": + generated_image = generated["images"]["a2b" if phase == "b" else "b2a"] + pred_fake, cam_pred = self.discriminators[dk + phase](generated_image) + loss[f"gan_{phase}_{dk}"] = self.config.loss.gan.weight * self.gan_loss(pred_fake, True) + loss[f"gan_cam_{phase}_{dk}"] = self.config.loss.gan.weight * mse_loss(cam_pred, True) + for t, f in [("a2b", "b2b"), ("b2a", "a2a")]: + loss[f"cam_{t[-1]}"] = self.config.loss.cam.weight * ( + bce_loss(generated["cam_pred"][t], True) + bce_loss(generated["cam_pred"][f], False)) + return loss + + def criterion_discriminators(self, batch, generated) -> dict: + loss = dict() + for phase in "ab": + for level in "gl": + generated_image = self.image_buffers[level + phase].query( + generated["images"]["a2b" if phase == "b" else "b2a"]) + pred_fake, cam_fake_pred = self.discriminators[level + phase](generated_image) + pred_real, cam_real_pred = self.discriminators[level + phase](batch[phase]) + loss[f"gan_{phase}_{level}"] = self.gan_loss(pred_real, True, is_discriminator=True) + self.gan_loss( + pred_fake, False, is_discriminator=True) + loss[f"cam_{phase}_{level}"] = mse_loss(cam_fake_pred, False) + mse_loss(cam_real_pred, True) + return loss + + def intermediate_images(self, batch, generated) -> dict: + """ + returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} + :param batch: + :param generated: dict of images + :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} + """ + attention_a = attention_colored_map(generated["heatmap"]["a2b"].detach(), batch["a"].size()[-2:]) + attention_b = attention_colored_map(generated["heatmap"]["b2a"].detach(), batch["b"].size()[-2:]) + generated = {img: generated["images"][img].detach() for img in generated["images"]} + return { + "a": [batch["a"], attention_a, generated["a2b"], generated["a2a"], generated["a2b2a"]], + "b": [batch["b"], attention_b, generated["b2a"], generated["b2b"], generated["b2a2b"]], + } + + +class UGATITTestEngineKernel(TestEngineKernel): + def __init__(self, config): + super().__init__(config) + + def build_generators(self) -> dict: + generators = dict( + a2b=build_model(self.config.model.generator), + ) + return generators + + def to_load(self): + return {f"generator_{k}": self.generators[k] for k in self.generators} + + def inference(self, batch): + with torch.no_grad(): + fake, _, _ = self.generators["a2b"](batch["a"]) + return {"a": fake.detach()} + + +def run(task, config, _): + if task == "train": + kernel = UGATITEngineKernel(config) + if task == "test": + kernel = UGATITTestEngineKernel(config) + run_kernel(task, config, kernel) diff --git a/engine/UGATIT.py b/engine/UGATIT.py deleted file mode 100644 index 256bdb4..0000000 --- a/engine/UGATIT.py +++ /dev/null @@ -1,320 +0,0 @@ -from itertools import chain -from math import ceil -from pathlib import Path - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision - -import ignite.distributed as idist -from ignite.engine import Events, Engine -from ignite.metrics import RunningAverage -from ignite.utils import convert_tensor -from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler -from ignite.contrib.handlers.param_scheduler import PiecewiseLinear - -from omegaconf import OmegaConf, read_write - -import data -from loss.gan import GANLoss -from model.weight_init import generation_init_weights -from model.GAN.residual_generator import GANImageBuffer -from model.GAN.UGATIT import RhoClipper -from util.image import make_2d_grid, fuse_attention_map, attention_colored_map -from util.handler import setup_common_handlers, setup_tensorboard_handler -from util.build import build_model, build_optimizer - - -def build_lr_schedulers(optimizers, config): - g_milestones_values = [ - (0, config.optimizers.generator.lr), - (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr), - (config.max_iteration, config.data.train.scheduler.target_lr) - ] - d_milestones_values = [ - (0, config.optimizers.discriminator.lr), - (int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr), - (config.max_iteration, config.data.train.scheduler.target_lr) - ] - return dict( - g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values), - d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values) - ) - - -def get_trainer(config, logger): - generators = dict( - a2b=build_model(config.model.generator, config.distributed.model), - b2a=build_model(config.model.generator, config.distributed.model), - ) - discriminators = dict( - la=build_model(config.model.local_discriminator, config.distributed.model), - lb=build_model(config.model.local_discriminator, config.distributed.model), - ga=build_model(config.model.global_discriminator, config.distributed.model), - gb=build_model(config.model.global_discriminator, config.distributed.model), - ) - for m in chain(generators.values(), discriminators.values()): - generation_init_weights(m) - - logger.debug(discriminators["ga"]) - logger.debug(generators["a2b"]) - - optimizers = dict( - g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator), - d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator), - ) - logger.info(f"build optimizers:\n{optimizers}") - - lr_schedulers = build_lr_schedulers(optimizers, config) - logger.info(f"build lr_schedulers:\n{lr_schedulers}") - - gan_loss_cfg = OmegaConf.to_container(config.loss.gan) - gan_loss_cfg.pop("weight") - gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) - cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() - id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss() - - def mse_loss(x, target_flag): - return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) - - def bce_loss(x, target_flag): - return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x)) - - image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()} - rho_clipper = RhoClipper(0, 1) - - def criterion_generator(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l, - discriminator_g): - discriminator_g.requires_grad_(False) - discriminator_l.requires_grad_(False) - pred_fake_g, cam_gd_pred = discriminator_g(fake) - pred_fake_l, cam_ld_pred = discriminator_l(fake) - return { - f"cycle_{name}": config.loss.cycle.weight * cycle_loss(real, rec), - f"id_{name}": config.loss.id.weight * id_loss(real, identity), - f"cam_{name}": config.loss.cam.weight * (bce_loss(cam_g_pred, True) + bce_loss(cam_g_id_pred, False)), - f"gan_l_{name}": config.loss.gan.weight * gan_loss(pred_fake_l, True), - f"gan_g_{name}": config.loss.gan.weight * gan_loss(pred_fake_g, True), - f"gan_cam_g_{name}": config.loss.gan.weight * mse_loss(cam_gd_pred, True), - f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True), - } - - def criterion_discriminator(name, discriminator, real, fake): - pred_real, cam_real = discriminator(real) - pred_fake, cam_fake = discriminator(fake) - # TODO: origin do not divide 2, but I think it better to divide 2. - loss_gan = gan_loss(pred_real, True, is_discriminator=True) + gan_loss(pred_fake, False, is_discriminator=True) - loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False) - return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam} - - def _step(engine, real): - real = convert_tensor(real, idist.device()) - - fake = dict() - cam_generator_pred = dict() - rec = dict() - identity = dict() - cam_identity_pred = dict() - heatmap = dict() - - fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real["a"]) - fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real["b"]) - rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"]) - rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"]) - identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"]) - identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"]) - - optimizers["g"].zero_grad() - loss_g = dict() - for n in ["a", "b"]: - loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n], - cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n])) - sum(loss_g.values()).backward() - optimizers["g"].step() - for generator in generators.values(): - generator.apply(rho_clipper) - for discriminator in discriminators.values(): - discriminator.requires_grad_(True) - - optimizers["d"].zero_grad() - loss_d = dict() - for k in discriminators.keys(): - n = k[-1] # "a" or "b" - loss_d.update( - criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach()))) - sum(loss_d.values()).backward() - optimizers["d"].step() - - for h in heatmap: - heatmap[h] = heatmap[h].detach() - generated_img = {f"real_{k}": real[k].detach() for k in real} - generated_img.update({f"fake_{k}": fake[k].detach() for k in fake}) - generated_img.update({f"id_{k}": identity[k].detach() for k in identity}) - generated_img.update({f"rec_{k}": rec[k].detach() for k in rec}) - - return { - "loss": { - "g": {ln: loss_g[ln].mean().item() for ln in loss_g}, - "d": {ln: loss_d[ln].mean().item() for ln in loss_d}, - }, - "img": { - "heatmap": heatmap, - "generated": generated_img - } - } - - trainer = Engine(_step) - trainer.logger = logger - for lr_shd in lr_schedulers.values(): - trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd) - - RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") - RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d") - - to_save = dict(trainer=trainer) - to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers}) - to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers}) - to_save.update({f"generator_{k}": generators[k] for k in generators}) - to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators}) - setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True, - end_event=Events.ITERATION_COMPLETED(once=config.max_iteration)) - - def output_transform(output): - loss = dict() - for tl in output["loss"]: - if isinstance(output["loss"][tl], dict): - for l in output["loss"][tl]: - loss[f"{tl}_{l}"] = output["loss"][tl][l] - else: - loss[tl] = output["loss"][tl] - return loss - - tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform) - if tensorboard_handler is not None: - tensorboard_handler.attach( - trainer, - log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"), - event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar) - ) - - @trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image)) - def show_images(engine): - output = engine.state.output - image_order = dict( - a=["real_a", "fake_b", "rec_a", "id_a"], - b=["real_b", "fake_a", "rec_b", "id_b"] - ) - output["img"]["generated"]["real_a"] = fuse_attention_map( - output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"]) - output["img"]["generated"]["real_b"] = fuse_attention_map( - output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"]) - for k in "ab": - tensorboard_handler.writer.add_image( - f"train/{k}", - make_2d_grid([output["img"]["generated"][o] for o in image_order[k]]), - engine.state.iteration - ) - - with torch.no_grad(): - g = torch.Generator() - g.manual_seed(config.misc.random_seed) - random_start = torch.randperm(len(engine.state.test_dataset)-11, generator=g).tolist()[0] - test_images = dict( - a=[[], [], [], []], - b=[[], [], [], []] - ) - for i in range(random_start, random_start+10): - batch = convert_tensor(engine.state.test_dataset[i], idist.device()) - - real_a, real_b = batch["a"].view(1, *batch["a"].size()), batch["b"].view(1, *batch["a"].size()) - fake_b, _, heatmap_a2b = generators["a2b"](real_a) - fake_a, _, heatmap_b2a = generators["b2a"](real_b) - rec_a = generators["b2a"](fake_b)[0] - rec_b = generators["a2b"](fake_a)[0] - - for idx, im in enumerate( - [attention_colored_map(heatmap_a2b, real_a.size()[-2:]), real_a, fake_b, rec_a]): - test_images["a"][idx].append(im.cpu()) - for idx, im in enumerate( - [attention_colored_map(heatmap_b2a, real_b.size()[-2:]), real_b, fake_a, rec_b]): - test_images["b"][idx].append(im.cpu()) - for n in "ab": - tensorboard_handler.writer.add_image( - f"test/{n}", - make_2d_grid([torch.cat(ti) for ti in test_images[n]]), - engine.state.iteration - ) - - return trainer - - -def get_tester(config, logger): - generator_a2b = build_model(config.model.generator, config.distributed.model) - - def _step(engine, batch): - real_a, path = convert_tensor(batch, idist.device()) - with torch.no_grad(): - fake_b = generator_a2b(real_a)[0] - return {"path": path, "img": [real_a.detach(), fake_b.detach()]} - - tester = Engine(_step) - tester.logger = logger - - to_load = dict(generator_a2b=generator_a2b) - setup_common_handlers(tester, config, use_profiler=False, to_save=to_load) - - @tester.on(Events.STARTED) - def mkdir(engine): - img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images" - engine.state.img_output_dir = Path(img_output_dir) - if idist.get_rank() == 0 and not engine.state.img_output_dir.exists(): - engine.logger.info(f"mkdir {engine.state.img_output_dir}") - engine.state.img_output_dir.mkdir() - - @tester.on(Events.ITERATION_COMPLETED) - def save_images(engine): - img_tensors = engine.state.output["img"] - paths = engine.state.output["path"] - batch_size = img_tensors[0].size(0) - for i in range(batch_size): - image_name = Path(paths[i]).name - torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name, - nrow=len(img_tensors)) - - return tester - - -def run(task, config, logger): - assert torch.backends.cudnn.enabled - torch.backends.cudnn.benchmark = True - logger.info(f"start task {task}") - with read_write(config): - config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size) - - if task == "train": - train_dataset = data.DATASET.build_with(config.data.train.dataset) - logger.info(f"train with dataset:\n{train_dataset}") - train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader) - trainer = get_trainer(config, logger) - if idist.get_rank() == 0: - test_dataset = data.DATASET.build_with(config.data.test.dataset) - trainer.state.test_dataset = test_dataset - try: - trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader))) - except Exception: - import traceback - print(traceback.format_exc()) - elif task == "test": - assert config.resume_from is not None - test_dataset = data.DATASET.build_with(config.data.test.video_dataset) - logger.info(f"test with dataset:\n{test_dataset}") - test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader) - tester = get_tester(config, logger) - try: - tester.run(test_data_loader, max_epochs=1) - except Exception: - import traceback - print(traceback.format_exc()) - else: - return NotImplemented(f"invalid task: {task}") diff --git a/engine/base/i2i.py b/engine/base/i2i.py index 3592699..94e1154 100644 --- a/engine/base/i2i.py +++ b/engine/base/i2i.py @@ -1,32 +1,23 @@ from itertools import chain -from math import ceil -from pathlib import Path import logging +from pathlib import Path import torch +import torchvision import ignite.distributed as idist from ignite.engine import Events, Engine from ignite.metrics import RunningAverage from ignite.utils import convert_tensor -from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler from ignite.contrib.handlers.param_scheduler import PiecewiseLinear -from model import MODEL +from omegaconf import read_write, OmegaConf + from util.image import make_2d_grid -from util.handler import setup_common_handlers, setup_tensorboard_handler -from util.build import build_optimizer +from engine.util.handler import setup_common_handlers, setup_tensorboard_handler +from engine.util.build import build_optimizer -from omegaconf import OmegaConf - - -def build_model(cfg): - cfg = OmegaConf.to_container(cfg) - bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False) - model = MODEL.build_with(cfg) - if bn_to_sync_bn: - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - return idist.auto_model(model) +import data def build_lr_schedulers(optimizers, config): @@ -47,10 +38,26 @@ def build_lr_schedulers(optimizers, config): ) -class EngineKernel(object): - def __init__(self, config, logger): +class TestEngineKernel(object): + def __init__(self, config): self.config = config - self.logger = logger + self.logger = logging.getLogger(config.name) + self.generators = self.build_generators() + + def build_generators(self) -> dict: + raise NotImplemented + + def to_load(self): + return {f"generator_{k}": self.generators[k] for k in self.generators} + + def inference(self, batch): + raise NotImplemented + + +class EngineKernel(object): + def __init__(self, config): + self.config = config + self.logger = logging.getLogger(config.name) self.generators, self.discriminators = self.build_models() def build_models(self) -> (dict, dict): @@ -87,39 +94,43 @@ class EngineKernel(object): raise NotImplemented -def get_trainer(config, ek: EngineKernel, iter_per_epoch): +def get_trainer(config, kernel: EngineKernel): logger = logging.getLogger(config.name) - generators, discriminators = ek.generators, ek.discriminators + generators, discriminators = kernel.generators, kernel.discriminators optimizers = dict( g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator), d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator), ) - logger.info("build optimizers", optimizers) + logger.info(f"build optimizers:\n{optimizers}") lr_schedulers = build_lr_schedulers(optimizers, config) logger.info(f"build lr_schedulers:\n{lr_schedulers}") + image_per_iteration = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1) + def _step(engine, batch): batch = convert_tensor(batch, idist.device()) - generated = ek.forward(batch) + generated = kernel.forward(batch) - ek.setup_before_g() + kernel.setup_before_g() optimizers["g"].zero_grad() - loss_g = ek.criterion_generators(batch, generated) + loss_g = kernel.criterion_generators(batch, generated) sum(loss_g.values()).backward() optimizers["g"].step() - ek.setup_before_d() + kernel.setup_before_d() optimizers["d"].zero_grad() - loss_d = ek.criterion_discriminators(batch, generated) + loss_d = kernel.criterion_discriminators(batch, generated) sum(loss_d.values()).backward() optimizers["d"].step() - return { - "loss": dict(g=loss_g, d=loss_d), - "img": ek.intermediate_images(batch, generated) - } + if engine.state.iteration % image_per_iteration == 0: + return { + "loss": dict(g=loss_g, d=loss_d), + "img": kernel.intermediate_images(batch, generated) + } + return {"loss": dict(g=loss_g, d=loss_d)} trainer = Engine(_step) trainer.logger = logger @@ -131,33 +142,22 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch): to_save = dict(trainer=trainer) to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers}) to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers}) - to_save.update(ek.to_save()) - setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True, + to_save.update(kernel.to_save()) + setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=config.handler.clear_cuda_cache, + set_epoch_for_dist_sampler=config.handler.set_epoch_for_dist_sampler, end_event=Events.ITERATION_COMPLETED(once=config.max_iteration)) - def output_transform(output): - loss = dict() - for tl in output["loss"]: - if isinstance(output["loss"][tl], dict): - for l in output["loss"][tl]: - loss[f"{tl}_{l}"] = output["loss"][tl][l] - else: - loss[tl] = output["loss"][tl] - return loss - - pairs_per_iteration = config.data.train.dataloader.batch_size - tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch) + tensorboard_handler = setup_tensorboard_handler(trainer, config, optimizers, step_type="item") if tensorboard_handler is not None: - tensorboard_handler.attach( - trainer, - log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"), - event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1)) - ) + basic_image_event = Events.ITERATION_COMPLETED( + every=image_per_iteration) + pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size() - @trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1))) + @trainer.on(basic_image_event) def show_images(engine): output = engine.state.output test_images = {} + for k in output["img"]: image_list = output["img"][k] tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list), @@ -174,8 +174,8 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch): batch = convert_tensor(engine.state.test_dataset[i], idist.device()) for k in batch: batch[k] = batch[k].view(1, *batch[k].size()) - generated = ek.forward(batch) - images = ek.intermediate_images(batch, generated) + generated = kernel.forward(batch) + images = kernel.intermediate_images(batch, generated) for k in test_images: for j in range(len(images[k])): @@ -187,3 +187,78 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch): engine.state.iteration * pairs_per_iteration ) return trainer + + +def get_tester(config, kernel: TestEngineKernel): + logger = logging.getLogger(config.name) + + def _step(engine, batch): + real_a, path = convert_tensor(batch, idist.device()) + fake = kernel.inference({"a": real_a})["a"] + return {"path": path, "img": [real_a.detach(), fake.detach()]} + + tester = Engine(_step) + tester.logger = logger + + setup_common_handlers(tester, config, use_profiler=True, to_save=kernel.to_load()) + + @tester.on(Events.STARTED) + def mkdir(engine): + img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images" + engine.state.img_output_dir = Path(img_output_dir) + if idist.get_rank() == 0 and not engine.state.img_output_dir.exists(): + engine.logger.info(f"mkdir {engine.state.img_output_dir}") + engine.state.img_output_dir.mkdir() + + @tester.on(Events.ITERATION_COMPLETED) + def save_images(engine): + img_tensors = engine.state.output["img"] + paths = engine.state.output["path"] + batch_size = img_tensors[0].size(0) + for i in range(batch_size): + image_name = Path(paths[i]).name + torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name, + nrow=len(img_tensors)) + + return tester + + +def run_kernel(task, config, kernel): + assert torch.backends.cudnn.enabled + torch.backends.cudnn.benchmark = True + logger = logging.getLogger(config.name) + with read_write(config): + real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size() + config.max_iteration = config.max_pairs // real_batch_size + 1 + + if task == "train": + train_dataset = data.DATASET.build_with(config.data.train.dataset) + logger.info(f"train with dataset:\n{train_dataset}") + dataloader_kwargs = OmegaConf.to_container(config.data.train.dataloader) + dataloader_kwargs["batch_size"] = dataloader_kwargs["batch_size"] * idist.get_world_size() + train_data_loader = idist.auto_dataloader(train_dataset, **dataloader_kwargs) + with read_write(config): + config.iterations_per_epoch = len(train_data_loader) + + trainer = get_trainer(config, kernel) + if idist.get_rank() == 0: + test_dataset = data.DATASET.build_with(config.data.test.dataset) + trainer.state.test_dataset = test_dataset + try: + trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1) + except Exception: + import traceback + print(traceback.format_exc()) + elif task == "test": + assert config.resume_from is not None + test_dataset = data.DATASET.build_with(config.data.test.video_dataset) + logger.info(f"test with dataset:\n{test_dataset}") + test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader) + tester = get_tester(config, kernel) + try: + tester.run(test_data_loader, max_epochs=1) + except Exception: + import traceback + print(traceback.format_exc()) + else: + return NotImplemented(f"invalid task: {task}") diff --git a/engine/crossdomain.py b/engine/crossdomain.py deleted file mode 100644 index cf732d3..0000000 --- a/engine/crossdomain.py +++ /dev/null @@ -1,85 +0,0 @@ -import torch -import torch.nn as nn -from torchvision.datasets import ImageFolder - -import ignite.distributed as idist -from ignite.contrib.metrics.gpu_info import GpuInfo -from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, global_step_from_engine, OutputHandler, \ - WeightsScalarHandler, GradsHistHandler, WeightsHistHandler, GradsScalarHandler -from ignite.engine import create_supervised_evaluator, create_supervised_trainer, Events -from ignite.metrics import Accuracy, Loss, RunningAverage -from ignite.contrib.engines.common import save_best_model_by_val_score -from ignite.contrib.handlers import ProgressBar - -from util.build import build_model, build_optimizer -from util.handler import setup_common_handlers -from data.transform import transform_pipeline -from data.dataset import LMDBDataset - - -def warmup_trainer(config, logger): - model = build_model(config.model, config.distributed.model) - optimizer = build_optimizer(model.parameters(), config.baseline.optimizers) - loss_fn = nn.CrossEntropyLoss() - trainer = create_supervised_trainer(model, optimizer, loss_fn, idist.device(), non_blocking=True, - output_transform=lambda x, y, y_pred, loss: (loss.item(), y_pred, y)) - trainer.logger = logger - RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss") - Accuracy(output_transform=lambda x: (x[1], x[2])).attach(trainer, "acc") - ProgressBar(ncols=0).attach(trainer) - - if idist.get_rank() == 0: - GpuInfo().attach(trainer, name='gpu') - - tb_logger = TensorboardLogger(log_dir=config.output_dir) - tb_logger.attach( - trainer, - log_handler=OutputHandler( - tag="train", - metric_names='all', - global_step_transform=global_step_from_engine(trainer), - ), - event_name=Events.EPOCH_COMPLETED - ) - tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), - event_name=Events.EPOCH_COMPLETED(every=10)) - - tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25)) - - tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), - event_name=Events.EPOCH_COMPLETED(every=10)) - - tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25)) - - @trainer.on(Events.COMPLETED) - def _(): - tb_logger.close() - - to_save = dict(model=model, optimizer=optimizer, trainer=trainer) - setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.EPOCH_COMPLETED, to_save=to_save, - save_interval_event=Events.EPOCH_COMPLETED(every=25), n_saved=5, - metrics_to_print=["loss", "acc"]) - return trainer - - -def run(task, config, logger): - assert torch.backends.cudnn.enabled - torch.backends.cudnn.benchmark = True - logger.info(f"start task {task}") - if task == "warmup": - train_dataset = LMDBDataset(config.baseline.data.dataset.train.lmdb_path, - pipeline=config.baseline.data.dataset.train.pipeline) - logger.info(f"train with dataset:\n{train_dataset}") - train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader) - trainer = warmup_trainer(config, logger) - try: - trainer.run(train_data_loader, max_epochs=400) - except Exception: - import traceback - print(traceback.format_exc()) - elif task == "protonet-wo": - pass - elif task == "protonet-w": - pass - else: - return ValueError(f"invalid task: {task}") diff --git a/engine/fewshot.py b/engine/fewshot.py deleted file mode 100644 index b449c6d..0000000 --- a/engine/fewshot.py +++ /dev/null @@ -1,9 +0,0 @@ -from data.dataset import EpisodicDataset, LMDBDataset - - -def prototypical_trainer(config, logger): - pass - - -def prototypical_dataloader(config): - pass diff --git a/engine/util/__init__.py b/engine/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/util/build.py b/engine/util/build.py similarity index 57% rename from util/build.py rename to engine/util/build.py index 0e53b98..6e59e22 100644 --- a/util/build.py +++ b/engine/util/build.py @@ -1,23 +1,19 @@ import torch -import torch.optim as optim import ignite.distributed as idist from omegaconf import OmegaConf from model import MODEL -from util.distributed import auto_model +import torch.optim as optim -def build_model(cfg, distributed_args=None): +def build_model(cfg): cfg = OmegaConf.to_container(cfg) - model_distributed_config = cfg.pop("_distributed", dict()) + bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False) model = MODEL.build_with(cfg) - - if model_distributed_config.get("bn_to_syncbn"): + if bn_to_sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - - distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args - return auto_model(model, **distributed_args) + return idist.auto_model(model) def build_optimizer(params, cfg): diff --git a/util/handler.py b/engine/util/handler.py similarity index 57% rename from util/handler.py rename to engine/util/handler.py index d7cf1c2..b2334cc 100644 --- a/util/handler.py +++ b/engine/util/handler.py @@ -7,7 +7,7 @@ import ignite.distributed as idist from ignite.engine import Events, Engine from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar -from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler +from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler, OptimizerParamsHandler def empty_cuda_cache(_): @@ -16,6 +16,16 @@ def empty_cuda_cache(_): gc.collect() +def step_transform_maker(stype: str, pairs_per_iteration=None): + assert stype in ["item", "iteration", "epoch"] + if stype == "item": + return lambda engine, _: engine.state.iteration * pairs_per_iteration + if stype == "iteration": + return lambda engine, _: engine.state.iteration + if stype == "epoch": + return lambda engine, _: engine.state.epoch + + def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True, to_save=None, end_event=None, set_epoch_for_dist_sampler=False): """ @@ -41,9 +51,10 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_ trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler") trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1) - @trainer.on(Events.STARTED | Events.EPOCH_COMPLETED(once=1)) + trainer.logger.info(f"data loader length: {config.iterations_per_epoch} iterations per epoch") + + @trainer.on(Events.EPOCH_COMPLETED(once=1)) def print_info(engine): - engine.logger.info(f"data loader length: {len(engine.state.dataloader)}") engine.logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}") if stop_on_nan: @@ -66,7 +77,7 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_ if to_save is not None: checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False), - n_saved=config.checkpoint.n_saved, filename_prefix=config.name) + n_saved=config.handler.checkpoint.n_saved, filename_prefix=config.name) if config.resume_from is not None: @trainer.on(Events.STARTED) def resume(engine): @@ -77,8 +88,10 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_ trainer.logger.info(f"load state_dict for {ckp.keys()}") Checkpoint.load_objects(to_load=to_save, checkpoint=ckp) engine.logger.info(f"resume from a checkpoint {checkpoint_path}") - trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED, - checkpoint_handler) + trainer.add_event_handler( + Events.EPOCH_COMPLETED(every=config.handler.checkpoint.epoch_interval) | Events.COMPLETED, + checkpoint_handler + ) trainer.logger.debug(f"add checkpoint handler to save {to_save.keys()} periodically") if end_event is not None: trainer.logger.debug(f"engine will stop on {end_event}") @@ -88,17 +101,48 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_ engine.terminate() -def setup_tensorboard_handler(trainer: Engine, config, output_transform, iter_per_epoch): - if config.interval.tensorboard is None: +def setup_tensorboard_handler(trainer: Engine, config, optimizers, step_type="item"): + if config.handler.tensorboard is None: return None if idist.get_rank() == 0: # Create a logger tb_logger = TensorboardLogger(log_dir=config.output_dir) - basic_event = Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1)) - tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"), - event_name=basic_event) - tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform), - event_name=basic_event) + tb_writer = tb_logger.writer + pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size() + global_step_transform = step_transform_maker(step_type, pairs_per_iteration) + + basic_event = Events.ITERATION_COMPLETED( + every=max(config.iterations_per_epoch // config.handler.tensorboard.scalar, 1)) + tb_logger.attach( + trainer, + log_handler=OutputHandler( + tag="metric", metric_names="all", + global_step_transform=global_step_transform + ), + event_name=basic_event + ) + + @trainer.on(basic_event) + def log_loss(engine): + global_step = global_step_transform(engine, None) + output_loss = engine.state.output["loss"] + for total_loss in output_loss: + if isinstance(output_loss[total_loss], dict): + for ln in output_loss[total_loss]: + tb_writer.add_scalar(f"train_{total_loss}/{ln}", output_loss[total_loss][ln], global_step) + else: + tb_writer.add_scalar(f"train/{total_loss}", output_loss[total_loss], global_step) + + if isinstance(optimizers, dict): + for name in optimizers: + tb_logger.attach( + trainer, + log_handler=OptimizerParamsHandler(optimizers[name], tag=f"optimizer_{name}"), + event_name=Events.ITERATION_STARTED + ) + else: + tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizers, tag=f"optimizer"), + event_name=Events.ITERATION_STARTED) @trainer.on(Events.COMPLETED) @idist.one_rank_only() diff --git a/loss/I2I/perceptual_loss.py b/loss/I2I/perceptual_loss.py index a390063..4436089 100644 --- a/loss/I2I/perceptual_loss.py +++ b/loss/I2I/perceptual_loss.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import torch.nn.functional as F import torchvision.models.vgg as vgg @@ -97,12 +98,13 @@ class PerceptualLoss(nn.Module): self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type, use_input_norm=use_input_norm) - if criterion == 'L1': - self.criterion = torch.nn.L1Loss() - elif criterion == "L2": - self.criterion = torch.nn.MSELoss() - else: - raise NotImplementedError(f'{criterion} criterion has not been supported in this version.') + self.criterion = self.set_criterion(criterion) + + def set_criterion(self, criterion: str): + assert criterion in ["NL1", "NL2", "L1", "L2"] + norm = F.instance_norm if criterion.startswith("N") else lambda x: x + fn = F.l1_loss if criterion.endswith("L1") else F.mse_loss + return lambda x, t: fn(norm(x), norm(t)) def forward(self, x, gt): """Forward function. @@ -124,8 +126,7 @@ class PerceptualLoss(nn.Module): if self.perceptual_loss: percep_loss = 0 for k in x_features.keys(): - percep_loss += self.criterion( - x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] else: percep_loss = None @@ -133,9 +134,8 @@ class PerceptualLoss(nn.Module): if self.style_loss: style_loss = 0 for k in x_features.keys(): - style_loss += self.criterion( - self._gram_mat(x_features[k]), - self._gram_mat(gt_features[k])) * self.layer_weights[k] + style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \ + self.layer_weights[k] else: style_loss = None diff --git a/main.py b/main.py index eca18d9..7e1c1d5 100644 --- a/main.py +++ b/main.py @@ -33,10 +33,10 @@ def running(local_rank, config, task, backup_config=False, setup_output_dir=Fals if setup_output_dir and config.resume_from is None: if output_dir.exists(): - # assert not any(output_dir.iterdir()), "output_dir must be empty" - contains = list(output_dir.iterdir()) - assert (len(contains) == 0) or (len(contains) == 1 and contains[0].name == "config.yml"), \ - f"output_dir must by empty or only contains config.yml, but now got {len(contains)} files" + assert len(list(output_dir.glob("events*"))) == 0 + assert len(list(output_dir.glob("*.pt"))) == 0 + if (output_dir / "train.log").exists() and idist.get_rank() == 0: + (output_dir / "train.log").unlink() else: if idist.get_rank() == 0: output_dir.mkdir(parents=True) diff --git a/model/GAN/TAHG.py b/model/GAN/TAFG.py similarity index 84% rename from model/GAN/TAHG.py rename to model/GAN/TAFG.py index afb619a..a3cf097 100644 --- a/model/GAN/TAHG.py +++ b/model/GAN/TAFG.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from .residual_generator import ResidualBlock +from .base import ResidualBlock from model.registry import MODEL from torchvision.models import vgg19 from model.normalization import select_norm_layer @@ -148,48 +148,65 @@ class Fusion(nn.Module): return self.end_fc(x) -@MODEL.register_module("TAHG-Generator") +class StyleGenerator(nn.Module): + def __init__(self, style_in_channels, style_dim=512, num_blocks=8, base_channels=64, padding_mode="reflect"): + super().__init__() + self.num_blocks = num_blocks + self.style_encoder = VGG19StyleEncoder( + style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE") + self.fc = nn.Sequential( + nn.Linear(style_dim, style_dim), + nn.ReLU(True), + ) + res_block_channels = 2 ** 2 * base_channels + self.fusion = Fusion(style_dim, num_blocks * 2 * res_block_channels * 2, base_features=256, n_blocks=3, + norm_type="NONE") + + def forward(self, x): + styles = self.fusion(self.fc(self.style_encoder(x))) + return styles + + +@MODEL.register_module("TAFG-Generator") class Generator(nn.Module): def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, num_blocks=8, base_channels=64, padding_mode="reflect"): super(Generator, self).__init__() self.num_blocks = num_blocks self.style_encoders = nn.ModuleDict({ - "a": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, - padding_mode=padding_mode, norm_type="NONE"), - "b": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, - padding_mode=padding_mode, norm_type="NONE") + "a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks, + base_channels=base_channels, padding_mode=padding_mode), + "b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks, + base_channels=base_channels, padding_mode=padding_mode), }) self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks, padding_mode=padding_mode, norm_type="IN") res_block_channels = 2 ** 2 * base_channels - self.adain_res = nn.ModuleList([ + self.adain_resnet_a = nn.ModuleList([ + ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks) + ]) + self.adain_resnet_b = nn.ModuleList([ ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks) ]) - self.decoders = nn.ModuleDict({ - "a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode), - "b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode) - }) - self.fc = nn.Sequential( - nn.Linear(style_dim, style_dim), - nn.ReLU(True), - ) - self.fusion = Fusion(style_dim, num_blocks * 2 * res_block_channels * 2, base_features=256, n_blocks=3, - norm_type="NONE") + self.decoders = nn.ModuleDict({ + "a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode), + "b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode) + }) def forward(self, content_img, style_img, which_decoder: str = "a"): x = self.content_encoder(content_img) - styles = self.fusion(self.fc(self.style_encoders[which_decoder](style_img))) + styles = self.style_encoders[which_decoder](style_img) styles = torch.chunk(styles, self.num_blocks * 2, dim=1) - for i, ar in enumerate(self.adain_res): + resnet = self.adain_resnet_a if which_decoder == "a" else self.adain_resnet_b + for i, ar in enumerate(resnet): ar.norm1.set_style(styles[2 * i]) ar.norm2.set_style(styles[2 * i + 1]) x = ar(x) return self.decoders[which_decoder](x) -@MODEL.register_module("TAHG-Discriminator") +@MODEL.register_module("TAFG-Discriminator") class Discriminator(nn.Module): def __init__(self, in_channels=3, base_channels=64, num_down_sampling=2, num_blocks=3, norm_type="IN", padding_mode="reflect"): diff --git a/model/GAN/base.py b/model/GAN/base.py index bd70ac2..3e15c3f 100644 --- a/model/GAN/base.py +++ b/model/GAN/base.py @@ -7,7 +7,7 @@ from model import MODEL # based SPADE or pix2pixHD Discriminator -@MODEL.register_module("base-PatchDiscriminator") +@MODEL.register_module("pix2pixHD-PatchDiscriminator") class PatchDiscriminator(nn.Module): def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN", need_intermediate_feature=False): @@ -59,3 +59,26 @@ class PatchDiscriminator(nn.Module): for layer in self.conv_blocks: x = layer(x) return x + + +@MODEL.register_module() +class ResidualBlock(nn.Module): + def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None): + super(ResidualBlock, self).__init__() + if use_bias is None: + # Only for IN, use bias since it does not have affine parameters. + use_bias = norm_type == "IN" + norm_layer = select_norm_layer(norm_type) + self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, + bias=use_bias) + self.norm1 = norm_layer(num_channels) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, + bias=use_bias) + self.norm2 = norm_layer(num_channels) + + def forward(self, x): + res = x + x = self.relu1(self.norm1(self.conv1(x))) + x = self.norm2(self.conv2(x)) + return x + res diff --git a/model/GAN/residual_generator.py b/model/GAN/residual_generator.py index de8ae9b..7a0f0d3 100644 --- a/model/GAN/residual_generator.py +++ b/model/GAN/residual_generator.py @@ -58,27 +58,29 @@ class GANImageBuffer(object): return return_images -@MODEL.register_module() class ResidualBlock(nn.Module): - def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None): + def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False, use_bias=None): super(ResidualBlock, self).__init__() + if use_bias is None: # Only for IN, use bias since it does not have affine parameters. use_bias = norm_type == "IN" norm_layer = select_norm_layer(norm_type) - self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, - bias=use_bias) - self.norm1 = norm_layer(num_channels) - self.relu1 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, - bias=use_bias) - self.norm2 = norm_layer(num_channels) + models = [nn.Sequential( + nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias), + norm_layer(num_channels), + nn.ReLU(inplace=True), + )] + if use_dropout: + models.append(nn.Dropout(0.5)) + models.append(nn.Sequential( + nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias), + norm_layer(num_channels), + )) + self.block = nn.Sequential(*models) def forward(self, x): - res = x - x = self.relu1(self.norm1(self.conv1(x))) - x = self.norm2(self.conv2(x)) - return x + res + return x + self.block(x) @MODEL.register_module() diff --git a/model/__init__.py b/model/__init__.py index 2b43540..6331c07 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -1,6 +1,6 @@ from model.registry import MODEL import model.GAN.residual_generator -import model.GAN.TAHG +import model.GAN.TAFG import model.GAN.UGATIT import model.fewshot import model.GAN.wrapper diff --git a/util/distributed.py b/util/distributed.py deleted file mode 100644 index fd10615..0000000 --- a/util/distributed.py +++ /dev/null @@ -1,66 +0,0 @@ -import torch -import torch.nn as nn -from ignite.distributed import utils as idist -from ignite.distributed.comp_models import native as idist_native -from ignite.utils import setup_logger - - -def auto_model(model: nn.Module, **additional_kwargs) -> nn.Module: - """Helper method to adapt provided model for non-distributed and distributed configurations (supporting - all available backends from :meth:`~ignite.distributed.utils.available_backends()`). - - Internally, we perform to following: - - - send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device. - - wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1. - - wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available. - - Examples: - - .. code-block:: python - - import ignite.distribted as idist - - model = idist.auto_model(model) - - In addition with NVidia/Apex, it can be used in the following way: - - .. code-block:: python - - import ignite.distribted as idist - - model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) - model = idist.auto_model(model) - - Args: - model (torch.nn.Module): model to adapt. - - Returns: - torch.nn.Module - - .. _torch DistributedDataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel - .. _torch DataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel - """ - logger = setup_logger(__name__ + ".auto_model") - - # Put model's parameters to device if its parameters are not on the device - device = idist.device() - if not all([p.device == device for p in model.parameters()]): - model.to(device) - - # distributed data parallel model - if idist.get_world_size() > 1: - if idist.backend() == idist_native.NCCL: - lrank = idist.get_local_rank() - logger.info("Apply torch DistributedDataParallel on model, device id: {}".format(lrank)) - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[lrank, ], **additional_kwargs) - elif idist.backend() == idist_native.GLOO: - logger.info("Apply torch DistributedDataParallel on model") - model = torch.nn.parallel.DistributedDataParallel(model, **additional_kwargs) - - # not distributed but multiple GPUs reachable so data parallel model - elif torch.cuda.device_count() > 1 and "cuda" in idist.device().type: - logger.info("Apply torch DataParallel on model") - model = torch.nn.parallel.DataParallel(model, **additional_kwargs) - - return model diff --git a/util/image.py b/util/image.py index 524d382..b43e288 100644 --- a/util/image.py +++ b/util/image.py @@ -1,26 +1,34 @@ import torchvision.utils -from matplotlib.pyplot import get_cmap import torch import warnings -from torch.nn.functional import interpolate +import numpy as np +import cv2 -def attention_colored_map(attentions, size=None, cmap_name="jet"): +def attention_colored_map(attentions, size=None): assert attentions.dim() == 4 and attentions.size(1) == 1 + device = attentions.device min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1) attentions -= min_attentions attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1) - if size is not None and attentions.size()[-2:] != size: + attentions = attentions.detach().cpu().numpy() + attentions = (attentions * 255).astype(np.uint8) + need_resize = False + if size is not None and attentions.shape[-2:] != size: assert len(size) == 2, "for interpolate, size must be (x, y), have two dim" - attentions = interpolate(attentions, size, mode="bilinear", align_corners=False) - cmap = get_cmap(cmap_name) - ca = cmap(attentions.squeeze(1).cpu())[:, :, :, :3] - return torch.from_numpy(ca).permute(0, 3, 1, 2).contiguous() + need_resize = True + + subs = [] + for sub in attentions: + sub = cv2.resize(sub[0], size) if need_resize else sub[0] # numpy.array shape=size + subs.append(cv2.applyColorMap(sub, cv2.COLORMAP_JET)) # append a (size[0], size[1], 3) numpy array + subs = np.stack(subs) # (batch_size, size[0], size[1], 3) + return torch.from_numpy(subs).permute(0, 3, 1, 2).contiguous().to(device).float() / 255 -def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5): +def fuse_attention_map(images, attentions, alpha=0.5): """ :param images: B x H x W @@ -35,7 +43,7 @@ def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5): if attentions.size(1) != 1: warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}") return images - colored_attentions = attention_colored_map(attentions, images.size()[-2:], cmap_name).to(images.device) + colored_attentions = attention_colored_map(attentions, images.size()[-2:]) return images * alpha + colored_attentions * (1 - alpha)