change
This commit is contained in:
9
.idea/deployment.xml
generated
9
.idea/deployment.xml
generated
@@ -1,6 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" autoUpload="Always" serverName="22d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<component name="PublishConfigData" autoUpload="Always" serverName="21d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<serverData>
|
||||
<paths name="14d">
|
||||
<serverdata>
|
||||
@@ -16,6 +16,13 @@
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="21d">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="/raycv" local="$PROJECT_DIR$" web="" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="22d">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
|
||||
@@ -3,31 +3,32 @@ engine: TAFG
|
||||
result_dir: ./result
|
||||
max_pairs: 1000000
|
||||
|
||||
handler:
|
||||
clear_cuda_cache: True
|
||||
set_epoch_for_dist_sampler: True
|
||||
checkpoint:
|
||||
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||
n_saved: 2
|
||||
tensorboard:
|
||||
scalar: 100 # log scalar `scalar` times per epoch
|
||||
image: 2 # log image `image` times per epoch
|
||||
|
||||
|
||||
misc:
|
||||
random_seed: 324
|
||||
|
||||
checkpoint:
|
||||
epoch_interval: 1 # one checkpoint every 1 epoch
|
||||
n_saved: 2
|
||||
|
||||
interval:
|
||||
print_per_iteration: 10 # print once per 10 iteration
|
||||
tensorboard:
|
||||
scalar: 100
|
||||
image: 2
|
||||
|
||||
model:
|
||||
generator:
|
||||
_type: TAHG-Generator
|
||||
_type: TAFG-Generator
|
||||
_bn_to_sync_bn: False
|
||||
style_in_channels: 3
|
||||
content_in_channels: 1
|
||||
num_blocks: 4
|
||||
content_in_channels: 24
|
||||
num_blocks: 8
|
||||
discriminator:
|
||||
_type: MultiScaleDiscriminator
|
||||
num_scale: 2
|
||||
discriminator_cfg:
|
||||
_type: base-PatchDiscriminator
|
||||
_type: pix2pixHD
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
use_spectral: True
|
||||
@@ -46,7 +47,7 @@ loss:
|
||||
"11": 0.125
|
||||
"20": 0.25
|
||||
"29": 1
|
||||
criterion: 'L1'
|
||||
criterion: 'NL1'
|
||||
style_loss: False
|
||||
perceptual_loss: True
|
||||
weight: 5
|
||||
@@ -63,10 +64,13 @@ loss:
|
||||
weight: 0
|
||||
fm:
|
||||
level: 1
|
||||
weight: 10
|
||||
weight: 1
|
||||
recon:
|
||||
level: 1
|
||||
weight: 5
|
||||
weight: 10
|
||||
style_recon:
|
||||
level: 1
|
||||
weight: 10
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
@@ -87,7 +91,7 @@ data:
|
||||
target_lr: 0
|
||||
buffer_size: 50
|
||||
dataloader:
|
||||
batch_size: 256
|
||||
batch_size: 24
|
||||
shuffle: True
|
||||
num_workers: 2
|
||||
pin_memory: True
|
||||
@@ -98,13 +102,13 @@ data:
|
||||
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
|
||||
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
||||
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
||||
edge_type: "landmark_canny"
|
||||
size: [128, 128]
|
||||
edge_type: "landmark_hed"
|
||||
size: [ 128, 128 ]
|
||||
random_pair: True
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [128, 128]
|
||||
size: [ 128, 128 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
@@ -121,13 +125,14 @@ data:
|
||||
root_a: "/data/i2i/VoxCeleb2Anime/testA"
|
||||
root_b: "/data/i2i/VoxCeleb2Anime/testB"
|
||||
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
||||
edge_type: "hed"
|
||||
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
||||
edge_type: "landmark_hed"
|
||||
random_pair: False
|
||||
size: [128, 128]
|
||||
size: [ 128, 128 ]
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [128, 128]
|
||||
size: [ 128, 128 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
|
||||
@@ -1,24 +1,20 @@
|
||||
name: selfie2anime
|
||||
engine: UGATIT
|
||||
engine: U-GAT-IT
|
||||
result_dir: ./result
|
||||
max_pairs: 1000000
|
||||
|
||||
distributed:
|
||||
model:
|
||||
# broadcast_buffers: False
|
||||
|
||||
misc:
|
||||
random_seed: 324
|
||||
|
||||
checkpoint:
|
||||
epoch_interval: 1 # one checkpoint every 1 epoch
|
||||
n_saved: 2
|
||||
|
||||
interval:
|
||||
print_per_iteration: 10 # print once per 10 iteration
|
||||
handler:
|
||||
clear_cuda_cache: True
|
||||
set_epoch_for_dist_sampler: True
|
||||
checkpoint:
|
||||
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||
n_saved: 2
|
||||
tensorboard:
|
||||
scalar: 10
|
||||
image: 500
|
||||
scalar: 100 # log scalar `scalar` times per epoch
|
||||
image: 2 # log image `image` times per epoch
|
||||
|
||||
model:
|
||||
generator:
|
||||
@@ -59,12 +55,12 @@ optimizers:
|
||||
generator:
|
||||
_type: Adam
|
||||
lr: 0.0001
|
||||
betas: [0.5, 0.999]
|
||||
betas: [ 0.5, 0.999 ]
|
||||
weight_decay: 0.0001
|
||||
discriminator:
|
||||
_type: Adam
|
||||
lr: 1e-4
|
||||
betas: [0.5, 0.999]
|
||||
betas: [ 0.5, 0.999 ]
|
||||
weight_decay: 0.0001
|
||||
|
||||
data:
|
||||
@@ -74,7 +70,7 @@ data:
|
||||
target_lr: 0
|
||||
buffer_size: 50
|
||||
dataloader:
|
||||
batch_size: 4
|
||||
batch_size: 24
|
||||
shuffle: True
|
||||
num_workers: 2
|
||||
pin_memory: True
|
||||
@@ -87,14 +83,14 @@ data:
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [286, 286]
|
||||
size: [ 286, 286 ]
|
||||
- RandomCrop:
|
||||
size: [256, 256]
|
||||
size: [ 256, 256 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
std: [0.5, 0.5, 0.5]
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
test:
|
||||
dataloader:
|
||||
batch_size: 8
|
||||
@@ -110,11 +106,11 @@ data:
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [256, 256]
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
std: [0.5, 0.5, 0.5]
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
video_dataset:
|
||||
_type: SingleFolderDataset
|
||||
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
|
||||
@@ -124,6 +120,3 @@ data:
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
|
||||
@@ -177,6 +177,13 @@ class GenerationUnpairedDataset(Dataset):
|
||||
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
||||
|
||||
|
||||
def normalize_tensor(tensor):
|
||||
tensor = tensor.float()
|
||||
tensor -= tensor.min()
|
||||
tensor /= tensor.max()
|
||||
return tensor
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256)):
|
||||
@@ -200,17 +207,19 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
edge_type = self.edge_type
|
||||
use_landmark = False
|
||||
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
|
||||
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size))
|
||||
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR))
|
||||
if not use_landmark:
|
||||
return origin_edge
|
||||
else:
|
||||
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.{edge_type}.txt"
|
||||
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt"
|
||||
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
|
||||
dist_tensor = torch.from_numpy(dlib_landmark.dist_tensor(key_points))
|
||||
part_labels = torch.from_numpy(part_labels)
|
||||
edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
|
||||
edges = part_edge + edges
|
||||
return torch.cat([edges, dist_tensor, part_labels], dim=0)
|
||||
|
||||
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size)))
|
||||
part_labels = normalize_tensor(torch.from_numpy(part_labels))
|
||||
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
|
||||
# edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
|
||||
# edges = part_edge + edges
|
||||
return torch.cat([origin_edge, part_edge, dist_tensor, part_labels])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
|
||||
@@ -1,24 +1,22 @@
|
||||
from itertools import chain
|
||||
from math import ceil
|
||||
|
||||
from omegaconf import read_write, OmegaConf
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import ignite.distributed as idist
|
||||
|
||||
import data
|
||||
from engine.base.i2i import get_trainer, EngineKernel, build_model
|
||||
from model.weight_init import generation_init_weights
|
||||
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
from loss.gan import GANLoss
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel
|
||||
from engine.util.build import build_model
|
||||
|
||||
|
||||
class TAFGEngineKernel(EngineKernel):
|
||||
def __init__(self, config, logger):
|
||||
super().__init__(config, logger)
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
||||
perceptual_loss_cfg.pop("weight")
|
||||
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||
@@ -29,6 +27,11 @@ class TAFGEngineKernel(EngineKernel):
|
||||
|
||||
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
|
||||
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
||||
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
|
||||
|
||||
def _process_batch(self, batch, inference=False):
|
||||
# batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size())
|
||||
return batch
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
generators = dict(
|
||||
@@ -56,6 +59,7 @@ class TAFGEngineKernel(EngineKernel):
|
||||
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
generator = self.generators["main"]
|
||||
batch = self._process_batch(batch, inference)
|
||||
with torch.set_grad_enabled(not inference):
|
||||
fake = dict(
|
||||
a=generator(content_img=batch["edge_a"], style_img=batch["a"], which_decoder="a"),
|
||||
@@ -64,6 +68,7 @@ class TAFGEngineKernel(EngineKernel):
|
||||
return fake
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
batch = self._process_batch(batch)
|
||||
loss = dict()
|
||||
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
|
||||
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
|
||||
@@ -85,10 +90,15 @@ class TAFGEngineKernel(EngineKernel):
|
||||
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
||||
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
|
||||
loss["recon"] = self.recon_loss(generated["a"], batch["a"]) * self.config.loss.recon.weight
|
||||
loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
||||
self.generators["main"].module.style_encoders["b"](batch["b"]),
|
||||
self.generators["main"].module.style_encoders["b"](generated["b"])
|
||||
)
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
# batch = self._process_batch(batch)
|
||||
for phase in self.discriminators.keys():
|
||||
pred_real = self.discriminators[phase](batch[phase])
|
||||
pred_fake = self.discriminators[phase](generated[phase].detach())
|
||||
@@ -105,31 +115,14 @@ class TAFGEngineKernel(EngineKernel):
|
||||
:param generated: dict of images
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
batch = self._process_batch(batch)
|
||||
edge = batch["edge_a"][:, 0:1, :, :]
|
||||
return dict(
|
||||
a=[batch[f"edge_a"].expand(-1, 3, -1, -1).detach(), batch["a"].detach(), generated["a"].detach()],
|
||||
b=[batch["b"].detach(), generated["b"].detach()]
|
||||
a=[edge.expand(-1, 3, -1, -1).detach(), batch["a"].detach(), batch["b"].detach(), generated["a"].detach(),
|
||||
generated["b"].detach()]
|
||||
)
|
||||
|
||||
|
||||
def run(task, config, logger):
|
||||
assert torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.benchmark = True
|
||||
logger.info(f"start task {task}")
|
||||
with read_write(config):
|
||||
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
|
||||
|
||||
if task == "train":
|
||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||
trainer = get_trainer(config, TAFGEngineKernel(config, logger), len(train_data_loader))
|
||||
if idist.get_rank() == 0:
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
trainer.state.test_dataset = test_dataset
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
else:
|
||||
return NotImplemented(f"invalid task: {task}")
|
||||
def run(task, config, _):
|
||||
kernel = TAFGEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
|
||||
153
engine/U-GAT-IT.py
Normal file
153
engine/U-GAT-IT.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from itertools import chain
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import ignite.distributed as idist
|
||||
|
||||
from model.weight_init import generation_init_weights
|
||||
from loss.gan import GANLoss
|
||||
from model.GAN.UGATIT import RhoClipper
|
||||
from model.GAN.residual_generator import GANImageBuffer
|
||||
from util.image import attention_colored_map
|
||||
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
||||
from engine.util.build import build_model
|
||||
|
||||
|
||||
def mse_loss(x, target_flag):
|
||||
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||
|
||||
|
||||
def bce_loss(x, target_flag):
|
||||
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||
|
||||
|
||||
class UGATITEngineKernel(EngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
|
||||
self.rho_clipper = RhoClipper(0, 1)
|
||||
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
||||
self.discriminators.keys()}
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
generators = dict(
|
||||
a2b=build_model(self.config.model.generator),
|
||||
b2a=build_model(self.config.model.generator)
|
||||
)
|
||||
discriminators = dict(
|
||||
la=build_model(self.config.model.local_discriminator),
|
||||
lb=build_model(self.config.model.local_discriminator),
|
||||
ga=build_model(self.config.model.global_discriminator),
|
||||
gb=build_model(self.config.model.global_discriminator),
|
||||
)
|
||||
self.logger.debug(discriminators["ga"])
|
||||
self.logger.debug(generators["a2b"])
|
||||
|
||||
for m in chain(generators.values(), discriminators.values()):
|
||||
generation_init_weights(m)
|
||||
|
||||
return generators, discriminators
|
||||
|
||||
def setup_before_d(self):
|
||||
for generator in self.generators.values():
|
||||
generator.apply(self.rho_clipper)
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
def setup_before_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(False)
|
||||
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
images = dict()
|
||||
heatmap = dict()
|
||||
cam_pred = dict()
|
||||
|
||||
with torch.set_grad_enabled(not inference):
|
||||
images["a2b"], cam_pred["a2b"], heatmap["a2b"] = self.generators["a2b"](batch["a"])
|
||||
images["b2a"], cam_pred["b2a"], heatmap["b2a"] = self.generators["b2a"](batch["b"])
|
||||
images["a2b2a"], _, heatmap["a2b2a"] = self.generators["b2a"](images["a2b"])
|
||||
images["b2a2b"], _, heatmap["b2a2b"] = self.generators["a2b"](images["b2a"])
|
||||
images["a2a"], cam_pred["a2a"], heatmap["a2a"] = self.generators["b2a"](batch["a"])
|
||||
images["b2b"], cam_pred["b2b"], heatmap["b2b"] = self.generators["a2b"](batch["b"])
|
||||
return dict(images=images, heatmap=heatmap, cam_pred=cam_pred)
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
for phase in "ab":
|
||||
cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"]
|
||||
loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase])
|
||||
loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase],
|
||||
generated["images"][f"{phase}2{phase}"])
|
||||
for dk in "lg":
|
||||
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
|
||||
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)
|
||||
loss[f"gan_{phase}_{dk}"] = self.config.loss.gan.weight * self.gan_loss(pred_fake, True)
|
||||
loss[f"gan_cam_{phase}_{dk}"] = self.config.loss.gan.weight * mse_loss(cam_pred, True)
|
||||
for t, f in [("a2b", "b2b"), ("b2a", "a2a")]:
|
||||
loss[f"cam_{t[-1]}"] = self.config.loss.cam.weight * (
|
||||
bce_loss(generated["cam_pred"][t], True) + bce_loss(generated["cam_pred"][f], False))
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
for phase in "ab":
|
||||
for level in "gl":
|
||||
generated_image = self.image_buffers[level + phase].query(
|
||||
generated["images"]["a2b" if phase == "b" else "b2a"])
|
||||
pred_fake, cam_fake_pred = self.discriminators[level + phase](generated_image)
|
||||
pred_real, cam_real_pred = self.discriminators[level + phase](batch[phase])
|
||||
loss[f"gan_{phase}_{level}"] = self.gan_loss(pred_real, True, is_discriminator=True) + self.gan_loss(
|
||||
pred_fake, False, is_discriminator=True)
|
||||
loss[f"cam_{phase}_{level}"] = mse_loss(cam_fake_pred, False) + mse_loss(cam_real_pred, True)
|
||||
return loss
|
||||
|
||||
def intermediate_images(self, batch, generated) -> dict:
|
||||
"""
|
||||
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
:param batch:
|
||||
:param generated: dict of images
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
attention_a = attention_colored_map(generated["heatmap"]["a2b"].detach(), batch["a"].size()[-2:])
|
||||
attention_b = attention_colored_map(generated["heatmap"]["b2a"].detach(), batch["b"].size()[-2:])
|
||||
generated = {img: generated["images"][img].detach() for img in generated["images"]}
|
||||
return {
|
||||
"a": [batch["a"], attention_a, generated["a2b"], generated["a2a"], generated["a2b2a"]],
|
||||
"b": [batch["b"], attention_b, generated["b2a"], generated["b2b"], generated["b2a2b"]],
|
||||
}
|
||||
|
||||
|
||||
class UGATITTestEngineKernel(TestEngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
def build_generators(self) -> dict:
|
||||
generators = dict(
|
||||
a2b=build_model(self.config.model.generator),
|
||||
)
|
||||
return generators
|
||||
|
||||
def to_load(self):
|
||||
return {f"generator_{k}": self.generators[k] for k in self.generators}
|
||||
|
||||
def inference(self, batch):
|
||||
with torch.no_grad():
|
||||
fake, _, _ = self.generators["a2b"](batch["a"])
|
||||
return {"a": fake.detach()}
|
||||
|
||||
|
||||
def run(task, config, _):
|
||||
if task == "train":
|
||||
kernel = UGATITEngineKernel(config)
|
||||
if task == "test":
|
||||
kernel = UGATITTestEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
320
engine/UGATIT.py
320
engine/UGATIT.py
@@ -1,320 +0,0 @@
|
||||
from itertools import chain
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.metrics import RunningAverage
|
||||
from ignite.utils import convert_tensor
|
||||
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
|
||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||
|
||||
from omegaconf import OmegaConf, read_write
|
||||
|
||||
import data
|
||||
from loss.gan import GANLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
from model.GAN.residual_generator import GANImageBuffer
|
||||
from model.GAN.UGATIT import RhoClipper
|
||||
from util.image import make_2d_grid, fuse_attention_map, attention_colored_map
|
||||
from util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||
from util.build import build_model, build_optimizer
|
||||
|
||||
|
||||
def build_lr_schedulers(optimizers, config):
|
||||
g_milestones_values = [
|
||||
(0, config.optimizers.generator.lr),
|
||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
d_milestones_values = [
|
||||
(0, config.optimizers.discriminator.lr),
|
||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
return dict(
|
||||
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
|
||||
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
|
||||
)
|
||||
|
||||
|
||||
def get_trainer(config, logger):
|
||||
generators = dict(
|
||||
a2b=build_model(config.model.generator, config.distributed.model),
|
||||
b2a=build_model(config.model.generator, config.distributed.model),
|
||||
)
|
||||
discriminators = dict(
|
||||
la=build_model(config.model.local_discriminator, config.distributed.model),
|
||||
lb=build_model(config.model.local_discriminator, config.distributed.model),
|
||||
ga=build_model(config.model.global_discriminator, config.distributed.model),
|
||||
gb=build_model(config.model.global_discriminator, config.distributed.model),
|
||||
)
|
||||
for m in chain(generators.values(), discriminators.values()):
|
||||
generation_init_weights(m)
|
||||
|
||||
logger.debug(discriminators["ga"])
|
||||
logger.debug(generators["a2b"])
|
||||
|
||||
optimizers = dict(
|
||||
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
|
||||
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
|
||||
)
|
||||
logger.info(f"build optimizers:\n{optimizers}")
|
||||
|
||||
lr_schedulers = build_lr_schedulers(optimizers, config)
|
||||
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
|
||||
def mse_loss(x, target_flag):
|
||||
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||
|
||||
def bce_loss(x, target_flag):
|
||||
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||
|
||||
image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()}
|
||||
rho_clipper = RhoClipper(0, 1)
|
||||
|
||||
def criterion_generator(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l,
|
||||
discriminator_g):
|
||||
discriminator_g.requires_grad_(False)
|
||||
discriminator_l.requires_grad_(False)
|
||||
pred_fake_g, cam_gd_pred = discriminator_g(fake)
|
||||
pred_fake_l, cam_ld_pred = discriminator_l(fake)
|
||||
return {
|
||||
f"cycle_{name}": config.loss.cycle.weight * cycle_loss(real, rec),
|
||||
f"id_{name}": config.loss.id.weight * id_loss(real, identity),
|
||||
f"cam_{name}": config.loss.cam.weight * (bce_loss(cam_g_pred, True) + bce_loss(cam_g_id_pred, False)),
|
||||
f"gan_l_{name}": config.loss.gan.weight * gan_loss(pred_fake_l, True),
|
||||
f"gan_g_{name}": config.loss.gan.weight * gan_loss(pred_fake_g, True),
|
||||
f"gan_cam_g_{name}": config.loss.gan.weight * mse_loss(cam_gd_pred, True),
|
||||
f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True),
|
||||
}
|
||||
|
||||
def criterion_discriminator(name, discriminator, real, fake):
|
||||
pred_real, cam_real = discriminator(real)
|
||||
pred_fake, cam_fake = discriminator(fake)
|
||||
# TODO: origin do not divide 2, but I think it better to divide 2.
|
||||
loss_gan = gan_loss(pred_real, True, is_discriminator=True) + gan_loss(pred_fake, False, is_discriminator=True)
|
||||
loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False)
|
||||
return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam}
|
||||
|
||||
def _step(engine, real):
|
||||
real = convert_tensor(real, idist.device())
|
||||
|
||||
fake = dict()
|
||||
cam_generator_pred = dict()
|
||||
rec = dict()
|
||||
identity = dict()
|
||||
cam_identity_pred = dict()
|
||||
heatmap = dict()
|
||||
|
||||
fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real["a"])
|
||||
fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real["b"])
|
||||
rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"])
|
||||
rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"])
|
||||
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"])
|
||||
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"])
|
||||
|
||||
optimizers["g"].zero_grad()
|
||||
loss_g = dict()
|
||||
for n in ["a", "b"]:
|
||||
loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
|
||||
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
|
||||
sum(loss_g.values()).backward()
|
||||
optimizers["g"].step()
|
||||
for generator in generators.values():
|
||||
generator.apply(rho_clipper)
|
||||
for discriminator in discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
optimizers["d"].zero_grad()
|
||||
loss_d = dict()
|
||||
for k in discriminators.keys():
|
||||
n = k[-1] # "a" or "b"
|
||||
loss_d.update(
|
||||
criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach())))
|
||||
sum(loss_d.values()).backward()
|
||||
optimizers["d"].step()
|
||||
|
||||
for h in heatmap:
|
||||
heatmap[h] = heatmap[h].detach()
|
||||
generated_img = {f"real_{k}": real[k].detach() for k in real}
|
||||
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
|
||||
generated_img.update({f"id_{k}": identity[k].detach() for k in identity})
|
||||
generated_img.update({f"rec_{k}": rec[k].detach() for k in rec})
|
||||
|
||||
return {
|
||||
"loss": {
|
||||
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
|
||||
"d": {ln: loss_d[ln].mean().item() for ln in loss_d},
|
||||
},
|
||||
"img": {
|
||||
"heatmap": heatmap,
|
||||
"generated": generated_img
|
||||
}
|
||||
}
|
||||
|
||||
trainer = Engine(_step)
|
||||
trainer.logger = logger
|
||||
for lr_shd in lr_schedulers.values():
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
|
||||
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
|
||||
|
||||
to_save = dict(trainer=trainer)
|
||||
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
|
||||
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
|
||||
to_save.update({f"generator_{k}": generators[k] for k in generators})
|
||||
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
||||
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
|
||||
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
|
||||
def output_transform(output):
|
||||
loss = dict()
|
||||
for tl in output["loss"]:
|
||||
if isinstance(output["loss"][tl], dict):
|
||||
for l in output["loss"][tl]:
|
||||
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
||||
else:
|
||||
loss[tl] = output["loss"][tl]
|
||||
return loss
|
||||
|
||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform)
|
||||
if tensorboard_handler is not None:
|
||||
tensorboard_handler.attach(
|
||||
trainer,
|
||||
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
|
||||
event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
|
||||
)
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
|
||||
def show_images(engine):
|
||||
output = engine.state.output
|
||||
image_order = dict(
|
||||
a=["real_a", "fake_b", "rec_a", "id_a"],
|
||||
b=["real_b", "fake_a", "rec_b", "id_b"]
|
||||
)
|
||||
output["img"]["generated"]["real_a"] = fuse_attention_map(
|
||||
output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"])
|
||||
output["img"]["generated"]["real_b"] = fuse_attention_map(
|
||||
output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"])
|
||||
for k in "ab":
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"train/{k}",
|
||||
make_2d_grid([output["img"]["generated"][o] for o in image_order[k]]),
|
||||
engine.state.iteration
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(config.misc.random_seed)
|
||||
random_start = torch.randperm(len(engine.state.test_dataset)-11, generator=g).tolist()[0]
|
||||
test_images = dict(
|
||||
a=[[], [], [], []],
|
||||
b=[[], [], [], []]
|
||||
)
|
||||
for i in range(random_start, random_start+10):
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
|
||||
real_a, real_b = batch["a"].view(1, *batch["a"].size()), batch["b"].view(1, *batch["a"].size())
|
||||
fake_b, _, heatmap_a2b = generators["a2b"](real_a)
|
||||
fake_a, _, heatmap_b2a = generators["b2a"](real_b)
|
||||
rec_a = generators["b2a"](fake_b)[0]
|
||||
rec_b = generators["a2b"](fake_a)[0]
|
||||
|
||||
for idx, im in enumerate(
|
||||
[attention_colored_map(heatmap_a2b, real_a.size()[-2:]), real_a, fake_b, rec_a]):
|
||||
test_images["a"][idx].append(im.cpu())
|
||||
for idx, im in enumerate(
|
||||
[attention_colored_map(heatmap_b2a, real_b.size()[-2:]), real_b, fake_a, rec_b]):
|
||||
test_images["b"][idx].append(im.cpu())
|
||||
for n in "ab":
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"test/{n}",
|
||||
make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
|
||||
engine.state.iteration
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
def get_tester(config, logger):
|
||||
generator_a2b = build_model(config.model.generator, config.distributed.model)
|
||||
|
||||
def _step(engine, batch):
|
||||
real_a, path = convert_tensor(batch, idist.device())
|
||||
with torch.no_grad():
|
||||
fake_b = generator_a2b(real_a)[0]
|
||||
return {"path": path, "img": [real_a.detach(), fake_b.detach()]}
|
||||
|
||||
tester = Engine(_step)
|
||||
tester.logger = logger
|
||||
|
||||
to_load = dict(generator_a2b=generator_a2b)
|
||||
setup_common_handlers(tester, config, use_profiler=False, to_save=to_load)
|
||||
|
||||
@tester.on(Events.STARTED)
|
||||
def mkdir(engine):
|
||||
img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images"
|
||||
engine.state.img_output_dir = Path(img_output_dir)
|
||||
if idist.get_rank() == 0 and not engine.state.img_output_dir.exists():
|
||||
engine.logger.info(f"mkdir {engine.state.img_output_dir}")
|
||||
engine.state.img_output_dir.mkdir()
|
||||
|
||||
@tester.on(Events.ITERATION_COMPLETED)
|
||||
def save_images(engine):
|
||||
img_tensors = engine.state.output["img"]
|
||||
paths = engine.state.output["path"]
|
||||
batch_size = img_tensors[0].size(0)
|
||||
for i in range(batch_size):
|
||||
image_name = Path(paths[i]).name
|
||||
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
|
||||
nrow=len(img_tensors))
|
||||
|
||||
return tester
|
||||
|
||||
|
||||
def run(task, config, logger):
|
||||
assert torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.benchmark = True
|
||||
logger.info(f"start task {task}")
|
||||
with read_write(config):
|
||||
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
|
||||
|
||||
if task == "train":
|
||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||
trainer = get_trainer(config, logger)
|
||||
if idist.get_rank() == 0:
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
trainer.state.test_dataset = test_dataset
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
elif task == "test":
|
||||
assert config.resume_from is not None
|
||||
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
|
||||
logger.info(f"test with dataset:\n{test_dataset}")
|
||||
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
||||
tester = get_tester(config, logger)
|
||||
try:
|
||||
tester.run(test_data_loader, max_epochs=1)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
else:
|
||||
return NotImplemented(f"invalid task: {task}")
|
||||
@@ -1,32 +1,23 @@
|
||||
from itertools import chain
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.metrics import RunningAverage
|
||||
from ignite.utils import convert_tensor
|
||||
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
|
||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||
|
||||
from model import MODEL
|
||||
from omegaconf import read_write, OmegaConf
|
||||
|
||||
from util.image import make_2d_grid
|
||||
from util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||
from util.build import build_optimizer
|
||||
from engine.util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||
from engine.util.build import build_optimizer
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def build_model(cfg):
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
|
||||
model = MODEL.build_with(cfg)
|
||||
if bn_to_sync_bn:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
return idist.auto_model(model)
|
||||
import data
|
||||
|
||||
|
||||
def build_lr_schedulers(optimizers, config):
|
||||
@@ -47,10 +38,26 @@ def build_lr_schedulers(optimizers, config):
|
||||
)
|
||||
|
||||
|
||||
class EngineKernel(object):
|
||||
def __init__(self, config, logger):
|
||||
class TestEngineKernel(object):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
self.logger = logging.getLogger(config.name)
|
||||
self.generators = self.build_generators()
|
||||
|
||||
def build_generators(self) -> dict:
|
||||
raise NotImplemented
|
||||
|
||||
def to_load(self):
|
||||
return {f"generator_{k}": self.generators[k] for k in self.generators}
|
||||
|
||||
def inference(self, batch):
|
||||
raise NotImplemented
|
||||
|
||||
|
||||
class EngineKernel(object):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(config.name)
|
||||
self.generators, self.discriminators = self.build_models()
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
@@ -87,39 +94,43 @@ class EngineKernel(object):
|
||||
raise NotImplemented
|
||||
|
||||
|
||||
def get_trainer(config, ek: EngineKernel, iter_per_epoch):
|
||||
def get_trainer(config, kernel: EngineKernel):
|
||||
logger = logging.getLogger(config.name)
|
||||
generators, discriminators = ek.generators, ek.discriminators
|
||||
generators, discriminators = kernel.generators, kernel.discriminators
|
||||
optimizers = dict(
|
||||
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
|
||||
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
|
||||
)
|
||||
logger.info("build optimizers", optimizers)
|
||||
logger.info(f"build optimizers:\n{optimizers}")
|
||||
|
||||
lr_schedulers = build_lr_schedulers(optimizers, config)
|
||||
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
|
||||
|
||||
image_per_iteration = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1)
|
||||
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
|
||||
generated = ek.forward(batch)
|
||||
generated = kernel.forward(batch)
|
||||
|
||||
ek.setup_before_g()
|
||||
kernel.setup_before_g()
|
||||
optimizers["g"].zero_grad()
|
||||
loss_g = ek.criterion_generators(batch, generated)
|
||||
loss_g = kernel.criterion_generators(batch, generated)
|
||||
sum(loss_g.values()).backward()
|
||||
optimizers["g"].step()
|
||||
|
||||
ek.setup_before_d()
|
||||
kernel.setup_before_d()
|
||||
optimizers["d"].zero_grad()
|
||||
loss_d = ek.criterion_discriminators(batch, generated)
|
||||
loss_d = kernel.criterion_discriminators(batch, generated)
|
||||
sum(loss_d.values()).backward()
|
||||
optimizers["d"].step()
|
||||
|
||||
return {
|
||||
"loss": dict(g=loss_g, d=loss_d),
|
||||
"img": ek.intermediate_images(batch, generated)
|
||||
}
|
||||
if engine.state.iteration % image_per_iteration == 0:
|
||||
return {
|
||||
"loss": dict(g=loss_g, d=loss_d),
|
||||
"img": kernel.intermediate_images(batch, generated)
|
||||
}
|
||||
return {"loss": dict(g=loss_g, d=loss_d)}
|
||||
|
||||
trainer = Engine(_step)
|
||||
trainer.logger = logger
|
||||
@@ -131,33 +142,22 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
|
||||
to_save = dict(trainer=trainer)
|
||||
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
|
||||
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
|
||||
to_save.update(ek.to_save())
|
||||
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
|
||||
to_save.update(kernel.to_save())
|
||||
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=config.handler.clear_cuda_cache,
|
||||
set_epoch_for_dist_sampler=config.handler.set_epoch_for_dist_sampler,
|
||||
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
|
||||
def output_transform(output):
|
||||
loss = dict()
|
||||
for tl in output["loss"]:
|
||||
if isinstance(output["loss"][tl], dict):
|
||||
for l in output["loss"][tl]:
|
||||
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
||||
else:
|
||||
loss[tl] = output["loss"][tl]
|
||||
return loss
|
||||
|
||||
pairs_per_iteration = config.data.train.dataloader.batch_size
|
||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
|
||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, optimizers, step_type="item")
|
||||
if tensorboard_handler is not None:
|
||||
tensorboard_handler.attach(
|
||||
trainer,
|
||||
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
|
||||
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
|
||||
)
|
||||
basic_image_event = Events.ITERATION_COMPLETED(
|
||||
every=image_per_iteration)
|
||||
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
|
||||
@trainer.on(basic_image_event)
|
||||
def show_images(engine):
|
||||
output = engine.state.output
|
||||
test_images = {}
|
||||
|
||||
for k in output["img"]:
|
||||
image_list = output["img"][k]
|
||||
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list),
|
||||
@@ -174,8 +174,8 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
for k in batch:
|
||||
batch[k] = batch[k].view(1, *batch[k].size())
|
||||
generated = ek.forward(batch)
|
||||
images = ek.intermediate_images(batch, generated)
|
||||
generated = kernel.forward(batch)
|
||||
images = kernel.intermediate_images(batch, generated)
|
||||
|
||||
for k in test_images:
|
||||
for j in range(len(images[k])):
|
||||
@@ -187,3 +187,78 @@ def get_trainer(config, ek: EngineKernel, iter_per_epoch):
|
||||
engine.state.iteration * pairs_per_iteration
|
||||
)
|
||||
return trainer
|
||||
|
||||
|
||||
def get_tester(config, kernel: TestEngineKernel):
|
||||
logger = logging.getLogger(config.name)
|
||||
|
||||
def _step(engine, batch):
|
||||
real_a, path = convert_tensor(batch, idist.device())
|
||||
fake = kernel.inference({"a": real_a})["a"]
|
||||
return {"path": path, "img": [real_a.detach(), fake.detach()]}
|
||||
|
||||
tester = Engine(_step)
|
||||
tester.logger = logger
|
||||
|
||||
setup_common_handlers(tester, config, use_profiler=True, to_save=kernel.to_load())
|
||||
|
||||
@tester.on(Events.STARTED)
|
||||
def mkdir(engine):
|
||||
img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images"
|
||||
engine.state.img_output_dir = Path(img_output_dir)
|
||||
if idist.get_rank() == 0 and not engine.state.img_output_dir.exists():
|
||||
engine.logger.info(f"mkdir {engine.state.img_output_dir}")
|
||||
engine.state.img_output_dir.mkdir()
|
||||
|
||||
@tester.on(Events.ITERATION_COMPLETED)
|
||||
def save_images(engine):
|
||||
img_tensors = engine.state.output["img"]
|
||||
paths = engine.state.output["path"]
|
||||
batch_size = img_tensors[0].size(0)
|
||||
for i in range(batch_size):
|
||||
image_name = Path(paths[i]).name
|
||||
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
|
||||
nrow=len(img_tensors))
|
||||
|
||||
return tester
|
||||
|
||||
|
||||
def run_kernel(task, config, kernel):
|
||||
assert torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.benchmark = True
|
||||
logger = logging.getLogger(config.name)
|
||||
with read_write(config):
|
||||
real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size()
|
||||
config.max_iteration = config.max_pairs // real_batch_size + 1
|
||||
|
||||
if task == "train":
|
||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
dataloader_kwargs = OmegaConf.to_container(config.data.train.dataloader)
|
||||
dataloader_kwargs["batch_size"] = dataloader_kwargs["batch_size"] * idist.get_world_size()
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **dataloader_kwargs)
|
||||
with read_write(config):
|
||||
config.iterations_per_epoch = len(train_data_loader)
|
||||
|
||||
trainer = get_trainer(config, kernel)
|
||||
if idist.get_rank() == 0:
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
trainer.state.test_dataset = test_dataset
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
elif task == "test":
|
||||
assert config.resume_from is not None
|
||||
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
|
||||
logger.info(f"test with dataset:\n{test_dataset}")
|
||||
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
||||
tester = get_tester(config, kernel)
|
||||
try:
|
||||
tester.run(test_data_loader, max_epochs=1)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
else:
|
||||
return NotImplemented(f"invalid task: {task}")
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.datasets import ImageFolder
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.contrib.metrics.gpu_info import GpuInfo
|
||||
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, global_step_from_engine, OutputHandler, \
|
||||
WeightsScalarHandler, GradsHistHandler, WeightsHistHandler, GradsScalarHandler
|
||||
from ignite.engine import create_supervised_evaluator, create_supervised_trainer, Events
|
||||
from ignite.metrics import Accuracy, Loss, RunningAverage
|
||||
from ignite.contrib.engines.common import save_best_model_by_val_score
|
||||
from ignite.contrib.handlers import ProgressBar
|
||||
|
||||
from util.build import build_model, build_optimizer
|
||||
from util.handler import setup_common_handlers
|
||||
from data.transform import transform_pipeline
|
||||
from data.dataset import LMDBDataset
|
||||
|
||||
|
||||
def warmup_trainer(config, logger):
|
||||
model = build_model(config.model, config.distributed.model)
|
||||
optimizer = build_optimizer(model.parameters(), config.baseline.optimizers)
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
trainer = create_supervised_trainer(model, optimizer, loss_fn, idist.device(), non_blocking=True,
|
||||
output_transform=lambda x, y, y_pred, loss: (loss.item(), y_pred, y))
|
||||
trainer.logger = logger
|
||||
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss")
|
||||
Accuracy(output_transform=lambda x: (x[1], x[2])).attach(trainer, "acc")
|
||||
ProgressBar(ncols=0).attach(trainer)
|
||||
|
||||
if idist.get_rank() == 0:
|
||||
GpuInfo().attach(trainer, name='gpu')
|
||||
|
||||
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OutputHandler(
|
||||
tag="train",
|
||||
metric_names='all',
|
||||
global_step_transform=global_step_from_engine(trainer),
|
||||
),
|
||||
event_name=Events.EPOCH_COMPLETED
|
||||
)
|
||||
tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model),
|
||||
event_name=Events.EPOCH_COMPLETED(every=10))
|
||||
|
||||
tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25))
|
||||
|
||||
tb_logger.attach(trainer, log_handler=GradsScalarHandler(model),
|
||||
event_name=Events.EPOCH_COMPLETED(every=10))
|
||||
|
||||
tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25))
|
||||
|
||||
@trainer.on(Events.COMPLETED)
|
||||
def _():
|
||||
tb_logger.close()
|
||||
|
||||
to_save = dict(model=model, optimizer=optimizer, trainer=trainer)
|
||||
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.EPOCH_COMPLETED, to_save=to_save,
|
||||
save_interval_event=Events.EPOCH_COMPLETED(every=25), n_saved=5,
|
||||
metrics_to_print=["loss", "acc"])
|
||||
return trainer
|
||||
|
||||
|
||||
def run(task, config, logger):
|
||||
assert torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.benchmark = True
|
||||
logger.info(f"start task {task}")
|
||||
if task == "warmup":
|
||||
train_dataset = LMDBDataset(config.baseline.data.dataset.train.lmdb_path,
|
||||
pipeline=config.baseline.data.dataset.train.pipeline)
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader)
|
||||
trainer = warmup_trainer(config, logger)
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=400)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
elif task == "protonet-wo":
|
||||
pass
|
||||
elif task == "protonet-w":
|
||||
pass
|
||||
else:
|
||||
return ValueError(f"invalid task: {task}")
|
||||
@@ -1,9 +0,0 @@
|
||||
from data.dataset import EpisodicDataset, LMDBDataset
|
||||
|
||||
|
||||
def prototypical_trainer(config, logger):
|
||||
pass
|
||||
|
||||
|
||||
def prototypical_dataloader(config):
|
||||
pass
|
||||
0
engine/util/__init__.py
Normal file
0
engine/util/__init__.py
Normal file
@@ -1,23 +1,19 @@
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import ignite.distributed as idist
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model import MODEL
|
||||
from util.distributed import auto_model
|
||||
import torch.optim as optim
|
||||
|
||||
|
||||
def build_model(cfg, distributed_args=None):
|
||||
def build_model(cfg):
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
model_distributed_config = cfg.pop("_distributed", dict())
|
||||
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
|
||||
model = MODEL.build_with(cfg)
|
||||
|
||||
if model_distributed_config.get("bn_to_syncbn"):
|
||||
if bn_to_sync_bn:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
|
||||
distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args
|
||||
return auto_model(model, **distributed_args)
|
||||
return idist.auto_model(model)
|
||||
|
||||
|
||||
def build_optimizer(params, cfg):
|
||||
@@ -7,7 +7,7 @@ import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
|
||||
from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
|
||||
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler
|
||||
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler, OptimizerParamsHandler
|
||||
|
||||
|
||||
def empty_cuda_cache(_):
|
||||
@@ -16,6 +16,16 @@ def empty_cuda_cache(_):
|
||||
gc.collect()
|
||||
|
||||
|
||||
def step_transform_maker(stype: str, pairs_per_iteration=None):
|
||||
assert stype in ["item", "iteration", "epoch"]
|
||||
if stype == "item":
|
||||
return lambda engine, _: engine.state.iteration * pairs_per_iteration
|
||||
if stype == "iteration":
|
||||
return lambda engine, _: engine.state.iteration
|
||||
if stype == "epoch":
|
||||
return lambda engine, _: engine.state.epoch
|
||||
|
||||
|
||||
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
|
||||
to_save=None, end_event=None, set_epoch_for_dist_sampler=False):
|
||||
"""
|
||||
@@ -41,9 +51,10 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler")
|
||||
trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1)
|
||||
|
||||
@trainer.on(Events.STARTED | Events.EPOCH_COMPLETED(once=1))
|
||||
trainer.logger.info(f"data loader length: {config.iterations_per_epoch} iterations per epoch")
|
||||
|
||||
@trainer.on(Events.EPOCH_COMPLETED(once=1))
|
||||
def print_info(engine):
|
||||
engine.logger.info(f"data loader length: {len(engine.state.dataloader)}")
|
||||
engine.logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}")
|
||||
|
||||
if stop_on_nan:
|
||||
@@ -66,7 +77,7 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
|
||||
if to_save is not None:
|
||||
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False),
|
||||
n_saved=config.checkpoint.n_saved, filename_prefix=config.name)
|
||||
n_saved=config.handler.checkpoint.n_saved, filename_prefix=config.name)
|
||||
if config.resume_from is not None:
|
||||
@trainer.on(Events.STARTED)
|
||||
def resume(engine):
|
||||
@@ -77,8 +88,10 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
trainer.logger.info(f"load state_dict for {ckp.keys()}")
|
||||
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
|
||||
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED,
|
||||
checkpoint_handler)
|
||||
trainer.add_event_handler(
|
||||
Events.EPOCH_COMPLETED(every=config.handler.checkpoint.epoch_interval) | Events.COMPLETED,
|
||||
checkpoint_handler
|
||||
)
|
||||
trainer.logger.debug(f"add checkpoint handler to save {to_save.keys()} periodically")
|
||||
if end_event is not None:
|
||||
trainer.logger.debug(f"engine will stop on {end_event}")
|
||||
@@ -88,17 +101,48 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
engine.terminate()
|
||||
|
||||
|
||||
def setup_tensorboard_handler(trainer: Engine, config, output_transform, iter_per_epoch):
|
||||
if config.interval.tensorboard is None:
|
||||
def setup_tensorboard_handler(trainer: Engine, config, optimizers, step_type="item"):
|
||||
if config.handler.tensorboard is None:
|
||||
return None
|
||||
if idist.get_rank() == 0:
|
||||
# Create a logger
|
||||
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
||||
basic_event = Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
|
||||
tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"),
|
||||
event_name=basic_event)
|
||||
tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform),
|
||||
event_name=basic_event)
|
||||
tb_writer = tb_logger.writer
|
||||
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
|
||||
global_step_transform = step_transform_maker(step_type, pairs_per_iteration)
|
||||
|
||||
basic_event = Events.ITERATION_COMPLETED(
|
||||
every=max(config.iterations_per_epoch // config.handler.tensorboard.scalar, 1))
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OutputHandler(
|
||||
tag="metric", metric_names="all",
|
||||
global_step_transform=global_step_transform
|
||||
),
|
||||
event_name=basic_event
|
||||
)
|
||||
|
||||
@trainer.on(basic_event)
|
||||
def log_loss(engine):
|
||||
global_step = global_step_transform(engine, None)
|
||||
output_loss = engine.state.output["loss"]
|
||||
for total_loss in output_loss:
|
||||
if isinstance(output_loss[total_loss], dict):
|
||||
for ln in output_loss[total_loss]:
|
||||
tb_writer.add_scalar(f"train_{total_loss}/{ln}", output_loss[total_loss][ln], global_step)
|
||||
else:
|
||||
tb_writer.add_scalar(f"train/{total_loss}", output_loss[total_loss], global_step)
|
||||
|
||||
if isinstance(optimizers, dict):
|
||||
for name in optimizers:
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OptimizerParamsHandler(optimizers[name], tag=f"optimizer_{name}"),
|
||||
event_name=Events.ITERATION_STARTED
|
||||
)
|
||||
else:
|
||||
tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizers, tag=f"optimizer"),
|
||||
event_name=Events.ITERATION_STARTED)
|
||||
|
||||
@trainer.on(Events.COMPLETED)
|
||||
@idist.one_rank_only()
|
||||
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models.vgg as vgg
|
||||
|
||||
|
||||
@@ -97,12 +98,13 @@ class PerceptualLoss(nn.Module):
|
||||
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
|
||||
use_input_norm=use_input_norm)
|
||||
|
||||
if criterion == 'L1':
|
||||
self.criterion = torch.nn.L1Loss()
|
||||
elif criterion == "L2":
|
||||
self.criterion = torch.nn.MSELoss()
|
||||
else:
|
||||
raise NotImplementedError(f'{criterion} criterion has not been supported in this version.')
|
||||
self.criterion = self.set_criterion(criterion)
|
||||
|
||||
def set_criterion(self, criterion: str):
|
||||
assert criterion in ["NL1", "NL2", "L1", "L2"]
|
||||
norm = F.instance_norm if criterion.startswith("N") else lambda x: x
|
||||
fn = F.l1_loss if criterion.endswith("L1") else F.mse_loss
|
||||
return lambda x, t: fn(norm(x), norm(t))
|
||||
|
||||
def forward(self, x, gt):
|
||||
"""Forward function.
|
||||
@@ -124,8 +126,7 @@ class PerceptualLoss(nn.Module):
|
||||
if self.perceptual_loss:
|
||||
percep_loss = 0
|
||||
for k in x_features.keys():
|
||||
percep_loss += self.criterion(
|
||||
x_features[k], gt_features[k]) * self.layer_weights[k]
|
||||
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
|
||||
else:
|
||||
percep_loss = None
|
||||
|
||||
@@ -133,9 +134,8 @@ class PerceptualLoss(nn.Module):
|
||||
if self.style_loss:
|
||||
style_loss = 0
|
||||
for k in x_features.keys():
|
||||
style_loss += self.criterion(
|
||||
self._gram_mat(x_features[k]),
|
||||
self._gram_mat(gt_features[k])) * self.layer_weights[k]
|
||||
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \
|
||||
self.layer_weights[k]
|
||||
else:
|
||||
style_loss = None
|
||||
|
||||
|
||||
8
main.py
8
main.py
@@ -33,10 +33,10 @@ def running(local_rank, config, task, backup_config=False, setup_output_dir=Fals
|
||||
|
||||
if setup_output_dir and config.resume_from is None:
|
||||
if output_dir.exists():
|
||||
# assert not any(output_dir.iterdir()), "output_dir must be empty"
|
||||
contains = list(output_dir.iterdir())
|
||||
assert (len(contains) == 0) or (len(contains) == 1 and contains[0].name == "config.yml"), \
|
||||
f"output_dir must by empty or only contains config.yml, but now got {len(contains)} files"
|
||||
assert len(list(output_dir.glob("events*"))) == 0
|
||||
assert len(list(output_dir.glob("*.pt"))) == 0
|
||||
if (output_dir / "train.log").exists() and idist.get_rank() == 0:
|
||||
(output_dir / "train.log").unlink()
|
||||
else:
|
||||
if idist.get_rank() == 0:
|
||||
output_dir.mkdir(parents=True)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .residual_generator import ResidualBlock
|
||||
from .base import ResidualBlock
|
||||
from model.registry import MODEL
|
||||
from torchvision.models import vgg19
|
||||
from model.normalization import select_norm_layer
|
||||
@@ -148,48 +148,65 @@ class Fusion(nn.Module):
|
||||
return self.end_fc(x)
|
||||
|
||||
|
||||
@MODEL.register_module("TAHG-Generator")
|
||||
class StyleGenerator(nn.Module):
|
||||
def __init__(self, style_in_channels, style_dim=512, num_blocks=8, base_channels=64, padding_mode="reflect"):
|
||||
super().__init__()
|
||||
self.num_blocks = num_blocks
|
||||
self.style_encoder = VGG19StyleEncoder(
|
||||
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE")
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(style_dim, style_dim),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
res_block_channels = 2 ** 2 * base_channels
|
||||
self.fusion = Fusion(style_dim, num_blocks * 2 * res_block_channels * 2, base_features=256, n_blocks=3,
|
||||
norm_type="NONE")
|
||||
|
||||
def forward(self, x):
|
||||
styles = self.fusion(self.fc(self.style_encoder(x)))
|
||||
return styles
|
||||
|
||||
|
||||
@MODEL.register_module("TAFG-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, num_blocks=8,
|
||||
base_channels=64, padding_mode="reflect"):
|
||||
super(Generator, self).__init__()
|
||||
self.num_blocks = num_blocks
|
||||
self.style_encoders = nn.ModuleDict({
|
||||
"a": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
|
||||
padding_mode=padding_mode, norm_type="NONE"),
|
||||
"b": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
|
||||
padding_mode=padding_mode, norm_type="NONE")
|
||||
"a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks,
|
||||
base_channels=base_channels, padding_mode=padding_mode),
|
||||
"b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks,
|
||||
base_channels=base_channels, padding_mode=padding_mode),
|
||||
})
|
||||
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks,
|
||||
padding_mode=padding_mode, norm_type="IN")
|
||||
res_block_channels = 2 ** 2 * base_channels
|
||||
self.adain_res = nn.ModuleList([
|
||||
self.adain_resnet_a = nn.ModuleList([
|
||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
|
||||
])
|
||||
self.adain_resnet_b = nn.ModuleList([
|
||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
|
||||
])
|
||||
self.decoders = nn.ModuleDict({
|
||||
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode),
|
||||
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode)
|
||||
})
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(style_dim, style_dim),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
self.fusion = Fusion(style_dim, num_blocks * 2 * res_block_channels * 2, base_features=256, n_blocks=3,
|
||||
norm_type="NONE")
|
||||
self.decoders = nn.ModuleDict({
|
||||
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode),
|
||||
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode)
|
||||
})
|
||||
|
||||
def forward(self, content_img, style_img, which_decoder: str = "a"):
|
||||
x = self.content_encoder(content_img)
|
||||
styles = self.fusion(self.fc(self.style_encoders[which_decoder](style_img)))
|
||||
styles = self.style_encoders[which_decoder](style_img)
|
||||
styles = torch.chunk(styles, self.num_blocks * 2, dim=1)
|
||||
for i, ar in enumerate(self.adain_res):
|
||||
resnet = self.adain_resnet_a if which_decoder == "a" else self.adain_resnet_b
|
||||
for i, ar in enumerate(resnet):
|
||||
ar.norm1.set_style(styles[2 * i])
|
||||
ar.norm2.set_style(styles[2 * i + 1])
|
||||
x = ar(x)
|
||||
return self.decoders[which_decoder](x)
|
||||
|
||||
|
||||
@MODEL.register_module("TAHG-Discriminator")
|
||||
@MODEL.register_module("TAFG-Discriminator")
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels=3, base_channels=64, num_down_sampling=2, num_blocks=3, norm_type="IN",
|
||||
padding_mode="reflect"):
|
||||
@@ -7,7 +7,7 @@ from model import MODEL
|
||||
|
||||
|
||||
# based SPADE or pix2pixHD Discriminator
|
||||
@MODEL.register_module("base-PatchDiscriminator")
|
||||
@MODEL.register_module("pix2pixHD-PatchDiscriminator")
|
||||
class PatchDiscriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN",
|
||||
need_intermediate_feature=False):
|
||||
@@ -59,3 +59,26 @@ class PatchDiscriminator(nn.Module):
|
||||
for layer in self.conv_blocks:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None):
|
||||
super(ResidualBlock, self).__init__()
|
||||
if use_bias is None:
|
||||
# Only for IN, use bias since it does not have affine parameters.
|
||||
use_bias = norm_type == "IN"
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
bias=use_bias)
|
||||
self.norm1 = norm_layer(num_channels)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
bias=use_bias)
|
||||
self.norm2 = norm_layer(num_channels)
|
||||
|
||||
def forward(self, x):
|
||||
res = x
|
||||
x = self.relu1(self.norm1(self.conv1(x)))
|
||||
x = self.norm2(self.conv2(x))
|
||||
return x + res
|
||||
|
||||
@@ -58,27 +58,29 @@ class GANImageBuffer(object):
|
||||
return return_images
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None):
|
||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_dropout=False, use_bias=None):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
if use_bias is None:
|
||||
# Only for IN, use bias since it does not have affine parameters.
|
||||
use_bias = norm_type == "IN"
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
bias=use_bias)
|
||||
self.norm1 = norm_layer(num_channels)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
bias=use_bias)
|
||||
self.norm2 = norm_layer(num_channels)
|
||||
models = [nn.Sequential(
|
||||
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
|
||||
norm_layer(num_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)]
|
||||
if use_dropout:
|
||||
models.append(nn.Dropout(0.5))
|
||||
models.append(nn.Sequential(
|
||||
nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias),
|
||||
norm_layer(num_channels),
|
||||
))
|
||||
self.block = nn.Sequential(*models)
|
||||
|
||||
def forward(self, x):
|
||||
res = x
|
||||
x = self.relu1(self.norm1(self.conv1(x)))
|
||||
x = self.norm2(self.conv2(x))
|
||||
return x + res
|
||||
return x + self.block(x)
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from model.registry import MODEL
|
||||
import model.GAN.residual_generator
|
||||
import model.GAN.TAHG
|
||||
import model.GAN.TAFG
|
||||
import model.GAN.UGATIT
|
||||
import model.fewshot
|
||||
import model.GAN.wrapper
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ignite.distributed import utils as idist
|
||||
from ignite.distributed.comp_models import native as idist_native
|
||||
from ignite.utils import setup_logger
|
||||
|
||||
|
||||
def auto_model(model: nn.Module, **additional_kwargs) -> nn.Module:
|
||||
"""Helper method to adapt provided model for non-distributed and distributed configurations (supporting
|
||||
all available backends from :meth:`~ignite.distributed.utils.available_backends()`).
|
||||
|
||||
Internally, we perform to following:
|
||||
|
||||
- send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device.
|
||||
- wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1.
|
||||
- wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ignite.distribted as idist
|
||||
|
||||
model = idist.auto_model(model)
|
||||
|
||||
In addition with NVidia/Apex, it can be used in the following way:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ignite.distribted as idist
|
||||
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
|
||||
model = idist.auto_model(model)
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model to adapt.
|
||||
|
||||
Returns:
|
||||
torch.nn.Module
|
||||
|
||||
.. _torch DistributedDataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel
|
||||
.. _torch DataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
|
||||
"""
|
||||
logger = setup_logger(__name__ + ".auto_model")
|
||||
|
||||
# Put model's parameters to device if its parameters are not on the device
|
||||
device = idist.device()
|
||||
if not all([p.device == device for p in model.parameters()]):
|
||||
model.to(device)
|
||||
|
||||
# distributed data parallel model
|
||||
if idist.get_world_size() > 1:
|
||||
if idist.backend() == idist_native.NCCL:
|
||||
lrank = idist.get_local_rank()
|
||||
logger.info("Apply torch DistributedDataParallel on model, device id: {}".format(lrank))
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[lrank, ], **additional_kwargs)
|
||||
elif idist.backend() == idist_native.GLOO:
|
||||
logger.info("Apply torch DistributedDataParallel on model")
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, **additional_kwargs)
|
||||
|
||||
# not distributed but multiple GPUs reachable so data parallel model
|
||||
elif torch.cuda.device_count() > 1 and "cuda" in idist.device().type:
|
||||
logger.info("Apply torch DataParallel on model")
|
||||
model = torch.nn.parallel.DataParallel(model, **additional_kwargs)
|
||||
|
||||
return model
|
||||
@@ -1,26 +1,34 @@
|
||||
import torchvision.utils
|
||||
from matplotlib.pyplot import get_cmap
|
||||
import torch
|
||||
import warnings
|
||||
from torch.nn.functional import interpolate
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def attention_colored_map(attentions, size=None, cmap_name="jet"):
|
||||
def attention_colored_map(attentions, size=None):
|
||||
assert attentions.dim() == 4 and attentions.size(1) == 1
|
||||
device = attentions.device
|
||||
|
||||
min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
||||
attentions -= min_attentions
|
||||
attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
||||
|
||||
if size is not None and attentions.size()[-2:] != size:
|
||||
attentions = attentions.detach().cpu().numpy()
|
||||
attentions = (attentions * 255).astype(np.uint8)
|
||||
need_resize = False
|
||||
if size is not None and attentions.shape[-2:] != size:
|
||||
assert len(size) == 2, "for interpolate, size must be (x, y), have two dim"
|
||||
attentions = interpolate(attentions, size, mode="bilinear", align_corners=False)
|
||||
cmap = get_cmap(cmap_name)
|
||||
ca = cmap(attentions.squeeze(1).cpu())[:, :, :, :3]
|
||||
return torch.from_numpy(ca).permute(0, 3, 1, 2).contiguous()
|
||||
need_resize = True
|
||||
|
||||
subs = []
|
||||
for sub in attentions:
|
||||
sub = cv2.resize(sub[0], size) if need_resize else sub[0] # numpy.array shape=size
|
||||
subs.append(cv2.applyColorMap(sub, cv2.COLORMAP_JET)) # append a (size[0], size[1], 3) numpy array
|
||||
subs = np.stack(subs) # (batch_size, size[0], size[1], 3)
|
||||
return torch.from_numpy(subs).permute(0, 3, 1, 2).contiguous().to(device).float() / 255
|
||||
|
||||
|
||||
def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
|
||||
def fuse_attention_map(images, attentions, alpha=0.5):
|
||||
"""
|
||||
|
||||
:param images: B x H x W
|
||||
@@ -35,7 +43,7 @@ def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
|
||||
if attentions.size(1) != 1:
|
||||
warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}")
|
||||
return images
|
||||
colored_attentions = attention_colored_map(attentions, images.size()[-2:], cmap_name).to(images.device)
|
||||
colored_attentions = attention_colored_map(attentions, images.size()[-2:])
|
||||
return images * alpha + colored_attentions * (1 - alpha)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user