Compare commits

...

9 Commits

Author SHA1 Message Date
8998c30c23 TSIT 2020-10-25 20:46:34 +08:00
0bec02bf6d 23333 2020-10-23 16:14:37 +08:00
f7b7b78669 imporved gan loss 2020-10-22 23:19:03 +08:00
376f5caeb7 v2 2020-10-22 22:42:01 +08:00
0019d4034c change a lot 2020-10-14 18:55:51 +08:00
0927fa3de5 add patch d 2020-10-13 10:31:17 +08:00
611901cbdf add ConvTranspose2d in Conv2d 2020-10-13 10:31:03 +08:00
a6ffab1445 add image buffers for gan 2020-10-13 10:30:27 +08:00
7b05b45156 update SPADE 2020-10-12 19:01:07 +08:00
26 changed files with 1070 additions and 286 deletions

2
.idea/deployment.xml generated
View File

@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="21d" remoteFilesAllowedToDisappearOnAutoupload="false">
<component name="PublishConfigData" autoUpload="Always" serverName="14d" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="14d">
<serverdata>

2
.idea/misc.xml generated
View File

@@ -1,4 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="15d-python" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="14d-python" project-jdk-type="Python SDK" />
</project>

2
.idea/raycv.iml generated
View File

@@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="15d-python" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="14d-python" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="TestRunnerService">

View File

@@ -1,57 +1,85 @@
name: horse2zebra-CyCleGAN
engine: CyCleGAN
name: huawei-cycylegan-7
engine: CycleGAN
result_dir: ./result
max_pairs: 266800
max_pairs: 1000000
misc:
random_seed: 324
handler:
clear_cuda_cache: False
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch
image: 4 # log image `image` times per epoch
test:
random: True
images: 10
model:
generator:
_type: CyCle-Generator
_type: CycleGAN-Generator
_add_spectral_norm: True
in_channels: 3
out_channels: 3
base_channels: 64
num_blocks: 9
padding_mode: reflect
norm_type: IN
use_transpose_conv: False
pre_activation: True
# discriminator:
# _type: MultiScaleDiscriminator
# _add_spectral_norm: True
# num_scale: 2
# down_sample_method: "bilinear"
# discriminator_cfg:
# _type: PatchDiscriminator
# in_channels: 3
# base_channels: 64
# num_conv: 4
# need_intermediate_feature: True
discriminator:
_type: PatchDiscriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_conv: 4
need_intermediate_feature: False
loss:
gan:
loss_type: lsgan
loss_type: hinge
weight: 1.0
real_label_val: 1.0
real_label_val: 1
fake_label_val: 0.0
cycle:
level: 1
weight: 10.0
id:
level: 1
weight: 10.0
mgc:
weight: 1
fm:
weight: 0
edge:
weight: 0
hed_pretrained_model_path: ./network-bsds500.pytorch
optimizers:
generator:
_type: Adam
lr: 2e-4
lr: 1e-4
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 2e-4
lr: 4e-4
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
data:
train:
@@ -60,17 +88,28 @@ data:
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 6
batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/horse2zebra/trainA"
root_b: "/data/i2i/horse2zebra/trainB"
root_a: "/data/face2cartoon/all_face"
root_b: "/data/selfie2anime/trainB/"
random_pair: True
pipeline:
pipeline_a:
- Load
- RandomCrop:
size: [ 178, 178 ]
- Resize:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 286, 286 ]
@@ -82,17 +121,38 @@ data:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: video_dataset
dataloader:
batch_size: 4
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/horse2zebra/testA"
root_b: "/data/i2i/horse2zebra/testB"
random_pair: False
root_a: "/data/face2cartoon/test/human"
root_b: "/data/face2cartoon/test/anime"
random_pair: True
pipeline_a:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
video_dataset:
_type: SingleFolderDataset
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
with_path: True
pipeline:
- Load
- Resize:

View File

@@ -0,0 +1,167 @@
name: huawei-GauGAN-3
engine: GauGAN
result_dir: ./result
max_pairs: 1000000
misc:
random_seed: 324
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 4 # log image `image` times per epoch
test:
random: True
images: 10
model:
generator:
_type: SPADEGenerator
_add_spectral_norm: True
in_channels: 3
out_channels: 3
num_blocks: 7
use_vae: False
num_z_dim: 256
# discriminator:
# _type: MultiScaleDiscriminator
# _add_spectral_norm: True
# num_scale: 2
# down_sample_method: "bilinear"
# discriminator_cfg:
# _type: PatchDiscriminator
# in_channels: 3
# base_channels: 64
# num_conv: 4
# need_intermediate_feature: True
discriminator:
_type: PatchDiscriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_conv: 4
need_intermediate_feature: True
loss:
gan:
loss_type: hinge
weight: 1.0
real_label_val: 1
fake_label_val: 0.0
perceptual:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 2
mgc:
weight: 5
fm:
weight: 5
edge:
weight: 0
hed_pretrained_model_path: ./network-bsds500.pytorch
optimizers:
generator:
_type: Adam
lr: 1e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 4e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/face2cartoon/all_face"
root_b: "/data/selfie2anime/trainB/"
random_pair: True
pipeline_a:
- Load
- RandomCrop:
size: [ 178, 178 ]
- Resize:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 286, 286 ]
- RandomCrop:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: video_dataset
dataloader:
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/face2cartoon/test/human"
root_b: "/data/face2cartoon/test/anime"
random_pair: True
pipeline_a:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
video_dataset:
_type: SingleFolderDataset
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
with_path: True
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@@ -1,7 +1,10 @@
name: VoxCeleb2Anime-TSIT
engine: TSIT
name: huawei-TSIT-1
engine: GauGAN
result_dir: ./result
max_pairs: 1500000
max_pairs: 1000000
misc:
random_seed: 324
handler:
clear_cuda_cache: True
@@ -16,34 +19,39 @@ handler:
random: True
images: 10
misc:
random_seed: 324
model:
generator:
_type: TSIT-Generator
_bn_to_sync_bn: True
style_in_channels: 3
content_in_channels: 3
num_blocks: 5
input_layer_type: "conv7x7"
_add_spectral_norm: True
in_channels: 3
out_channels: 3
num_blocks: 7
# discriminator:
# _type: MultiScaleDiscriminator
# _add_spectral_norm: True
# num_scale: 2
# down_sample_method: "bilinear"
# discriminator_cfg:
# _type: PatchDiscriminator
# in_channels: 3
# base_channels: 64
# num_conv: 4
# need_intermediate_feature: True
discriminator:
_type: MultiScaleDiscriminator
num_scale: 2
discriminator_cfg:
_type: PatchDiscriminator
in_channels: 3
base_channels: 64
use_spectral: True
need_intermediate_feature: True
_type: PatchDiscriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_conv: 4
need_intermediate_feature: True
loss:
gan:
loss_type: hinge
real_label_val: 1.0
fake_label_val: 0.0
weight: 1.0
real_label_val: 1
fake_label_val: 0.0
perceptual:
layer_weights:
"1": 0.03125
@@ -55,25 +63,18 @@ loss:
style_loss: False
perceptual_loss: True
weight: 1
style:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L2'
style_loss: True
perceptual_loss: False
weight: 0
mgc:
weight: 5
fm:
level: 1
weight: 1
edge:
weight: 0
hed_pretrained_model_path: ./network-bsds500.pytorch
optimizers:
generator:
_type: Adam
lr: 0.0001
lr: 1e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
discriminator:
@@ -87,24 +88,35 @@ data:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 50
buffer_size: 0
dataloader:
batch_size: 8
batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/faces/CelebA-Asian/trainA"
root_b: "/data/i2i/anime/your-name/faces"
root_a: "/data/face2cartoon/all_face"
root_b: "/data/selfie2anime/trainB/"
random_pair: True
pipeline:
pipeline_a:
- Load
- RandomCrop:
size: [ 178, 178 ]
- Resize:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 170, 144 ]
size: [ 286, 286 ]
- RandomCrop:
size: [ 128, 128 ]
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
@@ -113,22 +125,28 @@ data:
test:
which: video_dataset
dataloader:
batch_size: 8
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/faces/CelebA-Asian/testA"
root_b: "/data/i2i/anime/your-name/faces"
random_pair: False
pipeline:
root_a: "/data/face2cartoon/test/human"
root_b: "/data/face2cartoon/test/anime"
random_pair: True
pipeline_a:
- Load
- Resize:
size: [ 170, 144 ]
- RandomCrop:
size: [ 128, 128 ]
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]

View File

@@ -78,7 +78,7 @@ data:
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 4
batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
@@ -102,7 +102,7 @@ data:
test:
which: video_dataset
dataloader:
batch_size: 8
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False

View File

@@ -38,9 +38,9 @@ class SingleFolderDataset(Dataset):
@DATASET.register_module()
class GenerationUnpairedDataset(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False):
self.A = SingleFolderDataset(root_a, pipeline, with_path)
self.B = SingleFolderDataset(root_b, pipeline, with_path)
def __init__(self, root_a, root_b, random_pair, pipeline_a, pipeline_b, with_path=False):
self.A = SingleFolderDataset(root_a, pipeline_a, with_path)
self.B = SingleFolderDataset(root_b, pipeline_b, with_path)
self.with_path = with_path
self.random_pair = random_pair

View File

@@ -2,25 +2,27 @@ from itertools import chain
import ignite.distributed as idist
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from loss.gan import GANLoss
from model.GAN.base import GANImageBuffer
from engine.util.container import GANImageBuffer, LossContainer
from engine.util.loss import pixel_loss, gan_loss, feature_match_loss
from loss.I2I.edge_loss import EdgeLoss
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
from model.weight_init import generation_init_weights
class TAFGEngineKernel(EngineKernel):
class CycleGANEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
self.gan_loss = gan_loss(config.loss.gan)
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same"))
self.edge_loss = LossContainer(config.loss.edge.weight, EdgeLoss(
"HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(idist.device()))
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
@@ -56,21 +58,23 @@ class TAFGEngineKernel(EngineKernel):
images["b2a"] = self.generators["b2a"](batch["b"])
images["a2b2a"] = self.generators["b2a"](images["a2b"])
images["b2a2b"] = self.generators["a2b"](images["b2a"])
if self.config.loss.id.weight > 0:
if self.id_loss.weight > 0:
images["a2a"] = self.generators["b2a"](batch["a"])
images["b2b"] = self.generators["a2b"](batch["b"])
return images
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
for phase in ["a2b", "b2a"]:
loss[f"cycle_{phase[0]}"] = self.config.loss.cycle.weight * self.cycle_loss(
generated[f"{phase}2{phase[0]}"], batch[phase[0]])
loss[f"gan_{phase}"] = self.config.loss.gan.weight * self.gan_loss(
self.discriminators[phase[-1]](generated[phase]), True)
if self.config.loss.id.weight > 0:
loss[f"id_{phase[0]}"] = self.config.loss.id.weight * self.id_loss(
generated[f"{phase[0]}2{phase[0]}"], batch[phase[0]])
for ph in "ab":
loss[f"cycle_{ph}"] = self.cycle_loss(generated["a2b2a" if ph == "a" else "b2a2b"], batch[ph])
loss[f"id_{ph}"] = self.id_loss(generated[f"{ph}2{ph}"], batch[ph])
loss[f"mgc_{ph}"] = self.mgc_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph])
prediction_fake = self.discriminators[ph](generated["a2b" if ph == "b" else "b2a"])
loss[f"gan_{ph}"] = self.config.loss.gan.weight * self.gan_loss(prediction_fake, True)
if self.fm_loss.weight > 0:
prediction_real = self.discriminators[ph](batch[ph])
loss[f"feature_match_{ph}"] = self.fm_loss(prediction_fake, prediction_real)
loss[f"edge_{ph}"] = self.edge_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph], gt_is_edge=False)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
@@ -97,5 +101,5 @@ class TAFGEngineKernel(EngineKernel):
def run(task, config, _):
kernel = TAFGEngineKernel(config)
kernel = CycleGANEngineKernel(config)
run_kernel(task, config, kernel)

86
engine/GauGAN.py Normal file
View File

@@ -0,0 +1,86 @@
from itertools import chain
import torch
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from engine.util.container import GANImageBuffer, LossContainer
from engine.util.loss import gan_loss, feature_match_loss, perceptual_loss
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
from model.weight_init import generation_init_weights
class GauGANEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
self.gan_loss = gan_loss(config.loss.gan)
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same"))
self.perceptual_loss = LossContainer(config.loss.perceptual.weight, perceptual_loss(config.loss.perceptual))
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
def build_models(self) -> (dict, dict):
generators = dict(
main=build_model(self.config.model.generator)
)
discriminators = dict(
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["b"])
self.logger.debug(generators["main"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
images = dict()
with torch.set_grad_enabled(not inference):
images["a2b"] = self.generators["main"](batch["a"])
return images
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
prediction_fake = self.discriminators["b"](generated["a2b"])
loss["gan"] = self.config.loss.gan.weight * self.gan_loss(prediction_fake, True)
loss["mgc"] = self.mgc_loss(generated["a2b"], batch["a"])
loss["perceptual"] = self.perceptual_loss(generated["a2b"], batch["a"])
if self.fm_loss.weight > 0:
prediction_real = self.discriminators["b"](batch["b"])
loss["feature_match"] = self.fm_loss(prediction_fake, prediction_real)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
generated_image = self.image_buffers["b"].query(generated["a2b"].detach())
loss["b"] = (self.gan_loss(self.discriminators["b"](generated_image), False, is_discriminator=True) +
self.gan_loss(self.discriminators["b"](batch["b"]), True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
return dict(
a=[batch["a"].detach(), generated["a2b"].detach()],
)
def run(task, config, _):
kernel = GauGANEngineKernel(config)
run_kernel(task, config, kernel)

View File

@@ -1,38 +1,31 @@
import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
from engine.util.container import LossContainer
from engine.util.loss import bce_loss, mse_loss, pixel_loss, gan_loss
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
from loss.gan import GANLoss
from model.image_translation.UGATIT import RhoClipper
from util.image import attention_colored_map
def pixel_loss(level):
return nn.L1Loss() if level == 1 else nn.MSELoss()
class RhoClipper(object):
def __init__(self, clip_min, clip_max):
self.clip_min = clip_min
self.clip_max = clip_max
assert clip_min < clip_max
def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def __call__(self, module):
if hasattr(module, 'rho'):
w = module.rho.data
w = w.clamp(self.clip_min, self.clip_max)
module.rho.data = w
class UGATITEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.gan_loss = gan_loss(config.loss.gan)
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))

View File

@@ -101,9 +101,12 @@ class EngineKernel(object):
def _remove_no_grad_loss(loss_dict):
need_to_pop = []
for k in loss_dict:
if not isinstance(loss_dict[k], torch.Tensor):
loss_dict.pop(k)
need_to_pop.append(k)
for k in need_to_pop:
loss_dict.pop(k)
return loss_dict

View File

@@ -1,3 +1,6 @@
import torch
class LossContainer:
def __init__(self, weight, loss):
self.weight = weight
@@ -7,3 +10,57 @@ class LossContainer:
if self.weight > 0:
return self.weight * self.loss(*args, **kwargs)
return 0.0
class GANImageBuffer:
"""This class implements an image buffer that stores previously
generated images.
This buffer allows us to update the discriminator using a history of
generated images rather than the ones produced by the latest generator
to reduce model oscillation.
Args:
buffer_size (int): The size of image buffer. If buffer_size = 0,
no buffer will be created.
buffer_ratio (float): The chance / possibility to use the images
previously stored in the buffer.
"""
def __init__(self, buffer_size, buffer_ratio=0.5):
self.buffer_size = buffer_size
# create an empty buffer
if self.buffer_size > 0:
self.img_num = 0
self.image_buffer = []
self.buffer_ratio = buffer_ratio
def query(self, images):
"""Query current image batch using a history of generated images.
Args:
images (Tensor): Current image batch without history information.
"""
if self.buffer_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
# if the buffer is not full, keep inserting current images
if self.img_num < self.buffer_size:
self.img_num = self.img_num + 1
self.image_buffer.append(image)
return_images.append(image)
else:
use_buffer = torch.rand(1) < self.buffer_ratio
# by self.buffer_ratio, the buffer will return a previously
# stored image, and insert the current image into the buffer
if use_buffer:
random_id = torch.randint(0, self.buffer_size, (1,)).item()
image_tmp = self.image_buffer[random_id].clone()
self.image_buffer[random_id] = image
return_images.append(image_tmp)
# by (1 - self.buffer_ratio), the buffer will return the
# current image
else:
return_images.append(image)
# collect all the images and return
return_images = torch.cat(return_images, 0)
return return_images

48
engine/util/loss.py Normal file
View File

@@ -0,0 +1,48 @@
import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
def gan_loss(config):
gan_loss_cfg = OmegaConf.to_container(config)
gan_loss_cfg.pop("weight")
return GANLoss(**gan_loss_cfg).to(idist.device())
def perceptual_loss(config):
perceptual_loss_cfg = OmegaConf.to_container(config)
perceptual_loss_cfg.pop("weight")
return PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
def pixel_loss(level):
return nn.L1Loss() if level == 1 else nn.MSELoss()
def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def feature_match_loss(level, weight_policy):
compare_loss = pixel_loss(level)
assert weight_policy in ["same", "exponential_decline"]
def fm_loss(generated_features, target_features):
num_scale = len(generated_features)
loss = torch.zeros(1, device=idist.device())
for s_i in range(num_scale):
for i in range(len(generated_features[s_i]) - 1):
weight = 1 if weight_policy == "same" else 2 ** i
loss += weight * compare_loss(generated_features[s_i][i], target_features[s_i][i].detach()) / num_scale
return loss
return fm_loss

View File

@@ -1,3 +1,4 @@
import ignite.distributed as idist
import torch
import torch.nn as nn
@@ -5,17 +6,59 @@ import torch.nn as nn
def gaussian_radial_basis_function(x, mu, sigma):
# (kernel_size) -> (batch_size, kernel_size, c*h*w)
mu = mu.view(1, mu.size(0), 1).expand(x.size(0), -1, x.size(1) * x.size(2) * x.size(3))
mu = mu.to(x.device)
# (batch_size, c, h, w) -> (batch_size, kernel_size, c*h*w)
x = x.view(x.size(0), 1, -1).expand(-1, mu.size(1), -1)
return torch.exp((x - mu).pow(2) / (2 * sigma ** 2))
class ImporveMyLoss(torch.nn.Module):
def __init__(self, device=idist.device()):
super().__init__()
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).to(device)
self.x_mu_list = mu.repeat(9).view(-1, 81)
self.y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)
self.R = torch.eye(81).to(device)
def batch_ERSMI(self, I1, I2):
batch_size = I1.shape[0]
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
if I2.shape[1] == 1 and I1.shape[1] != 1:
I2 = I2.repeat(1, 3, 1, 1)
def kernel_F(y, mu_list, sigma):
tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size, 1, 1) # [81, 784]
tmp_y = y.view(batch_size, 1, -1).repeat(1, 81, 1)
tmp_y = tmp_mu - tmp_y
mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
return mat_L
mat_K = kernel_F(I1, self.x_mu_list, 1)
mat_L = kernel_F(I2, self.y_mu_list, 1)
mat_k_l = mat_K * mat_L
H1 = (mat_K @ mat_K.transpose(1, 2)) * (mat_L @ mat_L.transpose(1, 2)) / (img_size ** 2)
h_hat = mat_k_l @ mat_k_l.transpose(1, 2) / img_size
small_h_hat = mat_K.sum(2).view(batch_size, -1, 1) * mat_L.sum(2).view(batch_size, -1, 1) / (img_size ** 2)
h_hat = 0.5 * H1 + 0.5 * h_hat
alpha = (h_hat + 0.05 * self.R).inverse() @ small_h_hat
ersmi = 2 * alpha.transpose(1, 2) @ small_h_hat - alpha.transpose(1, 2) @ h_hat @ alpha - 1
ersmi = -ersmi.squeeze().mean()
return ersmi
def forward(self, fakeI, realI):
return self.batch_ERSMI(fakeI, realI)
class MyLoss(torch.nn.Module):
def __init__(self):
super(MyLoss, self).__init__()
def forward(self, fakeI, realI):
fakeI = fakeI.cuda()
realI = realI.cuda()
def batch_ERSMI(I1, I2):
batch_size = I1.shape[0]
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
@@ -49,6 +92,7 @@ class MyLoss(torch.nn.Module):
alpha = alpha.matmul(h2)
ersmi = (2 * (alpha.transpose(1, 2)).matmul(h2) - ((alpha.transpose(1, 2)).matmul(H2)).matmul(
alpha) - 1).squeeze()
ersmi = -ersmi.mean()
return ersmi
@@ -61,16 +105,19 @@ class MGCLoss(nn.Module):
Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ
"""
def __init__(self, beta=0.5, lambda_=0.05):
def __init__(self, mi_to_loss_way="opposite", beta=0.5, lambda_=0.05, device=idist.device()):
super().__init__()
self.beta = beta
self.lambda_ = lambda_
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0])
self.mu_x = mu.repeat(9)
self.mu_y = mu.unsqueeze(0).t().repeat(1, 9).view(-1)
assert mi_to_loss_way in ["opposite", "reciprocal"]
self.mi_to_loss_way = mi_to_loss_way
mu_y, mu_x = torch.meshgrid([torch.arange(-1, 1.25, 0.25), torch.arange(-1, 1.25, 0.25)])
self.mu_x = mu_x.flatten().to(device)
self.mu_y = mu_y.flatten().to(device)
self.R = torch.eye(81).unsqueeze(0).to(device)
@staticmethod
def batch_rSMI(img1, img2, mu_x, mu_y, beta, lambda_):
def batch_rSMI(img1, img2, mu_x, mu_y, beta, lambda_, R):
assert img1.size() == img2.size()
num_pixel = img1.size(1) * img1.size(2) * img2.size(3)
@@ -79,33 +126,104 @@ class MGCLoss(nn.Module):
mat_l = gaussian_radial_basis_function(img2, mu_y, sigma=1)
mat_k_mul_mat_l = mat_k * mat_l
h_hat = (1 - beta) * (mat_k_mul_mat_l.matmul(mat_k_mul_mat_l.transpose(1, 2))) / num_pixel
h_hat += beta * (mat_k.matmul(mat_k.transpose(1, 2)) * mat_l.matmul(mat_l.transpose(1, 2))) / (num_pixel ** 2)
h_hat = (1 - beta) * (mat_k_mul_mat_l @ mat_k_mul_mat_l.transpose(1, 2)) / num_pixel
h_hat += beta * ((mat_k @ mat_k.transpose(1, 2)) * (mat_l @ mat_l.transpose(1, 2))) / (num_pixel ** 2)
small_h_hat = mat_k.sum(2, keepdim=True) * mat_l.sum(2, keepdim=True) / (num_pixel ** 2)
R = torch.eye(h_hat.size(1)).to(img1.device)
alpha = (h_hat + lambda_ * R).inverse().matmul(small_h_hat)
rSMI = (2 * alpha.transpose(1, 2).matmul(small_h_hat)) - alpha.transpose(1, 2).matmul(h_hat).matmul(alpha) - 1
return rSMI
alpha = (h_hat + lambda_ * R).inverse() @ small_h_hat
rSMI = 2 * alpha.transpose(1, 2) @ small_h_hat - alpha.transpose(1, 2) @ h_hat @ alpha - 1
return rSMI.squeeze()
def forward(self, fake, real):
rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_)
return -rSMI.squeeze().mean()
rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_, self.R)
if self.mi_to_loss_way == "reciprocal":
return 1/rSMI.mean()
return -rSMI.mean()
if __name__ == '__main__':
mg = MGCLoss().to("cuda")
mg = MGCLoss(device=torch.device("cpu"))
my = MyLoss().to("cuda")
imy = ImporveMyLoss()
from data.transform import transform_pipeline
def norm(x):
x -= x.min()
x /= x.max()
return (x - 0.5) * 2
pipeline = transform_pipeline(
['Load', 'ToTensor', {'Normalize': {'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5]}}])
img_a1 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_1.jpg")
img_a2 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_2.jpg")
img_a3 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_3.jpg")
img_b1 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_1.jpg")
img_b2 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_2.jpg")
img_b3 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_3.jpg")
x1 = norm(torch.randn(5, 3, 256, 256))
x2 = norm(x1 * 2 + 1)
x3 = norm(torch.randn(5, 3, 256, 256))
x4 = norm(torch.exp(x3))
print(mg(x1, x1), mg(x1, x2), mg(x1, x3), mg(x1, x4))
img_a1.requires_grad_(True)
img_a2.requires_grad_(True)
img_a3.requires_grad_(True)
# print("MyLoss")
# l1 = my(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
# l2 = my(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
# l3 = my(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
# l = (l1+l2+l3)/3
# l.backward()
# print(img_a1.grad[0][0][0:10])
# print(img_a2.grad[0][0][0:10])
# print(img_a3.grad[0][0][0:10])
#
# img_a1.grad = None
# img_a2.grad = None
# img_a3.grad = None
#
# print("---")
# l = my(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# l.backward()
# print(img_a1.grad[0][0][0:10])
# print(img_a2.grad[0][0][0:10])
# print(img_a3.grad[0][0][0:10])
# img_a1.grad = None
# img_a2.grad = None
# img_a3.grad = None
print("MGCLoss")
l1 = mg(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
l2 = mg(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
l3 = mg(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
l = (l1 + l2 + l3) / 3
l.backward()
print(img_a1.grad[0][0][0:10])
print(img_a2.grad[0][0][0:10])
print(img_a3.grad[0][0][0:10])
img_a1.grad = None
img_a2.grad = None
img_a3.grad = None
print("---")
l = mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
l.backward()
print(img_a1.grad[0][0][0:10])
print(img_a2.grad[0][0][0:10])
print(img_a3.grad[0][0][0:10])
# print("\nMGCLoss")
# mg(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
# mg(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
# mg(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
#
# print("---")
# mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
#
# import pprofile
#
# profiler = pprofile.Profile()
# with profiler:
# iter_times = 1000
# for _ in range(iter_times):
# mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# for _ in range(iter_times):
# my(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# for _ in range(iter_times):
# imy(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# profiler.print_stats()

View File

@@ -1,4 +1,5 @@
import torch.nn as nn
import torch
import torch.nn.functional as F
@@ -10,7 +11,7 @@ class GANLoss(nn.Module):
self.fake_label_val = fake_label_val
self.loss_type = loss_type
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
def single_forward(self, prediction, target_is_real: bool, is_discriminator=False):
"""
gan loss forward
:param prediction: network prediction
@@ -37,3 +38,20 @@ class GANLoss(nn.Module):
return loss
else:
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
if isinstance(prediction, torch.Tensor):
# origin
return self.single_forward(prediction, target_is_real, is_discriminator)
elif isinstance(prediction, list):
# for multi scale discriminator, e.g. MultiScaleDiscriminator
loss = 0
for p in prediction:
loss += self.single_forward(p[-1], target_is_real, is_discriminator)
return loss
elif isinstance(prediction, tuple):
# for single discriminator set `need_intermediate_feature` true
return self.single_forward(prediction[-1], target_is_real, is_discriminator)
else:
raise NotImplementedError(f"not support discriminator output: {prediction}")

View File

@@ -1,3 +1,7 @@
from model.registry import MODEL, NORMALIZATION
import model.base.normalization
import model.image_translation
import model.image_translation.UGATIT
import model.image_translation.CycleGAN
import model.image_translation.pix2pixHD
import model.image_translation.GauGAN
import model.image_translation.TSIT

View File

@@ -52,35 +52,37 @@ class LinearBlock(nn.Module):
class Conv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, bias=None,
activation_type="ReLU", norm_type="NONE",
additional_norm_kwargs=None, **conv_kwargs):
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None,
pre_activation=False, use_transpose_conv=False, **conv_kwargs):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
self.pre_activation = pre_activation
# if caller not set bias, set bias automatically.
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
if use_transpose_conv:
# Only "zeros" padding mode is supported for ConvTranspose2d
conv_kwargs["padding_mode"] = "zeros"
conv = nn.ConvTranspose2d
else:
conv = nn.Conv2d
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
self.activation = _activation(activation_type)
if pre_activation:
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
self.activation = _activation(activation_type, inplace=False)
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
else:
# if caller not set bias, set bias automatically.
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
self.activation = _activation(activation_type)
def forward(self, x):
if self.pre_activation:
return self.convolution(self.activation(self.normalization(x)))
return self.activation(self.normalization(self.convolution(x)))
class ReverseConv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int,
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
super().__init__()
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
self.activation = _activation(activation_type, inplace=False)
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
def forward(self, x):
return self.convolution(self.activation(self.normalization(x)))
class ResidualBlock(nn.Module):
def __init__(self, in_channels,
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
@@ -109,16 +111,17 @@ class ResidualBlock(nn.Module):
self.learn_skip_connection = in_channels != out_channels
conv_block = ReverseConv2dBlock if pre_activation else Conv2dBlock
conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type,
additional_norm_kwargs=additional_norm_kwargs,
padding_mode=padding_mode)
additional_norm_kwargs=additional_norm_kwargs, pre_activation=pre_activation,
padding_mode=padding_mode)
self.conv1 = conv_block(in_channels, in_channels, **conv_param)
self.conv2 = conv_block(in_channels, out_channels, **conv_param)
self.conv1 = Conv2dBlock(in_channels, in_channels, **conv_param)
self.conv2 = Conv2dBlock(in_channels, out_channels, **conv_param)
if self.learn_skip_connection:
self.res_conv = conv_block(in_channels, out_channels, **conv_param)
conv_param['kernel_size'] = 1
conv_param['padding'] = 0
self.res_conv = Conv2dBlock(in_channels, out_channels, **conv_param)
def forward(self, x):
res = x if not self.learn_skip_connection else self.res_conv(x)

View File

@@ -1,5 +1,6 @@
import torch.nn as nn
from model import MODEL
from model.base.module import Conv2dBlock, ResidualBlock
@@ -20,7 +21,7 @@ class Encoder(nn.Module):
multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode=padding_mode,
kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode="zeros",
activation_type=activation_type, norm_type=down_conv_norm_type
))
self.out_channels = multiple_now * base_channels
@@ -43,7 +44,7 @@ class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
activation_type="ReLU", padding_mode='reflect',
up_conv_kernel_size=5, up_conv_norm_type="LN",
res_norm_type="AdaIN", pre_activation=False):
res_norm_type="AdaIN", pre_activation=False, use_transpose_conv=False):
super().__init__()
self.residual_blocks = nn.ModuleList([
ResidualBlock(
@@ -57,13 +58,23 @@ class Decoder(nn.Module):
sequence = list()
channels = in_channels
padding = (up_conv_kernel_size - 1) // 2
for i in range(num_up_sampling):
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1,
padding=int(up_conv_kernel_size / 2), padding_mode=padding_mode,
activation_type=activation_type, norm_type=up_conv_norm_type),
))
if use_transpose_conv:
sequence.append(Conv2dBlock(
channels, channels // 2, kernel_size=up_conv_kernel_size, stride=2,
padding=padding, output_padding=padding,
padding_mode=padding_mode,
activation_type=activation_type, norm_type=up_conv_norm_type,
use_transpose_conv=True
))
else:
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1,
padding=padding, padding_mode=padding_mode,
activation_type=activation_type, norm_type=up_conv_norm_type),
))
channels = channels // 2
sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3,
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"))
@@ -74,3 +85,67 @@ class Decoder(nn.Module):
for i, blk in enumerate(self.residual_blocks):
x = blk(x)
return self.up_sequence(x)
@MODEL.register_module("CycleGAN-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, activation_type="ReLU",
padding_mode='reflect', norm_type="IN", pre_activation=False, use_transpose_conv=True):
super().__init__()
self.encoder = Encoder(in_channels, base_channels, num_conv=2, num_res=num_blocks,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type=norm_type, res_norm_type=norm_type, pre_activation=pre_activation)
self.decoder = Decoder(self.encoder.out_channels, out_channels, num_up_sampling=2, num_residual_blocks=0,
padding_mode=padding_mode, activation_type=activation_type,
up_conv_kernel_size=3, up_conv_norm_type=norm_type,
pre_activation=pre_activation, use_transpose_conv=use_transpose_conv)
def forward(self, x):
return self.decoder(self.encoder(x))
@MODEL.register_module("PatchDiscriminator")
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_conv=4, need_intermediate_feature=False,
norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
self.need_intermediate_feature = need_intermediate_feature
kernel_size = 4
padding = (kernel_size - 1) // 2
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=kernel_size, stride=2, padding=padding, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
multiple_now = 1
for i in range(1, num_conv):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 3)
stride = 1 if i == num_conv - 1 else 2
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
))
sequence.append(nn.Conv2d(
base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding, padding_mode=padding_mode))
if self.need_intermediate_feature:
self.sequence = nn.ModuleList(sequence)
else:
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
if self.need_intermediate_feature:
intermediate_feature = []
for layer in self.sequence:
x = layer(x)
intermediate_feature.append(x)
return tuple(intermediate_feature)
else:
return self.sequence(x)
if __name__ == '__main__':
g = Generator(**dict(in_channels=3, out_channels=3))
print(g)
pd = PatchDiscriminator(**dict(in_channels=3, base_channels=64, num_conv=4))
print(pd)

View File

@@ -1,9 +1,13 @@
from collections import OrderedDict
from functools import partial
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.base.module import ResidualBlock, ReverseConv2dBlock, Conv2dBlock
from model.base.module import ResidualBlock, Conv2dBlock, LinearBlock
from model import MODEL
class StyleEncoder(nn.Module):
def __init__(self, in_channels, style_dim, num_conv, end_size=(4, 4), base_channels=64,
@@ -33,6 +37,92 @@ class StyleEncoder(nn.Module):
return self.fc_avg(x), self.fc_var(x)
class ImprovedSPADEGenerator(nn.Module):
def __init__(self, in_channels, out_channels, output_size, have_style_input, style_dim, start_size=(4, 4),
base_channels=64, padding_mode='reflect', activation_type="LeakyReLU", pre_activation=False):
super().__init__()
assert output_size in (128, 256, 512, 1024)
self.output_size = output_size
kernel_size = 3
if have_style_input:
self.style_converter = nn.Sequential(
LinearBlock(style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"),
LinearBlock(2 * style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"),
)
base_conv = partial(
Conv2dBlock,
pre_activation=pre_activation, activation_type=activation_type,
norm_type="AdaIN" if have_style_input else "NONE",
kernel_size=kernel_size, padding=(kernel_size - 1) // 2, padding_mode=padding_mode
)
base_residual_block = partial(
ResidualBlock,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type="SPADE",
pre_activation=True,
additional_norm_kwargs=dict(
condition_in_channels=in_channels, base_channels=128, base_norm_type="BN",
activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0
)
)
sequence = OrderedDict()
channels = (2 ** 4) * base_channels
sequence["block_head"] = nn.Sequential(OrderedDict([
("conv_input", base_conv(in_channels=in_channels, out_channels=channels)),
("conv_style", base_conv(in_channels=channels, out_channels=channels)),
("res_a", base_residual_block(in_channels=channels, out_channels=channels)),
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
("up", nn.Upsample(scale_factor=2, mode='nearest'))
]))
for i in range(4, 9 - min(int(math.log(self.output_size, 2)), 8), -1):
channels = (2 ** (i - 1)) * base_channels
sequence[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([
("res_a", base_residual_block(in_channels=channels * 2, out_channels=channels)),
("conv_style", base_conv(in_channels=channels, out_channels=channels)),
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
("up", nn.Upsample(scale_factor=2, mode='nearest'))
]))
self.sequence = nn.Sequential(sequence)
# channels = 2*base_channels when output size is 256, 512, 1024
# channels = 5*base_channels when output size is 128
out_modules = OrderedDict()
out_modules["out_1"] = nn.Sequential(
Conv2dBlock(
channels, out_channels, kernel_size=5, stride=1, padding=2,
pre_activation=pre_activation,
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"
),
nn.Tanh()
)
for i in range(int(math.log(self.output_size, 2)) - 8):
channels = channels // 2
out_modules[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([
("res_a", base_residual_block(in_channels=2 * channels, out_channels=channels)),
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
("up", nn.Upsample(scale_factor=2, mode='nearest'))
]))
out_modules[f"out_{i + 2}"] = nn.Sequential(
Conv2dBlock(
channels, out_channels, kernel_size=5, stride=1, padding=2,
pre_activation=pre_activation,
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"
),
nn.Tanh()
)
self.out_modules = nn.ModuleDict(out_modules)
def forward(self, seg, style=None):
pass
@MODEL.register_module()
class SPADEGenerator(nn.Module):
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
padding_mode='reflect', activation_type="LeakyReLU"):
@@ -66,11 +156,8 @@ class SPADEGenerator(nn.Module):
)
))
self.sequence = nn.Sequential(*sequence)
self.output_converter = nn.Sequential(
ReverseConv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"),
nn.Tanh()
)
self.output_converter = Conv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE")
def forward(self, seg, z=None):
if self.use_vae:
@@ -89,7 +176,8 @@ class SPADEGenerator(nn.Module):
x = blk(x)
return self.output_converter(x)
if __name__ == '__main__':
g = SPADEGenerator(3, 3, 7, False, 256)
print(g)
print(g(torch.randn(2, 3, 256, 256)).size())
print(g(torch.randn(2, 3, 256, 256)).size())

View File

@@ -0,0 +1,98 @@
import torch.nn as nn
import torch.nn.functional as F
import torch
from model import MODEL
from model.base.module import ResidualBlock, Conv2dBlock
class Interpolation(nn.Module):
def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
super(Interpolation, self).__init__()
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
recompute_scale_factor=False)
def __repr__(self):
return f"Interpolation(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
@MODEL.register_module("TSIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=7,
padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
self.input_layer = Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
activation_type=activation_type, norm_type="IN",
)
multiple_now = 1
stream_sequence = []
for i in range(1, num_blocks + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
stream_sequence.append(nn.Sequential(
Interpolation(scale_factor=0.5, mode="nearest"),
ResidualBlock(
multiple_prev * base_channels, out_channels=multiple_now * base_channels,
padding_mode=padding_mode, activation_type=activation_type, norm_type="IN")
))
self.down_sequence = nn.ModuleList(stream_sequence)
sequence = []
multiple_now = 16
for i in range(num_blocks - 1, -1, -1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
sequence.append(nn.Sequential(
ResidualBlock(
base_channels * multiple_prev,
out_channels=base_channels * multiple_now,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type="FADE",
pre_activation=True,
additional_norm_kwargs=dict(
condition_in_channels=base_channels * multiple_prev, base_norm_type="BN",
padding_mode="zeros", gamma_bias=0.0
)
),
Interpolation(scale_factor=2, mode="nearest")
))
self.up_sequence = nn.Sequential(*sequence)
self.output_layer = Conv2dBlock(
base_channels, out_channels, kernel_size=3, stride=1, padding=1,
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"
)
def forward(self, x, z=None):
c = self.input_layer(x)
contents = []
for blk in self.down_sequence:
c = blk(c)
contents.append(c)
if z is None:
# for image 256x256, z size: [batch_size, 1024, 2, 2]
z = torch.randn(size=contents[-1].size(), device=contents[-1].device)
for blk in self.up_sequence:
res = blk[0]
c = contents.pop()
res.conv1.normalization.set_feature(c)
res.conv2.normalization.set_feature(c)
if res.learn_skip_connection:
res.res_conv.normalization.set_feature(c)
return self.output_layer(self.up_sequence(z))
if __name__ == '__main__':
g = Generator(3, 3).cuda()
img = torch.randn(2, 3, 256, 256).cuda()
print(g(img).size())

View File

@@ -6,19 +6,6 @@ from model.base.module import Conv2dBlock, LinearBlock
from model.image_translation.CycleGAN import Encoder, Decoder
class RhoClipper(object):
def __init__(self, clip_min, clip_max):
self.clip_min = clip_min
self.clip_max = clip_max
assert clip_min < clip_max
def __call__(self, module):
if hasattr(module, 'rho'):
w = module.rho.data
w = w.clamp(self.clip_min, self.clip_max)
module.rho.data = w
class CAMClassifier(nn.Module):
def __init__(self, in_channels, activation_type="ReLU"):
super(CAMClassifier, self).__init__()

View File

@@ -0,0 +1,29 @@
import torch.nn as nn
import torch.nn.functional as F
from model import MODEL
@MODEL.register_module()
class MultiScaleDiscriminator(nn.Module):
def __init__(self, num_scale, discriminator_cfg, down_sample_method="avg"):
super().__init__()
assert down_sample_method in ["avg", "bilinear"]
self.down_sample_method = down_sample_method
self.discriminator_list = nn.ModuleList([
MODEL.build_with(discriminator_cfg) for _ in range(num_scale)
])
def down_sample(self, x):
if self.down_sample_method == "avg":
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
if self.down_sample_method == "bilinear":
return F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True)
def forward(self, x):
results = []
for discriminator in self.discriminator_list:
results.append(discriminator(x))
x = self.down_sample(x)
return results

View File

@@ -1,76 +0,0 @@
import functools
import torch
import torch.nn as nn
def select_norm_layer(norm_type):
if norm_type == "BN":
return functools.partial(nn.BatchNorm2d)
elif norm_type == "IN":
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == "LN":
return functools.partial(LayerNorm2d, affine=True)
elif norm_type == "NONE":
return lambda num_features: nn.Identity()
elif norm_type == "AdaIN":
return functools.partial(AdaptiveInstanceNorm2d, affine=False, track_running_stats=False)
else:
raise NotImplemented(f'normalization layer {norm_type} is not found')
class LayerNorm2d(nn.Module):
def __init__(self, num_features, eps: float = 1e-5, affine: bool = True):
super().__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
if self.affine:
self.channel_gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.channel_beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.reset_parameters()
def reset_parameters(self):
if self.affine:
nn.init.uniform_(self.channel_gamma)
nn.init.zeros_(self.channel_beta)
def forward(self, x):
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
x = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
if self.affine:
return self.channel_gamma * x + self.channel_beta
return x
def __repr__(self):
return f"{self.__class__.__name__}(num_features={self.num_features}, affine={self.affine})"
class AdaptiveInstanceNorm2d(nn.Module):
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1,
affine: bool = False, track_running_stats: bool = False):
super().__init__()
self.num_features = num_features
self.affine = affine
self.track_running_stats = track_running_stats
self.norm = nn.InstanceNorm2d(num_features, eps, momentum, affine, track_running_stats)
self.gamma = None
self.beta = None
self.have_set_style = False
def set_style(self, style):
style = style.view(*style.size(), 1, 1)
self.gamma, self.beta = style.chunk(2, 1)
self.have_set_style = True
def forward(self, x):
assert self.have_set_style
x = self.norm(x)
x = self.gamma * x + self.beta
self.have_set_style = False
return x
def __repr__(self):
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
f"affine={self.affine}, track_running_stats={self.track_running_stats})"

View File

@@ -65,7 +65,8 @@ def generation_init_weights(module, init_type='normal', init_gain=0.02):
elif classname.find('BatchNorm2d') != -1:
# BatchNorm Layer's weight is not a matrix;
# only normal distribution applies.
normal_init(m, 1.0, init_gain)
if m.weight is not None:
normal_init(m, 1.0, init_gain)
assert isinstance(module, nn.Module)
module.apply(init_func)

View File

@@ -1,8 +1,10 @@
import inspect
from omegaconf.dictconfig import DictConfig
from omegaconf import OmegaConf
from types import ModuleType
import warnings
from types import ModuleType
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
class _Registry:
def __init__(self, name):
@@ -51,11 +53,9 @@ class _Registry:
else:
raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}')
for k in args:
assert isinstance(k, str)
if k.startswith("_"):
warnings.warn(f"got param start with `_`: {k}, will remove it")
args.pop(k)
for invalid_key in [k for k in args.keys() if k.startswith("_")]:
warnings.warn(f"got param start with `_`: {invalid_key}, will remove it")
args.pop(invalid_key)
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
@@ -136,8 +136,11 @@ class Registry(_Registry):
if module_name is None:
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {self.name}')
if self._module_dict[module_name] == module_class:
warnings.warn(f'{module_name} is already registered in {self.name}, but is the same class')
return
raise KeyError(f'{module_name}:{self._module_dict[module_name]} is already registered in {self.name}'
f'so {module_class} can not be registered')
self._module_dict[module_name] = module_class
def register_module(self, name=None, force=False, module=None):