Compare commits

...

34 Commits

Author SHA1 Message Date
8998c30c23 TSIT 2020-10-25 20:46:34 +08:00
0bec02bf6d 23333 2020-10-23 16:14:37 +08:00
f7b7b78669 imporved gan loss 2020-10-22 23:19:03 +08:00
376f5caeb7 v2 2020-10-22 22:42:01 +08:00
0019d4034c change a lot 2020-10-14 18:55:51 +08:00
0927fa3de5 add patch d 2020-10-13 10:31:17 +08:00
611901cbdf add ConvTranspose2d in Conv2d 2020-10-13 10:31:03 +08:00
a6ffab1445 add image buffers for gan 2020-10-13 10:30:27 +08:00
7b05b45156 update SPADE 2020-10-12 19:01:07 +08:00
2de00d0245 use loss container 2020-10-11 23:36:37 +08:00
74a7cfb2d8 move sn to engine 2020-10-11 23:35:29 +08:00
436bca88b4 add loss container 2020-10-11 23:09:04 +08:00
6070f08835 add GauGAN 2020-10-11 23:05:38 +08:00
06b2abd19a add flag to switch to norm-activ-conv 2020-10-11 19:02:42 +08:00
9c08b4cd09 move encoder, decoder to CycleGAN 2020-10-11 11:09:16 +08:00
04c6366c07 rewrite 2020-10-11 10:02:33 +08:00
6ea13df465 temp commit 2020-10-10 10:43:00 +08:00
776fe40199 change a lot 2020-09-26 17:48:26 +08:00
f67bcdf161 use base module rewrite TSIT 2020-09-26 17:48:10 +08:00
16f18ab2e2 func to apply sn 2020-09-26 17:47:24 +08:00
0f2b67e215 base model, Norm&Conv&ResNet 2020-09-26 17:45:51 +08:00
acf243cb12 working 2020-09-25 18:31:12 +08:00
fbea96f6d7 add new dataset type 2020-09-24 16:50:53 +08:00
ca55318253 add context loss 2020-09-24 16:38:03 +08:00
b01016edb5 TAFG update 2020-09-18 12:03:44 +08:00
61e04de8a5 TAFG 2020-09-17 09:34:53 +08:00
2ff4a91057 add MUNIT 2020-09-14 22:30:05 +08:00
f70658eaed small fix 2020-09-11 23:04:26 +08:00
340a344e91 add distance 2020-09-11 22:35:59 +08:00
85b5c3f589 fix small bug in U-GAT-IT 2020-09-11 22:34:43 +08:00
72d09aa483 update tester 2020-09-10 18:34:52 +08:00
7ea9c6d0d5 TAFG good result 2020-09-09 14:46:07 +08:00
87cbcf34d3 add tool to dump images in tensorboard event file 2020-09-09 09:08:11 +08:00
97ded53b30 update a lot 2020-09-07 21:38:10 +08:00
60 changed files with 3408 additions and 1238 deletions

4
.idea/deployment.xml generated
View File

@@ -1,11 +1,11 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="15d" remoteFilesAllowedToDisappearOnAutoupload="false">
<component name="PublishConfigData" autoUpload="Always" serverName="14d" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="14d">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
<mapping deploy="raycv" local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>

2
.idea/misc.xml generated
View File

@@ -1,4 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="15d-python" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="14d-python" project-jdk-type="Python SDK" />
</project>

2
.idea/raycv.iml generated
View File

@@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="15d-python" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="14d-python" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="TestRunnerService">

View File

@@ -1,57 +1,85 @@
name: horse2zebra-CyCleGAN
engine: CyCleGAN
name: huawei-cycylegan-7
engine: CycleGAN
result_dir: ./result
max_pairs: 266800
max_pairs: 1000000
misc:
random_seed: 324
handler:
clear_cuda_cache: False
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch
image: 4 # log image `image` times per epoch
test:
random: True
images: 10
model:
generator:
_type: CyCle-Generator
_type: CycleGAN-Generator
_add_spectral_norm: True
in_channels: 3
out_channels: 3
base_channels: 64
num_blocks: 9
padding_mode: reflect
norm_type: IN
use_transpose_conv: False
pre_activation: True
# discriminator:
# _type: MultiScaleDiscriminator
# _add_spectral_norm: True
# num_scale: 2
# down_sample_method: "bilinear"
# discriminator_cfg:
# _type: PatchDiscriminator
# in_channels: 3
# base_channels: 64
# num_conv: 4
# need_intermediate_feature: True
discriminator:
_type: PatchDiscriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_conv: 4
need_intermediate_feature: False
loss:
gan:
loss_type: lsgan
loss_type: hinge
weight: 1.0
real_label_val: 1.0
real_label_val: 1
fake_label_val: 0.0
cycle:
level: 1
weight: 10.0
id:
level: 1
weight: 10.0
mgc:
weight: 1
fm:
weight: 0
edge:
weight: 0
hed_pretrained_model_path: ./network-bsds500.pytorch
optimizers:
generator:
_type: Adam
lr: 2e-4
lr: 1e-4
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 2e-4
lr: 4e-4
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
data:
train:
@@ -60,17 +88,28 @@ data:
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 6
batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/horse2zebra/trainA"
root_b: "/data/i2i/horse2zebra/trainB"
root_a: "/data/face2cartoon/all_face"
root_b: "/data/selfie2anime/trainB/"
random_pair: True
pipeline:
pipeline_a:
- Load
- RandomCrop:
size: [ 178, 178 ]
- Resize:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 286, 286 ]
@@ -82,17 +121,38 @@ data:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: video_dataset
dataloader:
batch_size: 4
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/horse2zebra/testA"
root_b: "/data/i2i/horse2zebra/testB"
random_pair: False
root_a: "/data/face2cartoon/test/human"
root_b: "/data/face2cartoon/test/anime"
random_pair: True
pipeline_a:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
video_dataset:
_type: SingleFolderDataset
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
with_path: True
pipeline:
- Load
- Resize:

View File

@@ -0,0 +1,167 @@
name: huawei-GauGAN-3
engine: GauGAN
result_dir: ./result
max_pairs: 1000000
misc:
random_seed: 324
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 4 # log image `image` times per epoch
test:
random: True
images: 10
model:
generator:
_type: SPADEGenerator
_add_spectral_norm: True
in_channels: 3
out_channels: 3
num_blocks: 7
use_vae: False
num_z_dim: 256
# discriminator:
# _type: MultiScaleDiscriminator
# _add_spectral_norm: True
# num_scale: 2
# down_sample_method: "bilinear"
# discriminator_cfg:
# _type: PatchDiscriminator
# in_channels: 3
# base_channels: 64
# num_conv: 4
# need_intermediate_feature: True
discriminator:
_type: PatchDiscriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_conv: 4
need_intermediate_feature: True
loss:
gan:
loss_type: hinge
weight: 1.0
real_label_val: 1
fake_label_val: 0.0
perceptual:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 2
mgc:
weight: 5
fm:
weight: 5
edge:
weight: 0
hed_pretrained_model_path: ./network-bsds500.pytorch
optimizers:
generator:
_type: Adam
lr: 1e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 4e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/face2cartoon/all_face"
root_b: "/data/selfie2anime/trainB/"
random_pair: True
pipeline_a:
- Load
- RandomCrop:
size: [ 178, 178 ]
- Resize:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 286, 286 ]
- RandomCrop:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: video_dataset
dataloader:
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/face2cartoon/test/human"
root_b: "/data/face2cartoon/test/anime"
random_pair: True
pipeline_a:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
video_dataset:
_type: SingleFolderDataset
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
with_path: True
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@@ -0,0 +1,132 @@
name: MUNIT-edges2shoes
engine: MUNIT
result_dir: ./result
max_pairs: 1000000
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch
misc:
random_seed: 324
model:
generator:
_type: MUNIT-Generator
in_channels: 3
out_channels: 3
base_channels: 64
num_sampling: 2
num_style_dim: 8
num_style_conv: 4
num_content_res_blocks: 4
num_decoder_res_blocks: 4
num_fusion_dim: 256
num_fusion_blocks: 3
discriminator:
_type: MultiScaleDiscriminator
num_scale: 2
discriminator_cfg:
_type: PatchDiscriminator
in_channels: 3
base_channels: 64
use_spectral: True
need_intermediate_feature: True
loss:
gan:
loss_type: lsgan
real_label_val: 1.0
fake_label_val: 0.0
weight: 1.0
perceptual:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 0
recon:
level: 1
style:
weight: 1
content:
weight: 1
image:
weight: 10
cycle:
weight: 0
optimizers:
generator:
_type: Adam
lr: 0.0001
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 4e-4
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 1
shuffle: True
num_workers: 1
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/edges2shoes/trainA"
root_b: "/data/i2i/edges2shoes/trainB"
random_pair: True
pipeline:
- Load
- Resize:
size: [ 286, 286 ]
- RandomCrop:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: dataset
dataloader:
batch_size: 8
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/edges2shoes/testA"
root_b: "/data/i2i/edges2shoes/testB"
random_pair: False
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@@ -1,7 +1,7 @@
name: TAFG
name: TAFG-vox2
engine: TAFG
result_dir: ./result
max_pairs: 1500000
max_pairs: 1000000
handler:
clear_cuda_cache: True
@@ -11,11 +11,15 @@ handler:
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch
image: 4 # log image `image` times per epoch
test:
random: True
images: 10
misc:
random_seed: 324
random_seed: 1004
add_new_loss_epoch: -1
model:
generator:
@@ -23,7 +27,13 @@ model:
_bn_to_sync_bn: False
style_in_channels: 3
content_in_channels: 24
num_blocks: 8
use_spectral_norm: False
style_encoder_type: StyleEncoder
num_style_conv: 4
style_dim: 8
num_adain_blocks: 4
num_res_blocks: 4
discriminator:
_type: MultiScaleDiscriminator
num_scale: 2
@@ -47,30 +57,32 @@ loss:
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L2'
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 0.5
weight: 0
style:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L2'
"3": 1
criterion: 'L1'
style_loss: True
perceptual_loss: False
weight: 0
fm:
level: 1
weight: 10
recon:
level: 1
weight: 10
style_recon:
level: 1
weight: 0
weight: 1
content_recon:
level: 1
weight: 1
edge:
weight: 5
hed_pretrained_model_path: ./network-bsds500.pytorch
cycle:
level: 1
weight: 10
optimizers:
generator:
@@ -91,9 +103,9 @@ data:
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 24
batch_size: 8
shuffle: True
num_workers: 2
num_workers: 1
pin_memory: True
drop_last: True
dataset:
@@ -114,8 +126,9 @@ data:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: video_dataset
dataloader:
batch_size: 8
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
@@ -144,7 +157,7 @@ data:
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]

View File

@@ -0,0 +1,165 @@
name: huawei-TSIT-1
engine: GauGAN
result_dir: ./result
max_pairs: 1000000
misc:
random_seed: 324
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 4 # log image `image` times per epoch
test:
random: True
images: 10
model:
generator:
_type: TSIT-Generator
_add_spectral_norm: True
in_channels: 3
out_channels: 3
num_blocks: 7
# discriminator:
# _type: MultiScaleDiscriminator
# _add_spectral_norm: True
# num_scale: 2
# down_sample_method: "bilinear"
# discriminator_cfg:
# _type: PatchDiscriminator
# in_channels: 3
# base_channels: 64
# num_conv: 4
# need_intermediate_feature: True
discriminator:
_type: PatchDiscriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_conv: 4
need_intermediate_feature: True
loss:
gan:
loss_type: hinge
weight: 1.0
real_label_val: 1
fake_label_val: 0.0
perceptual:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 1
mgc:
weight: 5
fm:
weight: 1
edge:
weight: 0
hed_pretrained_model_path: ./network-bsds500.pytorch
optimizers:
generator:
_type: Adam
lr: 1e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 4e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 0
dataloader:
batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/face2cartoon/all_face"
root_b: "/data/selfie2anime/trainB/"
random_pair: True
pipeline_a:
- Load
- RandomCrop:
size: [ 178, 178 ]
- Resize:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 286, 286 ]
- RandomCrop:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: video_dataset
dataloader:
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/face2cartoon/test/human"
root_b: "/data/face2cartoon/test/anime"
random_pair: True
pipeline_a:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
video_dataset:
_type: SingleFolderDataset
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
with_path: True
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@@ -14,11 +14,15 @@ handler:
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch
image: 4 # log image `image` times per epoch
test:
random: True
images: 10
model:
generator:
_type: UGATIT-Generator
_add_spectral_norm: True
in_channels: 3
out_channels: 3
base_channels: 64
@@ -27,11 +31,13 @@ model:
light: True
local_discriminator:
_type: UGATIT-Discriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_blocks: 5
global_discriminator:
_type: UGATIT-Discriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_blocks: 7
@@ -50,6 +56,8 @@ loss:
weight: 10.0
cam:
weight: 1000
mgc:
weight: 0
optimizers:
generator:
@@ -70,7 +78,7 @@ data:
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 24
batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
@@ -92,8 +100,9 @@ data:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: video_dataset
dataloader:
batch_size: 8
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False

View File

@@ -0,0 +1,171 @@
name: talking_anime
engine: talking_anime
result_dir: ./result
max_pairs: 1000000
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 100 # log image `image` times per epoch
test:
random: True
images: 10
misc:
random_seed: 1004
loss:
gan:
loss_type: hinge
real_label_val: 1.0
fake_label_val: 0.0
weight: 1.0
fm:
level: 1
weight: 1
style:
layer_weights:
"3": 1
criterion: 'L1'
style_loss: True
perceptual_loss: False
weight: 10
perceptual:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 0
context:
layer_weights:
#"13": 1
"22": 1
weight: 5
recon:
level: 1
weight: 10
edge:
weight: 5
hed_pretrained_model_path: ./network-bsds500.pytorch
model:
face_generator:
_type: TAFG-SingleGenerator
_bn_to_sync_bn: False
style_in_channels: 3
content_in_channels: 1
use_spectral_norm: True
style_encoder_type: VGG19StyleEncoder
num_style_conv: 4
style_dim: 512
num_adain_blocks: 8
num_res_blocks: 8
anime_generator:
_type: TAFG-ResGenerator
_bn_to_sync_bn: False
in_channels: 6
use_spectral_norm: True
num_res_blocks: 8
discriminator:
_type: MultiScaleDiscriminator
num_scale: 2
discriminator_cfg:
_type: PatchDiscriminator
in_channels: 3
base_channels: 64
use_spectral: True
need_intermediate_feature: True
optimizers:
generator:
_type: Adam
lr: 0.0001
betas: [ 0, 0.9 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 4e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
dataloader:
batch_size: 8
shuffle: True
num_workers: 1
pin_memory: True
drop_last: True
dataset:
_type: PoseFacesWithSingleAnime
root_face: "/data/i2i/VoxCeleb2Anime/trainA"
root_anime: "/data/i2i/VoxCeleb2Anime/trainB"
landmark_path: "/data/i2i/VoxCeleb2Anime/landmarks"
num_face: 2
img_size: [ 128, 128 ]
with_order: False
face_pipeline:
- Load
- Resize:
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
anime_pipeline:
- Load
- Resize:
size: [ 144, 144 ]
- RandomCrop:
size: [ 128, 128 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: dataset
dataloader:
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: PoseFacesWithSingleAnime
root_face: "/data/i2i/VoxCeleb2Anime/testA"
root_anime: "/data/i2i/VoxCeleb2Anime/testB"
landmark_path: "/data/i2i/VoxCeleb2Anime/landmarks"
num_face: 2
img_size: [ 128, 128 ]
with_order: False
face_pipeline:
- Load
- Resize:
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
anime_pipeline:
- Load
- Resize:
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@@ -1,237 +0,0 @@
import os
import pickle
from pathlib import Path
from collections import defaultdict
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from torchvision.transforms import functional as F
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader
import lmdb
from tqdm import tqdm
from .transform import transform_pipeline
from .registry import DATASET
from .util import dlib_landmark
def default_transform_way(transform, sample):
return [transform(sample[0]), *sample[1:]]
class LMDBDataset(Dataset):
def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True,
**lmdb_kwargs):
self.path = lmdb_path
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
lock=False, **lmdb_kwargs)
with self.db.begin(write=False) as txn:
self._len = pickle.loads(txn.get(b"$$len$$"))
self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$"))
if pipeline is None:
self.not_done_pipeline = []
else:
self.not_done_pipeline = self._remain_pipeline(pipeline)
self.transform = transform_pipeline(self.not_done_pipeline)
self.transform_way = transform_way
essential_attr = pickle.loads(txn.get(b"$$essential_attr$$"))
for ea in essential_attr:
setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8"))))
def _remain_pipeline(self, pipeline):
for i, dp in enumerate(self.done_pipeline):
if pipeline[i] != dp:
raise ValueError(
f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}")
return pipeline[len(self.done_pipeline):]
def __repr__(self):
return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}"
def __len__(self):
return self._len
def __getitem__(self, idx):
with self.db.begin(write=False) as txn:
sample = pickle.loads(txn.get("{}".format(idx).encode()))
sample = self.transform_way(self.transform, sample)
return sample
@staticmethod
def lmdbify(dataset, done_pipeline, lmdb_path):
env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=0):
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
txn.put(b"$$len$$", pickle.dumps(len(dataset)))
txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline))
essential_attr = getattr(dataset, "essential_attr", list())
txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr))
for ea in essential_attr:
txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea)))
@DATASET.register_module()
class ImprovedImageFolder(ImageFolder):
def __init__(self, root, pipeline):
super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x)
self.classes_list = defaultdict(list)
self.essential_attr = ["classes_list"]
for i in range(len(self)):
self.classes_list[self.samples[i][-1]].append(i)
assert len(self.classes_list) == len(self.classes)
class EpisodicDataset(Dataset):
def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes):
self.origin = origin_dataset
self.num_class = num_class
assert self.num_class < len(self.origin.classes_list)
self.num_query = num_query # K
self.num_support = num_support # K
self.num_episodes = num_episodes
def _fetch_list_data(self, id_list):
return [self.origin[i][0] for i in id_list]
def __len__(self):
return self.num_episodes
def __getitem__(self, _):
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
support_set = []
query_set = []
target_set = []
for tag, c in enumerate(random_classes):
image_list = self.origin.classes_list[c]
if len(image_list) >= self.num_query + self.num_support:
# have enough images belong to this class
idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist()
else:
idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist()
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
support_set.extend(support)
query_set.extend(query)
target_set.extend([tag] * self.num_query)
return {
"support": torch.stack(support_set),
"query": torch.stack(query_set),
"target": torch.tensor(target_set)
}
def __repr__(self):
return f"<EpisodicDataset NE{self.num_episodes} NC{self.num_class} NS{self.num_support} NQ{self.num_query}>"
@DATASET.register_module()
class SingleFolderDataset(Dataset):
def __init__(self, root, pipeline, with_path=False):
assert os.path.isdir(root)
self.root = root
samples = []
for r, _, fns in sorted(os.walk(self.root, followlinks=True)):
for fn in sorted(fns):
path = os.path.join(r, fn)
if has_file_allowed_extension(path, IMG_EXTENSIONS):
samples.append(path)
self.samples = samples
self.pipeline = transform_pipeline(pipeline)
self.with_path = with_path
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
if not self.with_path:
return self.pipeline(self.samples[idx])
else:
return self.pipeline(self.samples[idx]), self.samples[idx]
def __repr__(self):
return f"<SingleFolderDataset root={self.root} len={len(self)}>"
@DATASET.register_module()
class GenerationUnpairedDataset(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline):
self.A = SingleFolderDataset(root_a, pipeline)
self.B = SingleFolderDataset(root_b, pipeline)
self.random_pair = random_pair
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
return dict(a=self.A[a_idx], b=self.B[b_idx])
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
def normalize_tensor(tensor):
tensor = tensor.float()
tensor -= tensor.min()
tensor /= tensor.max()
return tensor
@DATASET.register_module()
class GenerationUnpairedDatasetWithEdge(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256)):
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
self.edge_type = edge_type
self.size = size
self.edges_path = Path(edges_path)
self.landmarks_path = Path(landmarks_path)
assert self.edges_path.exists()
assert self.landmarks_path.exists()
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
self.random_pair = random_pair
def get_edge(self, origin_path):
op = Path(origin_path)
if self.edge_type.startswith("landmark_"):
edge_type = self.edge_type.lstrip("landmark_")
use_landmark = True
else:
edge_type = self.edge_type
use_landmark = False
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR))
if not use_landmark:
return origin_edge
else:
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt"
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size)))
part_labels = normalize_tensor(torch.from_numpy(part_labels))
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
# edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
# edges = part_edge + edges
return torch.cat([origin_edge, part_edge, dist_tensor, part_labels])
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
output = dict()
output["a"], path_a = self.A[a_idx]
output["b"], path_b = self.B[b_idx]
output["edge_a"] = self.get_edge(path_a)
return output
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"

3
data/dataset/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from util.misc import import_submodules
__all__ = import_submodules(__name__).keys()

63
data/dataset/few-shot.py Normal file
View File

@@ -0,0 +1,63 @@
from collections import defaultdict
import torch
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from data.registry import DATASET
from data.transform import transform_pipeline
@DATASET.register_module()
class ImprovedImageFolder(ImageFolder):
def __init__(self, root, pipeline):
super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x)
self.classes_list = defaultdict(list)
self.essential_attr = ["classes_list"]
for i in range(len(self)):
self.classes_list[self.samples[i][-1]].append(i)
assert len(self.classes_list) == len(self.classes)
class EpisodicDataset(Dataset):
def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes):
self.origin = origin_dataset
self.num_class = num_class
assert self.num_class < len(self.origin.classes_list)
self.num_query = num_query # K
self.num_support = num_support # K
self.num_episodes = num_episodes
def _fetch_list_data(self, id_list):
return [self.origin[i][0] for i in id_list]
def __len__(self):
return self.num_episodes
def __getitem__(self, _):
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
support_set = []
query_set = []
target_set = []
for tag, c in enumerate(random_classes):
image_list = self.origin.classes_list[c]
if len(image_list) >= self.num_query + self.num_support:
# have enough images belong to this class
idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist()
else:
idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist()
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
support_set.extend(support)
query_set.extend(query)
target_set.extend([tag] * self.num_query)
return {
"support": torch.stack(support_set),
"query": torch.stack(query_set),
"target": torch.tensor(target_set)
}
def __repr__(self):
return f"<EpisodicDataset NE{self.num_episodes} NC{self.num_class} NS{self.num_support} NQ{self.num_query}>"

View File

@@ -0,0 +1,62 @@
import os
import torch
from torch.utils.data import Dataset
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
from data.registry import DATASET
from data.transform import transform_pipeline
@DATASET.register_module()
class SingleFolderDataset(Dataset):
def __init__(self, root, pipeline, with_path=False):
assert os.path.isdir(root)
self.root = root
samples = []
for r, _, fns in sorted(os.walk(self.root, followlinks=True)):
for fn in sorted(fns):
path = os.path.join(r, fn)
if has_file_allowed_extension(path, IMG_EXTENSIONS):
samples.append(path)
self.samples = samples
self.pipeline = transform_pipeline(pipeline)
self.with_path = with_path
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
output = dict(img=self.pipeline(self.samples[idx]))
if self.with_path:
output["path"] = self.samples[idx]
return output
def __repr__(self):
return f"<SingleFolderDataset root={self.root} len={len(self)} with_path={self.with_path}>"
@DATASET.register_module()
class GenerationUnpairedDataset(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline_a, pipeline_b, with_path=False):
self.A = SingleFolderDataset(root_a, pipeline_a, with_path)
self.B = SingleFolderDataset(root_b, pipeline_b, with_path)
self.with_path = with_path
self.random_pair = random_pair
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
output_a = self.A[a_idx]
output_b = self.B[b_idx]
output = dict(a=output_a["img"], b=output_b["img"])
if self.with_path:
output["a_path"] = output_a["path"]
output["b_path"] = output_b["path"]
return output
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"

65
data/dataset/lmdb.py Normal file
View File

@@ -0,0 +1,65 @@
import os
import pickle
import lmdb
from torch.utils.data import Dataset
from tqdm import tqdm
from data.transform import transform_pipeline
def default_transform_way(transform, sample):
return [transform(sample[0]), *sample[1:]]
class LMDBDataset(Dataset):
def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True,
**lmdb_kwargs):
self.path = lmdb_path
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
lock=False, **lmdb_kwargs)
with self.db.begin(write=False) as txn:
self._len = pickle.loads(txn.get(b"$$len$$"))
self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$"))
if pipeline is None:
self.not_done_pipeline = []
else:
self.not_done_pipeline = self._remain_pipeline(pipeline)
self.transform = transform_pipeline(self.not_done_pipeline)
self.transform_way = transform_way
essential_attr = pickle.loads(txn.get(b"$$essential_attr$$"))
for ea in essential_attr:
setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8"))))
def _remain_pipeline(self, pipeline):
for i, dp in enumerate(self.done_pipeline):
if pipeline[i] != dp:
raise ValueError(
f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}")
return pipeline[len(self.done_pipeline):]
def __repr__(self):
return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}"
def __len__(self):
return self._len
def __getitem__(self, idx):
with self.db.begin(write=False) as txn:
sample = pickle.loads(txn.get("{}".format(idx).encode()))
sample = self.transform_way(self.transform, sample)
return sample
@staticmethod
def lmdbify(dataset, done_pipeline, lmdb_path):
env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=0):
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
txn.put(b"$$len$$", pickle.dumps(len(dataset)))
txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline))
essential_attr = getattr(dataset, "essential_attr", list())
txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr))
for ea in essential_attr:
txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea)))

View File

@@ -0,0 +1,122 @@
from collections import defaultdict
from itertools import permutations, combinations
from pathlib import Path
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import functional as F
from data.registry import DATASET
from data.transform import transform_pipeline
from data.util import dlib_landmark
def normalize_tensor(tensor):
tensor = tensor.float()
tensor -= tensor.min()
tensor /= tensor.max()
return tensor
@DATASET.register_module()
class GenerationUnpairedDatasetWithEdge(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256),
with_path=False):
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
self.edge_type = edge_type
self.size = size
self.edges_path = Path(edges_path)
self.landmarks_path = Path(landmarks_path)
assert self.edges_path.exists()
assert self.landmarks_path.exists()
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
self.random_pair = random_pair
self.with_path = with_path
def get_edge(self, origin_path):
op = Path(origin_path)
if self.edge_type.startswith("landmark_"):
edge_type = self.edge_type.lstrip("landmark_")
use_landmark = op.parent.name.endswith("A")
else:
edge_type = self.edge_type
use_landmark = False
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR))
if not use_landmark:
return origin_edge
else:
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt"
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size)))
part_labels = normalize_tensor(torch.from_numpy(part_labels))
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
# edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
# edges = part_edge + edges
return torch.cat([origin_edge, part_edge, dist_tensor, part_labels])
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
output = dict(a={}, b={})
output["a"]["img"], output["a"]["path"] = self.A[a_idx]
output["b"]["img"], output["b"]["path"] = self.B[b_idx]
for p in "ab":
output[p]["edge"] = self.get_edge(output[p]["path"])
return output
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
@DATASET.register_module()
class PoseFacesWithSingleAnime(Dataset):
def __init__(self, root_face, root_anime, landmark_path, num_face, face_pipeline, anime_pipeline, img_size,
with_order=True):
self.num_face = num_face
self.landmark_path = Path(landmark_path)
self.with_order = with_order
self.root_face = Path(root_face)
self.root_anime = Path(root_anime)
self.img_size = img_size
self.face_samples = self.iter_folders()
self.face_pipeline = transform_pipeline(face_pipeline)
self.B = SingleFolderDataset(root_anime, anime_pipeline, with_path=True)
def iter_folders(self):
pics_per_person = defaultdict(list)
for p in self.root_face.glob("*.jpg"):
pics_per_person[p.stem[:7]].append(p.stem)
data = []
for p in pics_per_person:
if len(pics_per_person[p]) >= self.num_face:
if self.with_order:
data.extend(list(combinations(pics_per_person[p], self.num_face)))
else:
data.extend(list(permutations(pics_per_person[p], self.num_face)))
return data
def read_pose(self, pose_txt):
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(pose_txt, size=self.img_size)
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.img_size)))
part_labels = normalize_tensor(torch.from_numpy(part_labels))
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
return torch.cat([part_labels, part_edge, dist_tensor])
def __len__(self):
return len(self.face_samples)
def __getitem__(self, idx):
output = dict()
output["anime_img"], output["anime_path"] = self.B[torch.randint(len(self.B), (1,)).item()]
for i, f in enumerate(self.face_samples[idx]):
output[f"face_{i}"] = self.face_pipeline(self.root_face / f"{f}.jpg")
output[f"pose_{i}"] = self.read_pose(self.landmark_path / self.root_face.name / f"{f}.txt")
output[f"stem_{i}"] = f
return output

View File

@@ -28,7 +28,7 @@ class Load:
def transform_pipeline(pipeline_description):
if len(pipeline_description) == 0:
if pipeline_description is None or len(pipeline_description) == 0:
return lambda x: x
transform_list = [TRANSFORM.build_with(pd) for pd in pipeline_description]
return transforms.Compose(transform_list)

View File

@@ -2,25 +2,27 @@ from itertools import chain
import ignite.distributed as idist
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from loss.gan import GANLoss
from model.GAN.base import GANImageBuffer
from engine.util.container import GANImageBuffer, LossContainer
from engine.util.loss import pixel_loss, gan_loss, feature_match_loss
from loss.I2I.edge_loss import EdgeLoss
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
from model.weight_init import generation_init_weights
class TAFGEngineKernel(EngineKernel):
class CycleGANEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
self.gan_loss = gan_loss(config.loss.gan)
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same"))
self.edge_loss = LossContainer(config.loss.edge.weight, EdgeLoss(
"HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(idist.device()))
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
@@ -56,21 +58,23 @@ class TAFGEngineKernel(EngineKernel):
images["b2a"] = self.generators["b2a"](batch["b"])
images["a2b2a"] = self.generators["b2a"](images["a2b"])
images["b2a2b"] = self.generators["a2b"](images["b2a"])
if self.config.loss.id.weight > 0:
if self.id_loss.weight > 0:
images["a2a"] = self.generators["b2a"](batch["a"])
images["b2b"] = self.generators["a2b"](batch["b"])
return images
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
for phase in ["a2b", "b2a"]:
loss[f"cycle_{phase[0]}"] = self.config.loss.cycle.weight * self.cycle_loss(
generated[f"{phase}2{phase[0]}"], batch[phase[0]])
loss[f"gan_{phase}"] = self.config.loss.gan.weight * self.gan_loss(
self.discriminators[phase[-1]](generated[phase]), True)
if self.config.loss.id.weight > 0:
loss[f"id_{phase[0]}"] = self.config.loss.id.weight * self.id_loss(
generated[f"{phase[0]}2{phase[0]}"], batch[phase[0]])
for ph in "ab":
loss[f"cycle_{ph}"] = self.cycle_loss(generated["a2b2a" if ph == "a" else "b2a2b"], batch[ph])
loss[f"id_{ph}"] = self.id_loss(generated[f"{ph}2{ph}"], batch[ph])
loss[f"mgc_{ph}"] = self.mgc_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph])
prediction_fake = self.discriminators[ph](generated["a2b" if ph == "b" else "b2a"])
loss[f"gan_{ph}"] = self.config.loss.gan.weight * self.gan_loss(prediction_fake, True)
if self.fm_loss.weight > 0:
prediction_real = self.discriminators[ph](batch[ph])
loss[f"feature_match_{ph}"] = self.fm_loss(prediction_fake, prediction_real)
loss[f"edge_{ph}"] = self.edge_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph], gt_is_edge=False)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
@@ -97,5 +101,5 @@ class TAFGEngineKernel(EngineKernel):
def run(task, config, _):
kernel = TAFGEngineKernel(config)
kernel = CycleGANEngineKernel(config)
run_kernel(task, config, kernel)

86
engine/GauGAN.py Normal file
View File

@@ -0,0 +1,86 @@
from itertools import chain
import torch
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from engine.util.container import GANImageBuffer, LossContainer
from engine.util.loss import gan_loss, feature_match_loss, perceptual_loss
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
from model.weight_init import generation_init_weights
class GauGANEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
self.gan_loss = gan_loss(config.loss.gan)
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same"))
self.perceptual_loss = LossContainer(config.loss.perceptual.weight, perceptual_loss(config.loss.perceptual))
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
def build_models(self) -> (dict, dict):
generators = dict(
main=build_model(self.config.model.generator)
)
discriminators = dict(
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["b"])
self.logger.debug(generators["main"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
images = dict()
with torch.set_grad_enabled(not inference):
images["a2b"] = self.generators["main"](batch["a"])
return images
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
prediction_fake = self.discriminators["b"](generated["a2b"])
loss["gan"] = self.config.loss.gan.weight * self.gan_loss(prediction_fake, True)
loss["mgc"] = self.mgc_loss(generated["a2b"], batch["a"])
loss["perceptual"] = self.perceptual_loss(generated["a2b"], batch["a"])
if self.fm_loss.weight > 0:
prediction_real = self.discriminators["b"](batch["b"])
loss["feature_match"] = self.fm_loss(prediction_fake, prediction_real)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
generated_image = self.image_buffers["b"].query(generated["a2b"].detach())
loss["b"] = (self.gan_loss(self.discriminators["b"](generated_image), False, is_discriminator=True) +
self.gan_loss(self.discriminators["b"](batch["b"]), True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
return dict(
a=[batch["a"].detach(), generated["a2b"].detach()],
)
def run(task, config, _):
kernel = GauGANEngineKernel(config)
run_kernel(task, config, kernel)

154
engine/MUNIT.py Normal file
View File

@@ -0,0 +1,154 @@
import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
class MUNITEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.train_generator_first = False
def build_models(self) -> (dict, dict):
generators = dict(
a=build_model(self.config.model.generator),
b=build_model(self.config.model.generator)
)
discriminators = dict(
a=build_model(self.config.model.discriminator),
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["a"])
self.logger.debug(generators["a"])
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
styles = dict()
contents = dict()
images = dict()
with torch.set_grad_enabled(not inference):
for phase in "ab":
contents[phase], styles[phase] = self.generators[phase].encode(batch[phase])
images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase])
styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device())
for phase in ("a2b", "b2a"):
# images["a2b"] = Gb.decode(content_a, random_style_b)
images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"])
# contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"])
contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase])
if self.config.loss.recon.cycle.weight > 0:
images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]])
return dict(styles=styles, contents=contents, images=images)
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
for phase in "ab":
loss[f"recon_image_{phase}"] = self.config.loss.recon.image.weight * self.recon_loss(
batch[phase], generated["images"]["{0}2{0}".format(phase)])
loss[f"recon_content_{phase}"] = self.config.loss.recon.content.weight * self.recon_loss(
generated["contents"][phase], generated["contents"]["a2b" if phase == "a" else "b2a"])
loss[f"recon_style_{phase}"] = self.config.loss.recon.style.weight * self.recon_loss(
generated["styles"][f"random_{phase}"], generated["styles"]["b2a" if phase == "a" else "a2b"])
pred_fake = self.discriminators[phase](generated["images"]["b2a" if phase == "a" else "a2b"])
loss[f"gan_{phase}"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True)
if self.config.loss.recon.cycle.weight > 0:
loss[f"recon_cycle_{phase}"] = self.config.loss.recon.cycle.weight * self.recon_loss(
batch[phase], generated["images"]["a2b2a" if phase == "a" else "b2a2b"])
if self.config.loss.perceptual.weight > 0:
loss[f"perceptual_{phase}"] = self.config.loss.perceptual.weight * self.perceptual_loss(
batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
for phase in ("a2b", "b2a"):
pred_real = self.discriminators[phase[-1]](batch[phase[-1]])
pred_fake = self.discriminators[phase[-1]](generated["images"][phase].detach())
loss[f"gan_{phase[-1]}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase[-1]}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
generated = {img: generated["images"][img].detach() for img in generated["images"]}
images = dict()
for phase in "ab":
images[phase] = [batch[phase].detach(), generated["{0}2{0}".format(phase)],
generated["a2b" if phase == "a" else "b2a"]]
if self.config.loss.recon.cycle.weight > 0:
images[phase].append(generated["a2b2a" if phase == "a" else "b2a2b"])
return images
class MUNITTestEngineKernel(TestEngineKernel):
def __init__(self, config):
super().__init__(config)
def build_generators(self) -> dict:
generators = dict(
a=build_model(self.config.model.generator),
b=build_model(self.config.model.generator)
)
return generators
def to_load(self):
return {f"generator_{k}": self.generators[k] for k in self.generators}
def inference(self, batch):
with torch.no_grad():
fake, _, _ = self.generators["a2b"](batch[0])
return fake.detach()
def run(task, config, _):
if task == "train":
kernel = MUNITEngineKernel(config)
run_kernel(task, config, kernel)
elif task == "test":
kernel = MUNITTestEngineKernel(config)
run_kernel(task, config, kernel)
else:
raise NotImplemented

View File

@@ -1,20 +1,16 @@
from itertools import chain
from omegaconf import OmegaConf
import ignite.distributed as idist
import torch
import torch.nn as nn
import ignite.distributed as idist
from ignite.engine import Events
from omegaconf import read_write, OmegaConf
from model.weight_init import generation_init_weights
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from loss.I2I.edge_loss import EdgeLoss
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
class TAFGEngineKernel(EngineKernel):
@@ -24,13 +20,21 @@ class TAFGEngineKernel(EngineKernel):
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
style_loss_cfg = OmegaConf.to_container(config.loss.style)
style_loss_cfg.pop("weight")
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
self.content_recon_loss = nn.L1Loss() if config.loss.content_recon.level == 1 else nn.MSELoss()
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
idist.device())
def _process_batch(self, batch, inference=False):
# batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size())
@@ -63,52 +67,103 @@ class TAFGEngineKernel(EngineKernel):
def forward(self, batch, inference=False) -> dict:
generator = self.generators["main"]
batch = self._process_batch(batch, inference)
styles = dict()
contents = dict()
images = dict()
with torch.set_grad_enabled(not inference):
fake = dict(
a=generator(content_img=batch["edge_a"], style_img=batch["a"], which_decoder="a"),
b=generator(content_img=batch["edge_a"], style_img=batch["b"], which_decoder="b"),
)
return fake
contents["a"], styles["a"] = generator.encode(batch["a"]["edge"], batch["a"]["img"], "a", "a")
contents["b"], styles["b"] = generator.encode(batch["b"]["edge"], batch["b"]["img"], "b", "b")
for ph in "ab":
images[f"{ph}2{ph}"] = generator.decode(contents[ph], styles[ph], ph)
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
styles[f"random_b"] = torch.randn_like(styles["b"]).to(idist.device())
images["a2b"] = generator.decode(contents["a"], styles["random_b"], "b")
contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]),
images["a2b"], "b", "b")
images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b")
images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a")
return dict(styles=styles, contents=contents, images=images)
def criterion_generators(self, batch, generated) -> dict:
batch = self._process_batch(batch)
loss = dict()
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
loss["perceptual"] = loss_perceptual * self.config.loss.perceptual.weight
for phase in "ab":
pred_fake = self.discriminators[phase](generated[phase])
loss[f"gan_{phase}"] = 0
for ph in "ab":
loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss(
generated["images"][f"{ph}2{ph}"], batch[ph]["img"])
pred_fake = self.discriminators[ph](generated["images"][f"{ph}2{ph}"])
loss[f"gan_{ph}"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True)
loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
if self.engine.state.epoch == self.config.misc.add_new_loss_epoch:
self.generators["main"].style_converters.requires_grad_(False)
self.generators["main"].style_encoders.requires_grad_(False)
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
pred_fake = self.discriminators[ph](generated["images"]["a2b"])
loss["gan_a2b"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss["gan_a2b"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
generated["contents"]["a"], generated["contents"]["recon_a"]
)
loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
generated["styles"]["random_b"], generated["styles"]["recon_b"]
)
if self.config.loss.perceptual.weight > 0:
loss["perceptual_a"] = self.config.loss.perceptual.weight * self.perceptual_loss(
batch["a"]["img"], generated["images"]["a2b"]
)
if self.config.loss.cycle.weight > 0:
loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss(
batch["a"]["img"], generated["images"][f"cycle_a"]
)
# for ph in "ab":
#
# if self.config.loss.style.weight > 0:
# loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss(
# batch[ph]["img"], generated["images"][f"a2{ph}"]
# )
if self.config.loss.edge.weight > 0:
loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss(
generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :]
)
if self.config.loss.fm.weight > 0 and phase == "b":
pred_real = self.discriminators[phase](batch[phase])
loss_fm = 0
num_scale_discriminator = len(pred_fake)
for i in range(num_scale_discriminator):
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm
loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"])
# loss["style_recon"] = self.config.loss.style_recon.weight * self.style_recon_loss(
# self.generators["main"].module.style_encoders["b"](batch["b"]),
# self.generators["main"].module.style_encoders["b"](generated["b"])
# )
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
# batch = self._process_batch(batch)
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase])
pred_fake = self.discriminators[phase](generated[phase].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase]["img"])
pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach())
pred_fake_2 = self.discriminators[phase](generated["images"]["a2b"].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) +
self.gan_loss(pred_fake_2[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 3
else:
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase]["img"])
pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
@@ -119,17 +174,37 @@ class TAFGEngineKernel(EngineKernel):
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
batch = self._process_batch(batch)
edge = batch["edge_a"][:, 0:1, :, :]
return dict(
a=[edge.expand(-1, 3, -1, -1).detach(), batch["a"].detach(), batch["b"].detach(), generated["a"].detach(),
generated["b"].detach()]
)
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
return dict(
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
batch["a"]["img"].detach(),
generated["images"]["a2a"].detach(),
generated["images"]["a2b"].detach(),
generated["images"]["cycle_a"].detach(),
],
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
batch["b"]["img"].detach(),
generated["images"]["b2b"].detach(),
generated["images"]["cycle_b"].detach()]
)
else:
return dict(
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
batch["a"]["img"].detach(),
generated["images"]["a2a"].detach(),
],
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
batch["b"]["img"].detach(),
generated["images"]["b2b"].detach(),
]
)
def change_engine(self, config, trainer):
@trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3)))
def change_config(engine):
with read_write(config):
config.loss.perceptual.weight = 5
pass
# @trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3)))
# def change_config(engine):
# with read_write(config):
# config.loss.perceptual.weight = 5
def run(task, config, _):

119
engine/TSIT.py Normal file
View File

@@ -0,0 +1,119 @@
from itertools import chain
import ignite.distributed as idist
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
class TSITEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
def build_models(self) -> (dict, dict):
generators = dict(
main=build_model(self.config.model.generator)
)
discriminators = dict(
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["b"])
self.logger.debug(generators["main"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
with torch.set_grad_enabled(not inference):
fake = dict(
b=self.generators["main"](content_img=batch["a"])
)
return fake
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
loss["perceptual"] = self.perceptual_loss(generated["b"], batch["a"]) * self.config.loss.perceptual.weight
for phase in "b":
pred_fake = self.discriminators[phase](generated[phase])
loss[f"gan_{phase}"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss[f"gan_{phase}"] += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase])
pred_fake = self.discriminators[phase](generated[phase].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
return dict(
b=[batch["a"].detach(), batch["b"].detach(), generated["b"].detach()]
)
class TSITTestEngineKernel(TestEngineKernel):
def __init__(self, config):
super().__init__(config)
def build_generators(self) -> dict:
generators = dict(
main=build_model(self.config.model.generator)
)
return generators
def to_load(self):
return {f"generator_{k}": self.generators[k] for k in self.generators}
def inference(self, batch):
with torch.no_grad():
fake = self.generators["main"](content_img=batch["a"][0], style_img=batch["b"][0])
return {"a": fake.detach()}
def run(task, config, _):
if task == "train":
kernel = TSITEngineKernel(config)
run_kernel(task, config, kernel)
elif task == "test":
kernel = TSITTestEngineKernel(config)
run_kernel(task, config, kernel)
else:
raise NotImplemented

View File

@@ -1,38 +1,38 @@
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import torch.nn.functional as F
import ignite.distributed as idist
from loss.gan import GANLoss
from model.GAN.UGATIT import RhoClipper
from model.GAN.base import GANImageBuffer
from util.image import attention_colored_map
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
from engine.util.container import LossContainer
from engine.util.loss import bce_loss, mse_loss, pixel_loss, gan_loss
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
from util.image import attention_colored_map
def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
class RhoClipper(object):
def __init__(self, clip_min, clip_max):
self.clip_min = clip_min
self.clip_max = clip_max
assert clip_min < clip_max
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def __call__(self, module):
if hasattr(module, 'rho'):
w = module.rho.data
w = w.clamp(self.clip_min, self.clip_max)
module.rho.data = w
class UGATITEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
self.gan_loss = gan_loss(config.loss.gan)
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
self.bce_loss = LossContainer(self.config.loss.cam.weight, bce_loss)
self.mse_loss = LossContainer(self.config.loss.gan.weight, mse_loss)
self.rho_clipper = RhoClipper(0, 1)
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
self.train_generator_first = False
def build_models(self) -> (dict, dict):
@@ -79,17 +79,17 @@ class UGATITEngineKernel(EngineKernel):
loss = dict()
for phase in "ab":
cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"]
loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase])
loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase],
generated["images"][f"{phase}2{phase}"])
loss[f"cycle_{phase}"] = self.cycle_loss(cycle_image, batch[phase])
loss[f"id_{phase}"] = self.id_loss(batch[phase], generated["images"][f"{phase}2{phase}"])
loss[f"mgc_{phase}"] = self.mgc_loss(batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
for dk in "lg":
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)
loss[f"gan_{phase}_{dk}"] = self.config.loss.gan.weight * self.gan_loss(pred_fake, True)
loss[f"gan_cam_{phase}_{dk}"] = self.config.loss.gan.weight * mse_loss(cam_pred, True)
loss[f"gan_cam_{phase}_{dk}"] = self.mse_loss(cam_pred, True)
for t, f in [("a2b", "b2b"), ("b2a", "a2a")]:
loss[f"cam_{t[-1]}"] = self.config.loss.cam.weight * (
bce_loss(generated["cam_pred"][t], True) + bce_loss(generated["cam_pred"][f], False))
loss[f"cam_{t[-1]}"] = self.bce_loss(generated["cam_pred"][t], True) + \
self.bce_loss(generated["cam_pred"][f], False)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
@@ -135,8 +135,8 @@ class UGATITTestEngineKernel(TestEngineKernel):
def inference(self, batch):
with torch.no_grad():
fake, _, _ = self.generators["a2b"](batch["a"])
return {"a": fake.detach()}
fake, _, _ = self.generators["a2b"](batch[0])
return fake.detach()
def run(task, config, _):

View File

@@ -58,9 +58,13 @@ class EngineKernel(object):
self.logger = logging.getLogger(config.name)
self.generators, self.discriminators = self.build_models()
self.train_generator_first = True
self.engine = None
def bind_engine(self, engine):
self.engine = engine
def build_models(self) -> (dict, dict):
raise NotImplemented
raise NotImplementedError
def to_save(self):
to_save = {}
@@ -69,19 +73,19 @@ class EngineKernel(object):
return to_save
def setup_after_g(self):
raise NotImplemented
raise NotImplementedError
def setup_before_g(self):
raise NotImplemented
raise NotImplementedError
def forward(self, batch, inference=False) -> dict:
raise NotImplemented
raise NotImplementedError
def criterion_generators(self, batch, generated) -> dict:
raise NotImplemented
raise NotImplementedError
def criterion_discriminators(self, batch, generated) -> dict:
raise NotImplemented
raise NotImplementedError
def intermediate_images(self, batch, generated) -> dict:
"""
@@ -90,12 +94,22 @@ class EngineKernel(object):
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
raise NotImplemented
raise NotImplementedError
def change_engine(self, config, engine: Engine):
pass
def _remove_no_grad_loss(loss_dict):
need_to_pop = []
for k in loss_dict:
if not isinstance(loss_dict[k], torch.Tensor):
need_to_pop.append(k)
for k in need_to_pop:
loss_dict.pop(k)
return loss_dict
def get_trainer(config, kernel: EngineKernel):
logger = logging.getLogger(config.name)
generators, discriminators = kernel.generators, kernel.discriminators
@@ -132,18 +146,21 @@ def get_trainer(config, kernel: EngineKernel):
generated = kernel.forward(batch)
if kernel.train_generator_first:
# simultaneous, train G with simultaneous D
loss_g = train_generators(batch, generated)
loss_d = train_discriminators(batch, generated)
else:
# update discriminators first, not simultaneous.
# train G with updated discriminators
loss_d = train_discriminators(batch, generated)
loss_g = train_generators(batch, generated)
if engine.state.iteration % iteration_per_image == 0:
return {
"loss": dict(g=loss_g, d=loss_d),
"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d)),
"img": kernel.intermediate_images(batch, generated)
}
return {"loss": dict(g=loss_g, d=loss_d)}
return {"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d))}
trainer = Engine(_step)
trainer.logger = logger
@@ -151,9 +168,10 @@ def get_trainer(config, kernel: EngineKernel):
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
kernel.change_engine(config, trainer)
kernel.bind_engine(trainer)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values()), epoch_bound=False).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values()), epoch_bound=False).attach(trainer, "loss_d")
to_save = dict(trainer=trainer)
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
@@ -181,36 +199,51 @@ def get_trainer(config, kernel: EngineKernel):
for i in range(len(image_list)):
test_images[k].append([])
with torch.no_grad():
g = torch.Generator()
g.manual_seed(config.misc.random_seed)
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
for i in range(random_start, random_start + 10):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch:
batch[k] = batch[k].view(1, *batch[k].size())
generated = kernel.forward(batch)
images = kernel.intermediate_images(batch, generated)
g = torch.Generator()
g.manual_seed(config.misc.random_seed + engine.state.epoch
if config.handler.test.random else config.misc.random_seed)
random_start = \
torch.randperm(len(engine.state.test_dataset) - config.handler.test.images, generator=g).tolist()[0]
for i in range(random_start, random_start + config.handler.test.images):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].unsqueeze(0)
elif isinstance(batch[k], dict):
for kk in batch[k]:
if isinstance(batch[k][kk], torch.Tensor):
batch[k][kk] = batch[k][kk].unsqueeze(0)
generated = kernel.forward(batch, inference=True)
images = kernel.intermediate_images(batch, generated)
for k in test_images:
for j in range(len(images[k])):
test_images[k][j].append(images[k][j])
for k in test_images:
tensorboard_handler.writer.add_image(
f"test/{k}",
make_2d_grid([torch.cat(ti) for ti in test_images[k]], range=(-1, 1)),
engine.state.iteration * pairs_per_iteration
)
for j in range(len(images[k])):
test_images[k][j].append(images[k][j])
for k in test_images:
tensorboard_handler.writer.add_image(
f"test/{k}",
make_2d_grid([torch.cat(ti) for ti in test_images[k]], range=(-1, 1)),
engine.state.iteration * pairs_per_iteration
)
return trainer
def save_images_helper(output_dir, paths, images_list):
batch_size = len(paths)
for i in range(batch_size):
image_name = Path(paths[i]).name
img_list = [img[i] for img in images_list]
torchvision.utils.save_image(img_list, Path(output_dir) / image_name, nrow=len(img_list), padding=0,
normalize=True, range=(-1, 1))
def get_tester(config, kernel: TestEngineKernel):
logger = logging.getLogger(config.name)
def _step(engine, batch):
real_a, path = convert_tensor(batch, idist.device())
fake = kernel.inference({"a": real_a})["a"]
return {"path": path, "img": [real_a.detach(), fake.detach()]}
batch = convert_tensor(batch, idist.device())
return {"batch": batch, "generated": kernel.inference(batch)}
tester = Engine(_step)
tester.logger = logger
@@ -227,13 +260,14 @@ def get_tester(config, kernel: TestEngineKernel):
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
img_tensors = engine.state.output["img"]
paths = engine.state.output["path"]
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
image_name = Path(paths[i]).name
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
nrow=len(img_tensors), padding=0, normalize=True, range=(-1, 1))
if engine.state.dataloader.dataset.__class__.__name__ == "SingleFolderDataset":
images, paths = engine.state.output["batch"]
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"]])
else:
for k in engine.state.output['generated']:
images, paths = engine.state.output["batch"][k]
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"][k]])
return tester
@@ -264,7 +298,7 @@ def run_kernel(task, config, kernel):
print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
test_dataset = data.DATASET.build_with(config.data.test[config.data.test.which])
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, kernel)

153
engine/talking_anime.py Normal file
View File

@@ -0,0 +1,153 @@
from itertools import chain
import ignite.distributed as idist
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from loss.I2I.context_loss import ContextLoss
from loss.I2I.edge_loss import EdgeLoss
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
class TAEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
style_loss_cfg = OmegaConf.to_container(config.loss.style)
style_loss_cfg.pop("weight")
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
context_loss_cfg = OmegaConf.to_container(config.loss.context)
context_loss_cfg.pop("weight")
self.context_loss = ContextLoss(**context_loss_cfg).to(idist.device())
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
idist.device())
def build_models(self) -> (dict, dict):
generators = dict(
anime=build_model(self.config.model.anime_generator),
face=build_model(self.config.model.face_generator)
)
discriminators = dict(
anime=build_model(self.config.model.discriminator),
face=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["face"])
self.logger.debug(generators["face"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
with torch.set_grad_enabled(not inference):
target_pose_anime = self.generators["anime"](
torch.cat([batch["face_1"], torch.flip(batch["anime_img"], dims=[3])], dim=1))
target_pose_face = self.generators["face"](target_pose_anime.mean(dim=1, keepdim=True), batch["face_0"])
return dict(fake_anime=target_pose_anime, fake_face=target_pose_face)
def cal_gan_and_fm_loss(self, discriminator, generated_img, match_img=None):
pred_fake = discriminator(generated_img)
loss_gan = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss_gan += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
if match_img is None:
# do not cal feature match loss
return loss_gan, 0
pred_real = discriminator(match_img)
loss_fm = 0
num_scale_discriminator = len(pred_fake)
for i in range(num_scale_discriminator):
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss_fm = self.config.loss.fm.weight * loss_fm
return loss_gan, loss_fm
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
loss["face_style"] = self.config.loss.style.weight * self.style_loss(
generated["fake_face"], batch["face_1"]
)
loss["face_recon"] = self.config.loss.recon.weight * self.recon_loss(
generated["fake_face"], batch["face_1"]
)
loss["face_gan"], loss["face_fm"] = self.cal_gan_and_fm_loss(
self.discriminators["face"], generated["fake_face"], batch["face_1"])
loss["anime_gan"], loss["anime_fm"] = self.cal_gan_and_fm_loss(
self.discriminators["anime"], generated["fake_anime"], batch["anime_img"])
loss["anime_edge"] = self.config.loss.edge.weight * self.edge_loss(
generated["fake_anime"], batch["face_1"], gt_is_edge=False,
)
if self.config.loss.perceptual.weight > 0:
loss["anime_perceptual"] = self.config.loss.perceptual.weight * self.perceptual_loss(
generated["fake_anime"], batch["anime_img"]
)
if self.config.loss.context.weight > 0:
loss["anime_context"] = self.config.loss.context.weight * self.context_loss(
generated["fake_anime"], batch["anime_img"],
)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
real = {"anime": "anime_img", "face": "face_1"}
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[real[phase]])
pred_fake = self.discriminators[phase](generated[f"fake_{phase}"].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
images = [batch["face_0"], batch["face_1"], batch["anime_img"], generated["fake_anime"].detach(),
generated["fake_face"].detach()]
return dict(
b=[img for img in images]
)
def run(task, config, _):
kernel = TAEngineKernel(config)
run_kernel(task, config, kernel)

View File

@@ -1,18 +1,28 @@
import torch
import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.optim as optim
from omegaconf import OmegaConf
from model import MODEL
import torch.optim as optim
def add_spectral_norm(module):
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
return nn.utils.spectral_norm(module)
else:
return module
def build_model(cfg):
cfg = OmegaConf.to_container(cfg)
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
add_spectral_norm_flag = cfg.pop("_add_spectral_norm", False)
model = MODEL.build_with(cfg)
if bn_to_sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if add_spectral_norm_flag:
model.apply(add_spectral_norm)
return idist.auto_model(model)

66
engine/util/container.py Normal file
View File

@@ -0,0 +1,66 @@
import torch
class LossContainer:
def __init__(self, weight, loss):
self.weight = weight
self.loss = loss
def __call__(self, *args, **kwargs):
if self.weight > 0:
return self.weight * self.loss(*args, **kwargs)
return 0.0
class GANImageBuffer:
"""This class implements an image buffer that stores previously
generated images.
This buffer allows us to update the discriminator using a history of
generated images rather than the ones produced by the latest generator
to reduce model oscillation.
Args:
buffer_size (int): The size of image buffer. If buffer_size = 0,
no buffer will be created.
buffer_ratio (float): The chance / possibility to use the images
previously stored in the buffer.
"""
def __init__(self, buffer_size, buffer_ratio=0.5):
self.buffer_size = buffer_size
# create an empty buffer
if self.buffer_size > 0:
self.img_num = 0
self.image_buffer = []
self.buffer_ratio = buffer_ratio
def query(self, images):
"""Query current image batch using a history of generated images.
Args:
images (Tensor): Current image batch without history information.
"""
if self.buffer_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
# if the buffer is not full, keep inserting current images
if self.img_num < self.buffer_size:
self.img_num = self.img_num + 1
self.image_buffer.append(image)
return_images.append(image)
else:
use_buffer = torch.rand(1) < self.buffer_ratio
# by self.buffer_ratio, the buffer will return a previously
# stored image, and insert the current image into the buffer
if use_buffer:
random_id = torch.randint(0, self.buffer_size, (1,)).item()
image_tmp = self.image_buffer[random_id].clone()
self.image_buffer[random_id] = image
return_images.append(image_tmp)
# by (1 - self.buffer_ratio), the buffer will return the
# current image
else:
return_images.append(image)
# collect all the images and return
return_images = torch.cat(return_images, 0)
return return_images

View File

@@ -85,7 +85,7 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
trainer.logger.info(f"load state_dict for {ckp.keys()}")
trainer.logger.info(f"load state_dict for {to_save.keys()}")
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
trainer.add_event_handler(

48
engine/util/loss.py Normal file
View File

@@ -0,0 +1,48 @@
import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
def gan_loss(config):
gan_loss_cfg = OmegaConf.to_container(config)
gan_loss_cfg.pop("weight")
return GANLoss(**gan_loss_cfg).to(idist.device())
def perceptual_loss(config):
perceptual_loss_cfg = OmegaConf.to_container(config)
perceptual_loss_cfg.pop("weight")
return PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
def pixel_loss(level):
return nn.L1Loss() if level == 1 else nn.MSELoss()
def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def feature_match_loss(level, weight_policy):
compare_loss = pixel_loss(level)
assert weight_policy in ["same", "exponential_decline"]
def fm_loss(generated_features, target_features):
num_scale = len(generated_features)
loss = torch.zeros(1, device=idist.device())
for s_i in range(num_scale):
for i in range(len(generated_features[s_i]) - 1):
weight = 1 if weight_policy == "same" else 2 ** i
loss += weight * compare_loss(generated_features[s_i][i], target_features[s_i][i].detach()) / num_scale
return loss
return fm_loss

View File

@@ -6,7 +6,6 @@ channels:
dependencies:
- python=3.8
- numpy
- ipython
- tqdm
- pyyaml
- pytorch=1.6.*

44
loss/I2I/context_loss.py Normal file
View File

@@ -0,0 +1,44 @@
import torch
import torch.nn.functional as F
from torch import nn
from .perceptual_loss import PerceptualVGG
class ContextLoss(nn.Module):
def __init__(self, layer_weights, h=0.1, vgg_type='vgg19', norm_image_with_imagenet_param=True, norm_img=True,
eps=1e-5):
super(ContextLoss, self).__init__()
self.eps = eps
self.h = h
self.layer_weights = layer_weights
self.norm_img = norm_img
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
norm_image_with_imagenet_param=norm_image_with_imagenet_param)
def single_forward(self, source_feature, target_feature):
mean_target_feature = target_feature.mean(dim=[2, 3], keepdim=True)
source_feature = (source_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW
target_feature = (target_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW
source_feature = F.normalize(source_feature, p=2, dim=1)
target_feature = F.normalize(target_feature, p=2, dim=1)
cosine_distance = (1 - torch.bmm(source_feature.transpose(1, 2), target_feature)) / 2 # NxHWxHW
rel_distance = cosine_distance / (cosine_distance.min(2, keepdim=True)[0] + self.eps)
w = torch.exp((1 - rel_distance) / self.h)
cx = w.div(w.sum(dim=2, keepdim=True))
cx = cx.max(dim=1, keepdim=True)[0].mean(dim=2)
return -torch.log(cx).mean()
def forward(self, x, gt):
if self.norm_img:
x = (x + 1.) * 0.5
gt = (gt + 1.) * 0.5
# extract vgg features
x_features = self.vgg(x)
gt_features = self.vgg(gt.detach())
loss = 0
for k in x_features.keys():
loss += self.single_forward(x_features[k], gt_features[k]) * self.layer_weights[k]
return loss

View File

@@ -0,0 +1,229 @@
import ignite.distributed as idist
import torch
import torch.nn as nn
def gaussian_radial_basis_function(x, mu, sigma):
# (kernel_size) -> (batch_size, kernel_size, c*h*w)
mu = mu.view(1, mu.size(0), 1).expand(x.size(0), -1, x.size(1) * x.size(2) * x.size(3))
# (batch_size, c, h, w) -> (batch_size, kernel_size, c*h*w)
x = x.view(x.size(0), 1, -1).expand(-1, mu.size(1), -1)
return torch.exp((x - mu).pow(2) / (2 * sigma ** 2))
class ImporveMyLoss(torch.nn.Module):
def __init__(self, device=idist.device()):
super().__init__()
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).to(device)
self.x_mu_list = mu.repeat(9).view(-1, 81)
self.y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)
self.R = torch.eye(81).to(device)
def batch_ERSMI(self, I1, I2):
batch_size = I1.shape[0]
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
if I2.shape[1] == 1 and I1.shape[1] != 1:
I2 = I2.repeat(1, 3, 1, 1)
def kernel_F(y, mu_list, sigma):
tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size, 1, 1) # [81, 784]
tmp_y = y.view(batch_size, 1, -1).repeat(1, 81, 1)
tmp_y = tmp_mu - tmp_y
mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
return mat_L
mat_K = kernel_F(I1, self.x_mu_list, 1)
mat_L = kernel_F(I2, self.y_mu_list, 1)
mat_k_l = mat_K * mat_L
H1 = (mat_K @ mat_K.transpose(1, 2)) * (mat_L @ mat_L.transpose(1, 2)) / (img_size ** 2)
h_hat = mat_k_l @ mat_k_l.transpose(1, 2) / img_size
small_h_hat = mat_K.sum(2).view(batch_size, -1, 1) * mat_L.sum(2).view(batch_size, -1, 1) / (img_size ** 2)
h_hat = 0.5 * H1 + 0.5 * h_hat
alpha = (h_hat + 0.05 * self.R).inverse() @ small_h_hat
ersmi = 2 * alpha.transpose(1, 2) @ small_h_hat - alpha.transpose(1, 2) @ h_hat @ alpha - 1
ersmi = -ersmi.squeeze().mean()
return ersmi
def forward(self, fakeI, realI):
return self.batch_ERSMI(fakeI, realI)
class MyLoss(torch.nn.Module):
def __init__(self):
super(MyLoss, self).__init__()
def forward(self, fakeI, realI):
fakeI = fakeI.cuda()
realI = realI.cuda()
def batch_ERSMI(I1, I2):
batch_size = I1.shape[0]
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
if I2.shape[1] == 1 and I1.shape[1] != 1:
I2 = I2.repeat(1, 3, 1, 1)
def kernel_F(y, mu_list, sigma):
tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size, 1, 1).cuda() # [81, 784]
tmp_y = y.view(batch_size, 1, -1).repeat(1, 81, 1)
tmp_y = tmp_mu - tmp_y
mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
return mat_L
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).cuda()
x_mu_list = mu.repeat(9).view(-1, 81)
y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)
mat_K = kernel_F(I1, x_mu_list, 1)
mat_L = kernel_F(I2, y_mu_list, 1)
H1 = ((mat_K.matmul(mat_K.transpose(1, 2))).mul(mat_L.matmul(mat_L.transpose(1, 2))) / (
img_size ** 2)).cuda()
H2 = ((mat_K.mul(mat_L)).matmul((mat_K.mul(mat_L)).transpose(1, 2)) / img_size).cuda()
h2 = ((mat_K.sum(2).view(batch_size, -1, 1)).mul(mat_L.sum(2).view(batch_size, -1, 1)) / (
img_size ** 2)).cuda()
H2 = 0.5 * H1 + 0.5 * H2
tmp = H2 + 0.05 * torch.eye(len(H2[0])).cuda()
alpha = (tmp.inverse())
alpha = alpha.matmul(h2)
ersmi = (2 * (alpha.transpose(1, 2)).matmul(h2) - ((alpha.transpose(1, 2)).matmul(H2)).matmul(
alpha) - 1).squeeze()
ersmi = -ersmi.mean()
return ersmi
batch_loss = batch_ERSMI(fakeI, realI)
return batch_loss
class MGCLoss(nn.Module):
"""
Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ
"""
def __init__(self, mi_to_loss_way="opposite", beta=0.5, lambda_=0.05, device=idist.device()):
super().__init__()
self.beta = beta
self.lambda_ = lambda_
assert mi_to_loss_way in ["opposite", "reciprocal"]
self.mi_to_loss_way = mi_to_loss_way
mu_y, mu_x = torch.meshgrid([torch.arange(-1, 1.25, 0.25), torch.arange(-1, 1.25, 0.25)])
self.mu_x = mu_x.flatten().to(device)
self.mu_y = mu_y.flatten().to(device)
self.R = torch.eye(81).unsqueeze(0).to(device)
@staticmethod
def batch_rSMI(img1, img2, mu_x, mu_y, beta, lambda_, R):
assert img1.size() == img2.size()
num_pixel = img1.size(1) * img1.size(2) * img2.size(3)
mat_k = gaussian_radial_basis_function(img1, mu_x, sigma=1)
mat_l = gaussian_radial_basis_function(img2, mu_y, sigma=1)
mat_k_mul_mat_l = mat_k * mat_l
h_hat = (1 - beta) * (mat_k_mul_mat_l @ mat_k_mul_mat_l.transpose(1, 2)) / num_pixel
h_hat += beta * ((mat_k @ mat_k.transpose(1, 2)) * (mat_l @ mat_l.transpose(1, 2))) / (num_pixel ** 2)
small_h_hat = mat_k.sum(2, keepdim=True) * mat_l.sum(2, keepdim=True) / (num_pixel ** 2)
alpha = (h_hat + lambda_ * R).inverse() @ small_h_hat
rSMI = 2 * alpha.transpose(1, 2) @ small_h_hat - alpha.transpose(1, 2) @ h_hat @ alpha - 1
return rSMI.squeeze()
def forward(self, fake, real):
rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_, self.R)
if self.mi_to_loss_way == "reciprocal":
return 1/rSMI.mean()
return -rSMI.mean()
if __name__ == '__main__':
mg = MGCLoss(device=torch.device("cpu"))
my = MyLoss().to("cuda")
imy = ImporveMyLoss()
from data.transform import transform_pipeline
pipeline = transform_pipeline(
['Load', 'ToTensor', {'Normalize': {'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5]}}])
img_a1 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_1.jpg")
img_a2 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_2.jpg")
img_a3 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_3.jpg")
img_b1 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_1.jpg")
img_b2 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_2.jpg")
img_b3 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_3.jpg")
img_a1.requires_grad_(True)
img_a2.requires_grad_(True)
img_a3.requires_grad_(True)
# print("MyLoss")
# l1 = my(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
# l2 = my(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
# l3 = my(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
# l = (l1+l2+l3)/3
# l.backward()
# print(img_a1.grad[0][0][0:10])
# print(img_a2.grad[0][0][0:10])
# print(img_a3.grad[0][0][0:10])
#
# img_a1.grad = None
# img_a2.grad = None
# img_a3.grad = None
#
# print("---")
# l = my(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# l.backward()
# print(img_a1.grad[0][0][0:10])
# print(img_a2.grad[0][0][0:10])
# print(img_a3.grad[0][0][0:10])
# img_a1.grad = None
# img_a2.grad = None
# img_a3.grad = None
print("MGCLoss")
l1 = mg(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
l2 = mg(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
l3 = mg(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
l = (l1 + l2 + l3) / 3
l.backward()
print(img_a1.grad[0][0][0:10])
print(img_a2.grad[0][0][0:10])
print(img_a3.grad[0][0][0:10])
img_a1.grad = None
img_a2.grad = None
img_a3.grad = None
print("---")
l = mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
l.backward()
print(img_a1.grad[0][0][0:10])
print(img_a2.grad[0][0][0:10])
print(img_a3.grad[0][0][0:10])
# print("\nMGCLoss")
# mg(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
# mg(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
# mg(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
#
# print("---")
# mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
#
# import pprofile
#
# profiler = pprofile.Profile()
# with profiler:
# iter_times = 1000
# for _ in range(iter_times):
# mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# for _ in range(iter_times):
# my(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# for _ in range(iter_times):
# imy(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# profiler.print_stats()

View File

@@ -4,6 +4,49 @@ import torch.nn.functional as F
import torchvision.models.vgg as vgg
# Sequential(
# (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (1): ReLU(inplace=True)
# (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (3): ReLU(inplace=True)
# (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (6): ReLU(inplace=True)
# (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (8): ReLU(inplace=True)
# (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (11): ReLU(inplace=True)
# (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (13): ReLU(inplace=True)
# (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (15): ReLU(inplace=True)
# (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (17): ReLU(inplace=True)
# (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (20): ReLU(inplace=True)
# (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (22): ReLU(inplace=True)
# (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (24): ReLU(inplace=True)
# (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (26): ReLU(inplace=True)
# (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (29): ReLU(inplace=True)
# (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (31): ReLU(inplace=True)
# (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (33): ReLU(inplace=True)
# (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (35): ReLU(inplace=True)
# (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# )
class PerceptualVGG(nn.Module):
"""VGG network used in calculating perceptual loss.
In this implementation, we allow users to choose whether use normalization
@@ -15,15 +58,15 @@ class PerceptualVGG(nn.Module):
list contains the name each layer in `vgg.feature`. An example
of this list is ['4', '10'].
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image.
norm_image_with_imagenet_param (bool): If True, normalize the input image.
Importantly, the input feature must in the range [0, 1].
Default: True.
"""
def __init__(self, layer_name_list, vgg_type='vgg19', use_input_norm=True):
def __init__(self, layer_name_list, vgg_type='vgg19', norm_image_with_imagenet_param=True):
super(PerceptualVGG, self).__init__()
self.layer_name_list = layer_name_list
self.use_input_norm = use_input_norm
self.use_input_norm = norm_image_with_imagenet_param
# get vgg model and load pretrained vgg weight
# remove _vgg from attributes to avoid `find_unused_parameters` bug
@@ -75,7 +118,7 @@ class PerceptualLoss(nn.Module):
in calculating losses.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
norm_image_with_imagenet_param (bool): If True, normalize the input image in vgg.
Default: True.
perceptual_loss (bool): If `perceptual_loss == True`, the perceptual
loss will be calculated.
@@ -88,15 +131,16 @@ class PerceptualLoss(nn.Module):
Importantly, the input image must be in range [-1, 1].
"""
def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, perceptual_loss=True,
def __init__(self, layer_weights, vgg_type='vgg19', norm_image_with_imagenet_param=True, perceptual_loss=True,
style_loss=False, norm_img=True, criterion='L1'):
super(PerceptualLoss, self).__init__()
self.norm_img = norm_img
assert perceptual_loss ^ style_loss, "There must be one and only one true in style or perceptual"
self.perceptual_loss = perceptual_loss
self.style_loss = style_loss
self.layer_weights = layer_weights
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
use_input_norm=use_input_norm)
norm_image_with_imagenet_param=norm_image_with_imagenet_param)
self.percep_criterion, self.style_criterion = self.set_criterion(criterion)
@@ -127,8 +171,7 @@ class PerceptualLoss(nn.Module):
percep_loss = 0
for k in x_features.keys():
percep_loss += self.percep_criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
else:
percep_loss = None
return percep_loss
# calculate style loss
if self.style_loss:
@@ -136,10 +179,7 @@ class PerceptualLoss(nn.Module):
for k in x_features.keys():
style_loss += self.style_criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \
self.layer_weights[k]
else:
style_loss = None
return percep_loss, style_loss
return style_loss
def _gram_mat(self, x):
"""Calculate Gram matrix.

View File

@@ -1,4 +1,5 @@
import torch.nn as nn
import torch
import torch.nn.functional as F
@@ -10,7 +11,7 @@ class GANLoss(nn.Module):
self.fake_label_val = fake_label_val
self.loss_type = loss_type
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
def single_forward(self, prediction, target_is_real: bool, is_discriminator=False):
"""
gan loss forward
:param prediction: network prediction
@@ -37,3 +38,20 @@ class GANLoss(nn.Module):
return loss
else:
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
if isinstance(prediction, torch.Tensor):
# origin
return self.single_forward(prediction, target_is_real, is_discriminator)
elif isinstance(prediction, list):
# for multi scale discriminator, e.g. MultiScaleDiscriminator
loss = 0
for p in prediction:
loss += self.single_forward(p[-1], target_is_real, is_discriminator)
return loss
elif isinstance(prediction, tuple):
# for single discriminator set `need_intermediate_feature` true
return self.single_forward(prediction[-1], target_is_real, is_discriminator)
else:
raise NotImplementedError(f"not support discriminator output: {prediction}")

20
main.py
View File

@@ -1,16 +1,15 @@
from pathlib import Path
from importlib import import_module
import torch
import ignite
import ignite.distributed as idist
from ignite.utils import manual_seed
from util.misc import setup_logger
from pathlib import Path
import fire
import ignite
import ignite.distributed as idist
import torch
from ignite.utils import manual_seed
from omegaconf import OmegaConf
from util.misc import setup_logger
def log_basic_info(logger, config):
logger.info(f"Train {config.name}")
@@ -28,13 +27,12 @@ def log_basic_info(logger, config):
def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False):
if setup_random_seed:
manual_seed(config.misc.random_seed + idist.get_rank())
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else Path(config.output_dir)
config.output_dir = str(output_dir)
if setup_output_dir and config.resume_from is None:
if output_dir.exists():
assert len(list(output_dir.glob("events*"))) == 0
assert len(list(output_dir.glob("*.pt"))) == 0
assert len(list(output_dir.glob("events*"))) == 0, f"{output_dir} containers tensorboard event"
if (output_dir / "train.log").exists() and idist.get_rank() == 0:
(output_dir / "train.log").unlink()
else:

View File

@@ -1,62 +0,0 @@
import torch.nn as nn
from model.normalization import select_norm_layer
from model.registry import MODEL
from .base import ResidualBlock
@MODEL.register_module("CyCle-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect',
norm_type="IN"):
super(Generator, self).__init__()
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
self.start_conv = nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
bias=use_bias),
norm_layer(num_features=base_channels),
nn.ReLU(inplace=True)
)
# down sampling
submodules = []
num_down_sampling = 2
for i in range(num_down_sampling):
multiple = 2 ** i
submodules += [
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(num_features=base_channels * multiple * 2),
nn.ReLU(inplace=True)
]
self.encoder = nn.Sequential(*submodules)
res_block_channels = num_down_sampling ** 2 * base_channels
self.resnet_middle = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in
range(num_blocks)])
# up sampling
submodules = []
for i in range(num_down_sampling):
multiple = 2 ** (num_down_sampling - i)
submodules += [
nn.ConvTranspose2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=3, stride=2,
padding=1, output_padding=1, bias=use_bias),
norm_layer(num_features=base_channels * multiple // 2),
nn.ReLU(inplace=True),
]
self.decoder = nn.Sequential(*submodules)
self.end_conv = nn.Sequential(
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
nn.Tanh()
)
def forward(self, x):
x = self.encoder(self.start_conv(x))
x = self.resnet_middle(x)
return self.end_conv(self.decoder(x))

View File

@@ -1,241 +0,0 @@
import torch
import torch.nn as nn
from .base import ResidualBlock
from model.registry import MODEL
from torchvision.models import vgg19
from model.normalization import select_norm_layer
class VGG19StyleEncoder(nn.Module):
def __init__(self, in_channels, base_channels=64, style_dim=512, padding_mode='reflect', norm_type="NONE",
vgg19_layers=(0, 5, 10, 19)):
super().__init__()
self.vgg19_layers = vgg19_layers
self.vgg19 = vgg19(pretrained=True).features[:vgg19_layers[-1] + 1]
self.vgg19.requires_grad_(False)
norm_layer = select_norm_layer(norm_type)
self.conv0 = nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
bias=True),
norm_layer(base_channels),
nn.ReLU(True),
)
self.conv = nn.ModuleList([
nn.Sequential(
nn.Conv2d(base_channels * (2 ** i), base_channels * (2 ** i), kernel_size=4, stride=2, padding=1,
padding_mode=padding_mode, bias=True),
norm_layer(base_channels),
nn.ReLU(True),
) for i in range(1, 4)
])
self.pool = nn.AdaptiveAvgPool2d(1)
self.conv1x1 = nn.Conv2d(base_channels * (2 ** 4), style_dim, kernel_size=1, stride=1, padding=0)
def fixed_style_features(self, x):
features = []
for i in range(len(self.vgg19)):
x = self.vgg19[i](x)
if i in self.vgg19_layers:
features.append(x)
return features
def forward(self, x):
fsf = self.fixed_style_features(x)
x = self.conv0(x)
for i, l in enumerate(self.conv):
x = l(torch.cat([x, fsf[i]], dim=1))
x = self.pool(torch.cat([x, fsf[-1]], dim=1))
x = self.conv1x1(x)
return x.view(x.size(0), -1)
class ContentEncoder(nn.Module):
def __init__(self, in_channels, base_channels=64, num_blocks=8, padding_mode='reflect', norm_type="IN"):
super().__init__()
norm_layer = select_norm_layer(norm_type)
self.start_conv = nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
bias=True),
norm_layer(num_features=base_channels),
nn.ReLU(inplace=True)
)
# down sampling
submodules = []
num_down_sampling = 2
for i in range(num_down_sampling):
multiple = 2 ** i
submodules += [
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
kernel_size=4, stride=2, padding=1, bias=True),
norm_layer(num_features=base_channels * multiple * 2),
nn.ReLU(inplace=True)
]
self.encoder = nn.Sequential(*submodules)
res_block_channels = num_down_sampling ** 2 * base_channels
self.resnet = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
def forward(self, x):
x = self.start_conv(x)
x = self.encoder(x)
x = self.resnet(x)
return x
class Decoder(nn.Module):
def __init__(self, out_channels, base_channels=64, num_blocks=4, num_down_sampling=2, padding_mode='reflect',
norm_type="LN"):
super(Decoder, self).__init__()
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
res_block_channels = (2 ** 2) * base_channels
self.resnet = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
# up sampling
submodules = []
for i in range(num_down_sampling):
multiple = 2 ** (num_down_sampling - i)
submodules += [
nn.Upsample(scale_factor=2),
nn.Conv2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=5, stride=1,
padding=2, padding_mode=padding_mode, bias=use_bias),
norm_layer(num_features=base_channels * multiple // 2),
nn.ReLU(inplace=True),
]
self.decoder = nn.Sequential(*submodules)
self.end_conv = nn.Sequential(
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
nn.Tanh()
)
def forward(self, x):
x = self.resnet(x)
x = self.decoder(x)
x = self.end_conv(x)
return x
class Fusion(nn.Module):
def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"):
super().__init__()
norm_layer = select_norm_layer(norm_type)
self.start_fc = nn.Sequential(
nn.Linear(in_features, base_features),
norm_layer(base_features),
nn.ReLU(True),
)
self.fcs = nn.Sequential(*[
nn.Sequential(
nn.Linear(base_features, base_features),
norm_layer(base_features),
nn.ReLU(True),
) for _ in range(n_blocks - 2)
])
self.end_fc = nn.Sequential(
nn.Linear(base_features, out_features),
)
def forward(self, x):
x = self.start_fc(x)
x = self.fcs(x)
return self.end_fc(x)
class StyleGenerator(nn.Module):
def __init__(self, style_in_channels, style_dim=512, num_blocks=8, base_channels=64, padding_mode="reflect"):
super().__init__()
self.num_blocks = num_blocks
self.style_encoder = VGG19StyleEncoder(
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE")
self.fc = nn.Sequential(
nn.Linear(style_dim, style_dim),
nn.ReLU(True),
)
res_block_channels = 2 ** 2 * base_channels
self.fusion = Fusion(style_dim, num_blocks * 2 * res_block_channels * 2, base_features=256, n_blocks=3,
norm_type="NONE")
def forward(self, x):
styles = self.fusion(self.fc(self.style_encoder(x)))
return styles
@MODEL.register_module("TAFG-Generator")
class Generator(nn.Module):
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, num_blocks=8,
base_channels=64, padding_mode="reflect"):
super(Generator, self).__init__()
self.num_blocks = num_blocks
self.style_encoders = nn.ModuleDict({
"a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks,
base_channels=base_channels, padding_mode=padding_mode),
"b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_blocks,
base_channels=base_channels, padding_mode=padding_mode),
})
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks,
padding_mode=padding_mode, norm_type="IN")
res_block_channels = 2 ** 2 * base_channels
self.adain_resnet_a = nn.ModuleList([
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
])
self.adain_resnet_b = nn.ModuleList([
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
])
self.decoders = nn.ModuleDict({
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode),
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode)
})
def forward(self, content_img, style_img, which_decoder: str = "a"):
x = self.content_encoder(content_img)
styles = self.style_encoders[which_decoder](style_img)
styles = torch.chunk(styles, self.num_blocks * 2, dim=1)
resnet = self.adain_resnet_a if which_decoder == "a" else self.adain_resnet_b
for i, ar in enumerate(resnet):
ar.norm1.set_style(styles[2 * i])
ar.norm2.set_style(styles[2 * i + 1])
x = ar(x)
return self.decoders[which_decoder](x)
@MODEL.register_module("TAFG-Discriminator")
class Discriminator(nn.Module):
def __init__(self, in_channels=3, base_channels=64, num_down_sampling=2, num_blocks=3, norm_type="IN",
padding_mode="reflect"):
super(Discriminator, self).__init__()
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
sequence = [nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
bias=use_bias),
norm_layer(num_features=base_channels),
nn.ReLU(inplace=True)
)]
# stacked intermediate layers,
# gradually increasing the number of filters
multiple_now = 1
for n in range(1, num_down_sampling + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** n, 4)
sequence += [
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=3,
padding=1, stride=2, bias=use_bias),
norm_layer(base_channels * multiple_now),
nn.LeakyReLU(0.2, inplace=True)
]
for _ in range(num_blocks):
sequence.append(ResidualBlock(base_channels * multiple_now, padding_mode, norm_type))
self.model = nn.Sequential(*sequence)
def forward(self, x):
return self.model(x)

View File

@@ -1,236 +0,0 @@
import torch
import torch.nn as nn
from .base import ResidualBlock
from model.registry import MODEL
class RhoClipper(object):
def __init__(self, clip_min, clip_max):
self.clip_min = clip_min
self.clip_max = clip_max
assert clip_min < clip_max
def __call__(self, module):
if hasattr(module, 'rho'):
w = module.rho.data
w = w.clamp(self.clip_min, self.clip_max)
module.rho.data = w
@MODEL.register_module("UGATIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False):
assert (num_blocks >= 0)
super(Generator, self).__init__()
self.input_channels = in_channels
self.output_channels = out_channels
self.base_channels = base_channels
self.num_blocks = num_blocks
self.img_size = img_size
self.light = light
down_encoder = [nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3,
padding_mode="reflect", bias=False),
nn.InstanceNorm2d(base_channels),
nn.ReLU(True)]
n_down_sampling = 2
for i in range(n_down_sampling):
mult = 2 ** i
down_encoder += [nn.Conv2d(base_channels * mult, base_channels * mult * 2, kernel_size=3, stride=2,
padding=1, bias=False, padding_mode="reflect"),
nn.InstanceNorm2d(base_channels * mult * 2),
nn.ReLU(True)]
# Down-Sampling Bottleneck
mult = 2 ** n_down_sampling
for i in range(num_blocks):
down_encoder += [ResidualBlock(base_channels * mult, use_bias=False)]
self.down_encoder = nn.Sequential(*down_encoder)
# Class Activation Map
self.gap_fc = nn.Linear(base_channels * mult, 1, bias=False)
self.gmp_fc = nn.Linear(base_channels * mult, 1, bias=False)
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
self.relu = nn.ReLU(True)
# Gamma, Beta block
if self.light:
fc = [nn.Linear(base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True),
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True)]
else:
fc = [
nn.Linear(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True),
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True)]
self.fc = nn.Sequential(*fc)
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
# Up-Sampling Bottleneck
self.up_bottleneck = nn.ModuleList(
[ResnetAdaILNBlock(base_channels * mult, use_bias=False) for _ in range(num_blocks)])
# Up-Sampling
up_decoder = []
for i in range(n_down_sampling):
mult = 2 ** (n_down_sampling - i)
up_decoder += [nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(base_channels * mult, base_channels * mult // 2, kernel_size=3, stride=1,
padding=1, padding_mode="reflect", bias=False),
ILN(base_channels * mult // 2),
nn.ReLU(True)]
up_decoder += [nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3,
padding_mode="reflect", bias=False),
nn.Tanh()]
self.up_decoder = nn.Sequential(*up_decoder)
def forward(self, x):
x = self.down_encoder(x)
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
x = torch.cat([gap, gmp], 1)
x = self.relu(self.conv1x1(x))
heatmap = torch.sum(x, dim=1, keepdim=True)
if self.light:
x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1)
x_ = self.fc(x_.view(x_.shape[0], -1))
else:
x_ = self.fc(x.view(x.shape[0], -1))
gamma, beta = self.gamma(x_), self.beta(x_)
for ub in self.up_bottleneck:
x = ub(x, gamma, beta)
x = self.up_decoder(x)
return x, cam_logit, heatmap
class ResnetAdaILNBlock(nn.Module):
def __init__(self, dim, use_bias):
super(ResnetAdaILNBlock, self).__init__()
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
self.norm1 = AdaILN(dim)
self.relu1 = nn.ReLU(True)
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
self.norm2 = AdaILN(dim)
def forward(self, x, gamma, beta):
out = self.conv1(x)
out = self.norm1(out, gamma, beta)
out = self.relu1(out)
out = self.conv2(out)
out = self.norm2(out, gamma, beta)
return out + x
def instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True)
out_in = (x - in_mean) / torch.sqrt(in_var + eps)
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
out_ln = (x - ln_mean) / torch.sqrt(ln_var + eps)
out = rho.expand(x.shape[0], -1, -1, -1) * out_in + (1 - rho.expand(x.shape[0], -1, -1, -1)) * out_ln
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
return out
class AdaILN(nn.Module):
def __init__(self, num_features, eps=1e-5, default_rho=0.9):
super(AdaILN, self).__init__()
self.eps = eps
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.rho.data.fill_(default_rho)
def forward(self, x, gamma, beta):
return instance_layer_normalization(x, gamma, beta, self.rho, self.eps)
class ILN(nn.Module):
def __init__(self, num_features, eps=1e-5):
super(ILN, self).__init__()
self.eps = eps
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.gamma = nn.Parameter(torch.Tensor(1, num_features))
self.beta = nn.Parameter(torch.Tensor(1, num_features))
self.rho.data.fill_(0.0)
self.gamma.data.fill_(1.0)
self.beta.data.fill_(0.0)
def forward(self, x):
return instance_layer_normalization(
x, self.gamma.expand(x.shape[0], -1), self.beta.expand(x.shape[0], -1), self.rho, self.eps)
@MODEL.register_module("UGATIT-Discriminator")
class Discriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_blocks=5):
super(Discriminator, self).__init__()
encoder = [self.build_conv_block(in_channels, base_channels)]
for i in range(1, num_blocks - 2):
mult = 2 ** (i - 1)
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2))
mult = 2 ** (num_blocks - 2 - 1)
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2, stride=1))
self.encoder = nn.Sequential(*encoder)
# Class Activation Map
mult = 2 ** (num_blocks - 2)
self.gap_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
self.gmp_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
self.leaky_relu = nn.LeakyReLU(0.2, True)
self.conv = nn.utils.spectral_norm(
nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False, padding_mode="reflect"))
@staticmethod
def build_conv_block(in_channels, out_channels, kernel_size=4, stride=2, padding=1, padding_mode="reflect"):
return nn.Sequential(*[
nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
bias=True, padding=padding, padding_mode=padding_mode)),
nn.LeakyReLU(0.2, True),
])
def forward(self, x, return_heatmap=False):
x = self.encoder(x)
batch_size = x.size(0)
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) # B x C x 1 x 1, avg of per channel
gap_logit = self.gap_fc(gap.view(batch_size, -1))
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
gmp_logit = self.gmp_fc(gmp.view(batch_size, -1))
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
x = torch.cat([gap, gmp], 1)
x = self.leaky_relu(self.conv1x1(x))
if return_heatmap:
heatmap = torch.sum(x, dim=1, keepdim=True)
return self.conv(x), cam_logit, heatmap
else:
return self.conv(x), cam_logit

View File

@@ -1,139 +0,0 @@
import math
import torch
import torch.nn as nn
from model.normalization import select_norm_layer
from model import MODEL
class GANImageBuffer(object):
"""This class implements an image buffer that stores previously
generated images.
This buffer allows us to update the discriminator using a history of
generated images rather than the ones produced by the latest generator
to reduce model oscillation.
Args:
buffer_size (int): The size of image buffer. If buffer_size = 0,
no buffer will be created.
buffer_ratio (float): The chance / possibility to use the images
previously stored in the buffer.
"""
def __init__(self, buffer_size, buffer_ratio=0.5):
self.buffer_size = buffer_size
# create an empty buffer
if self.buffer_size > 0:
self.img_num = 0
self.image_buffer = []
self.buffer_ratio = buffer_ratio
def query(self, images):
"""Query current image batch using a history of generated images.
Args:
images (Tensor): Current image batch without history information.
"""
if self.buffer_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
# if the buffer is not full, keep inserting current images
if self.img_num < self.buffer_size:
self.img_num = self.img_num + 1
self.image_buffer.append(image)
return_images.append(image)
else:
use_buffer = torch.rand(1) < self.buffer_ratio
# by self.buffer_ratio, the buffer will return a previously
# stored image, and insert the current image into the buffer
if use_buffer:
random_id = torch.randint(0, self.buffer_size, (1,)).item()
image_tmp = self.image_buffer[random_id].clone()
self.image_buffer[random_id] = image
return_images.append(image_tmp)
# by (1 - self.buffer_ratio), the buffer will return the
# current image
else:
return_images.append(image)
# collect all the images and return
return_images = torch.cat(return_images, 0)
return return_images
# based SPADE or pix2pixHD Discriminator
@MODEL.register_module("PatchDiscriminator")
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN",
need_intermediate_feature=False):
super().__init__()
self.need_intermediate_feature = need_intermediate_feature
kernel_size = 4
padding = math.ceil((kernel_size - 1.0) / 2)
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
padding_mode = "zeros"
sequence = [nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size, stride=2, padding=padding),
nn.LeakyReLU(0.2, False)
)]
multiple_now = 1
for i in range(1, num_conv):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 3)
stride = 1 if i == num_conv - 1 else 2
sequence.append(nn.Sequential(
self.build_conv2d(use_spectral, base_channels * multiple_prev, base_channels * multiple_now,
kernel_size, stride, padding, bias=use_bias, padding_mode=padding_mode),
norm_layer(base_channels * multiple_now),
nn.LeakyReLU(0.2, inplace=False),
))
multiple_now = min(2 ** num_conv, 8)
sequence.append(nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding,
padding_mode=padding_mode))
self.conv_blocks = nn.ModuleList(sequence)
@staticmethod
def build_conv2d(use_spectral, in_channels: int, out_channels: int, kernel_size, stride, padding,
bias=True, padding_mode: str = 'zeros'):
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias, padding_mode=padding_mode)
if not use_spectral:
return conv
return nn.utils.spectral_norm(conv)
def forward(self, x):
if self.need_intermediate_feature:
intermediate_feature = []
for layer in self.conv_blocks:
x = layer(x)
intermediate_feature.append(x)
return tuple(intermediate_feature)
else:
for layer in self.conv_blocks:
x = layer(x)
return x
@MODEL.register_module()
class ResidualBlock(nn.Module):
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None):
super(ResidualBlock, self).__init__()
if use_bias is None:
# Only for IN, use bias since it does not have affine parameters.
use_bias = norm_type == "IN"
norm_layer = select_norm_layer(norm_type)
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
bias=use_bias)
self.norm1 = norm_layer(num_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
bias=use_bias)
self.norm2 = norm_layer(num_channels)
def forward(self, x):
res = x
x = self.relu1(self.norm1(self.conv1(x)))
x = self.norm2(self.conv2(x))
return x + res

View File

@@ -1,6 +1,7 @@
from model.registry import MODEL
import model.GAN.CycleGAN
import model.GAN.TAFG
import model.GAN.UGATIT
import model.GAN.wrapper
import model.GAN.base
from model.registry import MODEL, NORMALIZATION
import model.base.normalization
import model.image_translation.UGATIT
import model.image_translation.CycleGAN
import model.image_translation.pix2pixHD
import model.image_translation.GauGAN
import model.image_translation.TSIT

128
model/base/module.py Normal file
View File

@@ -0,0 +1,128 @@
import torch.nn as nn
from model.registry import NORMALIZATION
_DO_NO_THING_FUNC = lambda x: x
def _use_bias_checker(norm_type):
return norm_type not in ["IN", "BN", "AdaIN", "FADE", "SPADE"]
def _normalization(norm, num_features, additional_kwargs=None):
if norm == "NONE":
return _DO_NO_THING_FUNC
if additional_kwargs is None:
additional_kwargs = {}
kwargs = dict(_type=norm, num_features=num_features)
kwargs.update(additional_kwargs)
return NORMALIZATION.build_with(kwargs)
def _activation(activation, inplace=True):
if activation == "NONE":
return _DO_NO_THING_FUNC
elif activation == "ReLU":
return nn.ReLU(inplace=inplace)
elif activation == "LeakyReLU":
return nn.LeakyReLU(negative_slope=0.2, inplace=inplace)
elif activation == "Tanh":
return nn.Tanh()
else:
raise NotImplementedError(f"{activation} not valid")
class LinearBlock(nn.Module):
def __init__(self, in_features: int, out_features: int, bias=None, activation_type="ReLU", norm_type="NONE"):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
bias = _use_bias_checker(norm_type) if bias is None else bias
self.linear = nn.Linear(in_features, out_features, bias)
self.normalization = _normalization(norm_type, out_features)
self.activation = _activation(activation_type)
def forward(self, x):
return self.activation(self.normalization(self.linear(x)))
class Conv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, bias=None,
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None,
pre_activation=False, use_transpose_conv=False, **conv_kwargs):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
self.pre_activation = pre_activation
if use_transpose_conv:
# Only "zeros" padding mode is supported for ConvTranspose2d
conv_kwargs["padding_mode"] = "zeros"
conv = nn.ConvTranspose2d
else:
conv = nn.Conv2d
if pre_activation:
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
self.activation = _activation(activation_type, inplace=False)
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
else:
# if caller not set bias, set bias automatically.
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
self.activation = _activation(activation_type)
def forward(self, x):
if self.pre_activation:
return self.convolution(self.activation(self.normalization(x)))
return self.activation(self.normalization(self.convolution(x)))
class ResidualBlock(nn.Module):
def __init__(self, in_channels,
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
out_channels=None, out_activation_type=None, additional_norm_kwargs=None):
"""
Residual Conv Block
:param in_channels:
:param out_channels:
:param padding_mode:
:param activation_type:
:param norm_type:
:param out_activation_type:
:param pre_activation: full pre-activation mode from https://arxiv.org/pdf/1603.05027v3.pdf, figure 4
"""
super().__init__()
self.norm_type = norm_type
if out_channels is None:
out_channels = in_channels
if out_activation_type is None:
# if not specify `out_activation_type`, using default `out_activation_type`
# `out_activation_type` default mode:
# "NONE" for not full pre-activation
# `norm_type` for full pre-activation
out_activation_type = "NONE" if not pre_activation else norm_type
self.learn_skip_connection = in_channels != out_channels
conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type,
additional_norm_kwargs=additional_norm_kwargs, pre_activation=pre_activation,
padding_mode=padding_mode)
self.conv1 = Conv2dBlock(in_channels, in_channels, **conv_param)
self.conv2 = Conv2dBlock(in_channels, out_channels, **conv_param)
if self.learn_skip_connection:
conv_param['kernel_size'] = 1
conv_param['padding'] = 0
self.res_conv = Conv2dBlock(in_channels, out_channels, **conv_param)
def forward(self, x):
res = x if not self.learn_skip_connection else self.res_conv(x)
return self.conv2(self.conv1(x)) + res

143
model/base/normalization.py Normal file
View File

@@ -0,0 +1,143 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import NORMALIZATION
from model.base.module import Conv2dBlock
_VALID_NORM_AND_ABBREVIATION = dict(
IN="InstanceNorm2d",
BN="BatchNorm2d",
)
for abbr, name in _VALID_NORM_AND_ABBREVIATION.items():
NORMALIZATION.register_module(module=getattr(nn, name), name=abbr)
@NORMALIZATION.register_module("ADE")
class AdaptiveDenormalization(nn.Module):
def __init__(self, num_features, base_norm_type="BN", gamma_bias=0.0):
super().__init__()
self.num_features = num_features
self.base_norm_type = base_norm_type
self.norm = self.base_norm(num_features)
self.gamma = None
self.gamma_bias = gamma_bias
self.beta = None
self.have_set_condition = False
def base_norm(self, num_features):
if self.base_norm_type == "IN":
return nn.InstanceNorm2d(num_features, affine=False)
elif self.base_norm_type == "BN":
return nn.BatchNorm2d(num_features, affine=False, track_running_stats=True)
def set_condition(self, gamma, beta):
self.gamma, self.beta = gamma, beta
self.have_set_condition = True
def forward(self, x):
assert self.have_set_condition
x = self.norm(x)
x = (self.gamma + self.gamma_bias) * x + self.beta
self.have_set_condition = False
return x
#
# def __repr__(self):
# return f"{self.__class__.__name__}(num_features={self.num_features}, " \
# f"base_norm_type={self.base_norm_type})"
@NORMALIZATION.register_module("AdaIN")
class AdaptiveInstanceNorm2d(AdaptiveDenormalization):
def __init__(self, num_features: int):
super().__init__(num_features, "IN")
self.num_features = num_features
def set_style(self, style):
style = style.view(*style.size(), 1, 1)
gamma, beta = style.chunk(2, 1)
super().set_condition(gamma, beta)
@NORMALIZATION.register_module("FADE")
class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
def __init__(self, num_features: int, condition_in_channels,
base_norm_type="BN", padding_mode="zeros", gamma_bias=0.0):
super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias)
self.beta_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
padding_mode=padding_mode)
self.gamma_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
padding_mode=padding_mode)
def set_feature(self, feature):
gamma = self.gamma_conv(feature)
beta = self.beta_conv(feature)
super().set_condition(gamma, beta)
@NORMALIZATION.register_module("SPADE")
class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
def __init__(self, num_features: int, condition_in_channels, base_channels=128, base_norm_type="BN",
activation_type="ReLU", padding_mode="zeros", gamma_bias=0.0):
super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias)
self.base_conv_block = Conv2dBlock(condition_in_channels, base_channels, activation_type=activation_type,
kernel_size=3, padding=1, padding_mode=padding_mode, norm_type="NONE")
self.beta_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
self.gamma_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
def set_condition_image(self, condition_image):
feature = self.base_conv_block(condition_image)
gamma = self.gamma_conv(feature)
beta = self.beta_conv(feature)
super().set_condition(gamma, beta)
def _instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
out = rho * F.instance_norm(x, eps=eps) + (1 - rho) * F.layer_norm(x, x.size()[1:], eps=eps)
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
return out
@NORMALIZATION.register_module("ILN")
class ILN(nn.Module):
def __init__(self, num_features, eps=1e-5):
super(ILN, self).__init__()
self.eps = eps
self.rho = nn.Parameter(torch.Tensor(num_features))
self.gamma = nn.Parameter(torch.Tensor(num_features))
self.beta = nn.Parameter(torch.Tensor(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.zeros_(self.rho)
nn.init.ones_(self.gamma)
nn.init.zeros_(self.beta)
def forward(self, x):
return _instance_layer_normalization(
x, self.gamma.view(1, -1), self.beta.view(1, -1), self.rho.view(1, -1, 1, 1), self.eps)
@NORMALIZATION.register_module("AdaILN")
class AdaILN(nn.Module):
def __init__(self, num_features, eps=1e-5, default_rho=0.9):
super(AdaILN, self).__init__()
self.eps = eps
self.rho = nn.Parameter(torch.Tensor(num_features))
self.rho.data.fill_(default_rho)
self.gamma = None
self.beta = None
self.have_set_condition = False
def set_condition(self, gamma, beta):
self.gamma, self.beta = gamma, beta
self.have_set_condition = True
def forward(self, x):
assert self.have_set_condition
out = _instance_layer_normalization(x, self.gamma, self.beta, self.rho.view(1, -1, 1, 1), self.eps)
self.have_set_condition = False
return out

View File

@@ -0,0 +1,151 @@
import torch.nn as nn
from model import MODEL
from model.base.module import Conv2dBlock, ResidualBlock
class Encoder(nn.Module):
def __init__(self, in_channels, base_channels, num_conv, num_res, max_down_sampling_multiple=2,
padding_mode='reflect', activation_type="ReLU",
down_conv_norm_type="IN", down_conv_kernel_size=3,
res_norm_type="IN", pre_activation=False):
super().__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
activation_type=activation_type, norm_type=down_conv_norm_type
)]
multiple_now = 1
for i in range(1, num_conv + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode="zeros",
activation_type=activation_type, norm_type=down_conv_norm_type
))
self.out_channels = multiple_now * base_channels
sequence += [
ResidualBlock(
self.out_channels,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type=res_norm_type,
pre_activation=pre_activation
) for _ in range(num_res)
]
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
return self.sequence(x)
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
activation_type="ReLU", padding_mode='reflect',
up_conv_kernel_size=5, up_conv_norm_type="LN",
res_norm_type="AdaIN", pre_activation=False, use_transpose_conv=False):
super().__init__()
self.residual_blocks = nn.ModuleList([
ResidualBlock(
in_channels,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type=res_norm_type,
pre_activation=pre_activation
) for _ in range(num_residual_blocks)
])
sequence = list()
channels = in_channels
padding = (up_conv_kernel_size - 1) // 2
for i in range(num_up_sampling):
if use_transpose_conv:
sequence.append(Conv2dBlock(
channels, channels // 2, kernel_size=up_conv_kernel_size, stride=2,
padding=padding, output_padding=padding,
padding_mode=padding_mode,
activation_type=activation_type, norm_type=up_conv_norm_type,
use_transpose_conv=True
))
else:
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1,
padding=padding, padding_mode=padding_mode,
activation_type=activation_type, norm_type=up_conv_norm_type),
))
channels = channels // 2
sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3,
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"))
self.up_sequence = nn.Sequential(*sequence)
def forward(self, x):
for i, blk in enumerate(self.residual_blocks):
x = blk(x)
return self.up_sequence(x)
@MODEL.register_module("CycleGAN-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, activation_type="ReLU",
padding_mode='reflect', norm_type="IN", pre_activation=False, use_transpose_conv=True):
super().__init__()
self.encoder = Encoder(in_channels, base_channels, num_conv=2, num_res=num_blocks,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type=norm_type, res_norm_type=norm_type, pre_activation=pre_activation)
self.decoder = Decoder(self.encoder.out_channels, out_channels, num_up_sampling=2, num_residual_blocks=0,
padding_mode=padding_mode, activation_type=activation_type,
up_conv_kernel_size=3, up_conv_norm_type=norm_type,
pre_activation=pre_activation, use_transpose_conv=use_transpose_conv)
def forward(self, x):
return self.decoder(self.encoder(x))
@MODEL.register_module("PatchDiscriminator")
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_conv=4, need_intermediate_feature=False,
norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
self.need_intermediate_feature = need_intermediate_feature
kernel_size = 4
padding = (kernel_size - 1) // 2
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=kernel_size, stride=2, padding=padding, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
multiple_now = 1
for i in range(1, num_conv):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 3)
stride = 1 if i == num_conv - 1 else 2
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
))
sequence.append(nn.Conv2d(
base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding, padding_mode=padding_mode))
if self.need_intermediate_feature:
self.sequence = nn.ModuleList(sequence)
else:
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
if self.need_intermediate_feature:
intermediate_feature = []
for layer in self.sequence:
x = layer(x)
intermediate_feature.append(x)
return tuple(intermediate_feature)
else:
return self.sequence(x)
if __name__ == '__main__':
g = Generator(**dict(in_channels=3, out_channels=3))
print(g)
pd = PatchDiscriminator(**dict(in_channels=3, base_channels=64, num_conv=4))
print(pd)

View File

@@ -0,0 +1,183 @@
from collections import OrderedDict
from functools import partial
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.base.module import ResidualBlock, Conv2dBlock, LinearBlock
from model import MODEL
class StyleEncoder(nn.Module):
def __init__(self, in_channels, style_dim, num_conv, end_size=(4, 4), base_channels=64,
norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
multiple_now = 0
max_multiple = 3
for i in range(1, num_conv + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** max_multiple)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=3, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
))
self.sequence = nn.Sequential(*sequence)
self.fc_avg = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim)
self.fc_var = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim)
def forward(self, x):
x = self.sequence(x)
x = x.view(x.size(0), -1)
return self.fc_avg(x), self.fc_var(x)
class ImprovedSPADEGenerator(nn.Module):
def __init__(self, in_channels, out_channels, output_size, have_style_input, style_dim, start_size=(4, 4),
base_channels=64, padding_mode='reflect', activation_type="LeakyReLU", pre_activation=False):
super().__init__()
assert output_size in (128, 256, 512, 1024)
self.output_size = output_size
kernel_size = 3
if have_style_input:
self.style_converter = nn.Sequential(
LinearBlock(style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"),
LinearBlock(2 * style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"),
)
base_conv = partial(
Conv2dBlock,
pre_activation=pre_activation, activation_type=activation_type,
norm_type="AdaIN" if have_style_input else "NONE",
kernel_size=kernel_size, padding=(kernel_size - 1) // 2, padding_mode=padding_mode
)
base_residual_block = partial(
ResidualBlock,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type="SPADE",
pre_activation=True,
additional_norm_kwargs=dict(
condition_in_channels=in_channels, base_channels=128, base_norm_type="BN",
activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0
)
)
sequence = OrderedDict()
channels = (2 ** 4) * base_channels
sequence["block_head"] = nn.Sequential(OrderedDict([
("conv_input", base_conv(in_channels=in_channels, out_channels=channels)),
("conv_style", base_conv(in_channels=channels, out_channels=channels)),
("res_a", base_residual_block(in_channels=channels, out_channels=channels)),
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
("up", nn.Upsample(scale_factor=2, mode='nearest'))
]))
for i in range(4, 9 - min(int(math.log(self.output_size, 2)), 8), -1):
channels = (2 ** (i - 1)) * base_channels
sequence[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([
("res_a", base_residual_block(in_channels=channels * 2, out_channels=channels)),
("conv_style", base_conv(in_channels=channels, out_channels=channels)),
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
("up", nn.Upsample(scale_factor=2, mode='nearest'))
]))
self.sequence = nn.Sequential(sequence)
# channels = 2*base_channels when output size is 256, 512, 1024
# channels = 5*base_channels when output size is 128
out_modules = OrderedDict()
out_modules["out_1"] = nn.Sequential(
Conv2dBlock(
channels, out_channels, kernel_size=5, stride=1, padding=2,
pre_activation=pre_activation,
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"
),
nn.Tanh()
)
for i in range(int(math.log(self.output_size, 2)) - 8):
channels = channels // 2
out_modules[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([
("res_a", base_residual_block(in_channels=2 * channels, out_channels=channels)),
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
("up", nn.Upsample(scale_factor=2, mode='nearest'))
]))
out_modules[f"out_{i + 2}"] = nn.Sequential(
Conv2dBlock(
channels, out_channels, kernel_size=5, stride=1, padding=2,
pre_activation=pre_activation,
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"
),
nn.Tanh()
)
self.out_modules = nn.ModuleDict(out_modules)
def forward(self, seg, style=None):
pass
@MODEL.register_module()
class SPADEGenerator(nn.Module):
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
self.sx, self.sy = start_size
self.use_vae = use_vae
self.num_z_dim = num_z_dim
if use_vae:
self.input_converter = nn.Linear(num_z_dim, 16 * base_channels * self.sx * self.sy)
else:
self.input_converter = nn.Conv2d(in_channels, 16 * base_channels, kernel_size=3, padding=1)
sequence = []
multiple_now = 16
for i in range(num_blocks - 1, -1, -1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
if i != num_blocks - 1:
sequence.append(nn.Upsample(scale_factor=2))
sequence.append(ResidualBlock(
base_channels * multiple_prev,
out_channels=base_channels * multiple_now,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type="SPADE",
pre_activation=True,
additional_norm_kwargs=dict(
condition_in_channels=in_channels, base_channels=128, base_norm_type="BN",
activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0
)
))
self.sequence = nn.Sequential(*sequence)
self.output_converter = Conv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE")
def forward(self, seg, z=None):
if self.use_vae:
if z is None:
z = torch.randn(seg.size(0), self.num_z_dim, device=seg.device)
x = self.input_converter(z).view(seg.size(0), -1, self.sx, self.sy)
else:
x = self.input_converter(F.interpolate(seg, size=(self.sx, self.sy)))
for blk in self.sequence:
if isinstance(blk, ResidualBlock):
downsampling_seg = F.interpolate(seg, size=x.size()[2:], mode='nearest')
blk.conv1.normalization.set_condition_image(downsampling_seg)
blk.conv2.normalization.set_condition_image(downsampling_seg)
if blk.learn_skip_connection:
blk.res_conv.normalization.set_condition_image(downsampling_seg)
x = blk(x)
return self.output_converter(x)
if __name__ == '__main__':
g = SPADEGenerator(3, 3, 7, False, 256)
print(g)
print(g(torch.randn(2, 3, 256, 256)).size())

View File

@@ -0,0 +1,89 @@
import torch
import torch.nn as nn
from model import MODEL
from model.base.module import LinearBlock
from model.image_translation.CycleGAN import Encoder, Decoder
class StyleEncoder(nn.Module):
def __init__(self, in_channels, out_dim, num_conv, base_channels=64,
max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE",
pre_activation=False):
super().__init__()
self.down_encoder = Encoder(
in_channels, base_channels, num_conv, num_res=0, max_down_sampling_multiple=max_down_sampling_multiple,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type=norm_type, down_conv_kernel_size=4, pre_activation=pre_activation,
)
sequence = list()
sequence.append(nn.AdaptiveAvgPool2d(1))
# conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code
sequence.append(nn.Conv2d(self.down_encoder.out_channels, out_dim, kernel_size=1, stride=1, padding=0))
self.sequence = nn.Sequential(*sequence)
def forward(self, image):
return self.sequence(image).view(image.size(0), -1)
class MLPFusion(nn.Module):
def __init__(self, in_features, out_features, base_features, n_blocks, activation_type="ReLU", norm_type="NONE"):
super().__init__()
sequence = [LinearBlock(in_features, base_features, activation_type=activation_type, norm_type=norm_type)]
sequence += [
LinearBlock(base_features, base_features, activation_type=activation_type, norm_type=norm_type)
for _ in range(n_blocks - 2)
]
sequence.append(LinearBlock(base_features, out_features, activation_type=activation_type, norm_type=norm_type))
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
return self.sequence(x)
@MODEL.register_module("MUNIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, style_dim=8,
num_mlp_base_feature=256, num_mlp_blocks=3,
max_down_sampling_multiple=2, num_content_down_sampling=2, num_style_down_sampling=2,
encoder_num_residual_blocks=4, decoder_num_residual_blocks=4,
padding_mode='reflect', activation_type="ReLU", pre_activation=False):
super().__init__()
self.content_encoder = Encoder(
in_channels, base_channels, num_content_down_sampling, encoder_num_residual_blocks,
max_down_sampling_multiple=num_content_down_sampling,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type="IN", down_conv_kernel_size=4,
res_norm_type="IN", pre_activation=pre_activation
)
self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels,
max_down_sampling_multiple, padding_mode, activation_type,
norm_type="NONE", pre_activation=pre_activation)
content_channels = base_channels * (2 ** max_down_sampling_multiple)
self.fusion = MLPFusion(style_dim, decoder_num_residual_blocks * 2 * content_channels * 2,
num_mlp_base_feature, num_mlp_blocks, activation_type,
norm_type="NONE")
self.decoder = Decoder(in_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks,
activation_type=activation_type, padding_mode=padding_mode,
up_conv_kernel_size=5, up_conv_norm_type="LN",
res_norm_type="AdaIN", pre_activation=pre_activation)
def encode(self, x):
return self.content_encoder(x), self.style_encoder(x)
def decode(self, content, style):
as_param_style = torch.chunk(self.fusion(style), 2 * len(self.decoder.residual_blocks), dim=1)
# set style for decoder
for i, blk in enumerate(self.decoder.residual_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
self.decoder(content)
def forward(self, x):
content, style = self.encode(x)
return self.decode(content, style)

View File

@@ -0,0 +1,98 @@
import torch.nn as nn
import torch.nn.functional as F
import torch
from model import MODEL
from model.base.module import ResidualBlock, Conv2dBlock
class Interpolation(nn.Module):
def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
super(Interpolation, self).__init__()
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
recompute_scale_factor=False)
def __repr__(self):
return f"Interpolation(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
@MODEL.register_module("TSIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=7,
padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
self.input_layer = Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
activation_type=activation_type, norm_type="IN",
)
multiple_now = 1
stream_sequence = []
for i in range(1, num_blocks + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
stream_sequence.append(nn.Sequential(
Interpolation(scale_factor=0.5, mode="nearest"),
ResidualBlock(
multiple_prev * base_channels, out_channels=multiple_now * base_channels,
padding_mode=padding_mode, activation_type=activation_type, norm_type="IN")
))
self.down_sequence = nn.ModuleList(stream_sequence)
sequence = []
multiple_now = 16
for i in range(num_blocks - 1, -1, -1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
sequence.append(nn.Sequential(
ResidualBlock(
base_channels * multiple_prev,
out_channels=base_channels * multiple_now,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type="FADE",
pre_activation=True,
additional_norm_kwargs=dict(
condition_in_channels=base_channels * multiple_prev, base_norm_type="BN",
padding_mode="zeros", gamma_bias=0.0
)
),
Interpolation(scale_factor=2, mode="nearest")
))
self.up_sequence = nn.Sequential(*sequence)
self.output_layer = Conv2dBlock(
base_channels, out_channels, kernel_size=3, stride=1, padding=1,
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"
)
def forward(self, x, z=None):
c = self.input_layer(x)
contents = []
for blk in self.down_sequence:
c = blk(c)
contents.append(c)
if z is None:
# for image 256x256, z size: [batch_size, 1024, 2, 2]
z = torch.randn(size=contents[-1].size(), device=contents[-1].device)
for blk in self.up_sequence:
res = blk[0]
c = contents.pop()
res.conv1.normalization.set_feature(c)
res.conv2.normalization.set_feature(c)
if res.learn_skip_connection:
res.res_conv.normalization.set_feature(c)
return self.output_layer(self.up_sequence(z))
if __name__ == '__main__':
g = Generator(3, 3).cuda()
img = torch.randn(2, 3, 256, 256).cuda()
print(g(img).size())

View File

@@ -0,0 +1,125 @@
import torch
import torch.nn as nn
from model import MODEL
from model.base.module import Conv2dBlock, LinearBlock
from model.image_translation.CycleGAN import Encoder, Decoder
class CAMClassifier(nn.Module):
def __init__(self, in_channels, activation_type="ReLU"):
super(CAMClassifier, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.avg_fc = nn.Linear(in_channels, 1, bias=False)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.max_fc = nn.Linear(in_channels, 1, bias=False)
self.fusion_conv = Conv2dBlock(in_channels * 2, in_channels, kernel_size=1, stride=1, bias=True,
activation_type=activation_type, norm_type="NONE")
def forward(self, x):
avg_logit = self.avg_fc(self.avg_pool(x).view(x.size(0), -1))
max_logit = self.max_fc(self.max_pool(x).view(x.size(0), -1))
return self.fusion_conv(torch.cat(
[x * self.avg_fc.weight.unsqueeze(2).unsqueeze(3), x * self.max_fc.weight.unsqueeze(2).unsqueeze(3)],
dim=1
)), torch.cat([avg_logit, max_logit], 1)
@MODEL.register_module("UGATIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False,
activation_type="ReLU", norm_type="IN", padding_mode='reflect', pre_activation=False):
super(Generator, self).__init__()
self.light = light
n_down_sampling = 2
self.encoder = Encoder(in_channels, base_channels, n_down_sampling, num_blocks,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type,
pre_activation=pre_activation)
mult = 2 ** n_down_sampling
self.cam = CAMClassifier(base_channels * mult, activation_type)
# Gamma, Beta block
if self.light:
self.fc = nn.Sequential(
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE"),
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE")
)
else:
self.fc = nn.Sequential(
LinearBlock(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, False,
"ReLU", "NONE"),
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE")
)
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
self.decoder = Decoder(
base_channels * mult, out_channels, n_down_sampling, num_blocks,
activation_type=activation_type, padding_mode=padding_mode,
up_conv_kernel_size=3, up_conv_norm_type="ILN",
res_norm_type="AdaILN", pre_activation=pre_activation
)
def forward(self, x):
x = self.encoder(x)
x, cam_logit = self.cam(x)
heatmap = torch.sum(x, dim=1, keepdim=True)
if self.light:
x_ = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
x_ = self.fc(x_.view(x_.shape[0], -1))
else:
x_ = self.fc(x.view(x.shape[0], -1))
gamma, beta = self.gamma(x_), self.beta(x_)
for blk in self.decoder.residual_blocks:
blk.conv1.normalization.set_condition(gamma, beta)
blk.conv2.normalization.set_condition(gamma, beta)
return self.decoder(x), cam_logit, heatmap
@MODEL.register_module("UGATIT-Discriminator")
class Discriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_blocks=5,
activation_type="LeakyReLU", norm_type="NONE", padding_mode='reflect'):
super().__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
sequence += [Conv2dBlock(
base_channels * (2 ** i), base_channels * (2 ** i) * 2,
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type) for i in range(num_blocks - 3)]
sequence.append(
Conv2dBlock(base_channels * (2 ** (num_blocks - 3)), base_channels * (2 ** (num_blocks - 2)),
kernel_size=4, stride=1, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type)
)
self.sequence = nn.Sequential(*sequence)
mult = 2 ** (num_blocks - 2)
self.cam = CAMClassifier(base_channels * mult, activation_type)
self.conv = nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False,
padding_mode="reflect")
def forward(self, x, return_heatmap=False):
x = self.sequence(x)
x, cam_logit = self.cam(x)
if return_heatmap:
heatmap = torch.sum(x, dim=1, keepdim=True)
return self.conv(x), cam_logit, heatmap
else:
return self.conv(x), cam_logit

View File

View File

@@ -6,16 +6,20 @@ from model import MODEL
@MODEL.register_module()
class MultiScaleDiscriminator(nn.Module):
def __init__(self, num_scale, discriminator_cfg):
def __init__(self, num_scale, discriminator_cfg, down_sample_method="avg"):
super().__init__()
assert down_sample_method in ["avg", "bilinear"]
self.down_sample_method = down_sample_method
self.discriminator_list = nn.ModuleList([
MODEL.build_with(discriminator_cfg) for _ in range(num_scale)
])
@staticmethod
def down_sample(x):
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
def down_sample(self, x):
if self.down_sample_method == "avg":
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
if self.down_sample_method == "bilinear":
return F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True)
def forward(self, x):
results = []

View File

@@ -1,75 +0,0 @@
import torch.nn as nn
import functools
import torch
def select_norm_layer(norm_type):
if norm_type == "BN":
return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == "IN":
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == "LN":
return functools.partial(LayerNorm2d, affine=True)
elif norm_type == "NONE":
return lambda num_features: nn.Identity()
elif norm_type == "AdaIN":
return functools.partial(AdaptiveInstanceNorm2d, affine=False, track_running_stats=False)
else:
raise NotImplemented(f'normalization layer {norm_type} is not found')
class LayerNorm2d(nn.Module):
def __init__(self, num_features, eps: float = 1e-5, affine: bool = True):
super().__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
if self.affine:
self.channel_gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.channel_beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.reset_parameters()
def reset_parameters(self):
if self.affine:
nn.init.uniform_(self.channel_gamma)
nn.init.zeros_(self.channel_beta)
def forward(self, x):
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
x = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
if self.affine:
return self.channel_gamma * x + self.channel_beta
return x
def __repr__(self):
return f"{self.__class__.__name__}(num_features={self.num_features}, affine={self.affine})"
class AdaptiveInstanceNorm2d(nn.Module):
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1,
affine: bool = False, track_running_stats: bool = False):
super().__init__()
self.num_features = num_features
self.affine = affine
self.track_running_stats = track_running_stats
self.norm = nn.InstanceNorm2d(num_features, eps, momentum, affine, track_running_stats)
self.gamma = None
self.beta = None
self.have_set_style = False
def set_style(self, style):
style = style.view(*style.size(), 1, 1)
self.gamma, self.beta = style.chunk(2, 1)
self.have_set_style = True
def forward(self, x):
assert self.have_set_style
x = self.norm(x)
x = self.gamma * x + self.beta
self.have_set_style = False
return x
def __repr__(self):
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
f"affine={self.affine}, track_running_stats={self.track_running_stats})"

View File

@@ -1,3 +1,4 @@
from util.registry import Registry
MODEL = Registry("model")
NORMALIZATION = Registry("normalization")

View File

@@ -65,7 +65,8 @@ def generation_init_weights(module, init_type='normal', init_gain=0.02):
elif classname.find('BatchNorm2d') != -1:
# BatchNorm Layer's weight is not a matrix;
# only normal distribution applies.
normal_init(m, 1.0, init_gain)
if m.weight is not None:
normal_init(m, 1.0, init_gain)
assert isinstance(module, nn.Module)
module.apply(init_func)

8
run.sh
View File

@@ -5,16 +5,18 @@ TASK=$2
GPUS=$3
MORE_ARG=${*:4}
RANDOM_MASTER=$(shuf -i 2000-65000 -n 1)
_command="print(len('${GPUS}'.split(',')))"
GPU_COUNT=$(python3 -c "${_command}")
echo "GPU_COUNT:${GPU_COUNT}"
echo CUDA_VISIBLE_DEVICES=$GPUS \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed "$MORE_ARG"
CUDA_VISIBLE_DEVICES=$GPUS \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
--master_port=${RANDOM_MASTER} \
main.py "$TASK" "$CONFIG" $MORE_ARG --backup_config --setup_output_dir --setup_random_seed

46
tool/dump_tensorboard.py Normal file
View File

@@ -0,0 +1,46 @@
#!/usr/bin/env python
# edit from https://gist.github.com/hysts/81a0d30ac4f33dfa0c8859383aec42c2
import argparse
import pathlib
import cv2
import numpy as np
from tensorboard.backend.event_processing import event_accumulator
def save(outdir: pathlib.Path, tag, event_acc):
events = event_acc.Images(tag)
for index, event in enumerate(events):
s = np.frombuffer(event.encoded_image_string, dtype=np.uint8)
image = cv2.imdecode(s, cv2.IMREAD_COLOR)
outpath = outdir / f"{tag.replace('/', '_')}@{index}.png"
cv2.imwrite(outpath.as_posix(), image)
# ffmpeg -framerate 1 -i ./tmp/test_b/%04d.jpg -vcodec mpeg4 test_b.mp4
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, required=True)
parser.add_argument('--outdir', type=str, required=True)
parser.add_argument("--tag", type=str, required=False)
args = parser.parse_args()
event_acc = event_accumulator.EventAccumulator(args.path, size_guidance={'images': 0})
event_acc.Reload()
outdir = pathlib.Path(args.outdir)
outdir.mkdir(exist_ok=True, parents=True)
if args.tag is None:
for tag in event_acc.Tags()['images']:
save(outdir, tag, event_acc)
else:
assert args.tag in event_acc.Tags()['images'], f"{args.tag} not in {event_acc.Tags()['images']}"
save(outdir, args.tag, event_acc)
if __name__ == '__main__':
main()

32
tool/encoder_distance.py Normal file
View File

@@ -0,0 +1,32 @@
from pathlib import Path
import torch
#
# data = {}
#
# for i in range(1, 422 + 1):
# _, names = torch.load(f"/tmp/pt/batch{i}.pt")
# generated = torch.load(f"/tmp/pt/generated{i}.pt")
# print(len(names))
# for j, n in enumerate(names):
# data[Path(names[j]).stem] = generated[j]
#
# torch.save(data, "/tmp/data.pt")
data = torch.load("/tmp/data.pt")
videos = sorted(list(set([k.split("@")[0] for k in data.keys()])))
for idx in range(len(videos)):
print(videos[idx])
videos_data = {}
for k in data:
if k.startswith(videos[idx]):
videos_data[int(k.split("@")[-1])] = data[k]
to_save = []
for i in range(2, len(videos_data) + 1):
to_save.append(torch.mean(torch.abs(videos_data[i] - videos_data[1])).cpu())
torch.save(to_save, f"{videos[idx]}.pt")
print(f"{videos[idx]}.pt")

14
tool/inspect_model.py Normal file
View File

@@ -0,0 +1,14 @@
import sys
import torch
from omegaconf import OmegaConf
from engine.util.build import build_model
config = OmegaConf.load(sys.argv[1])
generator = build_model(config.model.generator)
ckp = torch.load(sys.argv[2], map_location="cpu")
generator.module.load_state_dict(ckp["generator_main"])

View File

@@ -0,0 +1,13 @@
from pathlib import Path
import sys
from collections import defaultdict
from itertools import permutations
pids = defaultdict(list)
for p in Path(sys.argv[1]).glob("*.jpg"):
pids[p.stem[:7]].append(p.stem)
data = []
for p in pids:
data.extend(list(permutations(pids[p], 2)))

View File

@@ -1,6 +1,25 @@
import importlib
import logging
from typing import Optional
import pkgutil
from pathlib import Path
from typing import Optional
def import_submodules(package, recursive=True):
""" Import all submodules of a module, recursively, including subpackages
:param package: package (name or actual module)
:type package: str | module
:rtype: dict[str, types.ModuleType]
"""
if isinstance(package, str):
package = importlib.import_module(package)
results = {}
for loader, name, is_pkg in pkgutil.walk_packages(package.__path__):
full_name = package.__name__ + '.' + name
results[name] = importlib.import_module(full_name)
if recursive and is_pkg:
results.update(import_submodules(full_name))
return results
def setup_logger(
@@ -8,7 +27,6 @@ def setup_logger(
level: int = logging.INFO,
logger_format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s",
filepath: Optional[str] = None,
file_level: int = logging.DEBUG,
distributed_rank: Optional[int] = None,
) -> logging.Logger:
"""Setups logger: name, level, format etc.
@@ -18,7 +36,6 @@ def setup_logger(
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG
logger_format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`
filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
file_level (int): Optional logging level for logging file.
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
If None, distributed_rank is initialized to the rank of process.

View File

@@ -1,8 +1,10 @@
import inspect
from omegaconf.dictconfig import DictConfig
from omegaconf import OmegaConf
from types import ModuleType
import warnings
from types import ModuleType
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
class _Registry:
def __init__(self, name):
@@ -51,11 +53,9 @@ class _Registry:
else:
raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}')
for k in args:
assert isinstance(k, str)
if k.startswith("_"):
warnings.warn(f"got param start with `_`: {k}, will remove it")
args.pop(k)
for invalid_key in [k for k in args.keys() if k.startswith("_")]:
warnings.warn(f"got param start with `_`: {invalid_key}, will remove it")
args.pop(invalid_key)
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
@@ -136,8 +136,11 @@ class Registry(_Registry):
if module_name is None:
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {self.name}')
if self._module_dict[module_name] == module_class:
warnings.warn(f'{module_name} is already registered in {self.name}, but is the same class')
return
raise KeyError(f'{module_name}:{self._module_dict[module_name]} is already registered in {self.name}'
f'so {module_class} can not be registered')
self._module_dict[module_name] = module_class
def register_module(self, name=None, force=False, module=None):