This commit is contained in:
2020-09-05 10:33:35 +08:00
parent 2469bf15fe
commit 39c754374c
21 changed files with 550 additions and 705 deletions

9
.idea/deployment.xml generated
View File

@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="22d" remoteFilesAllowedToDisappearOnAutoupload="false">
<component name="PublishConfigData" autoUpload="Always" serverName="21d" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="14d">
<serverdata>
@@ -16,6 +16,13 @@
</mappings>
</serverdata>
</paths>
<paths name="21d">
<serverdata>
<mappings>
<mapping deploy="/raycv" local="$PROJECT_DIR$" web="" />
</mappings>
</serverdata>
</paths>
<paths name="22d">
<serverdata>
<mappings>

View File

@@ -3,31 +3,32 @@ engine: TAFG
result_dir: ./result
max_pairs: 1000000
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch
misc:
random_seed: 324
checkpoint:
epoch_interval: 1 # one checkpoint every 1 epoch
n_saved: 2
interval:
print_per_iteration: 10 # print once per 10 iteration
tensorboard:
scalar: 100
image: 2
model:
generator:
_type: TAHG-Generator
_type: TAFG-Generator
_bn_to_sync_bn: False
style_in_channels: 3
content_in_channels: 1
num_blocks: 4
content_in_channels: 24
num_blocks: 8
discriminator:
_type: MultiScaleDiscriminator
num_scale: 2
discriminator_cfg:
_type: base-PatchDiscriminator
_type: pix2pixHD
in_channels: 3
base_channels: 64
use_spectral: True
@@ -46,7 +47,7 @@ loss:
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L1'
criterion: 'NL1'
style_loss: False
perceptual_loss: True
weight: 5
@@ -63,10 +64,13 @@ loss:
weight: 0
fm:
level: 1
weight: 10
weight: 1
recon:
level: 1
weight: 5
weight: 10
style_recon:
level: 1
weight: 10
optimizers:
generator:
@@ -87,7 +91,7 @@ data:
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 256
batch_size: 24
shuffle: True
num_workers: 2
pin_memory: True
@@ -98,13 +102,13 @@ data:
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
edge_type: "landmark_canny"
size: [128, 128]
edge_type: "landmark_hed"
size: [ 128, 128 ]
random_pair: True
pipeline:
- Load
- Resize:
size: [128, 128]
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
@@ -121,13 +125,14 @@ data:
root_a: "/data/i2i/VoxCeleb2Anime/testA"
root_b: "/data/i2i/VoxCeleb2Anime/testB"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
edge_type: "hed"
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
edge_type: "landmark_hed"
random_pair: False
size: [128, 128]
size: [ 128, 128 ]
pipeline:
- Load
- Resize:
size: [128, 128]
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]

View File

@@ -1,24 +1,20 @@
name: selfie2anime
engine: UGATIT
engine: U-GAT-IT
result_dir: ./result
max_pairs: 1000000
distributed:
model:
# broadcast_buffers: False
misc:
random_seed: 324
checkpoint:
epoch_interval: 1 # one checkpoint every 1 epoch
n_saved: 2
interval:
print_per_iteration: 10 # print once per 10 iteration
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 10
image: 500
scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch
model:
generator:
@@ -59,12 +55,12 @@ optimizers:
generator:
_type: Adam
lr: 0.0001
betas: [0.5, 0.999]
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 1e-4
betas: [0.5, 0.999]
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
data:
@@ -74,7 +70,7 @@ data:
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 4
batch_size: 24
shuffle: True
num_workers: 2
pin_memory: True
@@ -87,14 +83,14 @@ data:
pipeline:
- Load
- Resize:
size: [286, 286]
size: [ 286, 286 ]
- RandomCrop:
size: [256, 256]
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
dataloader:
batch_size: 8
@@ -110,11 +106,11 @@ data:
pipeline:
- Load
- Resize:
size: [256, 256]
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
video_dataset:
_type: SingleFolderDataset
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
@@ -124,6 +120,3 @@ data:
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@@ -177,6 +177,13 @@ class GenerationUnpairedDataset(Dataset):
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
def normalize_tensor(tensor):
tensor = tensor.float()
tensor -= tensor.min()
tensor /= tensor.max()
return tensor
@DATASET.register_module()
class GenerationUnpairedDatasetWithEdge(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256)):
@@ -200,17 +207,19 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
edge_type = self.edge_type
use_landmark = False
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size))
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR))
if not use_landmark:
return origin_edge
else:
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.{edge_type}.txt"
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt"
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
dist_tensor = torch.from_numpy(dlib_landmark.dist_tensor(key_points))
part_labels = torch.from_numpy(part_labels)
edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
edges = part_edge + edges
return torch.cat([edges, dist_tensor, part_labels], dim=0)
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size)))
part_labels = normalize_tensor(torch.from_numpy(part_labels))
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
# edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
# edges = part_edge + edges
return torch.cat([origin_edge, part_edge, dist_tensor, part_labels])
def __getitem__(self, idx):
a_idx = idx % len(self.A)

View File

@@ -1,24 +1,22 @@
from itertools import chain
from math import ceil
from omegaconf import read_write, OmegaConf
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import torch.nn.functional as F
import ignite.distributed as idist
import data
from engine.base.i2i import get_trainer, EngineKernel, build_model
from model.weight_init import generation_init_weights
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
class TAFGEngineKernel(EngineKernel):
def __init__(self, config, logger):
super().__init__(config, logger)
def __init__(self, config):
super().__init__(config)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
@@ -29,6 +27,11 @@ class TAFGEngineKernel(EngineKernel):
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
def _process_batch(self, batch, inference=False):
# batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size())
return batch
def build_models(self) -> (dict, dict):
generators = dict(
@@ -56,6 +59,7 @@ class TAFGEngineKernel(EngineKernel):
def forward(self, batch, inference=False) -> dict:
generator = self.generators["main"]
batch = self._process_batch(batch, inference)
with torch.set_grad_enabled(not inference):
fake = dict(
a=generator(content_img=batch["edge_a"], style_img=batch["a"], which_decoder="a"),
@@ -64,6 +68,7 @@ class TAFGEngineKernel(EngineKernel):
return fake
def criterion_generators(self, batch, generated) -> dict:
batch = self._process_batch(batch)
loss = dict()
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
@@ -85,10 +90,15 @@ class TAFGEngineKernel(EngineKernel):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
loss["recon"] = self.recon_loss(generated["a"], batch["a"]) * self.config.loss.recon.weight
loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
self.generators["main"].module.style_encoders["b"](batch["b"]),
self.generators["main"].module.style_encoders["b"](generated["b"])
)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
# batch = self._process_batch(batch)
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase])
pred_fake = self.discriminators[phase](generated[phase].detach())
@@ -105,31 +115,14 @@ class TAFGEngineKernel(EngineKernel):
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
batch = self._process_batch(batch)
edge = batch["edge_a"][:, 0:1, :, :]
return dict(
a=[batch[f"edge_a"].expand(-1, 3, -1, -1).detach(), batch["a"].detach(), generated["a"].detach()],
b=[batch["b"].detach(), generated["b"].detach()]
a=[edge.expand(-1, 3, -1, -1).detach(), batch["a"].detach(), batch["b"].detach(), generated["a"].detach(),
generated["b"].detach()]
)
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
with read_write(config):
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, TAFGEngineKernel(config, logger), len(train_data_loader))
if idist.get_rank() == 0:
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset
try:
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")
def run(task, config, _):
kernel = TAFGEngineKernel(config)
run_kernel(task, config, kernel)

153
engine/U-GAT-IT.py Normal file
View File

@@ -0,0 +1,153 @@
from itertools import chain
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import torch.nn.functional as F
import ignite.distributed as idist
from model.weight_init import generation_init_weights
from loss.gan import GANLoss
from model.GAN.UGATIT import RhoClipper
from model.GAN.residual_generator import GANImageBuffer
from util.image import attention_colored_map
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
class UGATITEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
self.rho_clipper = RhoClipper(0, 1)
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
def build_models(self) -> (dict, dict):
generators = dict(
a2b=build_model(self.config.model.generator),
b2a=build_model(self.config.model.generator)
)
discriminators = dict(
la=build_model(self.config.model.local_discriminator),
lb=build_model(self.config.model.local_discriminator),
ga=build_model(self.config.model.global_discriminator),
gb=build_model(self.config.model.global_discriminator),
)
self.logger.debug(discriminators["ga"])
self.logger.debug(generators["a2b"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_before_d(self):
for generator in self.generators.values():
generator.apply(self.rho_clipper)
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
images = dict()
heatmap = dict()
cam_pred = dict()
with torch.set_grad_enabled(not inference):
images["a2b"], cam_pred["a2b"], heatmap["a2b"] = self.generators["a2b"](batch["a"])
images["b2a"], cam_pred["b2a"], heatmap["b2a"] = self.generators["b2a"](batch["b"])
images["a2b2a"], _, heatmap["a2b2a"] = self.generators["b2a"](images["a2b"])
images["b2a2b"], _, heatmap["b2a2b"] = self.generators["a2b"](images["b2a"])
images["a2a"], cam_pred["a2a"], heatmap["a2a"] = self.generators["b2a"](batch["a"])
images["b2b"], cam_pred["b2b"], heatmap["b2b"] = self.generators["a2b"](batch["b"])
return dict(images=images, heatmap=heatmap, cam_pred=cam_pred)
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
for phase in "ab":
cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"]
loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase])
loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase],
generated["images"][f"{phase}2{phase}"])
for dk in "lg":
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)
loss[f"gan_{phase}_{dk}"] = self.config.loss.gan.weight * self.gan_loss(pred_fake, True)
loss[f"gan_cam_{phase}_{dk}"] = self.config.loss.gan.weight * mse_loss(cam_pred, True)
for t, f in [("a2b", "b2b"), ("b2a", "a2a")]:
loss[f"cam_{t[-1]}"] = self.config.loss.cam.weight * (
bce_loss(generated["cam_pred"][t], True) + bce_loss(generated["cam_pred"][f], False))
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
for phase in "ab":
for level in "gl":
generated_image = self.image_buffers[level + phase].query(
generated["images"]["a2b" if phase == "b" else "b2a"])
pred_fake, cam_fake_pred = self.discriminators[level + phase](generated_image)
pred_real, cam_real_pred = self.discriminators[level + phase](batch[phase])
loss[f"gan_{phase}_{level}"] = self.gan_loss(pred_real, True, is_discriminator=True) + self.gan_loss(
pred_fake, False, is_discriminator=True)
loss[f"cam_{phase}_{level}"] = mse_loss(cam_fake_pred, False) + mse_loss(cam_real_pred, True)
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
attention_a = attention_colored_map(generated["heatmap"]["a2b"].detach(), batch["a"].size()[-2:])
attention_b = attention_colored_map(generated["heatmap"]["b2a"].detach(), batch["b"].size()[-2:])
generated = {img: generated["images"][img].detach() for img in generated["images"]}
return {
"a": [batch["a"], attention_a, generated["a2b"], generated["a2a"], generated["a2b2a"]],
"b": [batch["b"], attention_b, generated["b2a"], generated["b2b"], generated["b2a2b"]],
}
class UGATITTestEngineKernel(TestEngineKernel):
def __init__(self, config):
super().__init__(config)
def build_generators(self) -> dict:
generators = dict(
a2b=build_model(self.config.model.generator),
)
return generators
def to_load(self):
return {f"generator_{k}": self.generators[k] for k in self.generators}
def inference(self, batch):
with torch.no_grad():
fake, _, _ = self.generators["a2b"](batch["a"])
return {"a": fake.detach()}
def run(task, config, _):
if task == "train":
kernel = UGATITEngineKernel(config)
if task == "test":
kernel = UGATITTestEngineKernel(config)
run_kernel(task, config, kernel)

View File

@@ -1,320 +0,0 @@
from itertools import chain
from math import ceil
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.metrics import RunningAverage
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from omegaconf import OmegaConf, read_write
import data
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
from model.GAN.residual_generator import GANImageBuffer
from model.GAN.UGATIT import RhoClipper
from util.image import make_2d_grid, fuse_attention_map, attention_colored_map
from util.handler import setup_common_handlers, setup_tensorboard_handler
from util.build import build_model, build_optimizer
def build_lr_schedulers(optimizers, config):
g_milestones_values = [
(0, config.optimizers.generator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
d_milestones_values = [
(0, config.optimizers.discriminator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
return dict(
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
)
def get_trainer(config, logger):
generators = dict(
a2b=build_model(config.model.generator, config.distributed.model),
b2a=build_model(config.model.generator, config.distributed.model),
)
discriminators = dict(
la=build_model(config.model.local_discriminator, config.distributed.model),
lb=build_model(config.model.local_discriminator, config.distributed.model),
ga=build_model(config.model.global_discriminator, config.distributed.model),
gb=build_model(config.model.global_discriminator, config.distributed.model),
)
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
logger.debug(discriminators["ga"])
logger.debug(generators["a2b"])
optimizers = dict(
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
)
logger.info(f"build optimizers:\n{optimizers}")
lr_schedulers = build_lr_schedulers(optimizers, config)
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()}
rho_clipper = RhoClipper(0, 1)
def criterion_generator(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l,
discriminator_g):
discriminator_g.requires_grad_(False)
discriminator_l.requires_grad_(False)
pred_fake_g, cam_gd_pred = discriminator_g(fake)
pred_fake_l, cam_ld_pred = discriminator_l(fake)
return {
f"cycle_{name}": config.loss.cycle.weight * cycle_loss(real, rec),
f"id_{name}": config.loss.id.weight * id_loss(real, identity),
f"cam_{name}": config.loss.cam.weight * (bce_loss(cam_g_pred, True) + bce_loss(cam_g_id_pred, False)),
f"gan_l_{name}": config.loss.gan.weight * gan_loss(pred_fake_l, True),
f"gan_g_{name}": config.loss.gan.weight * gan_loss(pred_fake_g, True),
f"gan_cam_g_{name}": config.loss.gan.weight * mse_loss(cam_gd_pred, True),
f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True),
}
def criterion_discriminator(name, discriminator, real, fake):
pred_real, cam_real = discriminator(real)
pred_fake, cam_fake = discriminator(fake)
# TODO: origin do not divide 2, but I think it better to divide 2.
loss_gan = gan_loss(pred_real, True, is_discriminator=True) + gan_loss(pred_fake, False, is_discriminator=True)
loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False)
return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam}
def _step(engine, real):
real = convert_tensor(real, idist.device())
fake = dict()
cam_generator_pred = dict()
rec = dict()
identity = dict()
cam_identity_pred = dict()
heatmap = dict()
fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real["a"])
fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real["b"])
rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"])
rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"])
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"])
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"])
optimizers["g"].zero_grad()
loss_g = dict()
for n in ["a", "b"]:
loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
sum(loss_g.values()).backward()
optimizers["g"].step()
for generator in generators.values():
generator.apply(rho_clipper)
for discriminator in discriminators.values():
discriminator.requires_grad_(True)
optimizers["d"].zero_grad()
loss_d = dict()
for k in discriminators.keys():
n = k[-1] # "a" or "b"
loss_d.update(
criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach())))
sum(loss_d.values()).backward()
optimizers["d"].step()
for h in heatmap:
heatmap[h] = heatmap[h].detach()
generated_img = {f"real_{k}": real[k].detach() for k in real}
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
generated_img.update({f"id_{k}": identity[k].detach() for k in identity})
generated_img.update({f"rec_{k}": rec[k].detach() for k in rec})
return {
"loss": {
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
"d": {ln: loss_d[ln].mean().item() for ln in loss_d},
},
"img": {
"heatmap": heatmap,
"generated": generated_img
}
}
trainer = Engine(_step)
trainer.logger = logger
for lr_shd in lr_schedulers.values():
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
to_save = dict(trainer=trainer)
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
to_save.update({f"generator_{k}": generators[k] for k in generators})
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform)
if tensorboard_handler is not None:
tensorboard_handler.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
)
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
def show_images(engine):
output = engine.state.output
image_order = dict(
a=["real_a", "fake_b", "rec_a", "id_a"],
b=["real_b", "fake_a", "rec_b", "id_b"]
)
output["img"]["generated"]["real_a"] = fuse_attention_map(
output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"])
output["img"]["generated"]["real_b"] = fuse_attention_map(
output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"])
for k in "ab":
tensorboard_handler.writer.add_image(
f"train/{k}",
make_2d_grid([output["img"]["generated"][o] for o in image_order[k]]),
engine.state.iteration
)
with torch.no_grad():
g = torch.Generator()
g.manual_seed(config.misc.random_seed)
random_start = torch.randperm(len(engine.state.test_dataset)-11, generator=g).tolist()[0]
test_images = dict(
a=[[], [], [], []],
b=[[], [], [], []]
)
for i in range(random_start, random_start+10):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
real_a, real_b = batch["a"].view(1, *batch["a"].size()), batch["b"].view(1, *batch["a"].size())
fake_b, _, heatmap_a2b = generators["a2b"](real_a)
fake_a, _, heatmap_b2a = generators["b2a"](real_b)
rec_a = generators["b2a"](fake_b)[0]
rec_b = generators["a2b"](fake_a)[0]
for idx, im in enumerate(
[attention_colored_map(heatmap_a2b, real_a.size()[-2:]), real_a, fake_b, rec_a]):
test_images["a"][idx].append(im.cpu())
for idx, im in enumerate(
[attention_colored_map(heatmap_b2a, real_b.size()[-2:]), real_b, fake_a, rec_b]):
test_images["b"][idx].append(im.cpu())
for n in "ab":
tensorboard_handler.writer.add_image(
f"test/{n}",
make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
engine.state.iteration
)
return trainer
def get_tester(config, logger):
generator_a2b = build_model(config.model.generator, config.distributed.model)
def _step(engine, batch):
real_a, path = convert_tensor(batch, idist.device())
with torch.no_grad():
fake_b = generator_a2b(real_a)[0]
return {"path": path, "img": [real_a.detach(), fake_b.detach()]}
tester = Engine(_step)
tester.logger = logger
to_load = dict(generator_a2b=generator_a2b)
setup_common_handlers(tester, config, use_profiler=False, to_save=to_load)
@tester.on(Events.STARTED)
def mkdir(engine):
img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images"
engine.state.img_output_dir = Path(img_output_dir)
if idist.get_rank() == 0 and not engine.state.img_output_dir.exists():
engine.logger.info(f"mkdir {engine.state.img_output_dir}")
engine.state.img_output_dir.mkdir()
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
img_tensors = engine.state.output["img"]
paths = engine.state.output["path"]
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
image_name = Path(paths[i]).name
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
nrow=len(img_tensors))
return tester
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
with read_write(config):
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger)
if idist.get_rank() == 0:
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset
try:
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
except Exception:
import traceback
print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, logger)
try:
tester.run(test_data_loader, max_epochs=1)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")

View File

@@ -1,32 +1,23 @@
from itertools import chain
from math import ceil
from pathlib import Path
import logging
from pathlib import Path
import torch
import torchvision
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.metrics import RunningAverage
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from model import MODEL
from omegaconf import read_write, OmegaConf
from util.image import make_2d_grid
from util.handler import setup_common_handlers, setup_tensorboard_handler
from util.build import build_optimizer
from engine.util.handler import setup_common_handlers, setup_tensorboard_handler
from engine.util.build import build_optimizer
from omegaconf import OmegaConf
def build_model(cfg):
cfg = OmegaConf.to_container(cfg)
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
model = MODEL.build_with(cfg)
if bn_to_sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
return idist.auto_model(model)
import data
def build_lr_schedulers(optimizers, config):
@@ -47,10 +38,26 @@ def build_lr_schedulers(optimizers, config):
)
class EngineKernel(object):
def __init__(self, config, logger):
class TestEngineKernel(object):
def __init__(self, config):
self.config = config
self.logger = logger
self.logger = logging.getLogger(config.name)
self.generators = self.build_generators()
def build_generators(self) -> dict:
raise NotImplemented
def to_load(self):
return {f"generator_{k}": self.generators[k] for k in self.generators}
def inference(self, batch):
raise NotImplemented
class EngineKernel(object):
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(config.name)
self.generators, self.discriminators = self.build_models()
def build_models(self) -> (dict, dict):
@@ -87,39 +94,43 @@ class EngineKernel(object):
raise NotImplemented
def get_trainer(config, ek: EngineKernel, iter_per_epoch):
def get_trainer(config, kernel: EngineKernel):
logger = logging.getLogger(config.name)
generators, discriminators = ek.generators, ek.discriminators
generators, discriminators = kernel.generators, kernel.discriminators
optimizers = dict(
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
)
logger.info("build optimizers", optimizers)
logger.info(f"build optimizers:\n{optimizers}")
lr_schedulers = build_lr_schedulers(optimizers, config)
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
image_per_iteration = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1)
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
generated = ek.forward(batch)
generated = kernel.forward(batch)
ek.setup_before_g()
kernel.setup_before_g()
optimizers["g"].zero_grad()
loss_g = ek.criterion_generators(batch, generated)
loss_g = kernel.criterion_generators(batch, generated)
sum(loss_g.values()).backward()
optimizers["g"].step()
ek.setup_before_d()
kernel.setup_before_d()
optimizers["d"].zero_grad()
loss_d = ek.criterion_discriminators(batch, generated)
loss_d = kernel.criterion_discriminators(batch, generated)
sum(loss_d.values()).backward()
optimizers["d"].step()
return {
"loss": dict(g=loss_g, d=loss_d),
"img": ek.intermediate_images(batch, generated)
}
if engine.state.iteration % image_per_iteration == 0:
return {
"loss": dict(g=loss_g, d=loss_d),
"img": kernel.intermediate_images(batch, generated)
}
return {"loss": dict(g=loss_g, d=loss_d)}
trainer = Engine(_step)
trainer.logger = logger
@@ -131,33 +142,22 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
to_save = dict(trainer=trainer)
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
to_save.update(ek.to_save())
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
to_save.update(kernel.to_save())
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=config.handler.clear_cuda_cache,
set_epoch_for_dist_sampler=config.handler.set_epoch_for_dist_sampler,
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
pairs_per_iteration = config.data.train.dataloader.batch_size
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
tensorboard_handler = setup_tensorboard_handler(trainer, config, optimizers, step_type="item")
if tensorboard_handler is not None:
tensorboard_handler.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
)
basic_image_event = Events.ITERATION_COMPLETED(
every=image_per_iteration)
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
@trainer.on(basic_image_event)
def show_images(engine):
output = engine.state.output
test_images = {}
for k in output["img"]:
image_list = output["img"][k]
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list),
@@ -174,8 +174,8 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch:
batch[k] = batch[k].view(1, *batch[k].size())
generated = ek.forward(batch)
images = ek.intermediate_images(batch, generated)
generated = kernel.forward(batch)
images = kernel.intermediate_images(batch, generated)
for k in test_images:
for j in range(len(images[k])):
@@ -187,3 +187,78 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
engine.state.iteration * pairs_per_iteration
)
return trainer
def get_tester(config, kernel: TestEngineKernel):
logger = logging.getLogger(config.name)
def _step(engine, batch):
real_a, path = convert_tensor(batch, idist.device())
fake = kernel.inference({"a": real_a})["a"]
return {"path": path, "img": [real_a.detach(), fake.detach()]}
tester = Engine(_step)
tester.logger = logger
setup_common_handlers(tester, config, use_profiler=True, to_save=kernel.to_load())
@tester.on(Events.STARTED)
def mkdir(engine):
img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images"
engine.state.img_output_dir = Path(img_output_dir)
if idist.get_rank() == 0 and not engine.state.img_output_dir.exists():
engine.logger.info(f"mkdir {engine.state.img_output_dir}")
engine.state.img_output_dir.mkdir()
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
img_tensors = engine.state.output["img"]
paths = engine.state.output["path"]
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
image_name = Path(paths[i]).name
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
nrow=len(img_tensors))
return tester
def run_kernel(task, config, kernel):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger = logging.getLogger(config.name)
with read_write(config):
real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size()
config.max_iteration = config.max_pairs // real_batch_size + 1
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
dataloader_kwargs = OmegaConf.to_container(config.data.train.dataloader)
dataloader_kwargs["batch_size"] = dataloader_kwargs["batch_size"] * idist.get_world_size()
train_data_loader = idist.auto_dataloader(train_dataset, **dataloader_kwargs)
with read_write(config):
config.iterations_per_epoch = len(train_data_loader)
trainer = get_trainer(config, kernel)
if idist.get_rank() == 0:
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset
try:
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
except Exception:
import traceback
print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, kernel)
try:
tester.run(test_data_loader, max_epochs=1)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")

View File

@@ -1,85 +0,0 @@
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
import ignite.distributed as idist
from ignite.contrib.metrics.gpu_info import GpuInfo
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, global_step_from_engine, OutputHandler, \
WeightsScalarHandler, GradsHistHandler, WeightsHistHandler, GradsScalarHandler
from ignite.engine import create_supervised_evaluator, create_supervised_trainer, Events
from ignite.metrics import Accuracy, Loss, RunningAverage
from ignite.contrib.engines.common import save_best_model_by_val_score
from ignite.contrib.handlers import ProgressBar
from util.build import build_model, build_optimizer
from util.handler import setup_common_handlers
from data.transform import transform_pipeline
from data.dataset import LMDBDataset
def warmup_trainer(config, logger):
model = build_model(config.model, config.distributed.model)
optimizer = build_optimizer(model.parameters(), config.baseline.optimizers)
loss_fn = nn.CrossEntropyLoss()
trainer = create_supervised_trainer(model, optimizer, loss_fn, idist.device(), non_blocking=True,
output_transform=lambda x, y, y_pred, loss: (loss.item(), y_pred, y))
trainer.logger = logger
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss")
Accuracy(output_transform=lambda x: (x[1], x[2])).attach(trainer, "acc")
ProgressBar(ncols=0).attach(trainer)
if idist.get_rank() == 0:
GpuInfo().attach(trainer, name='gpu')
tb_logger = TensorboardLogger(log_dir=config.output_dir)
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="train",
metric_names='all',
global_step_transform=global_step_from_engine(trainer),
),
event_name=Events.EPOCH_COMPLETED
)
tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model),
event_name=Events.EPOCH_COMPLETED(every=10))
tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25))
tb_logger.attach(trainer, log_handler=GradsScalarHandler(model),
event_name=Events.EPOCH_COMPLETED(every=10))
tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25))
@trainer.on(Events.COMPLETED)
def _():
tb_logger.close()
to_save = dict(model=model, optimizer=optimizer, trainer=trainer)
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.EPOCH_COMPLETED, to_save=to_save,
save_interval_event=Events.EPOCH_COMPLETED(every=25), n_saved=5,
metrics_to_print=["loss", "acc"])
return trainer
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
if task == "warmup":
train_dataset = LMDBDataset(config.baseline.data.dataset.train.lmdb_path,
pipeline=config.baseline.data.dataset.train.pipeline)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader)
trainer = warmup_trainer(config, logger)
try:
trainer.run(train_data_loader, max_epochs=400)
except Exception:
import traceback
print(traceback.format_exc())
elif task == "protonet-wo":
pass
elif task == "protonet-w":
pass
else:
return ValueError(f"invalid task: {task}")

View File

@@ -1,9 +0,0 @@
from data.dataset import EpisodicDataset, LMDBDataset
def prototypical_trainer(config, logger):
pass
def prototypical_dataloader(config):
pass

0
engine/util/__init__.py Normal file
View File

View File

@@ -1,23 +1,19 @@
import torch
import torch.optim as optim
import ignite.distributed as idist
from omegaconf import OmegaConf
from model import MODEL
from util.distributed import auto_model
import torch.optim as optim
def build_model(cfg, distributed_args=None):
def build_model(cfg):
cfg = OmegaConf.to_container(cfg)
model_distributed_config = cfg.pop("_distributed", dict())
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
model = MODEL.build_with(cfg)
if model_distributed_config.get("bn_to_syncbn"):
if bn_to_sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args
return auto_model(model, **distributed_args)
return idist.auto_model(model)
def build_optimizer(params, cfg):

View File

@@ -7,7 +7,7 @@ import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler, OptimizerParamsHandler
def empty_cuda_cache(_):
@@ -16,6 +16,16 @@ def empty_cuda_cache(_):
gc.collect()
def step_transform_maker(stype: str, pairs_per_iteration=None):
assert stype in ["item", "iteration", "epoch"]
if stype == "item":
return lambda engine, _: engine.state.iteration * pairs_per_iteration
if stype == "iteration":
return lambda engine, _: engine.state.iteration
if stype == "epoch":
return lambda engine, _: engine.state.epoch
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
to_save=None, end_event=None, set_epoch_for_dist_sampler=False):
"""
@@ -41,9 +51,10 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler")
trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1)
@trainer.on(Events.STARTED | Events.EPOCH_COMPLETED(once=1))
trainer.logger.info(f"data loader length: {config.iterations_per_epoch} iterations per epoch")
@trainer.on(Events.EPOCH_COMPLETED(once=1))
def print_info(engine):
engine.logger.info(f"data loader length: {len(engine.state.dataloader)}")
engine.logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}")
if stop_on_nan:
@@ -66,7 +77,7 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
if to_save is not None:
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False),
n_saved=config.checkpoint.n_saved, filename_prefix=config.name)
n_saved=config.handler.checkpoint.n_saved, filename_prefix=config.name)
if config.resume_from is not None:
@trainer.on(Events.STARTED)
def resume(engine):
@@ -77,8 +88,10 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
trainer.logger.info(f"load state_dict for {ckp.keys()}")
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED,
checkpoint_handler)
trainer.add_event_handler(
Events.EPOCH_COMPLETED(every=config.handler.checkpoint.epoch_interval) | Events.COMPLETED,
checkpoint_handler
)
trainer.logger.debug(f"add checkpoint handler to save {to_save.keys()} periodically")
if end_event is not None:
trainer.logger.debug(f"engine will stop on {end_event}")
@@ -88,17 +101,48 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
engine.terminate()
def setup_tensorboard_handler(trainer: Engine, config, output_transform, iter_per_epoch):
if config.interval.tensorboard is None:
def setup_tensorboard_handler(trainer: Engine, config, optimizers, step_type="item"):
if config.handler.tensorboard is None:
return None
if idist.get_rank() == 0:
# Create a logger
tb_logger = TensorboardLogger(log_dir=config.output_dir)
basic_event = Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"),
event_name=basic_event)
tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform),
event_name=basic_event)
tb_writer = tb_logger.writer
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
global_step_transform = step_transform_maker(step_type, pairs_per_iteration)
basic_event = Events.ITERATION_COMPLETED(
every=max(config.iterations_per_epoch // config.handler.tensorboard.scalar, 1))
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="metric", metric_names="all",
global_step_transform=global_step_transform
),
event_name=basic_event
)
@trainer.on(basic_event)
def log_loss(engine):
global_step = global_step_transform(engine, None)
output_loss = engine.state.output["loss"]
for total_loss in output_loss:
if isinstance(output_loss[total_loss], dict):
for ln in output_loss[total_loss]:
tb_writer.add_scalar(f"train_{total_loss}/{ln}", output_loss[total_loss][ln], global_step)
else:
tb_writer.add_scalar(f"train/{total_loss}", output_loss[total_loss], global_step)
if isinstance(optimizers, dict):
for name in optimizers:
tb_logger.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizers[name], tag=f"optimizer_{name}"),
event_name=Events.ITERATION_STARTED
)
else:
tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizers, tag=f"optimizer"),
event_name=Events.ITERATION_STARTED)
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()

View File

@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models.vgg as vgg
@@ -97,12 +98,13 @@ class PerceptualLoss(nn.Module):
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
use_input_norm=use_input_norm)
if criterion == 'L1':
self.criterion = torch.nn.L1Loss()
elif criterion == "L2":
self.criterion = torch.nn.MSELoss()
else:
raise NotImplementedError(f'{criterion} criterion has not been supported in this version.')
self.criterion = self.set_criterion(criterion)
def set_criterion(self, criterion: str):
assert criterion in ["NL1", "NL2", "L1", "L2"]
norm = F.instance_norm if criterion.startswith("N") else lambda x: x
fn = F.l1_loss if criterion.endswith("L1") else F.mse_loss
return lambda x, t: fn(norm(x), norm(t))
def forward(self, x, gt):
"""Forward function.
@@ -124,8 +126,7 @@ class PerceptualLoss(nn.Module):
if self.perceptual_loss:
percep_loss = 0
for k in x_features.keys():
percep_loss += self.criterion(
x_features[k], gt_features[k]) * self.layer_weights[k]
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
else:
percep_loss = None
@@ -133,9 +134,8 @@ class PerceptualLoss(nn.Module):
if self.style_loss:
style_loss = 0
for k in x_features.keys():
style_loss += self.criterion(
self._gram_mat(x_features[k]),
self._gram_mat(gt_features[k])) * self.layer_weights[k]
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \
self.layer_weights[k]
else:
style_loss = None

View File

@@ -33,10 +33,10 @@ def running(local_rank, config, task, backup_config=False, setup_output_dir=Fals
if setup_output_dir and config.resume_from is None:
if output_dir.exists():
# assert not any(output_dir.iterdir()), "output_dir must be empty"
contains = list(output_dir.iterdir())
assert (len(contains) == 0) or (len(contains) == 1 and contains[0].name == "config.yml"), \
f"output_dir must by empty or only contains config.yml, but now got {len(contains)} files"
assert len(list(output_dir.glob("events*"))) == 0
assert len(list(output_dir.glob("*.pt"))) == 0
if (output_dir / "train.log").exists() and idist.get_rank() == 0:
(output_dir / "train.log").unlink()
else:
if idist.get_rank() == 0:
output_dir.mkdir(parents=True)

View File

@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
from .residual_generator import ResidualBlock
from .base import ResidualBlock
from model.registry import MODEL
from torchvision.models import vgg19
from model.normalization import select_norm_layer
@@ -148,48 +148,65 @@ class Fusion(nn.Module):
return self.end_fc(x)
@MODEL.register_module("TAHG-Generator")
class StyleGenerator(nn.Module):
def __init__(self, style_in_channels, style_dim=512, num_blocks=8, base_channels=64, padding_mode="reflect"):
super().__init__()
self.num_blocks = num_blocks
self.style_encoder = VGG19StyleEncoder(
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE")
self.fc = nn.Sequential(
nn.Linear(style_dim, style_dim),
nn.ReLU(True),
)
res_block_channels = 2 ** 2 * base_channels
self.fusion = Fusion(style_dim, num_blocks * 2 * res_block_channels * 2, base_features=256, n_blocks=3,
norm_type="NONE")
def forward(self, x):
styles = self.fusion(self.fc(self.style_encoder(x)))
return styles
@MODEL.register_module("TAFG-Generator")
class Generator(nn.Module):
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, num_blocks=8,
base_channels=64, padding_mode="reflect"):
super(Generator, self).__init__()
self.num_blocks = num_blocks
self.style_encoders = nn.ModuleDict({
"a": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
padding_mode=padding_mode, norm_type="NONE"),
"b": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
padding_mode=padding_mode, norm_type="NONE")
"a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks,
base_channels=base_channels, padding_mode=padding_mode),
"b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks,
base_channels=base_channels, padding_mode=padding_mode),
})
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks,
padding_mode=padding_mode, norm_type="IN")
res_block_channels = 2 ** 2 * base_channels
self.adain_res = nn.ModuleList([
self.adain_resnet_a = nn.ModuleList([
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
])
self.adain_resnet_b = nn.ModuleList([
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
])
self.decoders = nn.ModuleDict({
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode),
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode)
})
self.fc = nn.Sequential(
nn.Linear(style_dim, style_dim),
nn.ReLU(True),
)
self.fusion = Fusion(style_dim, num_blocks * 2 * res_block_channels * 2, base_features=256, n_blocks=3,
norm_type="NONE")
self.decoders = nn.ModuleDict({
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode),
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode)
})
def forward(self, content_img, style_img, which_decoder: str = "a"):
x = self.content_encoder(content_img)
styles = self.fusion(self.fc(self.style_encoders[which_decoder](style_img)))
styles = self.style_encoders[which_decoder](style_img)
styles = torch.chunk(styles, self.num_blocks * 2, dim=1)
for i, ar in enumerate(self.adain_res):
resnet = self.adain_resnet_a if which_decoder == "a" else self.adain_resnet_b
for i, ar in enumerate(resnet):
ar.norm1.set_style(styles[2 * i])
ar.norm2.set_style(styles[2 * i + 1])
x = ar(x)
return self.decoders[which_decoder](x)
@MODEL.register_module("TAHG-Discriminator")
@MODEL.register_module("TAFG-Discriminator")
class Discriminator(nn.Module):
def __init__(self, in_channels=3, base_channels=64, num_down_sampling=2, num_blocks=3, norm_type="IN",
padding_mode="reflect"):

View File

@@ -7,7 +7,7 @@ from model import MODEL
# based SPADE or pix2pixHD Discriminator
@MODEL.register_module("base-PatchDiscriminator")
@MODEL.register_module("pix2pixHD-PatchDiscriminator")
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN",
need_intermediate_feature=False):
@@ -59,3 +59,26 @@ class PatchDiscriminator(nn.Module):
for layer in self.conv_blocks:
x = layer(x)
return x
@MODEL.register_module()
class ResidualBlock(nn.Module):
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None):
super(ResidualBlock, self).__init__()
if use_bias is None:
# Only for IN, use bias since it does not have affine parameters.
use_bias = norm_type == "IN"
norm_layer = select_norm_layer(norm_type)
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
bias=use_bias)
self.norm1 = norm_layer(num_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
bias=use_bias)
self.norm2 = norm_layer(num_channels)
def forward(self, x):
res = x
x = self.relu1(self.norm1(self.conv1(x)))
x = self.norm2(self.conv2(x))
return x + res

View File

@@ -58,27 +58,29 @@ class GANImageBuffer(object):
return return_images
@MODEL.register_module()
class ResidualBlock(nn.Module):
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None):
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False, use_bias=None):
super(ResidualBlock, self).__init__()
if use_bias is None:
# Only for IN, use bias since it does not have affine parameters.
use_bias = norm_type == "IN"
norm_layer = select_norm_layer(norm_type)
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
bias=use_bias)
self.norm1 = norm_layer(num_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
bias=use_bias)
self.norm2 = norm_layer(num_channels)
models = [nn.Sequential(
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
norm_layer(num_channels),
nn.ReLU(inplace=True),
)]
if use_dropout:
models.append(nn.Dropout(0.5))
models.append(nn.Sequential(
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
norm_layer(num_channels),
))
self.block = nn.Sequential(*models)
def forward(self, x):
res = x
x = self.relu1(self.norm1(self.conv1(x)))
x = self.norm2(self.conv2(x))
return x + res
return x + self.block(x)
@MODEL.register_module()

View File

@@ -1,6 +1,6 @@
from model.registry import MODEL
import model.GAN.residual_generator
import model.GAN.TAHG
import model.GAN.TAFG
import model.GAN.UGATIT
import model.fewshot
import model.GAN.wrapper

View File

@@ -1,66 +0,0 @@
import torch
import torch.nn as nn
from ignite.distributed import utils as idist
from ignite.distributed.comp_models import native as idist_native
from ignite.utils import setup_logger
def auto_model(model: nn.Module, **additional_kwargs) -> nn.Module:
"""Helper method to adapt provided model for non-distributed and distributed configurations (supporting
all available backends from :meth:`~ignite.distributed.utils.available_backends()`).
Internally, we perform to following:
- send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device.
- wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1.
- wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available.
Examples:
.. code-block:: python
import ignite.distribted as idist
model = idist.auto_model(model)
In addition with NVidia/Apex, it can be used in the following way:
.. code-block:: python
import ignite.distribted as idist
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model = idist.auto_model(model)
Args:
model (torch.nn.Module): model to adapt.
Returns:
torch.nn.Module
.. _torch DistributedDataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel
.. _torch DataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
"""
logger = setup_logger(__name__ + ".auto_model")
# Put model's parameters to device if its parameters are not on the device
device = idist.device()
if not all([p.device == device for p in model.parameters()]):
model.to(device)
# distributed data parallel model
if idist.get_world_size() > 1:
if idist.backend() == idist_native.NCCL:
lrank = idist.get_local_rank()
logger.info("Apply torch DistributedDataParallel on model, device id: {}".format(lrank))
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[lrank, ], **additional_kwargs)
elif idist.backend() == idist_native.GLOO:
logger.info("Apply torch DistributedDataParallel on model")
model = torch.nn.parallel.DistributedDataParallel(model, **additional_kwargs)
# not distributed but multiple GPUs reachable so data parallel model
elif torch.cuda.device_count() > 1 and "cuda" in idist.device().type:
logger.info("Apply torch DataParallel on model")
model = torch.nn.parallel.DataParallel(model, **additional_kwargs)
return model

View File

@@ -1,26 +1,34 @@
import torchvision.utils
from matplotlib.pyplot import get_cmap
import torch
import warnings
from torch.nn.functional import interpolate
import numpy as np
import cv2
def attention_colored_map(attentions, size=None, cmap_name="jet"):
def attention_colored_map(attentions, size=None):
assert attentions.dim() == 4 and attentions.size(1) == 1
device = attentions.device
min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
attentions -= min_attentions
attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
if size is not None and attentions.size()[-2:] != size:
attentions = attentions.detach().cpu().numpy()
attentions = (attentions * 255).astype(np.uint8)
need_resize = False
if size is not None and attentions.shape[-2:] != size:
assert len(size) == 2, "for interpolate, size must be (x, y), have two dim"
attentions = interpolate(attentions, size, mode="bilinear", align_corners=False)
cmap = get_cmap(cmap_name)
ca = cmap(attentions.squeeze(1).cpu())[:, :, :, :3]
return torch.from_numpy(ca).permute(0, 3, 1, 2).contiguous()
need_resize = True
subs = []
for sub in attentions:
sub = cv2.resize(sub[0], size) if need_resize else sub[0] # numpy.array shape=size
subs.append(cv2.applyColorMap(sub, cv2.COLORMAP_JET)) # append a (size[0], size[1], 3) numpy array
subs = np.stack(subs) # (batch_size, size[0], size[1], 3)
return torch.from_numpy(subs).permute(0, 3, 1, 2).contiguous().to(device).float() / 255
def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
def fuse_attention_map(images, attentions, alpha=0.5):
"""
:param images: B x H x W
@@ -35,7 +43,7 @@ def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
if attentions.size(1) != 1:
warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}")
return images
colored_attentions = attention_colored_map(attentions, images.size()[-2:], cmap_name).to(images.device)
colored_attentions = attention_colored_map(attentions, images.size()[-2:])
return images * alpha + colored_attentions * (1 - alpha)