Compare commits

...

39 Commits

Author SHA1 Message Date
8998c30c23 TSIT 2020-10-25 20:46:34 +08:00
0bec02bf6d 23333 2020-10-23 16:14:37 +08:00
f7b7b78669 imporved gan loss 2020-10-22 23:19:03 +08:00
376f5caeb7 v2 2020-10-22 22:42:01 +08:00
0019d4034c change a lot 2020-10-14 18:55:51 +08:00
0927fa3de5 add patch d 2020-10-13 10:31:17 +08:00
611901cbdf add ConvTranspose2d in Conv2d 2020-10-13 10:31:03 +08:00
a6ffab1445 add image buffers for gan 2020-10-13 10:30:27 +08:00
7b05b45156 update SPADE 2020-10-12 19:01:07 +08:00
2de00d0245 use loss container 2020-10-11 23:36:37 +08:00
74a7cfb2d8 move sn to engine 2020-10-11 23:35:29 +08:00
436bca88b4 add loss container 2020-10-11 23:09:04 +08:00
6070f08835 add GauGAN 2020-10-11 23:05:38 +08:00
06b2abd19a add flag to switch to norm-activ-conv 2020-10-11 19:02:42 +08:00
9c08b4cd09 move encoder, decoder to CycleGAN 2020-10-11 11:09:16 +08:00
04c6366c07 rewrite 2020-10-11 10:02:33 +08:00
6ea13df465 temp commit 2020-10-10 10:43:00 +08:00
776fe40199 change a lot 2020-09-26 17:48:26 +08:00
f67bcdf161 use base module rewrite TSIT 2020-09-26 17:48:10 +08:00
16f18ab2e2 func to apply sn 2020-09-26 17:47:24 +08:00
0f2b67e215 base model, Norm&Conv&ResNet 2020-09-26 17:45:51 +08:00
acf243cb12 working 2020-09-25 18:31:12 +08:00
fbea96f6d7 add new dataset type 2020-09-24 16:50:53 +08:00
ca55318253 add context loss 2020-09-24 16:38:03 +08:00
b01016edb5 TAFG update 2020-09-18 12:03:44 +08:00
61e04de8a5 TAFG 2020-09-17 09:34:53 +08:00
2ff4a91057 add MUNIT 2020-09-14 22:30:05 +08:00
f70658eaed small fix 2020-09-11 23:04:26 +08:00
340a344e91 add distance 2020-09-11 22:35:59 +08:00
85b5c3f589 fix small bug in U-GAT-IT 2020-09-11 22:34:43 +08:00
72d09aa483 update tester 2020-09-10 18:34:52 +08:00
7ea9c6d0d5 TAFG good result 2020-09-09 14:46:07 +08:00
87cbcf34d3 add tool to dump images in tensorboard event file 2020-09-09 09:08:11 +08:00
97ded53b30 update a lot 2020-09-07 21:38:10 +08:00
ab545843bf almost 0.1 2020-09-06 10:34:52 +08:00
e3c760d0c5 update 2020-09-05 22:00:17 +08:00
39c754374c change 2020-09-05 10:33:35 +08:00
2469bf15fe TAFG 0.01 2020-09-03 09:34:38 +08:00
14d4247112 base 2020-09-01 17:56:18 +08:00
75 changed files with 4307 additions and 2380 deletions

11
.idea/deployment.xml generated
View File

@@ -1,11 +1,11 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="22d" remoteFilesAllowedToDisappearOnAutoupload="false">
<component name="PublishConfigData" autoUpload="Always" serverName="14d" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="14d">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
<mapping deploy="raycv" local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
@@ -16,6 +16,13 @@
</mappings>
</serverdata>
</paths>
<paths name="21d">
<serverdata>
<mappings>
<mapping deploy="/raycv" local="$PROJECT_DIR$" web="" />
</mappings>
</serverdata>
</paths>
<paths name="22d">
<serverdata>
<mappings>

2
.idea/misc.xml generated
View File

@@ -1,4 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="22d-base" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="14d-python" project-jdk-type="Python SDK" />
</project>

2
.idea/raycv.iml generated
View File

@@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="22d-base" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="14d-python" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="TestRunnerService">

8
.idea/sshConfigs.xml generated Normal file
View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="SshConfigs">
<configs>
<sshConfig host="222.20.77.126" id="38d32db7-46b2-4b95-a40c-d17e8eeca6c1" keyPath="C:\Users\wr\.ssh\sg_id_rsa" port="50001" nameFormat="DESCRIPTIVE" username="dancer" />
</configs>
</component>
</project>

View File

@@ -1,51 +0,0 @@
name: cross-domain-1
engine: crossdomain
result_dir: ./result
distributed:
model:
# broadcast_buffers: False
misc:
random_seed: 1004
checkpoints:
interval: 2000
log:
logger:
level: 20 # DEBUG(10) INFO(20)
model:
_type: resnet10
baseline:
plusplus: False
optimizers:
_type: Adam
data:
dataloader:
batch_size: 1200
shuffle: True
num_workers: 16
pin_memory: True
drop_last: True
dataset:
train:
path: /data/few-shot/mini_imagenet_full_size/train
lmdb_path: /data/few-shot/lmdb/mini-ImageNet/train.lmdb
pipeline:
- Load
- RandomResizedCrop:
size: [224, 224]
- ColorJitter:
brightness: 0.4
contrast: 0.4
saturation: 0.4
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]

View File

@@ -0,0 +1,163 @@
name: huawei-cycylegan-7
engine: CycleGAN
result_dir: ./result
max_pairs: 1000000
misc:
random_seed: 324
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 4 # log image `image` times per epoch
test:
random: True
images: 10
model:
generator:
_type: CycleGAN-Generator
_add_spectral_norm: True
in_channels: 3
out_channels: 3
base_channels: 64
num_blocks: 9
use_transpose_conv: False
pre_activation: True
# discriminator:
# _type: MultiScaleDiscriminator
# _add_spectral_norm: True
# num_scale: 2
# down_sample_method: "bilinear"
# discriminator_cfg:
# _type: PatchDiscriminator
# in_channels: 3
# base_channels: 64
# num_conv: 4
# need_intermediate_feature: True
discriminator:
_type: PatchDiscriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_conv: 4
need_intermediate_feature: False
loss:
gan:
loss_type: hinge
weight: 1.0
real_label_val: 1
fake_label_val: 0.0
cycle:
level: 1
weight: 10.0
id:
level: 1
weight: 10.0
mgc:
weight: 1
fm:
weight: 0
edge:
weight: 0
hed_pretrained_model_path: ./network-bsds500.pytorch
optimizers:
generator:
_type: Adam
lr: 1e-4
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 4e-4
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/face2cartoon/all_face"
root_b: "/data/selfie2anime/trainB/"
random_pair: True
pipeline_a:
- Load
- RandomCrop:
size: [ 178, 178 ]
- Resize:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 286, 286 ]
- RandomCrop:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: video_dataset
dataloader:
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/face2cartoon/test/human"
root_b: "/data/face2cartoon/test/anime"
random_pair: True
pipeline_a:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
video_dataset:
_type: SingleFolderDataset
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
with_path: True
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@@ -0,0 +1,167 @@
name: huawei-GauGAN-3
engine: GauGAN
result_dir: ./result
max_pairs: 1000000
misc:
random_seed: 324
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 4 # log image `image` times per epoch
test:
random: True
images: 10
model:
generator:
_type: SPADEGenerator
_add_spectral_norm: True
in_channels: 3
out_channels: 3
num_blocks: 7
use_vae: False
num_z_dim: 256
# discriminator:
# _type: MultiScaleDiscriminator
# _add_spectral_norm: True
# num_scale: 2
# down_sample_method: "bilinear"
# discriminator_cfg:
# _type: PatchDiscriminator
# in_channels: 3
# base_channels: 64
# num_conv: 4
# need_intermediate_feature: True
discriminator:
_type: PatchDiscriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_conv: 4
need_intermediate_feature: True
loss:
gan:
loss_type: hinge
weight: 1.0
real_label_val: 1
fake_label_val: 0.0
perceptual:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 2
mgc:
weight: 5
fm:
weight: 5
edge:
weight: 0
hed_pretrained_model_path: ./network-bsds500.pytorch
optimizers:
generator:
_type: Adam
lr: 1e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 4e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/face2cartoon/all_face"
root_b: "/data/selfie2anime/trainB/"
random_pair: True
pipeline_a:
- Load
- RandomCrop:
size: [ 178, 178 ]
- Resize:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 286, 286 ]
- RandomCrop:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: video_dataset
dataloader:
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/face2cartoon/test/human"
root_b: "/data/face2cartoon/test/anime"
random_pair: True
pipeline_a:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
video_dataset:
_type: SingleFolderDataset
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
with_path: True
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@@ -0,0 +1,132 @@
name: MUNIT-edges2shoes
engine: MUNIT
result_dir: ./result
max_pairs: 1000000
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch
misc:
random_seed: 324
model:
generator:
_type: MUNIT-Generator
in_channels: 3
out_channels: 3
base_channels: 64
num_sampling: 2
num_style_dim: 8
num_style_conv: 4
num_content_res_blocks: 4
num_decoder_res_blocks: 4
num_fusion_dim: 256
num_fusion_blocks: 3
discriminator:
_type: MultiScaleDiscriminator
num_scale: 2
discriminator_cfg:
_type: PatchDiscriminator
in_channels: 3
base_channels: 64
use_spectral: True
need_intermediate_feature: True
loss:
gan:
loss_type: lsgan
real_label_val: 1.0
fake_label_val: 0.0
weight: 1.0
perceptual:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 0
recon:
level: 1
style:
weight: 1
content:
weight: 1
image:
weight: 10
cycle:
weight: 0
optimizers:
generator:
_type: Adam
lr: 0.0001
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 4e-4
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 1
shuffle: True
num_workers: 1
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/edges2shoes/trainA"
root_b: "/data/i2i/edges2shoes/trainB"
random_pair: True
pipeline:
- Load
- Resize:
size: [ 286, 286 ]
- RandomCrop:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: dataset
dataloader:
batch_size: 8
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/edges2shoes/testA"
root_b: "/data/i2i/edges2shoes/testB"
random_pair: False
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@@ -1,70 +1,99 @@
name: TAHG
engine: TAHG
name: TAFG-vox2
engine: TAFG
result_dir: ./result
max_pairs: 1000000
distributed:
model:
# broadcast_buffers: False
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 4 # log image `image` times per epoch
test:
random: True
images: 10
misc:
random_seed: 324
checkpoint:
epoch_interval: 1 # one checkpoint every 1 epoch
n_saved: 2
interval:
print_per_iteration: 10 # print once per 10 iteration
tensorboard:
scalar: 100
image: 2
random_seed: 1004
add_new_loss_epoch: -1
model:
generator:
_type: TAHG-Generator
_type: TAFG-Generator
_bn_to_sync_bn: False
style_in_channels: 3
content_in_channels: 1
num_blocks: 4
content_in_channels: 24
use_spectral_norm: False
style_encoder_type: StyleEncoder
num_style_conv: 4
style_dim: 8
num_adain_blocks: 4
num_res_blocks: 4
discriminator:
_type: TAHG-Discriminator
in_channels: 3
_type: MultiScaleDiscriminator
num_scale: 2
discriminator_cfg:
_type: PatchDiscriminator
in_channels: 3
base_channels: 64
use_spectral: True
need_intermediate_feature: True
loss:
gan:
loss_type: lsgan
loss_type: hinge
real_label_val: 1.0
fake_label_val: 0.0
weight: 1.0
edge:
criterion: 'L1'
hed_pretrained_model_path: "./network-bsds500.pytorch"
weight: 1
perceptual:
layer_weights:
"3": 1.0
# "0": 1.0
# "5": 1.0
# "10": 1.0
# "19": 1.0
criterion: 'L2'
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 0
style:
layer_weights:
"3": 1
criterion: 'L1'
style_loss: True
perceptual_loss: False
weight: 20
weight: 10
recon:
level: 1
weight: 10
style_recon:
level: 1
weight: 1
content_recon:
level: 1
weight: 1
edge:
weight: 5
hed_pretrained_model_path: ./network-bsds500.pytorch
cycle:
level: 1
weight: 10
optimizers:
generator:
_type: Adam
lr: 0.0001
betas: [ 0.5, 0.999 ]
betas: [ 0, 0.9 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 1e-4
betas: [ 0.5, 0.999 ]
lr: 4e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
data:
@@ -74,9 +103,9 @@ data:
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 160
batch_size: 8
shuffle: True
num_workers: 2
num_workers: 1
pin_memory: True
drop_last: True
dataset:
@@ -84,20 +113,22 @@ data:
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
edge_type: "hed"
size: [128, 128]
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
edge_type: "landmark_hed"
size: [ 128, 128 ]
random_pair: True
pipeline:
- Load
- Resize:
size: [128, 128]
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: video_dataset
dataloader:
batch_size: 8
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
@@ -107,13 +138,14 @@ data:
root_a: "/data/i2i/VoxCeleb2Anime/testA"
root_b: "/data/i2i/VoxCeleb2Anime/testB"
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
edge_type: "hed"
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
edge_type: "landmark_hed"
random_pair: False
size: [128, 128]
size: [ 128, 128 ]
pipeline:
- Load
- Resize:
size: [128, 128]
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
@@ -125,7 +157,7 @@ data:
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]

View File

@@ -0,0 +1,165 @@
name: huawei-TSIT-1
engine: GauGAN
result_dir: ./result
max_pairs: 1000000
misc:
random_seed: 324
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 4 # log image `image` times per epoch
test:
random: True
images: 10
model:
generator:
_type: TSIT-Generator
_add_spectral_norm: True
in_channels: 3
out_channels: 3
num_blocks: 7
# discriminator:
# _type: MultiScaleDiscriminator
# _add_spectral_norm: True
# num_scale: 2
# down_sample_method: "bilinear"
# discriminator_cfg:
# _type: PatchDiscriminator
# in_channels: 3
# base_channels: 64
# num_conv: 4
# need_intermediate_feature: True
discriminator:
_type: PatchDiscriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_conv: 4
need_intermediate_feature: True
loss:
gan:
loss_type: hinge
weight: 1.0
real_label_val: 1
fake_label_val: 0.0
perceptual:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 1
mgc:
weight: 5
fm:
weight: 1
edge:
weight: 0
hed_pretrained_model_path: ./network-bsds500.pytorch
optimizers:
generator:
_type: Adam
lr: 1e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 4e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 0
dataloader:
batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/face2cartoon/all_face"
root_b: "/data/selfie2anime/trainB/"
random_pair: True
pipeline_a:
- Load
- RandomCrop:
size: [ 178, 178 ]
- Resize:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 286, 286 ]
- RandomCrop:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: video_dataset
dataloader:
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/face2cartoon/test/human"
root_b: "/data/face2cartoon/test/anime"
random_pair: True
pipeline_a:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
pipeline_b:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
video_dataset:
_type: SingleFolderDataset
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
with_path: True
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@@ -1,24 +1,20 @@
name: VoxCeleb2Anime
engine: UGATIT
name: selfie2anime-vox2
engine: U-GAT-IT
result_dir: ./result
max_pairs: 1000000
distributed:
model:
# broadcast_buffers: False
misc:
random_seed: 324
checkpoint:
epoch_interval: 1 # one checkpoint every 1 epoch
n_saved: 2
interval:
print_per_iteration: 10 # print once per 10 iteration
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 10
image: 500
scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch
model:
generator:
@@ -27,7 +23,7 @@ model:
out_channels: 3
base_channels: 64
num_blocks: 4
img_size: 128
img_size: 256
light: True
local_discriminator:
_type: UGATIT-Discriminator
@@ -74,7 +70,7 @@ data:
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 20
batch_size: 6
shuffle: True
num_workers: 2
pin_memory: True
@@ -87,9 +83,9 @@ data:
pipeline:
- Load
- Resize:
size: [ 135, 135 ]
size: [ 286, 286 ]
- RandomCrop:
size: [ 128, 128 ]
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
@@ -97,7 +93,7 @@ data:
std: [ 0.5, 0.5, 0.5 ]
test:
dataloader:
batch_size: 8
batch_size: 4
shuffle: False
num_workers: 1
pin_memory: False
@@ -110,7 +106,7 @@ data:
pipeline:
- Load
- Resize:
size: [ 128, 128 ]
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
@@ -122,7 +118,7 @@ data:
pipeline:
- Load
- Resize:
size: [ 128, 128 ]
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]

View File

@@ -1,28 +1,28 @@
name: selfie2anime
engine: UGATIT
engine: U-GAT-IT
result_dir: ./result
max_pairs: 1000000
distributed:
model:
# broadcast_buffers: False
misc:
random_seed: 324
checkpoint:
epoch_interval: 1 # one checkpoint every 1 epoch
n_saved: 2
interval:
print_per_iteration: 10 # print once per 10 iteration
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 10
image: 500
scalar: 100 # log scalar `scalar` times per epoch
image: 4 # log image `image` times per epoch
test:
random: True
images: 10
model:
generator:
_type: UGATIT-Generator
_add_spectral_norm: True
in_channels: 3
out_channels: 3
base_channels: 64
@@ -31,11 +31,13 @@ model:
light: True
local_discriminator:
_type: UGATIT-Discriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_blocks: 5
global_discriminator:
_type: UGATIT-Discriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_blocks: 7
@@ -54,17 +56,19 @@ loss:
weight: 10.0
cam:
weight: 1000
mgc:
weight: 0
optimizers:
generator:
_type: Adam
lr: 0.0001
betas: [0.5, 0.999]
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 1e-4
betas: [0.5, 0.999]
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
data:
@@ -74,7 +78,7 @@ data:
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 4
batch_size: 1
shuffle: True
num_workers: 2
pin_memory: True
@@ -87,17 +91,18 @@ data:
pipeline:
- Load
- Resize:
size: [286, 286]
size: [ 286, 286 ]
- RandomCrop:
size: [256, 256]
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: video_dataset
dataloader:
batch_size: 8
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
@@ -110,11 +115,11 @@ data:
pipeline:
- Load
- Resize:
size: [256, 256]
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
video_dataset:
_type: SingleFolderDataset
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"

View File

@@ -1,103 +0,0 @@
name: horse2zebra
engine: cyclegan
result_dir: ./result
max_iteration: 16600
distributed:
model:
# broadcast_buffers: False
misc:
random_seed: 324
checkpoints:
interval: 2000
log:
logger:
level: 20 # DEBUG(10) INFO(20)
model:
generator:
_type: ResGenerator
in_channels: 3
out_channels: 3
base_channels: 64
num_blocks: 9
padding_mode: reflect
norm_type: IN
use_dropout: False
discriminator:
_type: PatchDiscriminator
# _distributed:
# bn_to_syncbn: False
in_channels: 3
base_channels: 64
num_conv: 3
norm_type: IN
loss:
gan:
loss_type: lsgan
weight: 1.0
real_label_val: 1.0
fake_label_val: 0.0
cycle:
level: 1
weight: 10.0
id:
level: 1
weight: 0
optimizers:
generator:
_type: Adam
lr: 2e-4
betas: [0.5, 0.999]
discriminator:
_type: Adam
lr: 2e-4
betas: [0.5, 0.999]
data:
train:
buffer_size: 50
dataloader:
batch_size: 16
shuffle: True
num_workers: 4
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/horse2zebra/trainA"
root_b: "/data/i2i/horse2zebra/trainB"
random_pair: True
pipeline:
- Load
- Resize:
size: [286, 286]
- RandomCrop:
size: [256, 256]
- RandomHorizontalFlip
- ToTensor
scheduler:
start: 8300
target_lr: 0
test:
dataloader:
batch_size: 4
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/horse2zebra/testA"
root_b: "/data/i2i/horse2zebra/testB"
random_pair: False
pipeline:
- Load
- Resize:
size: [256, 256]
- ToTensor

View File

@@ -0,0 +1,171 @@
name: talking_anime
engine: talking_anime
result_dir: ./result
max_pairs: 1000000
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 100 # log image `image` times per epoch
test:
random: True
images: 10
misc:
random_seed: 1004
loss:
gan:
loss_type: hinge
real_label_val: 1.0
fake_label_val: 0.0
weight: 1.0
fm:
level: 1
weight: 1
style:
layer_weights:
"3": 1
criterion: 'L1'
style_loss: True
perceptual_loss: False
weight: 10
perceptual:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 0
context:
layer_weights:
#"13": 1
"22": 1
weight: 5
recon:
level: 1
weight: 10
edge:
weight: 5
hed_pretrained_model_path: ./network-bsds500.pytorch
model:
face_generator:
_type: TAFG-SingleGenerator
_bn_to_sync_bn: False
style_in_channels: 3
content_in_channels: 1
use_spectral_norm: True
style_encoder_type: VGG19StyleEncoder
num_style_conv: 4
style_dim: 512
num_adain_blocks: 8
num_res_blocks: 8
anime_generator:
_type: TAFG-ResGenerator
_bn_to_sync_bn: False
in_channels: 6
use_spectral_norm: True
num_res_blocks: 8
discriminator:
_type: MultiScaleDiscriminator
num_scale: 2
discriminator_cfg:
_type: PatchDiscriminator
in_channels: 3
base_channels: 64
use_spectral: True
need_intermediate_feature: True
optimizers:
generator:
_type: Adam
lr: 0.0001
betas: [ 0, 0.9 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 4e-4
betas: [ 0, 0.9 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
dataloader:
batch_size: 8
shuffle: True
num_workers: 1
pin_memory: True
drop_last: True
dataset:
_type: PoseFacesWithSingleAnime
root_face: "/data/i2i/VoxCeleb2Anime/trainA"
root_anime: "/data/i2i/VoxCeleb2Anime/trainB"
landmark_path: "/data/i2i/VoxCeleb2Anime/landmarks"
num_face: 2
img_size: [ 128, 128 ]
with_order: False
face_pipeline:
- Load
- Resize:
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
anime_pipeline:
- Load
- Resize:
size: [ 144, 144 ]
- RandomCrop:
size: [ 128, 128 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: dataset
dataloader:
batch_size: 1
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: PoseFacesWithSingleAnime
root_face: "/data/i2i/VoxCeleb2Anime/testA"
root_anime: "/data/i2i/VoxCeleb2Anime/testB"
landmark_path: "/data/i2i/VoxCeleb2Anime/landmarks"
num_face: 2
img_size: [ 128, 128 ]
with_order: False
face_pipeline:
- Load
- Resize:
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
anime_pipeline:
- Load
- Resize:
size: [ 128, 128 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@@ -1,210 +0,0 @@
import os
import pickle
from pathlib import Path
from collections import defaultdict
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from torchvision.transforms import functional as F
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader
import lmdb
from tqdm import tqdm
from .transform import transform_pipeline
from .registry import DATASET
def default_transform_way(transform, sample):
return [transform(sample[0]), *sample[1:]]
class LMDBDataset(Dataset):
def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True,
**lmdb_kwargs):
self.path = lmdb_path
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
lock=False, **lmdb_kwargs)
with self.db.begin(write=False) as txn:
self._len = pickle.loads(txn.get(b"$$len$$"))
self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$"))
if pipeline is None:
self.not_done_pipeline = []
else:
self.not_done_pipeline = self._remain_pipeline(pipeline)
self.transform = transform_pipeline(self.not_done_pipeline)
self.transform_way = transform_way
essential_attr = pickle.loads(txn.get(b"$$essential_attr$$"))
for ea in essential_attr:
setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8"))))
def _remain_pipeline(self, pipeline):
for i, dp in enumerate(self.done_pipeline):
if pipeline[i] != dp:
raise ValueError(
f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}")
return pipeline[len(self.done_pipeline):]
def __repr__(self):
return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}"
def __len__(self):
return self._len
def __getitem__(self, idx):
with self.db.begin(write=False) as txn:
sample = pickle.loads(txn.get("{}".format(idx).encode()))
sample = self.transform_way(self.transform, sample)
return sample
@staticmethod
def lmdbify(dataset, done_pipeline, lmdb_path):
env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=0):
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
txn.put(b"$$len$$", pickle.dumps(len(dataset)))
txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline))
essential_attr = getattr(dataset, "essential_attr", list())
txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr))
for ea in essential_attr:
txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea)))
@DATASET.register_module()
class ImprovedImageFolder(ImageFolder):
def __init__(self, root, pipeline):
super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x)
self.classes_list = defaultdict(list)
self.essential_attr = ["classes_list"]
for i in range(len(self)):
self.classes_list[self.samples[i][-1]].append(i)
assert len(self.classes_list) == len(self.classes)
class EpisodicDataset(Dataset):
def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes):
self.origin = origin_dataset
self.num_class = num_class
assert self.num_class < len(self.origin.classes_list)
self.num_query = num_query # K
self.num_support = num_support # K
self.num_episodes = num_episodes
def _fetch_list_data(self, id_list):
return [self.origin[i][0] for i in id_list]
def __len__(self):
return self.num_episodes
def __getitem__(self, _):
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
support_set = []
query_set = []
target_set = []
for tag, c in enumerate(random_classes):
image_list = self.origin.classes_list[c]
if len(image_list) >= self.num_query + self.num_support:
# have enough images belong to this class
idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist()
else:
idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist()
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
support_set.extend(support)
query_set.extend(query)
target_set.extend([tag] * self.num_query)
return {
"support": torch.stack(support_set),
"query": torch.stack(query_set),
"target": torch.tensor(target_set)
}
def __repr__(self):
return f"<EpisodicDataset NE{self.num_episodes} NC{self.num_class} NS{self.num_support} NQ{self.num_query}>"
@DATASET.register_module()
class SingleFolderDataset(Dataset):
def __init__(self, root, pipeline, with_path=False):
assert os.path.isdir(root)
self.root = root
samples = []
for r, _, fns in sorted(os.walk(self.root, followlinks=True)):
for fn in sorted(fns):
path = os.path.join(r, fn)
if has_file_allowed_extension(path, IMG_EXTENSIONS):
samples.append(path)
self.samples = samples
self.pipeline = transform_pipeline(pipeline)
self.with_path = with_path
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
if not self.with_path:
return self.pipeline(self.samples[idx])
else:
return self.pipeline(self.samples[idx]), self.samples[idx]
def __repr__(self):
return f"<SingleFolderDataset root={self.root} len={len(self)}>"
@DATASET.register_module()
class GenerationUnpairedDataset(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline):
self.A = SingleFolderDataset(root_a, pipeline)
self.B = SingleFolderDataset(root_b, pipeline)
self.random_pair = random_pair
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
return dict(a=self.A[a_idx], b=self.B[b_idx])
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
@DATASET.register_module()
class GenerationUnpairedDatasetWithEdge(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, size=(256, 256)):
self.edge_type = edge_type
self.size = size
self.edges_path = Path(edges_path)
assert self.edges_path.exists()
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
self.random_pair = random_pair
def get_edge(self, origin_path):
op = Path(origin_path)
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{self.edge_type}.png"
img = Image.open(edge_path).resize(self.size)
return F.to_tensor(img)
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
output = dict()
output["a"], path_a = self.A[a_idx]
output["b"], path_b = self.B[b_idx]
output["edge_a"] = self.get_edge(path_a)
output["edge_b"] = self.get_edge(path_b)
return output
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"

3
data/dataset/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from util.misc import import_submodules
__all__ = import_submodules(__name__).keys()

63
data/dataset/few-shot.py Normal file
View File

@@ -0,0 +1,63 @@
from collections import defaultdict
import torch
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from data.registry import DATASET
from data.transform import transform_pipeline
@DATASET.register_module()
class ImprovedImageFolder(ImageFolder):
def __init__(self, root, pipeline):
super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x)
self.classes_list = defaultdict(list)
self.essential_attr = ["classes_list"]
for i in range(len(self)):
self.classes_list[self.samples[i][-1]].append(i)
assert len(self.classes_list) == len(self.classes)
class EpisodicDataset(Dataset):
def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes):
self.origin = origin_dataset
self.num_class = num_class
assert self.num_class < len(self.origin.classes_list)
self.num_query = num_query # K
self.num_support = num_support # K
self.num_episodes = num_episodes
def _fetch_list_data(self, id_list):
return [self.origin[i][0] for i in id_list]
def __len__(self):
return self.num_episodes
def __getitem__(self, _):
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
support_set = []
query_set = []
target_set = []
for tag, c in enumerate(random_classes):
image_list = self.origin.classes_list[c]
if len(image_list) >= self.num_query + self.num_support:
# have enough images belong to this class
idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist()
else:
idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist()
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
support_set.extend(support)
query_set.extend(query)
target_set.extend([tag] * self.num_query)
return {
"support": torch.stack(support_set),
"query": torch.stack(query_set),
"target": torch.tensor(target_set)
}
def __repr__(self):
return f"<EpisodicDataset NE{self.num_episodes} NC{self.num_class} NS{self.num_support} NQ{self.num_query}>"

View File

@@ -0,0 +1,62 @@
import os
import torch
from torch.utils.data import Dataset
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
from data.registry import DATASET
from data.transform import transform_pipeline
@DATASET.register_module()
class SingleFolderDataset(Dataset):
def __init__(self, root, pipeline, with_path=False):
assert os.path.isdir(root)
self.root = root
samples = []
for r, _, fns in sorted(os.walk(self.root, followlinks=True)):
for fn in sorted(fns):
path = os.path.join(r, fn)
if has_file_allowed_extension(path, IMG_EXTENSIONS):
samples.append(path)
self.samples = samples
self.pipeline = transform_pipeline(pipeline)
self.with_path = with_path
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
output = dict(img=self.pipeline(self.samples[idx]))
if self.with_path:
output["path"] = self.samples[idx]
return output
def __repr__(self):
return f"<SingleFolderDataset root={self.root} len={len(self)} with_path={self.with_path}>"
@DATASET.register_module()
class GenerationUnpairedDataset(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline_a, pipeline_b, with_path=False):
self.A = SingleFolderDataset(root_a, pipeline_a, with_path)
self.B = SingleFolderDataset(root_b, pipeline_b, with_path)
self.with_path = with_path
self.random_pair = random_pair
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
output_a = self.A[a_idx]
output_b = self.B[b_idx]
output = dict(a=output_a["img"], b=output_b["img"])
if self.with_path:
output["a_path"] = output_a["path"]
output["b_path"] = output_b["path"]
return output
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"

65
data/dataset/lmdb.py Normal file
View File

@@ -0,0 +1,65 @@
import os
import pickle
import lmdb
from torch.utils.data import Dataset
from tqdm import tqdm
from data.transform import transform_pipeline
def default_transform_way(transform, sample):
return [transform(sample[0]), *sample[1:]]
class LMDBDataset(Dataset):
def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True,
**lmdb_kwargs):
self.path = lmdb_path
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
lock=False, **lmdb_kwargs)
with self.db.begin(write=False) as txn:
self._len = pickle.loads(txn.get(b"$$len$$"))
self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$"))
if pipeline is None:
self.not_done_pipeline = []
else:
self.not_done_pipeline = self._remain_pipeline(pipeline)
self.transform = transform_pipeline(self.not_done_pipeline)
self.transform_way = transform_way
essential_attr = pickle.loads(txn.get(b"$$essential_attr$$"))
for ea in essential_attr:
setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8"))))
def _remain_pipeline(self, pipeline):
for i, dp in enumerate(self.done_pipeline):
if pipeline[i] != dp:
raise ValueError(
f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}")
return pipeline[len(self.done_pipeline):]
def __repr__(self):
return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}"
def __len__(self):
return self._len
def __getitem__(self, idx):
with self.db.begin(write=False) as txn:
sample = pickle.loads(txn.get("{}".format(idx).encode()))
sample = self.transform_way(self.transform, sample)
return sample
@staticmethod
def lmdbify(dataset, done_pipeline, lmdb_path):
env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=0):
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
txn.put(b"$$len$$", pickle.dumps(len(dataset)))
txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline))
essential_attr = getattr(dataset, "essential_attr", list())
txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr))
for ea in essential_attr:
txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea)))

View File

@@ -0,0 +1,122 @@
from collections import defaultdict
from itertools import permutations, combinations
from pathlib import Path
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import functional as F
from data.registry import DATASET
from data.transform import transform_pipeline
from data.util import dlib_landmark
def normalize_tensor(tensor):
tensor = tensor.float()
tensor -= tensor.min()
tensor /= tensor.max()
return tensor
@DATASET.register_module()
class GenerationUnpairedDatasetWithEdge(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256),
with_path=False):
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
self.edge_type = edge_type
self.size = size
self.edges_path = Path(edges_path)
self.landmarks_path = Path(landmarks_path)
assert self.edges_path.exists()
assert self.landmarks_path.exists()
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
self.random_pair = random_pair
self.with_path = with_path
def get_edge(self, origin_path):
op = Path(origin_path)
if self.edge_type.startswith("landmark_"):
edge_type = self.edge_type.lstrip("landmark_")
use_landmark = op.parent.name.endswith("A")
else:
edge_type = self.edge_type
use_landmark = False
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR))
if not use_landmark:
return origin_edge
else:
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt"
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size)))
part_labels = normalize_tensor(torch.from_numpy(part_labels))
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
# edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
# edges = part_edge + edges
return torch.cat([origin_edge, part_edge, dist_tensor, part_labels])
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
output = dict(a={}, b={})
output["a"]["img"], output["a"]["path"] = self.A[a_idx]
output["b"]["img"], output["b"]["path"] = self.B[b_idx]
for p in "ab":
output[p]["edge"] = self.get_edge(output[p]["path"])
return output
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
@DATASET.register_module()
class PoseFacesWithSingleAnime(Dataset):
def __init__(self, root_face, root_anime, landmark_path, num_face, face_pipeline, anime_pipeline, img_size,
with_order=True):
self.num_face = num_face
self.landmark_path = Path(landmark_path)
self.with_order = with_order
self.root_face = Path(root_face)
self.root_anime = Path(root_anime)
self.img_size = img_size
self.face_samples = self.iter_folders()
self.face_pipeline = transform_pipeline(face_pipeline)
self.B = SingleFolderDataset(root_anime, anime_pipeline, with_path=True)
def iter_folders(self):
pics_per_person = defaultdict(list)
for p in self.root_face.glob("*.jpg"):
pics_per_person[p.stem[:7]].append(p.stem)
data = []
for p in pics_per_person:
if len(pics_per_person[p]) >= self.num_face:
if self.with_order:
data.extend(list(combinations(pics_per_person[p], self.num_face)))
else:
data.extend(list(permutations(pics_per_person[p], self.num_face)))
return data
def read_pose(self, pose_txt):
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(pose_txt, size=self.img_size)
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.img_size)))
part_labels = normalize_tensor(torch.from_numpy(part_labels))
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
return torch.cat([part_labels, part_edge, dist_tensor])
def __len__(self):
return len(self.face_samples)
def __getitem__(self, idx):
output = dict()
output["anime_img"], output["anime_path"] = self.B[torch.randint(len(self.B), (1,)).item()]
for i, f in enumerate(self.face_samples[idx]):
output[f"face_{i}"] = self.face_pipeline(self.root_face / f"{f}.jpg")
output[f"pose_{i}"] = self.read_pose(self.landmark_path / self.root_face.name / f"{f}.txt")
output[f"stem_{i}"] = f
return output

View File

@@ -28,7 +28,7 @@ class Load:
def transform_pipeline(pipeline_description):
if len(pipeline_description) == 0:
if pipeline_description is None or len(pipeline_description) == 0:
return lambda x: x
transform_list = [TRANSFORM.build_with(pd) for pd in pipeline_description]
return transforms.Compose(transform_list)

105
engine/CycleGAN.py Normal file
View File

@@ -0,0 +1,105 @@
from itertools import chain
import ignite.distributed as idist
import torch
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from engine.util.container import GANImageBuffer, LossContainer
from engine.util.loss import pixel_loss, gan_loss, feature_match_loss
from loss.I2I.edge_loss import EdgeLoss
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
from model.weight_init import generation_init_weights
class CycleGANEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
self.gan_loss = gan_loss(config.loss.gan)
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same"))
self.edge_loss = LossContainer(config.loss.edge.weight, EdgeLoss(
"HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(idist.device()))
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
def build_models(self) -> (dict, dict):
generators = dict(
a2b=build_model(self.config.model.generator),
b2a=build_model(self.config.model.generator)
)
discriminators = dict(
a=build_model(self.config.model.discriminator),
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["a"])
self.logger.debug(generators["a2b"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
images = dict()
with torch.set_grad_enabled(not inference):
images["a2b"] = self.generators["a2b"](batch["a"])
images["b2a"] = self.generators["b2a"](batch["b"])
images["a2b2a"] = self.generators["b2a"](images["a2b"])
images["b2a2b"] = self.generators["a2b"](images["b2a"])
if self.id_loss.weight > 0:
images["a2a"] = self.generators["b2a"](batch["a"])
images["b2b"] = self.generators["a2b"](batch["b"])
return images
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
for ph in "ab":
loss[f"cycle_{ph}"] = self.cycle_loss(generated["a2b2a" if ph == "a" else "b2a2b"], batch[ph])
loss[f"id_{ph}"] = self.id_loss(generated[f"{ph}2{ph}"], batch[ph])
loss[f"mgc_{ph}"] = self.mgc_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph])
prediction_fake = self.discriminators[ph](generated["a2b" if ph == "b" else "b2a"])
loss[f"gan_{ph}"] = self.config.loss.gan.weight * self.gan_loss(prediction_fake, True)
if self.fm_loss.weight > 0:
prediction_real = self.discriminators[ph](batch[ph])
loss[f"feature_match_{ph}"] = self.fm_loss(prediction_fake, prediction_real)
loss[f"edge_{ph}"] = self.edge_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph], gt_is_edge=False)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
for phase in "ab":
generated_image = self.image_buffers[phase].query(generated["b2a" if phase == "a" else "a2b"].detach())
loss[f"gan_{phase}"] = (self.gan_loss(self.discriminators[phase](generated_image), False,
is_discriminator=True) +
self.gan_loss(self.discriminators[phase](batch[phase]), True,
is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
return dict(
a=[batch["a"].detach(), generated["a2b"].detach(), generated["a2b2a"].detach()],
b=[batch["b"].detach(), generated["b2a"].detach(), generated["b2a2b"].detach()],
)
def run(task, config, _):
kernel = CycleGANEngineKernel(config)
run_kernel(task, config, kernel)

86
engine/GauGAN.py Normal file
View File

@@ -0,0 +1,86 @@
from itertools import chain
import torch
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from engine.util.container import GANImageBuffer, LossContainer
from engine.util.loss import gan_loss, feature_match_loss, perceptual_loss
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
from model.weight_init import generation_init_weights
class GauGANEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
self.gan_loss = gan_loss(config.loss.gan)
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same"))
self.perceptual_loss = LossContainer(config.loss.perceptual.weight, perceptual_loss(config.loss.perceptual))
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
def build_models(self) -> (dict, dict):
generators = dict(
main=build_model(self.config.model.generator)
)
discriminators = dict(
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["b"])
self.logger.debug(generators["main"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
images = dict()
with torch.set_grad_enabled(not inference):
images["a2b"] = self.generators["main"](batch["a"])
return images
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
prediction_fake = self.discriminators["b"](generated["a2b"])
loss["gan"] = self.config.loss.gan.weight * self.gan_loss(prediction_fake, True)
loss["mgc"] = self.mgc_loss(generated["a2b"], batch["a"])
loss["perceptual"] = self.perceptual_loss(generated["a2b"], batch["a"])
if self.fm_loss.weight > 0:
prediction_real = self.discriminators["b"](batch["b"])
loss["feature_match"] = self.fm_loss(prediction_fake, prediction_real)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
generated_image = self.image_buffers["b"].query(generated["a2b"].detach())
loss["b"] = (self.gan_loss(self.discriminators["b"](generated_image), False, is_discriminator=True) +
self.gan_loss(self.discriminators["b"](batch["b"]), True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
return dict(
a=[batch["a"].detach(), generated["a2b"].detach()],
)
def run(task, config, _):
kernel = GauGANEngineKernel(config)
run_kernel(task, config, kernel)

154
engine/MUNIT.py Normal file
View File

@@ -0,0 +1,154 @@
import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
class MUNITEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.train_generator_first = False
def build_models(self) -> (dict, dict):
generators = dict(
a=build_model(self.config.model.generator),
b=build_model(self.config.model.generator)
)
discriminators = dict(
a=build_model(self.config.model.discriminator),
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["a"])
self.logger.debug(generators["a"])
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
styles = dict()
contents = dict()
images = dict()
with torch.set_grad_enabled(not inference):
for phase in "ab":
contents[phase], styles[phase] = self.generators[phase].encode(batch[phase])
images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase])
styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device())
for phase in ("a2b", "b2a"):
# images["a2b"] = Gb.decode(content_a, random_style_b)
images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"])
# contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"])
contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase])
if self.config.loss.recon.cycle.weight > 0:
images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]])
return dict(styles=styles, contents=contents, images=images)
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
for phase in "ab":
loss[f"recon_image_{phase}"] = self.config.loss.recon.image.weight * self.recon_loss(
batch[phase], generated["images"]["{0}2{0}".format(phase)])
loss[f"recon_content_{phase}"] = self.config.loss.recon.content.weight * self.recon_loss(
generated["contents"][phase], generated["contents"]["a2b" if phase == "a" else "b2a"])
loss[f"recon_style_{phase}"] = self.config.loss.recon.style.weight * self.recon_loss(
generated["styles"][f"random_{phase}"], generated["styles"]["b2a" if phase == "a" else "a2b"])
pred_fake = self.discriminators[phase](generated["images"]["b2a" if phase == "a" else "a2b"])
loss[f"gan_{phase}"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True)
if self.config.loss.recon.cycle.weight > 0:
loss[f"recon_cycle_{phase}"] = self.config.loss.recon.cycle.weight * self.recon_loss(
batch[phase], generated["images"]["a2b2a" if phase == "a" else "b2a2b"])
if self.config.loss.perceptual.weight > 0:
loss[f"perceptual_{phase}"] = self.config.loss.perceptual.weight * self.perceptual_loss(
batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
for phase in ("a2b", "b2a"):
pred_real = self.discriminators[phase[-1]](batch[phase[-1]])
pred_fake = self.discriminators[phase[-1]](generated["images"][phase].detach())
loss[f"gan_{phase[-1]}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase[-1]}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
generated = {img: generated["images"][img].detach() for img in generated["images"]}
images = dict()
for phase in "ab":
images[phase] = [batch[phase].detach(), generated["{0}2{0}".format(phase)],
generated["a2b" if phase == "a" else "b2a"]]
if self.config.loss.recon.cycle.weight > 0:
images[phase].append(generated["a2b2a" if phase == "a" else "b2a2b"])
return images
class MUNITTestEngineKernel(TestEngineKernel):
def __init__(self, config):
super().__init__(config)
def build_generators(self) -> dict:
generators = dict(
a=build_model(self.config.model.generator),
b=build_model(self.config.model.generator)
)
return generators
def to_load(self):
return {f"generator_{k}": self.generators[k] for k in self.generators}
def inference(self, batch):
with torch.no_grad():
fake, _, _ = self.generators["a2b"](batch[0])
return fake.detach()
def run(task, config, _):
if task == "train":
kernel = MUNITEngineKernel(config)
run_kernel(task, config, kernel)
elif task == "test":
kernel = MUNITTestEngineKernel(config)
run_kernel(task, config, kernel)
else:
raise NotImplemented

212
engine/TAFG.py Normal file
View File

@@ -0,0 +1,212 @@
from itertools import chain
import ignite.distributed as idist
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from loss.I2I.edge_loss import EdgeLoss
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
class TAFGEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
style_loss_cfg = OmegaConf.to_container(config.loss.style)
style_loss_cfg.pop("weight")
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
self.content_recon_loss = nn.L1Loss() if config.loss.content_recon.level == 1 else nn.MSELoss()
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
idist.device())
def _process_batch(self, batch, inference=False):
# batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size())
return batch
def build_models(self) -> (dict, dict):
generators = dict(
main=build_model(self.config.model.generator)
)
discriminators = dict(
a=build_model(self.config.model.discriminator),
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["a"])
self.logger.debug(generators["main"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
generator = self.generators["main"]
batch = self._process_batch(batch, inference)
styles = dict()
contents = dict()
images = dict()
with torch.set_grad_enabled(not inference):
contents["a"], styles["a"] = generator.encode(batch["a"]["edge"], batch["a"]["img"], "a", "a")
contents["b"], styles["b"] = generator.encode(batch["b"]["edge"], batch["b"]["img"], "b", "b")
for ph in "ab":
images[f"{ph}2{ph}"] = generator.decode(contents[ph], styles[ph], ph)
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
styles[f"random_b"] = torch.randn_like(styles["b"]).to(idist.device())
images["a2b"] = generator.decode(contents["a"], styles["random_b"], "b")
contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]),
images["a2b"], "b", "b")
images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b")
images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a")
return dict(styles=styles, contents=contents, images=images)
def criterion_generators(self, batch, generated) -> dict:
batch = self._process_batch(batch)
loss = dict()
for ph in "ab":
loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss(
generated["images"][f"{ph}2{ph}"], batch[ph]["img"])
pred_fake = self.discriminators[ph](generated["images"][f"{ph}2{ph}"])
loss[f"gan_{ph}"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
if self.engine.state.epoch == self.config.misc.add_new_loss_epoch:
self.generators["main"].style_converters.requires_grad_(False)
self.generators["main"].style_encoders.requires_grad_(False)
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
pred_fake = self.discriminators[ph](generated["images"]["a2b"])
loss["gan_a2b"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss["gan_a2b"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
generated["contents"]["a"], generated["contents"]["recon_a"]
)
loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
generated["styles"]["random_b"], generated["styles"]["recon_b"]
)
if self.config.loss.perceptual.weight > 0:
loss["perceptual_a"] = self.config.loss.perceptual.weight * self.perceptual_loss(
batch["a"]["img"], generated["images"]["a2b"]
)
if self.config.loss.cycle.weight > 0:
loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss(
batch["a"]["img"], generated["images"][f"cycle_a"]
)
# for ph in "ab":
#
# if self.config.loss.style.weight > 0:
# loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss(
# batch[ph]["img"], generated["images"][f"a2{ph}"]
# )
if self.config.loss.edge.weight > 0:
loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss(
generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :]
)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase]["img"])
pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach())
pred_fake_2 = self.discriminators[phase](generated["images"]["a2b"].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) +
self.gan_loss(pred_fake_2[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 3
else:
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase]["img"])
pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
batch = self._process_batch(batch)
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
return dict(
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
batch["a"]["img"].detach(),
generated["images"]["a2a"].detach(),
generated["images"]["a2b"].detach(),
generated["images"]["cycle_a"].detach(),
],
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
batch["b"]["img"].detach(),
generated["images"]["b2b"].detach(),
generated["images"]["cycle_b"].detach()]
)
else:
return dict(
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
batch["a"]["img"].detach(),
generated["images"]["a2a"].detach(),
],
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
batch["b"]["img"].detach(),
generated["images"]["b2b"].detach(),
]
)
def change_engine(self, config, trainer):
pass
# @trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3)))
# def change_config(engine):
# with read_write(config):
# config.loss.perceptual.weight = 5
def run(task, config, _):
kernel = TAFGEngineKernel(config)
run_kernel(task, config, kernel)

View File

@@ -1,245 +0,0 @@
from itertools import chain
from math import ceil
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.metrics import RunningAverage
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from omegaconf import OmegaConf, read_write
import data
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
from model.GAN.residual_generator import GANImageBuffer
from loss.I2I.edge_loss import EdgeLoss
from loss.I2I.perceptual_loss import PerceptualLoss
from util.image import make_2d_grid
from util.handler import setup_common_handlers, setup_tensorboard_handler
from util.build import build_model, build_optimizer
def build_lr_schedulers(optimizers, config):
g_milestones_values = [
(0, config.optimizers.generator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
d_milestones_values = [
(0, config.optimizers.discriminator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
return dict(
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
)
def get_trainer(config, logger, train_data_loader):
generator = build_model(config.model.generator, config.distributed.model)
discriminators = dict(
a=build_model(config.model.discriminator, config.distributed.model),
b=build_model(config.model.discriminator, config.distributed.model),
)
generation_init_weights(generator)
for m in discriminators.values():
generation_init_weights(m)
logger.debug(discriminators["a"])
logger.debug(generator)
optimizers = dict(
g=build_optimizer(generator.parameters(), config.optimizers.generator),
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
)
logger.info(f"build optimizers:\n{optimizers}")
lr_schedulers = build_lr_schedulers(optimizers, config)
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
edge_loss_cfg = OmegaConf.to_container(config.loss.edge)
edge_loss_cfg.pop("weight")
edge_loss = EdgeLoss(**edge_loss_cfg).to(idist.device())
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()}
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real = dict(a=batch["a"], b=batch["b"])
fake = dict(
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
)
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
optimizers["g"].zero_grad()
loss_g = dict()
for d in "ab":
discriminators[d].requires_grad_(False)
pred_fake = discriminators[d](fake[d])
loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True)
_, t = perceptual_loss(fake[d], real[d])
loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t
loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], batch["edge_a"])
loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"])
loss_g["recon_b"] = config.loss.recon.weight * recon_loss(rec_b, real["b"])
loss_g["recon_bb"] = config.loss.recon.weight * recon_loss(rec_bb, real["b"])
sum(loss_g.values()).backward()
optimizers["g"].step()
for discriminator in discriminators.values():
discriminator.requires_grad_(True)
optimizers["d"].zero_grad()
loss_d = dict()
for k in discriminators.keys():
pred_real = discriminators[k](real[k])
pred_fake = discriminators[k](image_buffers[k].query(fake[k].detach()))
loss_d[f"gan_{k}"] = (gan_loss(pred_real, True, is_discriminator=True) +
gan_loss(pred_fake, False, is_discriminator=True)) / 2
sum(loss_d.values()).backward()
optimizers["d"].step()
generated_img = {f"real_{k}": real[k].detach() for k in real}
generated_img["rec_b"] = rec_b.detach()
generated_img["rec_bb"] = rec_b.detach()
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
generated_img.update({f"edge_{k}": batch[f"edge_{k}"].expand(-1, 3, -1, -1).detach() for k in "ab"})
return {
"loss": {
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
"d": {ln: loss_d[ln].mean().item() for ln in loss_d},
},
"img": generated_img
}
trainer = Engine(_step)
trainer.logger = logger
for lr_shd in lr_schedulers.values():
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
to_save = dict(trainer=trainer)
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
to_save.update({"generator": generator})
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
iter_per_epoch = len(train_data_loader)
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
if tensorboard_handler is not None:
tensorboard_handler.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
)
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
def show_images(engine):
output = engine.state.output
image_order = dict(
a=["edge_a", "real_a", "fake_a", "fake_b"],
b=["edge_b", "real_b", "rec_b", "rec_bb"]
)
for k in "ab":
tensorboard_handler.writer.add_image(
f"train/{k}",
make_2d_grid([output["img"][o] for o in image_order[k]]),
engine.state.iteration
)
with torch.no_grad():
g = torch.Generator()
g.manual_seed(config.misc.random_seed)
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
test_images = dict(
a=[[], [], [], []],
b=[[], [], [], []]
)
for i in range(random_start, random_start + 10):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch:
batch[k] = batch[k].view(1, *batch[k].size())
real = dict(a=batch["a"], b=batch["b"])
fake = dict(
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
)
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
test_images["a"][0].append(batch["edge_a"])
test_images["a"][1].append(batch["a"])
test_images["a"][2].append(fake["a"])
test_images["a"][3].append(fake["b"])
test_images["b"][0].append(batch["edge_b"])
test_images["b"][1].append(batch["b"])
test_images["b"][2].append(rec_b)
test_images["b"][3].append(rec_bb)
for n in "ab":
tensorboard_handler.writer.add_image(
f"test/{n}",
make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
engine.state.iteration
)
return trainer
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
with read_write(config):
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger, train_data_loader)
if idist.get_rank() == 0:
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset
try:
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")

119
engine/TSIT.py Normal file
View File

@@ -0,0 +1,119 @@
from itertools import chain
import ignite.distributed as idist
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
class TSITEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
def build_models(self) -> (dict, dict):
generators = dict(
main=build_model(self.config.model.generator)
)
discriminators = dict(
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["b"])
self.logger.debug(generators["main"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
with torch.set_grad_enabled(not inference):
fake = dict(
b=self.generators["main"](content_img=batch["a"])
)
return fake
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
loss["perceptual"] = self.perceptual_loss(generated["b"], batch["a"]) * self.config.loss.perceptual.weight
for phase in "b":
pred_fake = self.discriminators[phase](generated[phase])
loss[f"gan_{phase}"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss[f"gan_{phase}"] += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase])
pred_fake = self.discriminators[phase](generated[phase].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
return dict(
b=[batch["a"].detach(), batch["b"].detach(), generated["b"].detach()]
)
class TSITTestEngineKernel(TestEngineKernel):
def __init__(self, config):
super().__init__(config)
def build_generators(self) -> dict:
generators = dict(
main=build_model(self.config.model.generator)
)
return generators
def to_load(self):
return {f"generator_{k}": self.generators[k] for k in self.generators}
def inference(self, batch):
with torch.no_grad():
fake = self.generators["main"](content_img=batch["a"][0], style_img=batch["b"][0])
return {"a": fake.detach()}
def run(task, config, _):
if task == "train":
kernel = TSITEngineKernel(config)
run_kernel(task, config, kernel)
elif task == "test":
kernel = TSITTestEngineKernel(config)
run_kernel(task, config, kernel)
else:
raise NotImplemented

150
engine/U-GAT-IT.py Normal file
View File

@@ -0,0 +1,150 @@
import torch
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
from engine.util.container import LossContainer
from engine.util.loss import bce_loss, mse_loss, pixel_loss, gan_loss
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
from util.image import attention_colored_map
class RhoClipper(object):
def __init__(self, clip_min, clip_max):
self.clip_min = clip_min
self.clip_max = clip_max
assert clip_min < clip_max
def __call__(self, module):
if hasattr(module, 'rho'):
w = module.rho.data
w = w.clamp(self.clip_min, self.clip_max)
module.rho.data = w
class UGATITEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
self.gan_loss = gan_loss(config.loss.gan)
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
self.bce_loss = LossContainer(self.config.loss.cam.weight, bce_loss)
self.mse_loss = LossContainer(self.config.loss.gan.weight, mse_loss)
self.rho_clipper = RhoClipper(0, 1)
self.train_generator_first = False
def build_models(self) -> (dict, dict):
generators = dict(
a2b=build_model(self.config.model.generator),
b2a=build_model(self.config.model.generator)
)
discriminators = dict(
la=build_model(self.config.model.local_discriminator),
lb=build_model(self.config.model.local_discriminator),
ga=build_model(self.config.model.global_discriminator),
gb=build_model(self.config.model.global_discriminator),
)
self.logger.debug(discriminators["ga"])
self.logger.debug(generators["a2b"])
return generators, discriminators
def setup_after_g(self):
for generator in self.generators.values():
generator.apply(self.rho_clipper)
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
images = dict()
heatmap = dict()
cam_pred = dict()
with torch.set_grad_enabled(not inference):
images["a2b"], cam_pred["a2b"], heatmap["a2b"] = self.generators["a2b"](batch["a"])
images["b2a"], cam_pred["b2a"], heatmap["b2a"] = self.generators["b2a"](batch["b"])
images["a2b2a"], _, heatmap["a2b2a"] = self.generators["b2a"](images["a2b"])
images["b2a2b"], _, heatmap["b2a2b"] = self.generators["a2b"](images["b2a"])
images["a2a"], cam_pred["a2a"], heatmap["a2a"] = self.generators["b2a"](batch["a"])
images["b2b"], cam_pred["b2b"], heatmap["b2b"] = self.generators["a2b"](batch["b"])
return dict(images=images, heatmap=heatmap, cam_pred=cam_pred)
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
for phase in "ab":
cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"]
loss[f"cycle_{phase}"] = self.cycle_loss(cycle_image, batch[phase])
loss[f"id_{phase}"] = self.id_loss(batch[phase], generated["images"][f"{phase}2{phase}"])
loss[f"mgc_{phase}"] = self.mgc_loss(batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
for dk in "lg":
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)
loss[f"gan_{phase}_{dk}"] = self.config.loss.gan.weight * self.gan_loss(pred_fake, True)
loss[f"gan_cam_{phase}_{dk}"] = self.mse_loss(cam_pred, True)
for t, f in [("a2b", "b2b"), ("b2a", "a2a")]:
loss[f"cam_{t[-1]}"] = self.bce_loss(generated["cam_pred"][t], True) + \
self.bce_loss(generated["cam_pred"][f], False)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
for phase in "ab":
for level in "gl":
generated_image = generated["images"]["b2a" if phase == "a" else "a2b"].detach()
pred_fake, cam_fake_pred = self.discriminators[level + phase](generated_image)
pred_real, cam_real_pred = self.discriminators[level + phase](batch[phase])
loss[f"gan_{phase}_{level}"] = self.gan_loss(pred_real, True, is_discriminator=True) + self.gan_loss(
pred_fake, False, is_discriminator=True)
loss[f"cam_{phase}_{level}"] = mse_loss(cam_fake_pred, False) + mse_loss(cam_real_pred, True)
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
attention_a = attention_colored_map(generated["heatmap"]["a2b"].detach(), batch["a"].size()[-2:])
attention_b = attention_colored_map(generated["heatmap"]["b2a"].detach(), batch["b"].size()[-2:])
generated = {img: generated["images"][img].detach() for img in generated["images"]}
return {
"a": [batch["a"], attention_a, generated["a2b"], generated["a2a"], generated["a2b2a"]],
"b": [batch["b"], attention_b, generated["b2a"], generated["b2b"], generated["b2a2b"]],
}
class UGATITTestEngineKernel(TestEngineKernel):
def __init__(self, config):
super().__init__(config)
def build_generators(self) -> dict:
generators = dict(
a2b=build_model(self.config.model.generator),
)
return generators
def to_load(self):
return {f"generator_{k}": self.generators[k] for k in self.generators}
def inference(self, batch):
with torch.no_grad():
fake, _, _ = self.generators["a2b"](batch[0])
return fake.detach()
def run(task, config, _):
if task == "train":
kernel = UGATITEngineKernel(config)
run_kernel(task, config, kernel)
elif task == "test":
kernel = UGATITTestEngineKernel(config)
run_kernel(task, config, kernel)
else:
raise NotImplemented

View File

@@ -1,320 +0,0 @@
from itertools import chain
from math import ceil
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.metrics import RunningAverage
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from omegaconf import OmegaConf, read_write
import data
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
from model.GAN.residual_generator import GANImageBuffer
from model.GAN.UGATIT import RhoClipper
from util.image import make_2d_grid, fuse_attention_map, attention_colored_map
from util.handler import setup_common_handlers, setup_tensorboard_handler
from util.build import build_model, build_optimizer
def build_lr_schedulers(optimizers, config):
g_milestones_values = [
(0, config.optimizers.generator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
d_milestones_values = [
(0, config.optimizers.discriminator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
return dict(
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
)
def get_trainer(config, logger):
generators = dict(
a2b=build_model(config.model.generator, config.distributed.model),
b2a=build_model(config.model.generator, config.distributed.model),
)
discriminators = dict(
la=build_model(config.model.local_discriminator, config.distributed.model),
lb=build_model(config.model.local_discriminator, config.distributed.model),
ga=build_model(config.model.global_discriminator, config.distributed.model),
gb=build_model(config.model.global_discriminator, config.distributed.model),
)
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
logger.debug(discriminators["ga"])
logger.debug(generators["a2b"])
optimizers = dict(
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
)
logger.info(f"build optimizers:\n{optimizers}")
lr_schedulers = build_lr_schedulers(optimizers, config)
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()}
rho_clipper = RhoClipper(0, 1)
def criterion_generator(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l,
discriminator_g):
discriminator_g.requires_grad_(False)
discriminator_l.requires_grad_(False)
pred_fake_g, cam_gd_pred = discriminator_g(fake)
pred_fake_l, cam_ld_pred = discriminator_l(fake)
return {
f"cycle_{name}": config.loss.cycle.weight * cycle_loss(real, rec),
f"id_{name}": config.loss.id.weight * id_loss(real, identity),
f"cam_{name}": config.loss.cam.weight * (bce_loss(cam_g_pred, True) + bce_loss(cam_g_id_pred, False)),
f"gan_l_{name}": config.loss.gan.weight * gan_loss(pred_fake_l, True),
f"gan_g_{name}": config.loss.gan.weight * gan_loss(pred_fake_g, True),
f"gan_cam_g_{name}": config.loss.gan.weight * mse_loss(cam_gd_pred, True),
f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True),
}
def criterion_discriminator(name, discriminator, real, fake):
pred_real, cam_real = discriminator(real)
pred_fake, cam_fake = discriminator(fake)
# TODO: origin do not divide 2, but I think it better to divide 2.
loss_gan = gan_loss(pred_real, True, is_discriminator=True) + gan_loss(pred_fake, False, is_discriminator=True)
loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False)
return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam}
def _step(engine, real):
real = convert_tensor(real, idist.device())
fake = dict()
cam_generator_pred = dict()
rec = dict()
identity = dict()
cam_identity_pred = dict()
heatmap = dict()
fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real["a"])
fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real["b"])
rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"])
rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"])
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"])
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"])
optimizers["g"].zero_grad()
loss_g = dict()
for n in ["a", "b"]:
loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
sum(loss_g.values()).backward()
optimizers["g"].step()
for generator in generators.values():
generator.apply(rho_clipper)
for discriminator in discriminators.values():
discriminator.requires_grad_(True)
optimizers["d"].zero_grad()
loss_d = dict()
for k in discriminators.keys():
n = k[-1] # "a" or "b"
loss_d.update(
criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach())))
sum(loss_d.values()).backward()
optimizers["d"].step()
for h in heatmap:
heatmap[h] = heatmap[h].detach()
generated_img = {f"real_{k}": real[k].detach() for k in real}
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
generated_img.update({f"id_{k}": identity[k].detach() for k in identity})
generated_img.update({f"rec_{k}": rec[k].detach() for k in rec})
return {
"loss": {
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
"d": {ln: loss_d[ln].mean().item() for ln in loss_d},
},
"img": {
"heatmap": heatmap,
"generated": generated_img
}
}
trainer = Engine(_step)
trainer.logger = logger
for lr_shd in lr_schedulers.values():
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
to_save = dict(trainer=trainer)
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
to_save.update({f"generator_{k}": generators[k] for k in generators})
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform)
if tensorboard_handler is not None:
tensorboard_handler.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
)
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
def show_images(engine):
output = engine.state.output
image_order = dict(
a=["real_a", "fake_b", "rec_a", "id_a"],
b=["real_b", "fake_a", "rec_b", "id_b"]
)
output["img"]["generated"]["real_a"] = fuse_attention_map(
output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"])
output["img"]["generated"]["real_b"] = fuse_attention_map(
output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"])
for k in "ab":
tensorboard_handler.writer.add_image(
f"train/{k}",
make_2d_grid([output["img"]["generated"][o] for o in image_order[k]]),
engine.state.iteration
)
with torch.no_grad():
g = torch.Generator()
g.manual_seed(config.misc.random_seed)
random_start = torch.randperm(len(engine.state.test_dataset)-11, generator=g).tolist()[0]
test_images = dict(
a=[[], [], [], []],
b=[[], [], [], []]
)
for i in range(random_start, random_start+10):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
real_a, real_b = batch["a"].view(1, *batch["a"].size()), batch["b"].view(1, *batch["a"].size())
fake_b, _, heatmap_a2b = generators["a2b"](real_a)
fake_a, _, heatmap_b2a = generators["b2a"](real_b)
rec_a = generators["b2a"](fake_b)[0]
rec_b = generators["a2b"](fake_a)[0]
for idx, im in enumerate(
[attention_colored_map(heatmap_a2b, real_a.size()[-2:]), real_a, fake_b, rec_a]):
test_images["a"][idx].append(im.cpu())
for idx, im in enumerate(
[attention_colored_map(heatmap_b2a, real_b.size()[-2:]), real_b, fake_a, rec_b]):
test_images["b"][idx].append(im.cpu())
for n in "ab":
tensorboard_handler.writer.add_image(
f"test/{n}",
make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
engine.state.iteration
)
return trainer
def get_tester(config, logger):
generator_a2b = build_model(config.model.generator, config.distributed.model)
def _step(engine, batch):
real_a, path = convert_tensor(batch, idist.device())
with torch.no_grad():
fake_b = generator_a2b(real_a)[0]
return {"path": path, "img": [real_a.detach(), fake_b.detach()]}
tester = Engine(_step)
tester.logger = logger
to_load = dict(generator_a2b=generator_a2b)
setup_common_handlers(tester, config, use_profiler=False, to_save=to_load)
@tester.on(Events.STARTED)
def mkdir(engine):
img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images"
engine.state.img_output_dir = Path(img_output_dir)
if idist.get_rank() == 0 and not engine.state.img_output_dir.exists():
engine.logger.info(f"mkdir {engine.state.img_output_dir}")
engine.state.img_output_dir.mkdir()
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
img_tensors = engine.state.output["img"]
paths = engine.state.output["path"]
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
image_name = Path(paths[i]).name
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
nrow=len(img_tensors))
return tester
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
with read_write(config):
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger)
if idist.get_rank() == 0:
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset
try:
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
except Exception:
import traceback
print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, logger)
try:
tester.run(test_data_loader, max_epochs=1)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")

311
engine/base/i2i.py Normal file
View File

@@ -0,0 +1,311 @@
import logging
from itertools import chain
from pathlib import Path
import ignite.distributed as idist
import torch
import torchvision
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from ignite.engine import Events, Engine
from ignite.metrics import RunningAverage
from ignite.utils import convert_tensor
from math import ceil
from omegaconf import read_write, OmegaConf
import data
from engine.util.build import build_optimizer
from engine.util.handler import setup_common_handlers, setup_tensorboard_handler
from util.image import make_2d_grid
def build_lr_schedulers(optimizers, config):
# TODO: support more scheduler type
g_milestones_values = [
(0, config.optimizers.generator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
d_milestones_values = [
(0, config.optimizers.discriminator.lr),
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
(config.max_iteration, config.data.train.scheduler.target_lr)
]
return dict(
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
)
class TestEngineKernel(object):
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(config.name)
self.generators = self.build_generators()
def build_generators(self) -> dict:
raise NotImplemented
def to_load(self):
return {f"generator_{k}": self.generators[k] for k in self.generators}
def inference(self, batch):
raise NotImplemented
class EngineKernel(object):
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(config.name)
self.generators, self.discriminators = self.build_models()
self.train_generator_first = True
self.engine = None
def bind_engine(self, engine):
self.engine = engine
def build_models(self) -> (dict, dict):
raise NotImplementedError
def to_save(self):
to_save = {}
to_save.update({f"generator_{k}": self.generators[k] for k in self.generators})
to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators})
return to_save
def setup_after_g(self):
raise NotImplementedError
def setup_before_g(self):
raise NotImplementedError
def forward(self, batch, inference=False) -> dict:
raise NotImplementedError
def criterion_generators(self, batch, generated) -> dict:
raise NotImplementedError
def criterion_discriminators(self, batch, generated) -> dict:
raise NotImplementedError
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
raise NotImplementedError
def change_engine(self, config, engine: Engine):
pass
def _remove_no_grad_loss(loss_dict):
need_to_pop = []
for k in loss_dict:
if not isinstance(loss_dict[k], torch.Tensor):
need_to_pop.append(k)
for k in need_to_pop:
loss_dict.pop(k)
return loss_dict
def get_trainer(config, kernel: EngineKernel):
logger = logging.getLogger(config.name)
generators, discriminators = kernel.generators, kernel.discriminators
optimizers = dict(
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
)
logger.info(f"build optimizers:\n{optimizers}")
lr_schedulers = build_lr_schedulers(optimizers, config)
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
iteration_per_image = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1)
def train_generators(batch, generated):
kernel.setup_before_g()
optimizers["g"].zero_grad()
loss_g = kernel.criterion_generators(batch, generated)
sum(loss_g.values()).backward()
optimizers["g"].step()
kernel.setup_after_g()
return loss_g
def train_discriminators(batch, generated):
optimizers["d"].zero_grad()
loss_d = kernel.criterion_discriminators(batch, generated)
sum(loss_d.values()).backward()
optimizers["d"].step()
return loss_d
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
generated = kernel.forward(batch)
if kernel.train_generator_first:
# simultaneous, train G with simultaneous D
loss_g = train_generators(batch, generated)
loss_d = train_discriminators(batch, generated)
else:
# update discriminators first, not simultaneous.
# train G with updated discriminators
loss_d = train_discriminators(batch, generated)
loss_g = train_generators(batch, generated)
if engine.state.iteration % iteration_per_image == 0:
return {
"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d)),
"img": kernel.intermediate_images(batch, generated)
}
return {"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d))}
trainer = Engine(_step)
trainer.logger = logger
for lr_shd in lr_schedulers.values():
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
kernel.change_engine(config, trainer)
kernel.bind_engine(trainer)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values()), epoch_bound=False).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values()), epoch_bound=False).attach(trainer, "loss_d")
to_save = dict(trainer=trainer)
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
to_save.update(kernel.to_save())
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=config.handler.clear_cuda_cache,
set_epoch_for_dist_sampler=config.handler.set_epoch_for_dist_sampler,
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
tensorboard_handler = setup_tensorboard_handler(trainer, config, optimizers, step_type="item")
if tensorboard_handler is not None:
basic_image_event = Events.ITERATION_COMPLETED(
every=iteration_per_image)
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
@trainer.on(basic_image_event)
def show_images(engine):
output = engine.state.output
test_images = {}
for k in output["img"]:
image_list = output["img"][k]
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list, range=(-1, 1)),
engine.state.iteration * pairs_per_iteration)
test_images[k] = []
for i in range(len(image_list)):
test_images[k].append([])
g = torch.Generator()
g.manual_seed(config.misc.random_seed + engine.state.epoch
if config.handler.test.random else config.misc.random_seed)
random_start = \
torch.randperm(len(engine.state.test_dataset) - config.handler.test.images, generator=g).tolist()[0]
for i in range(random_start, random_start + config.handler.test.images):
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].unsqueeze(0)
elif isinstance(batch[k], dict):
for kk in batch[k]:
if isinstance(batch[k][kk], torch.Tensor):
batch[k][kk] = batch[k][kk].unsqueeze(0)
generated = kernel.forward(batch, inference=True)
images = kernel.intermediate_images(batch, generated)
for k in test_images:
for j in range(len(images[k])):
test_images[k][j].append(images[k][j])
for k in test_images:
tensorboard_handler.writer.add_image(
f"test/{k}",
make_2d_grid([torch.cat(ti) for ti in test_images[k]], range=(-1, 1)),
engine.state.iteration * pairs_per_iteration
)
return trainer
def save_images_helper(output_dir, paths, images_list):
batch_size = len(paths)
for i in range(batch_size):
image_name = Path(paths[i]).name
img_list = [img[i] for img in images_list]
torchvision.utils.save_image(img_list, Path(output_dir) / image_name, nrow=len(img_list), padding=0,
normalize=True, range=(-1, 1))
def get_tester(config, kernel: TestEngineKernel):
logger = logging.getLogger(config.name)
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
return {"batch": batch, "generated": kernel.inference(batch)}
tester = Engine(_step)
tester.logger = logger
setup_common_handlers(tester, config, use_profiler=True, to_save=kernel.to_load())
@tester.on(Events.STARTED)
def mkdir(engine):
img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images"
engine.state.img_output_dir = Path(img_output_dir)
if idist.get_rank() == 0 and not engine.state.img_output_dir.exists():
engine.logger.info(f"mkdir {engine.state.img_output_dir}")
engine.state.img_output_dir.mkdir()
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
if engine.state.dataloader.dataset.__class__.__name__ == "SingleFolderDataset":
images, paths = engine.state.output["batch"]
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"]])
else:
for k in engine.state.output['generated']:
images, paths = engine.state.output["batch"][k]
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"][k]])
return tester
def run_kernel(task, config, kernel):
logger = logging.getLogger(config.name)
with read_write(config):
real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size()
config.max_iteration = ceil(config.max_pairs / real_batch_size)
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
dataloader_kwargs = OmegaConf.to_container(config.data.train.dataloader)
dataloader_kwargs["batch_size"] = dataloader_kwargs["batch_size"] * idist.get_world_size()
train_data_loader = idist.auto_dataloader(train_dataset, **dataloader_kwargs)
with read_write(config):
config.iterations_per_epoch = len(train_data_loader)
trainer = get_trainer(config, kernel)
if idist.get_rank() == 0:
test_dataset = data.DATASET.build_with(config.data.test.dataset)
trainer.state.test_dataset = test_dataset
try:
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
except Exception:
import traceback
print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test[config.data.test.which])
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, kernel)
try:
tester.run(test_data_loader, max_epochs=1)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")

View File

@@ -1,85 +0,0 @@
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
import ignite.distributed as idist
from ignite.contrib.metrics.gpu_info import GpuInfo
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, global_step_from_engine, OutputHandler, \
WeightsScalarHandler, GradsHistHandler, WeightsHistHandler, GradsScalarHandler
from ignite.engine import create_supervised_evaluator, create_supervised_trainer, Events
from ignite.metrics import Accuracy, Loss, RunningAverage
from ignite.contrib.engines.common import save_best_model_by_val_score
from ignite.contrib.handlers import ProgressBar
from util.build import build_model, build_optimizer
from util.handler import setup_common_handlers
from data.transform import transform_pipeline
from data.dataset import LMDBDataset
def warmup_trainer(config, logger):
model = build_model(config.model, config.distributed.model)
optimizer = build_optimizer(model.parameters(), config.baseline.optimizers)
loss_fn = nn.CrossEntropyLoss()
trainer = create_supervised_trainer(model, optimizer, loss_fn, idist.device(), non_blocking=True,
output_transform=lambda x, y, y_pred, loss: (loss.item(), y_pred, y))
trainer.logger = logger
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss")
Accuracy(output_transform=lambda x: (x[1], x[2])).attach(trainer, "acc")
ProgressBar(ncols=0).attach(trainer)
if idist.get_rank() == 0:
GpuInfo().attach(trainer, name='gpu')
tb_logger = TensorboardLogger(log_dir=config.output_dir)
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="train",
metric_names='all',
global_step_transform=global_step_from_engine(trainer),
),
event_name=Events.EPOCH_COMPLETED
)
tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model),
event_name=Events.EPOCH_COMPLETED(every=10))
tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25))
tb_logger.attach(trainer, log_handler=GradsScalarHandler(model),
event_name=Events.EPOCH_COMPLETED(every=10))
tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25))
@trainer.on(Events.COMPLETED)
def _():
tb_logger.close()
to_save = dict(model=model, optimizer=optimizer, trainer=trainer)
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.EPOCH_COMPLETED, to_save=to_save,
save_interval_event=Events.EPOCH_COMPLETED(every=25), n_saved=5,
metrics_to_print=["loss", "acc"])
return trainer
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
if task == "warmup":
train_dataset = LMDBDataset(config.baseline.data.dataset.train.lmdb_path,
pipeline=config.baseline.data.dataset.train.pipeline)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader)
trainer = warmup_trainer(config, logger)
try:
trainer.run(train_data_loader, max_epochs=400)
except Exception:
import traceback
print(traceback.format_exc())
elif task == "protonet-wo":
pass
elif task == "protonet-w":
pass
else:
return ValueError(f"invalid task: {task}")

View File

@@ -1,268 +0,0 @@
import itertools
from pathlib import Path
import torch
import torch.nn as nn
import torchvision.utils
import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
from ignite.metrics import RunningAverage
from ignite.contrib.handlers import ProgressBar
from ignite.utils import convert_tensor
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
from omegaconf import OmegaConf
import data
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
from model.GAN.residual_generator import GANImageBuffer
from util.image import make_2d_grid
from util.handler import setup_common_handlers
from util.build import build_model, build_optimizer
def get_trainer(config, logger):
generator_a = build_model(config.model.generator, config.distributed.model)
generator_b = build_model(config.model.generator, config.distributed.model)
discriminator_a = build_model(config.model.discriminator, config.distributed.model)
discriminator_b = build_model(config.model.discriminator, config.distributed.model)
for m in [generator_b, generator_a, discriminator_b, discriminator_a]:
generation_init_weights(m)
logger.info(discriminator_a)
logger.info(generator_a)
optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
config.optimizers.generator)
optimizer_d = build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()),
config.optimizers.discriminator)
milestones_values = [
(0, config.optimizers.generator.lr),
(100, config.optimizers.generator.lr),
(200, config.data.train.scheduler.target_lr)
]
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
milestones_values = [
(0, config.optimizers.discriminator.lr),
(100, config.optimizers.discriminator.lr),
(200, config.data.train.scheduler.target_lr)
]
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real_a, real_b = batch["a"], batch["b"]
fake_b = generator_a(real_a) # G_A(A)
rec_a = generator_b(fake_b) # G_B(G_A(A))
fake_a = generator_b(real_b) # G_B(B)
rec_b = generator_a(fake_a) # G_A(G_B(B))
optimizer_g.zero_grad()
discriminator_a.requires_grad_(False)
discriminator_b.requires_grad_(False)
loss_g = dict(
cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a),
cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b),
gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True),
gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True)
)
if config.loss.id.weight > 0:
loss_g["id_a"] = config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
loss_g["id_b"] = config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
sum(loss_g.values()).backward()
optimizer_g.step()
discriminator_a.requires_grad_(True)
discriminator_b.requires_grad_(True)
optimizer_d.zero_grad()
loss_d_a = dict(
real=gan_loss(discriminator_a(real_b), True, is_discriminator=True),
fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True),
)
loss_d_b = dict(
real=gan_loss(discriminator_b(real_a), True, is_discriminator=True),
fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True),
)
(sum(loss_d_a.values()) * 0.5).backward()
(sum(loss_d_b.values()) * 0.5).backward()
optimizer_d.step()
return {
"loss": {
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
"d_a": {ln: loss_d_a[ln].mean().item() for ln in loss_d_a},
"d_b": {ln: loss_d_b[ln].mean().item() for ln in loss_d_b},
},
"img": [
real_a.detach(),
fake_b.detach(),
rec_a.detach(),
real_b.detach(),
fake_a.detach(),
rec_b.detach()
]
}
trainer = Engine(_step)
trainer.logger = logger
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_g)
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_d)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b")
to_save = dict(
generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a,
discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer,
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
)
setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5,
filename_prefix=config.name, to_save=to_save,
print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED,
metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"],
save_interval_event=Events.ITERATION_COMPLETED(
every=config.checkpoints.interval) | Events.COMPLETED)
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
def terminate(engine):
engine.terminate()
if idist.get_rank() == 0:
# Create a logger
tb_logger = TensorboardLogger(log_dir=config.output_dir)
tb_writer = tb_logger.writer
# Attach the logger to the trainer to log training loss at each iteration
def global_step_transform(*args, **kwargs):
return trainer.state.iteration
def output_transform(output):
loss = dict()
for tl in output["loss"]:
if isinstance(output["loss"][tl], dict):
for l in output["loss"][tl]:
loss[f"{tl}_{l}"] = output["loss"][tl][l]
else:
loss[tl] = output["loss"][tl]
return loss
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="loss",
metric_names=["loss_g", "loss_d_a", "loss_d_b"],
global_step_transform=global_step_transform,
output_transform=output_transform
),
event_name=Events.ITERATION_COMPLETED(every=50)
)
tb_logger.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
event_name=Events.ITERATION_STARTED(every=50)
)
@trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
def show_images(engine):
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration)
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()
def _():
# We need to close the logger with we are done
tb_logger.close()
return trainer
def get_tester(config, logger):
generator_a = build_model(config.model.generator, config.distributed.model)
generator_b = build_model(config.model.generator, config.distributed.model)
def _step(engine, batch):
batch = convert_tensor(batch, idist.device())
real_a, real_b = batch["a"], batch["b"]
with torch.no_grad():
fake_b = generator_a(real_a) # G_A(A)
rec_a = generator_b(fake_b) # G_B(G_A(A))
fake_a = generator_b(real_b) # G_B(B)
rec_b = generator_a(fake_a) # G_A(G_B(B))
return [
real_a.detach(),
fake_b.detach(),
rec_a.detach(),
real_b.detach(),
fake_a.detach(),
rec_b.detach()
]
tester = Engine(_step)
tester.logger = logger
if idist.get_rank == 0:
ProgressBar(ncols=0).attach(tester)
to_load = dict(generator_a=generator_a, generator_b=generator_b)
setup_common_handlers(tester, use_profiler=False, to_save=to_load, resume_from=config.resume_from)
@tester.on(Events.STARTED)
@idist.one_rank_only()
def mkdir(engine):
img_output_dir = Path(config.output_dir) / "test_images"
if not img_output_dir.exists():
engine.logger.info(f"mkdir {img_output_dir}")
img_output_dir.mkdir()
@tester.on(Events.ITERATION_COMPLETED)
def save_images(engine):
img_tensors = engine.state.output
batch_size = img_tensors[0].size(0)
for i in range(batch_size):
torchvision.utils.save_image([img[i] for img in img_tensors],
Path(config.output_dir) / f"test_images/{engine.state.iteration}_{i}.jpg",
nrow=len(img_tensors))
return tester
def run(task, config, logger):
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
logger.info(f"start task {task}")
if task == "train":
train_dataset = data.DATASET.build_with(config.data.train.dataset)
logger.info(f"train with dataset:\n{train_dataset}")
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
trainer = get_trainer(config, logger)
try:
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
except Exception:
import traceback
print(traceback.format_exc())
elif task == "test":
assert config.resume_from is not None
test_dataset = data.DATASET.build_with(config.data.test.dataset)
logger.info(f"test with dataset:\n{test_dataset}")
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
tester = get_tester(config, logger)
try:
tester.run(test_data_loader, max_epochs=1)
except Exception:
import traceback
print(traceback.format_exc())
else:
return NotImplemented(f"invalid task: {task}")

View File

@@ -1,9 +0,0 @@
from data.dataset import EpisodicDataset, LMDBDataset
def prototypical_trainer(config, logger):
pass
def prototypical_dataloader(config):
pass

153
engine/talking_anime.py Normal file
View File

@@ -0,0 +1,153 @@
from itertools import chain
import ignite.distributed as idist
import torch
import torch.nn as nn
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model
from loss.I2I.context_loss import ContextLoss
from loss.I2I.edge_loss import EdgeLoss
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
from model.weight_init import generation_init_weights
class TAEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
style_loss_cfg = OmegaConf.to_container(config.loss.style)
style_loss_cfg.pop("weight")
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
context_loss_cfg = OmegaConf.to_container(config.loss.context)
context_loss_cfg.pop("weight")
self.context_loss = ContextLoss(**context_loss_cfg).to(idist.device())
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
idist.device())
def build_models(self) -> (dict, dict):
generators = dict(
anime=build_model(self.config.model.anime_generator),
face=build_model(self.config.model.face_generator)
)
discriminators = dict(
anime=build_model(self.config.model.discriminator),
face=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["face"])
self.logger.debug(generators["face"])
for m in chain(generators.values(), discriminators.values()):
generation_init_weights(m)
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
with torch.set_grad_enabled(not inference):
target_pose_anime = self.generators["anime"](
torch.cat([batch["face_1"], torch.flip(batch["anime_img"], dims=[3])], dim=1))
target_pose_face = self.generators["face"](target_pose_anime.mean(dim=1, keepdim=True), batch["face_0"])
return dict(fake_anime=target_pose_anime, fake_face=target_pose_face)
def cal_gan_and_fm_loss(self, discriminator, generated_img, match_img=None):
pred_fake = discriminator(generated_img)
loss_gan = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss_gan += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
if match_img is None:
# do not cal feature match loss
return loss_gan, 0
pred_real = discriminator(match_img)
loss_fm = 0
num_scale_discriminator = len(pred_fake)
for i in range(num_scale_discriminator):
# last output is the final prediction, so we exclude it
num_intermediate_outputs = len(pred_fake[i]) - 1
for j in range(num_intermediate_outputs):
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
loss_fm = self.config.loss.fm.weight * loss_fm
return loss_gan, loss_fm
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
loss["face_style"] = self.config.loss.style.weight * self.style_loss(
generated["fake_face"], batch["face_1"]
)
loss["face_recon"] = self.config.loss.recon.weight * self.recon_loss(
generated["fake_face"], batch["face_1"]
)
loss["face_gan"], loss["face_fm"] = self.cal_gan_and_fm_loss(
self.discriminators["face"], generated["fake_face"], batch["face_1"])
loss["anime_gan"], loss["anime_fm"] = self.cal_gan_and_fm_loss(
self.discriminators["anime"], generated["fake_anime"], batch["anime_img"])
loss["anime_edge"] = self.config.loss.edge.weight * self.edge_loss(
generated["fake_anime"], batch["face_1"], gt_is_edge=False,
)
if self.config.loss.perceptual.weight > 0:
loss["anime_perceptual"] = self.config.loss.perceptual.weight * self.perceptual_loss(
generated["fake_anime"], batch["anime_img"]
)
if self.config.loss.context.weight > 0:
loss["anime_context"] = self.config.loss.context.weight * self.context_loss(
generated["fake_anime"], batch["anime_img"],
)
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
real = {"anime": "anime_img", "face": "face_1"}
for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[real[phase]])
pred_fake = self.discriminators[phase](generated[f"fake_{phase}"].detach())
loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
images = [batch["face_0"], batch["face_1"], batch["anime_img"], generated["fake_anime"].detach(),
generated["fake_face"].detach()]
return dict(
b=[img for img in images]
)
def run(task, config, _):
kernel = TAEngineKernel(config)
run_kernel(task, config, kernel)

0
engine/util/__init__.py Normal file
View File

33
engine/util/build.py Normal file
View File

@@ -0,0 +1,33 @@
import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.optim as optim
from omegaconf import OmegaConf
from model import MODEL
def add_spectral_norm(module):
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
return nn.utils.spectral_norm(module)
else:
return module
def build_model(cfg):
cfg = OmegaConf.to_container(cfg)
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
add_spectral_norm_flag = cfg.pop("_add_spectral_norm", False)
model = MODEL.build_with(cfg)
if bn_to_sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if add_spectral_norm_flag:
model.apply(add_spectral_norm)
return idist.auto_model(model)
def build_optimizer(params, cfg):
assert "_type" in cfg
cfg = OmegaConf.to_container(cfg)
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
return idist.auto_optim(optimizer)

66
engine/util/container.py Normal file
View File

@@ -0,0 +1,66 @@
import torch
class LossContainer:
def __init__(self, weight, loss):
self.weight = weight
self.loss = loss
def __call__(self, *args, **kwargs):
if self.weight > 0:
return self.weight * self.loss(*args, **kwargs)
return 0.0
class GANImageBuffer:
"""This class implements an image buffer that stores previously
generated images.
This buffer allows us to update the discriminator using a history of
generated images rather than the ones produced by the latest generator
to reduce model oscillation.
Args:
buffer_size (int): The size of image buffer. If buffer_size = 0,
no buffer will be created.
buffer_ratio (float): The chance / possibility to use the images
previously stored in the buffer.
"""
def __init__(self, buffer_size, buffer_ratio=0.5):
self.buffer_size = buffer_size
# create an empty buffer
if self.buffer_size > 0:
self.img_num = 0
self.image_buffer = []
self.buffer_ratio = buffer_ratio
def query(self, images):
"""Query current image batch using a history of generated images.
Args:
images (Tensor): Current image batch without history information.
"""
if self.buffer_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
# if the buffer is not full, keep inserting current images
if self.img_num < self.buffer_size:
self.img_num = self.img_num + 1
self.image_buffer.append(image)
return_images.append(image)
else:
use_buffer = torch.rand(1) < self.buffer_ratio
# by self.buffer_ratio, the buffer will return a previously
# stored image, and insert the current image into the buffer
if use_buffer:
random_id = torch.randint(0, self.buffer_size, (1,)).item()
image_tmp = self.image_buffer[random_id].clone()
self.image_buffer[random_id] = image
return_images.append(image_tmp)
# by (1 - self.buffer_ratio), the buffer will return the
# current image
else:
return_images.append(image)
# collect all the images and return
return_images = torch.cat(return_images, 0)
return return_images

View File

@@ -7,7 +7,7 @@ import ignite.distributed as idist
from ignite.engine import Events, Engine
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler, OptimizerParamsHandler
def empty_cuda_cache(_):
@@ -16,6 +16,16 @@ def empty_cuda_cache(_):
gc.collect()
def step_transform_maker(stype: str, pairs_per_iteration=None):
assert stype in ["item", "iteration", "epoch"]
if stype == "item":
return lambda engine, _: engine.state.iteration * pairs_per_iteration
if stype == "iteration":
return lambda engine, _: engine.state.iteration
if stype == "epoch":
return lambda engine, _: engine.state.epoch
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
to_save=None, end_event=None, set_epoch_for_dist_sampler=False):
"""
@@ -41,9 +51,10 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler")
trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1)
@trainer.on(Events.STARTED | Events.EPOCH_COMPLETED(once=1))
trainer.logger.info(f"data loader length: {config.iterations_per_epoch} iterations per epoch")
@trainer.on(Events.EPOCH_COMPLETED(once=1))
def print_info(engine):
engine.logger.info(f"data loader length: {len(engine.state.dataloader)}")
engine.logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}")
if stop_on_nan:
@@ -66,7 +77,7 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
if to_save is not None:
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False),
n_saved=config.checkpoint.n_saved, filename_prefix=config.name)
n_saved=config.handler.checkpoint.n_saved, filename_prefix=config.name)
if config.resume_from is not None:
@trainer.on(Events.STARTED)
def resume(engine):
@@ -74,11 +85,13 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
trainer.logger.info(f"load state_dict for {ckp.keys()}")
trainer.logger.info(f"load state_dict for {to_save.keys()}")
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED,
checkpoint_handler)
trainer.add_event_handler(
Events.EPOCH_COMPLETED(every=config.handler.checkpoint.epoch_interval) | Events.COMPLETED,
checkpoint_handler
)
trainer.logger.debug(f"add checkpoint handler to save {to_save.keys()} periodically")
if end_event is not None:
trainer.logger.debug(f"engine will stop on {end_event}")
@@ -88,17 +101,48 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
engine.terminate()
def setup_tensorboard_handler(trainer: Engine, config, output_transform, iter_per_epoch):
if config.interval.tensorboard is None:
def setup_tensorboard_handler(trainer: Engine, config, optimizers, step_type="item"):
if config.handler.tensorboard is None:
return None
if idist.get_rank() == 0:
# Create a logger
tb_logger = TensorboardLogger(log_dir=config.output_dir)
basic_event = Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"),
event_name=basic_event)
tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform),
event_name=basic_event)
tb_writer = tb_logger.writer
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
global_step_transform = step_transform_maker(step_type, pairs_per_iteration)
basic_event = Events.ITERATION_COMPLETED(
every=max(config.iterations_per_epoch // config.handler.tensorboard.scalar, 1))
tb_logger.attach(
trainer,
log_handler=OutputHandler(
tag="metric", metric_names="all",
global_step_transform=global_step_transform
),
event_name=basic_event
)
@trainer.on(basic_event)
def log_loss(engine):
global_step = global_step_transform(engine, None)
output_loss = engine.state.output["loss"]
for total_loss in output_loss:
if isinstance(output_loss[total_loss], dict):
for ln in output_loss[total_loss]:
tb_writer.add_scalar(f"train_{total_loss}/{ln}", output_loss[total_loss][ln], global_step)
else:
tb_writer.add_scalar(f"train/{total_loss}", output_loss[total_loss], global_step)
if isinstance(optimizers, dict):
for name in optimizers:
tb_logger.attach(
trainer,
log_handler=OptimizerParamsHandler(optimizers[name], tag=f"optimizer_{name}"),
event_name=Events.ITERATION_STARTED
)
else:
tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizers, tag=f"optimizer"),
event_name=Events.ITERATION_STARTED)
@trainer.on(Events.COMPLETED)
@idist.one_rank_only()

48
engine/util/loss.py Normal file
View File

@@ -0,0 +1,48 @@
import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
def gan_loss(config):
gan_loss_cfg = OmegaConf.to_container(config)
gan_loss_cfg.pop("weight")
return GANLoss(**gan_loss_cfg).to(idist.device())
def perceptual_loss(config):
perceptual_loss_cfg = OmegaConf.to_container(config)
perceptual_loss_cfg.pop("weight")
return PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
def pixel_loss(level):
return nn.L1Loss() if level == 1 else nn.MSELoss()
def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def feature_match_loss(level, weight_policy):
compare_loss = pixel_loss(level)
assert weight_policy in ["same", "exponential_decline"]
def fm_loss(generated_features, target_features):
num_scale = len(generated_features)
loss = torch.zeros(1, device=idist.device())
for s_i in range(num_scale):
for i in range(len(generated_features[s_i]) - 1):
weight = 1 if weight_policy == "same" else 2 ** i
loss += weight * compare_loss(generated_features[s_i][i], target_features[s_i][i].detach()) / num_scale
return loss
return fm_loss

View File

@@ -6,7 +6,6 @@ channels:
dependencies:
- python=3.8
- numpy
- ipython
- tqdm
- pyyaml
- pytorch=1.6.*
@@ -17,6 +16,6 @@ dependencies:
- omegaconf
- python-lmdb
- fire
# - opencv
- opencv
# - jupyterlab

44
loss/I2I/context_loss.py Normal file
View File

@@ -0,0 +1,44 @@
import torch
import torch.nn.functional as F
from torch import nn
from .perceptual_loss import PerceptualVGG
class ContextLoss(nn.Module):
def __init__(self, layer_weights, h=0.1, vgg_type='vgg19', norm_image_with_imagenet_param=True, norm_img=True,
eps=1e-5):
super(ContextLoss, self).__init__()
self.eps = eps
self.h = h
self.layer_weights = layer_weights
self.norm_img = norm_img
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
norm_image_with_imagenet_param=norm_image_with_imagenet_param)
def single_forward(self, source_feature, target_feature):
mean_target_feature = target_feature.mean(dim=[2, 3], keepdim=True)
source_feature = (source_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW
target_feature = (target_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW
source_feature = F.normalize(source_feature, p=2, dim=1)
target_feature = F.normalize(target_feature, p=2, dim=1)
cosine_distance = (1 - torch.bmm(source_feature.transpose(1, 2), target_feature)) / 2 # NxHWxHW
rel_distance = cosine_distance / (cosine_distance.min(2, keepdim=True)[0] + self.eps)
w = torch.exp((1 - rel_distance) / self.h)
cx = w.div(w.sum(dim=2, keepdim=True))
cx = cx.max(dim=1, keepdim=True)[0].mean(dim=2)
return -torch.log(cx).mean()
def forward(self, x, gt):
if self.norm_img:
x = (x + 1.) * 0.5
gt = (gt + 1.) * 0.5
# extract vgg features
x_features = self.vgg(x)
gt_features = self.vgg(gt.detach())
loss = 0
for k in x_features.keys():
loss += self.single_forward(x_features[k], gt_features[k]) * self.layer_weights[k]
return loss

View File

@@ -0,0 +1,229 @@
import ignite.distributed as idist
import torch
import torch.nn as nn
def gaussian_radial_basis_function(x, mu, sigma):
# (kernel_size) -> (batch_size, kernel_size, c*h*w)
mu = mu.view(1, mu.size(0), 1).expand(x.size(0), -1, x.size(1) * x.size(2) * x.size(3))
# (batch_size, c, h, w) -> (batch_size, kernel_size, c*h*w)
x = x.view(x.size(0), 1, -1).expand(-1, mu.size(1), -1)
return torch.exp((x - mu).pow(2) / (2 * sigma ** 2))
class ImporveMyLoss(torch.nn.Module):
def __init__(self, device=idist.device()):
super().__init__()
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).to(device)
self.x_mu_list = mu.repeat(9).view(-1, 81)
self.y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)
self.R = torch.eye(81).to(device)
def batch_ERSMI(self, I1, I2):
batch_size = I1.shape[0]
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
if I2.shape[1] == 1 and I1.shape[1] != 1:
I2 = I2.repeat(1, 3, 1, 1)
def kernel_F(y, mu_list, sigma):
tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size, 1, 1) # [81, 784]
tmp_y = y.view(batch_size, 1, -1).repeat(1, 81, 1)
tmp_y = tmp_mu - tmp_y
mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
return mat_L
mat_K = kernel_F(I1, self.x_mu_list, 1)
mat_L = kernel_F(I2, self.y_mu_list, 1)
mat_k_l = mat_K * mat_L
H1 = (mat_K @ mat_K.transpose(1, 2)) * (mat_L @ mat_L.transpose(1, 2)) / (img_size ** 2)
h_hat = mat_k_l @ mat_k_l.transpose(1, 2) / img_size
small_h_hat = mat_K.sum(2).view(batch_size, -1, 1) * mat_L.sum(2).view(batch_size, -1, 1) / (img_size ** 2)
h_hat = 0.5 * H1 + 0.5 * h_hat
alpha = (h_hat + 0.05 * self.R).inverse() @ small_h_hat
ersmi = 2 * alpha.transpose(1, 2) @ small_h_hat - alpha.transpose(1, 2) @ h_hat @ alpha - 1
ersmi = -ersmi.squeeze().mean()
return ersmi
def forward(self, fakeI, realI):
return self.batch_ERSMI(fakeI, realI)
class MyLoss(torch.nn.Module):
def __init__(self):
super(MyLoss, self).__init__()
def forward(self, fakeI, realI):
fakeI = fakeI.cuda()
realI = realI.cuda()
def batch_ERSMI(I1, I2):
batch_size = I1.shape[0]
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
if I2.shape[1] == 1 and I1.shape[1] != 1:
I2 = I2.repeat(1, 3, 1, 1)
def kernel_F(y, mu_list, sigma):
tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size, 1, 1).cuda() # [81, 784]
tmp_y = y.view(batch_size, 1, -1).repeat(1, 81, 1)
tmp_y = tmp_mu - tmp_y
mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
return mat_L
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).cuda()
x_mu_list = mu.repeat(9).view(-1, 81)
y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)
mat_K = kernel_F(I1, x_mu_list, 1)
mat_L = kernel_F(I2, y_mu_list, 1)
H1 = ((mat_K.matmul(mat_K.transpose(1, 2))).mul(mat_L.matmul(mat_L.transpose(1, 2))) / (
img_size ** 2)).cuda()
H2 = ((mat_K.mul(mat_L)).matmul((mat_K.mul(mat_L)).transpose(1, 2)) / img_size).cuda()
h2 = ((mat_K.sum(2).view(batch_size, -1, 1)).mul(mat_L.sum(2).view(batch_size, -1, 1)) / (
img_size ** 2)).cuda()
H2 = 0.5 * H1 + 0.5 * H2
tmp = H2 + 0.05 * torch.eye(len(H2[0])).cuda()
alpha = (tmp.inverse())
alpha = alpha.matmul(h2)
ersmi = (2 * (alpha.transpose(1, 2)).matmul(h2) - ((alpha.transpose(1, 2)).matmul(H2)).matmul(
alpha) - 1).squeeze()
ersmi = -ersmi.mean()
return ersmi
batch_loss = batch_ERSMI(fakeI, realI)
return batch_loss
class MGCLoss(nn.Module):
"""
Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ
"""
def __init__(self, mi_to_loss_way="opposite", beta=0.5, lambda_=0.05, device=idist.device()):
super().__init__()
self.beta = beta
self.lambda_ = lambda_
assert mi_to_loss_way in ["opposite", "reciprocal"]
self.mi_to_loss_way = mi_to_loss_way
mu_y, mu_x = torch.meshgrid([torch.arange(-1, 1.25, 0.25), torch.arange(-1, 1.25, 0.25)])
self.mu_x = mu_x.flatten().to(device)
self.mu_y = mu_y.flatten().to(device)
self.R = torch.eye(81).unsqueeze(0).to(device)
@staticmethod
def batch_rSMI(img1, img2, mu_x, mu_y, beta, lambda_, R):
assert img1.size() == img2.size()
num_pixel = img1.size(1) * img1.size(2) * img2.size(3)
mat_k = gaussian_radial_basis_function(img1, mu_x, sigma=1)
mat_l = gaussian_radial_basis_function(img2, mu_y, sigma=1)
mat_k_mul_mat_l = mat_k * mat_l
h_hat = (1 - beta) * (mat_k_mul_mat_l @ mat_k_mul_mat_l.transpose(1, 2)) / num_pixel
h_hat += beta * ((mat_k @ mat_k.transpose(1, 2)) * (mat_l @ mat_l.transpose(1, 2))) / (num_pixel ** 2)
small_h_hat = mat_k.sum(2, keepdim=True) * mat_l.sum(2, keepdim=True) / (num_pixel ** 2)
alpha = (h_hat + lambda_ * R).inverse() @ small_h_hat
rSMI = 2 * alpha.transpose(1, 2) @ small_h_hat - alpha.transpose(1, 2) @ h_hat @ alpha - 1
return rSMI.squeeze()
def forward(self, fake, real):
rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_, self.R)
if self.mi_to_loss_way == "reciprocal":
return 1/rSMI.mean()
return -rSMI.mean()
if __name__ == '__main__':
mg = MGCLoss(device=torch.device("cpu"))
my = MyLoss().to("cuda")
imy = ImporveMyLoss()
from data.transform import transform_pipeline
pipeline = transform_pipeline(
['Load', 'ToTensor', {'Normalize': {'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5]}}])
img_a1 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_1.jpg")
img_a2 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_2.jpg")
img_a3 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_3.jpg")
img_b1 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_1.jpg")
img_b2 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_2.jpg")
img_b3 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_3.jpg")
img_a1.requires_grad_(True)
img_a2.requires_grad_(True)
img_a3.requires_grad_(True)
# print("MyLoss")
# l1 = my(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
# l2 = my(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
# l3 = my(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
# l = (l1+l2+l3)/3
# l.backward()
# print(img_a1.grad[0][0][0:10])
# print(img_a2.grad[0][0][0:10])
# print(img_a3.grad[0][0][0:10])
#
# img_a1.grad = None
# img_a2.grad = None
# img_a3.grad = None
#
# print("---")
# l = my(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# l.backward()
# print(img_a1.grad[0][0][0:10])
# print(img_a2.grad[0][0][0:10])
# print(img_a3.grad[0][0][0:10])
# img_a1.grad = None
# img_a2.grad = None
# img_a3.grad = None
print("MGCLoss")
l1 = mg(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
l2 = mg(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
l3 = mg(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
l = (l1 + l2 + l3) / 3
l.backward()
print(img_a1.grad[0][0][0:10])
print(img_a2.grad[0][0][0:10])
print(img_a3.grad[0][0][0:10])
img_a1.grad = None
img_a2.grad = None
img_a3.grad = None
print("---")
l = mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
l.backward()
print(img_a1.grad[0][0][0:10])
print(img_a2.grad[0][0][0:10])
print(img_a3.grad[0][0][0:10])
# print("\nMGCLoss")
# mg(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
# mg(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
# mg(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
#
# print("---")
# mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
#
# import pprofile
#
# profiler = pprofile.Profile()
# with profiler:
# iter_times = 1000
# for _ in range(iter_times):
# mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# for _ in range(iter_times):
# my(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# for _ in range(iter_times):
# imy(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
# profiler.print_stats()

View File

@@ -1,8 +1,52 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models.vgg as vgg
# Sequential(
# (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (1): ReLU(inplace=True)
# (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (3): ReLU(inplace=True)
# (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (6): ReLU(inplace=True)
# (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (8): ReLU(inplace=True)
# (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (11): ReLU(inplace=True)
# (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (13): ReLU(inplace=True)
# (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (15): ReLU(inplace=True)
# (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (17): ReLU(inplace=True)
# (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (20): ReLU(inplace=True)
# (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (22): ReLU(inplace=True)
# (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (24): ReLU(inplace=True)
# (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (26): ReLU(inplace=True)
# (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (29): ReLU(inplace=True)
# (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (31): ReLU(inplace=True)
# (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (33): ReLU(inplace=True)
# (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
# (35): ReLU(inplace=True)
# (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# )
class PerceptualVGG(nn.Module):
"""VGG network used in calculating perceptual loss.
In this implementation, we allow users to choose whether use normalization
@@ -14,15 +58,15 @@ class PerceptualVGG(nn.Module):
list contains the name each layer in `vgg.feature`. An example
of this list is ['4', '10'].
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image.
norm_image_with_imagenet_param (bool): If True, normalize the input image.
Importantly, the input feature must in the range [0, 1].
Default: True.
"""
def __init__(self, layer_name_list, vgg_type='vgg19', use_input_norm=True):
def __init__(self, layer_name_list, vgg_type='vgg19', norm_image_with_imagenet_param=True):
super(PerceptualVGG, self).__init__()
self.layer_name_list = layer_name_list
self.use_input_norm = use_input_norm
self.use_input_norm = norm_image_with_imagenet_param
# get vgg model and load pretrained vgg weight
# remove _vgg from attributes to avoid `find_unused_parameters` bug
@@ -74,7 +118,7 @@ class PerceptualLoss(nn.Module):
in calculating losses.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
norm_image_with_imagenet_param (bool): If True, normalize the input image in vgg.
Default: True.
perceptual_loss (bool): If `perceptual_loss == True`, the perceptual
loss will be calculated.
@@ -87,22 +131,24 @@ class PerceptualLoss(nn.Module):
Importantly, the input image must be in range [-1, 1].
"""
def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, perceptual_loss=True,
def __init__(self, layer_weights, vgg_type='vgg19', norm_image_with_imagenet_param=True, perceptual_loss=True,
style_loss=False, norm_img=True, criterion='L1'):
super(PerceptualLoss, self).__init__()
self.norm_img = norm_img
assert perceptual_loss ^ style_loss, "There must be one and only one true in style or perceptual"
self.perceptual_loss = perceptual_loss
self.style_loss = style_loss
self.layer_weights = layer_weights
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
use_input_norm=use_input_norm)
norm_image_with_imagenet_param=norm_image_with_imagenet_param)
if criterion == 'L1':
self.criterion = torch.nn.L1Loss()
elif criterion == "L2":
self.criterion = torch.nn.MSELoss()
else:
raise NotImplementedError(f'{criterion} criterion has not been supported in this version.')
self.percep_criterion, self.style_criterion = self.set_criterion(criterion)
def set_criterion(self, criterion: str):
assert criterion in ["NL1", "NL2", "L1", "L2"]
norm = F.instance_norm if criterion.startswith("N") else lambda x: x
fn = F.l1_loss if criterion.endswith("L1") else F.mse_loss
return lambda x, t: fn(norm(x), norm(t)), lambda x, t: fn(x, t)
def forward(self, x, gt):
"""Forward function.
@@ -124,22 +170,16 @@ class PerceptualLoss(nn.Module):
if self.perceptual_loss:
percep_loss = 0
for k in x_features.keys():
percep_loss += self.criterion(
x_features[k], gt_features[k]) * self.layer_weights[k]
else:
percep_loss = None
percep_loss += self.percep_criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
return percep_loss
# calculate style loss
if self.style_loss:
style_loss = 0
for k in x_features.keys():
style_loss += self.criterion(
self._gram_mat(x_features[k]),
self._gram_mat(gt_features[k])) * self.layer_weights[k]
else:
style_loss = None
return percep_loss, style_loss
style_loss += self.style_criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \
self.layer_weights[k]
return style_loss
def _gram_mat(self, x):
"""Calculate Gram matrix.

View File

@@ -1,4 +1,5 @@
import torch.nn as nn
import torch
import torch.nn.functional as F
@@ -10,7 +11,7 @@ class GANLoss(nn.Module):
self.fake_label_val = fake_label_val
self.loss_type = loss_type
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
def single_forward(self, prediction, target_is_real: bool, is_discriminator=False):
"""
gan loss forward
:param prediction: network prediction
@@ -37,3 +38,20 @@ class GANLoss(nn.Module):
return loss
else:
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
if isinstance(prediction, torch.Tensor):
# origin
return self.single_forward(prediction, target_is_real, is_discriminator)
elif isinstance(prediction, list):
# for multi scale discriminator, e.g. MultiScaleDiscriminator
loss = 0
for p in prediction:
loss += self.single_forward(p[-1], target_is_real, is_discriminator)
return loss
elif isinstance(prediction, tuple):
# for single discriminator set `need_intermediate_feature` true
return self.single_forward(prediction[-1], target_is_real, is_discriminator)
else:
raise NotImplementedError(f"not support discriminator output: {prediction}")

26
main.py
View File

@@ -1,16 +1,15 @@
from pathlib import Path
from importlib import import_module
import torch
import ignite
import ignite.distributed as idist
from ignite.utils import manual_seed
from util.misc import setup_logger
from pathlib import Path
import fire
import ignite
import ignite.distributed as idist
import torch
from ignite.utils import manual_seed
from omegaconf import OmegaConf
from util.misc import setup_logger
def log_basic_info(logger, config):
logger.info(f"Train {config.name}")
@@ -28,15 +27,14 @@ def log_basic_info(logger, config):
def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False):
if setup_random_seed:
manual_seed(config.misc.random_seed + idist.get_rank())
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else Path(config.output_dir)
config.output_dir = str(output_dir)
if setup_output_dir and config.resume_from is None:
if output_dir.exists():
# assert not any(output_dir.iterdir()), "output_dir must be empty"
contains = list(output_dir.iterdir())
assert (len(contains) == 0) or (len(contains) == 1 and contains[0].name == "config.yml"), \
f"output_dir must by empty or only contains config.yml, but now got {len(contains)} files"
assert len(list(output_dir.glob("events*"))) == 0, f"{output_dir} containers tensorboard event"
if (output_dir / "train.log").exists() and idist.get_rank() == 0:
(output_dir / "train.log").unlink()
else:
if idist.get_rank() == 0:
output_dir.mkdir(parents=True)
@@ -65,6 +63,8 @@ def run(task, config: str, *omega_options, **kwargs):
backup_config = kwargs.get("backup_config", False)
setup_output_dir = kwargs.get("setup_output_dir", False)
setup_random_seed = kwargs.get("setup_random_seed", False)
assert torch.backends.cudnn.enabled
torch.backends.cudnn.benchmark = True
with idist.Parallel(backend=backend) as parallel:
parallel.run(running, conf, task, backup_config=backup_config, setup_output_dir=setup_output_dir,
setup_random_seed=setup_random_seed)

View File

@@ -1,224 +0,0 @@
import torch
import torch.nn as nn
from .residual_generator import ResidualBlock
from model.registry import MODEL
from torchvision.models import vgg19
from model.normalization import select_norm_layer
class VGG19StyleEncoder(nn.Module):
def __init__(self, in_channels, base_channels=64, style_dim=512, padding_mode='reflect', norm_type="NONE",
vgg19_layers=(0, 5, 10, 19)):
super().__init__()
self.vgg19_layers = vgg19_layers
self.vgg19 = vgg19(pretrained=True).features[:vgg19_layers[-1] + 1]
self.vgg19.requires_grad_(False)
norm_layer = select_norm_layer(norm_type)
self.conv0 = nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
bias=True),
norm_layer(base_channels),
nn.ReLU(True),
)
self.conv = nn.ModuleList([
nn.Sequential(
nn.Conv2d(base_channels * (2 ** i), base_channels * (2 ** i), kernel_size=4, stride=2, padding=1,
padding_mode=padding_mode, bias=True),
norm_layer(base_channels),
nn.ReLU(True),
) for i in range(1, 4)
])
self.pool = nn.AdaptiveAvgPool2d(1)
self.conv1x1 = nn.Conv2d(base_channels * (2 ** 4), style_dim, kernel_size=1, stride=1, padding=0)
def fixed_style_features(self, x):
features = []
for i in range(len(self.vgg19)):
x = self.vgg19[i](x)
if i in self.vgg19_layers:
features.append(x)
return features
def forward(self, x):
fsf = self.fixed_style_features(x)
x = self.conv0(x)
for i, l in enumerate(self.conv):
x = l(torch.cat([x, fsf[i]], dim=1))
x = self.pool(torch.cat([x, fsf[-1]], dim=1))
x = self.conv1x1(x)
return x.view(x.size(0), -1)
class ContentEncoder(nn.Module):
def __init__(self, in_channels, base_channels=64, num_blocks=8, padding_mode='reflect', norm_type="IN"):
super().__init__()
norm_layer = select_norm_layer(norm_type)
self.start_conv = nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
bias=True),
norm_layer(num_features=base_channels),
nn.ReLU(inplace=True)
)
# down sampling
submodules = []
num_down_sampling = 2
for i in range(num_down_sampling):
multiple = 2 ** i
submodules += [
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
kernel_size=4, stride=2, padding=1, bias=True),
norm_layer(num_features=base_channels * multiple * 2),
nn.ReLU(inplace=True)
]
self.encoder = nn.Sequential(*submodules)
res_block_channels = num_down_sampling ** 2 * base_channels
self.resnet = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
def forward(self, x):
x = self.start_conv(x)
x = self.encoder(x)
x = self.resnet(x)
return x
class Decoder(nn.Module):
def __init__(self, out_channels, base_channels=64, num_blocks=4, num_down_sampling=2, padding_mode='reflect',
norm_type="LN"):
super(Decoder, self).__init__()
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
res_block_channels = (2 ** 2) * base_channels
self.resnet = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
# up sampling
submodules = []
for i in range(num_down_sampling):
multiple = 2 ** (num_down_sampling - i)
submodules += [
nn.Upsample(scale_factor=2),
nn.Conv2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=5, stride=1,
padding=2, padding_mode=padding_mode, bias=use_bias),
norm_layer(num_features=base_channels * multiple // 2),
nn.ReLU(inplace=True),
]
self.decoder = nn.Sequential(*submodules)
self.end_conv = nn.Sequential(
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
nn.Tanh()
)
def forward(self, x):
x = self.resnet(x)
x = self.decoder(x)
x = self.end_conv(x)
return x
class Fusion(nn.Module):
def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"):
super().__init__()
norm_layer = select_norm_layer(norm_type)
self.start_fc = nn.Sequential(
nn.Linear(in_features, base_features),
norm_layer(base_features),
nn.ReLU(True),
)
self.fcs = nn.Sequential(*[
nn.Sequential(
nn.Linear(base_features, base_features),
norm_layer(base_features),
nn.ReLU(True),
) for _ in range(n_blocks - 2)
])
self.end_fc = nn.Sequential(
nn.Linear(base_features, out_features),
)
def forward(self, x):
x = self.start_fc(x)
x = self.fcs(x)
return self.end_fc(x)
@MODEL.register_module("TAHG-Generator")
class Generator(nn.Module):
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, num_blocks=8,
base_channels=64, padding_mode="reflect"):
super(Generator, self).__init__()
self.num_blocks = num_blocks
self.style_encoders = nn.ModuleDict({
"a": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
padding_mode=padding_mode, norm_type="NONE"),
"b": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
padding_mode=padding_mode, norm_type="NONE")
})
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks,
padding_mode=padding_mode, norm_type="IN")
res_block_channels = 2 ** 2 * base_channels
self.adain_res = nn.ModuleList([
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
])
self.decoders = nn.ModuleDict({
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode),
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode)
})
self.fc = nn.Sequential(
nn.Linear(style_dim, style_dim),
nn.ReLU(True),
)
self.fusion = Fusion(style_dim, num_blocks * 2 * res_block_channels * 2, base_features=256, n_blocks=3,
norm_type="NONE")
def forward(self, content_img, style_img, which_decoder: str = "a"):
x = self.content_encoder(content_img)
styles = self.fusion(self.fc(self.style_encoders[which_decoder](style_img)))
styles = torch.chunk(styles, self.num_blocks * 2, dim=1)
for i, ar in enumerate(self.adain_res):
ar.norm1.set_style(styles[2 * i])
ar.norm2.set_style(styles[2 * i + 1])
x = ar(x)
return self.decoders[which_decoder](x)
@MODEL.register_module("TAHG-Discriminator")
class Discriminator(nn.Module):
def __init__(self, in_channels=3, base_channels=64, num_down_sampling=2, num_blocks=3, norm_type="IN",
padding_mode="reflect"):
super(Discriminator, self).__init__()
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
sequence = [nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
bias=use_bias),
norm_layer(num_features=base_channels),
nn.ReLU(inplace=True)
)]
# stacked intermediate layers,
# gradually increasing the number of filters
multiple_now = 1
for n in range(1, num_down_sampling + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** n, 4)
sequence += [
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=3,
padding=1, stride=2, bias=use_bias),
norm_layer(base_channels * multiple_now),
nn.LeakyReLU(0.2, inplace=True)
]
for _ in range(num_blocks):
sequence.append(ResidualBlock(base_channels * multiple_now, padding_mode, norm_type))
self.model = nn.Sequential(*sequence)
def forward(self, x):
return self.model(x)

View File

@@ -1,237 +0,0 @@
import torch
import torch.nn as nn
from .residual_generator import ResidualBlock
from model.registry import MODEL
class RhoClipper(object):
def __init__(self, clip_min, clip_max):
self.clip_min = clip_min
self.clip_max = clip_max
assert clip_min < clip_max
def __call__(self, module):
if hasattr(module, 'rho'):
w = module.rho.data
w = w.clamp(self.clip_min, self.clip_max)
module.rho.data = w
@MODEL.register_module("UGATIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False):
assert (num_blocks >= 0)
super(Generator, self).__init__()
self.input_channels = in_channels
self.output_channels = out_channels
self.base_channels = base_channels
self.num_blocks = num_blocks
self.img_size = img_size
self.light = light
down_encoder = [nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3,
padding_mode="reflect", bias=False),
nn.InstanceNorm2d(base_channels),
nn.ReLU(True)]
n_down_sampling = 2
for i in range(n_down_sampling):
mult = 2 ** i
down_encoder += [nn.Conv2d(base_channels * mult, base_channels * mult * 2, kernel_size=3, stride=2,
padding=1, bias=False, padding_mode="reflect"),
nn.InstanceNorm2d(base_channels * mult * 2),
nn.ReLU(True)]
# Down-Sampling Bottleneck
mult = 2 ** n_down_sampling
for i in range(num_blocks):
# TODO: change ResnetBlock to ResidualBlock, check use_bias param
down_encoder += [ResidualBlock(base_channels * mult, use_bias=False)]
self.down_encoder = nn.Sequential(*down_encoder)
# Class Activation Map
self.gap_fc = nn.Linear(base_channels * mult, 1, bias=False)
self.gmp_fc = nn.Linear(base_channels * mult, 1, bias=False)
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
self.relu = nn.ReLU(True)
# Gamma, Beta block
if self.light:
fc = [nn.Linear(base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True),
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True)]
else:
fc = [
nn.Linear(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True),
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True)]
self.fc = nn.Sequential(*fc)
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
# Up-Sampling Bottleneck
self.up_bottleneck = nn.ModuleList(
[ResnetAdaILNBlock(base_channels * mult, use_bias=False) for _ in range(num_blocks)])
# Up-Sampling
up_decoder = []
for i in range(n_down_sampling):
mult = 2 ** (n_down_sampling - i)
up_decoder += [nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(base_channels * mult, base_channels * mult // 2, kernel_size=3, stride=1,
padding=1, padding_mode="reflect", bias=False),
ILN(base_channels * mult // 2),
nn.ReLU(True)]
up_decoder += [nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3,
padding_mode="reflect", bias=False),
nn.Tanh()]
self.up_decoder = nn.Sequential(*up_decoder)
def forward(self, x):
x = self.down_encoder(x)
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
x = torch.cat([gap, gmp], 1)
x = self.relu(self.conv1x1(x))
heatmap = torch.sum(x, dim=1, keepdim=True)
if self.light:
x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1)
x_ = self.fc(x_.view(x_.shape[0], -1))
else:
x_ = self.fc(x.view(x.shape[0], -1))
gamma, beta = self.gamma(x_), self.beta(x_)
for ub in self.up_bottleneck:
x = ub(x, gamma, beta)
x = self.up_decoder(x)
return x, cam_logit, heatmap
class ResnetAdaILNBlock(nn.Module):
def __init__(self, dim, use_bias):
super(ResnetAdaILNBlock, self).__init__()
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
self.norm1 = AdaILN(dim)
self.relu1 = nn.ReLU(True)
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
self.norm2 = AdaILN(dim)
def forward(self, x, gamma, beta):
out = self.conv1(x)
out = self.norm1(out, gamma, beta)
out = self.relu1(out)
out = self.conv2(out)
out = self.norm2(out, gamma, beta)
return out + x
def instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True)
out_in = (x - in_mean) / torch.sqrt(in_var + eps)
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
out_ln = (x - ln_mean) / torch.sqrt(ln_var + eps)
out = rho.expand(x.shape[0], -1, -1, -1) * out_in + (1 - rho.expand(x.shape[0], -1, -1, -1)) * out_ln
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
return out
class AdaILN(nn.Module):
def __init__(self, num_features, eps=1e-5, default_rho=0.9):
super(AdaILN, self).__init__()
self.eps = eps
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.rho.data.fill_(default_rho)
def forward(self, x, gamma, beta):
return instance_layer_normalization(x, gamma, beta, self.rho, self.eps)
class ILN(nn.Module):
def __init__(self, num_features, eps=1e-5):
super(ILN, self).__init__()
self.eps = eps
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.gamma = nn.Parameter(torch.Tensor(1, num_features))
self.beta = nn.Parameter(torch.Tensor(1, num_features))
self.rho.data.fill_(0.0)
self.gamma.data.fill_(1.0)
self.beta.data.fill_(0.0)
def forward(self, x):
return instance_layer_normalization(
x, self.gamma.expand(x.shape[0], -1), self.beta.expand(x.shape[0], -1), self.rho, self.eps)
@MODEL.register_module("UGATIT-Discriminator")
class Discriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_blocks=5):
super(Discriminator, self).__init__()
encoder = [self.build_conv_block(in_channels, base_channels)]
for i in range(1, num_blocks - 2):
mult = 2 ** (i - 1)
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2))
mult = 2 ** (num_blocks - 2 - 1)
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2, stride=1))
self.encoder = nn.Sequential(*encoder)
# Class Activation Map
mult = 2 ** (num_blocks - 2)
self.gap_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
self.gmp_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
self.leaky_relu = nn.LeakyReLU(0.2, True)
self.conv = nn.utils.spectral_norm(
nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False, padding_mode="reflect"))
@staticmethod
def build_conv_block(in_channels, out_channels, kernel_size=4, stride=2, padding=1, padding_mode="reflect"):
return nn.Sequential(*[
nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
bias=True, padding=padding, padding_mode=padding_mode)),
nn.LeakyReLU(0.2, True),
])
def forward(self, x, return_heatmap=False):
x = self.encoder(x)
batch_size = x.size(0)
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) # B x C x 1 x 1, avg of per channel
gap_logit = self.gap_fc(gap.view(batch_size, -1))
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
gmp_logit = self.gmp_fc(gmp.view(batch_size, -1))
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
x = torch.cat([gap, gmp], 1)
x = self.leaky_relu(self.conv1x1(x))
if return_heatmap:
heatmap = torch.sum(x, dim=1, keepdim=True)
return self.conv(x), cam_logit, heatmap
else:
return self.conv(x), cam_logit

View File

@@ -1,180 +0,0 @@
import torch
import torch.nn as nn
from model.registry import MODEL
from model.normalization import select_norm_layer
class GANImageBuffer(object):
"""This class implements an image buffer that stores previously
generated images.
This buffer allows us to update the discriminator using a history of
generated images rather than the ones produced by the latest generator
to reduce model oscillation.
Args:
buffer_size (int): The size of image buffer. If buffer_size = 0,
no buffer will be created.
buffer_ratio (float): The chance / possibility to use the images
previously stored in the buffer.
"""
def __init__(self, buffer_size, buffer_ratio=0.5):
self.buffer_size = buffer_size
# create an empty buffer
if self.buffer_size > 0:
self.img_num = 0
self.image_buffer = []
self.buffer_ratio = buffer_ratio
def query(self, images):
"""Query current image batch using a history of generated images.
Args:
images (Tensor): Current image batch without history information.
"""
if self.buffer_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
# if the buffer is not full, keep inserting current images
if self.img_num < self.buffer_size:
self.img_num = self.img_num + 1
self.image_buffer.append(image)
return_images.append(image)
else:
use_buffer = torch.rand(1) < self.buffer_ratio
# by self.buffer_ratio, the buffer will return a previously
# stored image, and insert the current image into the buffer
if use_buffer:
random_id = torch.randint(0, self.buffer_size, (1,)).item()
image_tmp = self.image_buffer[random_id].clone()
self.image_buffer[random_id] = image
return_images.append(image_tmp)
# by (1 - self.buffer_ratio), the buffer will return the
# current image
else:
return_images.append(image)
# collect all the images and return
return_images = torch.cat(return_images, 0)
return return_images
@MODEL.register_module()
class ResidualBlock(nn.Module):
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None):
super(ResidualBlock, self).__init__()
if use_bias is None:
# Only for IN, use bias since it does not have affine parameters.
use_bias = norm_type == "IN"
norm_layer = select_norm_layer(norm_type)
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
bias=use_bias)
self.norm1 = norm_layer(num_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
bias=use_bias)
self.norm2 = norm_layer(num_channels)
def forward(self, x):
res = x
x = self.relu1(self.norm1(self.conv1(x)))
x = self.norm2(self.conv2(x))
return x + res
@MODEL.register_module()
class ResGenerator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect',
norm_type="IN"):
super(ResGenerator, self).__init__()
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
self.start_conv = nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
bias=use_bias),
norm_layer(num_features=base_channels),
nn.ReLU(inplace=True)
)
# down sampling
submodules = []
num_down_sampling = 2
for i in range(num_down_sampling):
multiple = 2 ** i
submodules += [
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(num_features=base_channels * multiple * 2),
nn.ReLU(inplace=True)
]
self.encoder = nn.Sequential(*submodules)
res_block_channels = num_down_sampling ** 2 * base_channels
self.resnet_middle = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in
range(num_blocks)])
# up sampling
submodules = []
for i in range(num_down_sampling):
multiple = 2 ** (num_down_sampling - i)
submodules += [
nn.ConvTranspose2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=3, stride=2,
padding=1, output_padding=1, bias=use_bias),
norm_layer(num_features=base_channels * multiple // 2),
nn.ReLU(inplace=True),
]
self.decoder = nn.Sequential(*submodules)
self.end_conv = nn.Sequential(
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
nn.Tanh()
)
def forward(self, x):
x = self.encoder(self.start_conv(x))
x = self.resnet_middle(x)
return self.end_conv(self.decoder(x))
@MODEL.register_module()
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="IN"):
super(PatchDiscriminator, self).__init__()
assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.'
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
kernel_size = 4
padding = 1
sequence = [
nn.Conv2d(in_channels, base_channels, kernel_size=kernel_size, stride=2, padding=padding),
nn.LeakyReLU(0.2, inplace=True),
]
# stacked intermediate layers,
# gradually increasing the number of filters
multiple_now = 1
for n in range(1, num_conv):
multiple_prev = multiple_now
multiple_now = min(2 ** n, 8)
sequence += [
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=kernel_size,
padding=padding, stride=2, bias=use_bias),
norm_layer(base_channels * multiple_now),
nn.LeakyReLU(0.2, inplace=True)
]
multiple_prev = multiple_now
multiple_now = min(2 ** num_conv, 8)
sequence += [
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size, stride=1,
padding=padding, bias=use_bias),
norm_layer(base_channels * multiple_now),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding)
]
self.model = nn.Sequential(*sequence)
def forward(self, x):
return self.model(x)

View File

@@ -1,5 +1,7 @@
from model.registry import MODEL
import model.GAN.residual_generator
import model.GAN.TAHG
import model.GAN.UGATIT
import model.fewshot
from model.registry import MODEL, NORMALIZATION
import model.base.normalization
import model.image_translation.UGATIT
import model.image_translation.CycleGAN
import model.image_translation.pix2pixHD
import model.image_translation.GauGAN
import model.image_translation.TSIT

0
model/base/__init__.py Normal file
View File

128
model/base/module.py Normal file
View File

@@ -0,0 +1,128 @@
import torch.nn as nn
from model.registry import NORMALIZATION
_DO_NO_THING_FUNC = lambda x: x
def _use_bias_checker(norm_type):
return norm_type not in ["IN", "BN", "AdaIN", "FADE", "SPADE"]
def _normalization(norm, num_features, additional_kwargs=None):
if norm == "NONE":
return _DO_NO_THING_FUNC
if additional_kwargs is None:
additional_kwargs = {}
kwargs = dict(_type=norm, num_features=num_features)
kwargs.update(additional_kwargs)
return NORMALIZATION.build_with(kwargs)
def _activation(activation, inplace=True):
if activation == "NONE":
return _DO_NO_THING_FUNC
elif activation == "ReLU":
return nn.ReLU(inplace=inplace)
elif activation == "LeakyReLU":
return nn.LeakyReLU(negative_slope=0.2, inplace=inplace)
elif activation == "Tanh":
return nn.Tanh()
else:
raise NotImplementedError(f"{activation} not valid")
class LinearBlock(nn.Module):
def __init__(self, in_features: int, out_features: int, bias=None, activation_type="ReLU", norm_type="NONE"):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
bias = _use_bias_checker(norm_type) if bias is None else bias
self.linear = nn.Linear(in_features, out_features, bias)
self.normalization = _normalization(norm_type, out_features)
self.activation = _activation(activation_type)
def forward(self, x):
return self.activation(self.normalization(self.linear(x)))
class Conv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, bias=None,
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None,
pre_activation=False, use_transpose_conv=False, **conv_kwargs):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
self.pre_activation = pre_activation
if use_transpose_conv:
# Only "zeros" padding mode is supported for ConvTranspose2d
conv_kwargs["padding_mode"] = "zeros"
conv = nn.ConvTranspose2d
else:
conv = nn.Conv2d
if pre_activation:
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
self.activation = _activation(activation_type, inplace=False)
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
else:
# if caller not set bias, set bias automatically.
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
self.activation = _activation(activation_type)
def forward(self, x):
if self.pre_activation:
return self.convolution(self.activation(self.normalization(x)))
return self.activation(self.normalization(self.convolution(x)))
class ResidualBlock(nn.Module):
def __init__(self, in_channels,
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
out_channels=None, out_activation_type=None, additional_norm_kwargs=None):
"""
Residual Conv Block
:param in_channels:
:param out_channels:
:param padding_mode:
:param activation_type:
:param norm_type:
:param out_activation_type:
:param pre_activation: full pre-activation mode from https://arxiv.org/pdf/1603.05027v3.pdf, figure 4
"""
super().__init__()
self.norm_type = norm_type
if out_channels is None:
out_channels = in_channels
if out_activation_type is None:
# if not specify `out_activation_type`, using default `out_activation_type`
# `out_activation_type` default mode:
# "NONE" for not full pre-activation
# `norm_type` for full pre-activation
out_activation_type = "NONE" if not pre_activation else norm_type
self.learn_skip_connection = in_channels != out_channels
conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type,
additional_norm_kwargs=additional_norm_kwargs, pre_activation=pre_activation,
padding_mode=padding_mode)
self.conv1 = Conv2dBlock(in_channels, in_channels, **conv_param)
self.conv2 = Conv2dBlock(in_channels, out_channels, **conv_param)
if self.learn_skip_connection:
conv_param['kernel_size'] = 1
conv_param['padding'] = 0
self.res_conv = Conv2dBlock(in_channels, out_channels, **conv_param)
def forward(self, x):
res = x if not self.learn_skip_connection else self.res_conv(x)
return self.conv2(self.conv1(x)) + res

143
model/base/normalization.py Normal file
View File

@@ -0,0 +1,143 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import NORMALIZATION
from model.base.module import Conv2dBlock
_VALID_NORM_AND_ABBREVIATION = dict(
IN="InstanceNorm2d",
BN="BatchNorm2d",
)
for abbr, name in _VALID_NORM_AND_ABBREVIATION.items():
NORMALIZATION.register_module(module=getattr(nn, name), name=abbr)
@NORMALIZATION.register_module("ADE")
class AdaptiveDenormalization(nn.Module):
def __init__(self, num_features, base_norm_type="BN", gamma_bias=0.0):
super().__init__()
self.num_features = num_features
self.base_norm_type = base_norm_type
self.norm = self.base_norm(num_features)
self.gamma = None
self.gamma_bias = gamma_bias
self.beta = None
self.have_set_condition = False
def base_norm(self, num_features):
if self.base_norm_type == "IN":
return nn.InstanceNorm2d(num_features, affine=False)
elif self.base_norm_type == "BN":
return nn.BatchNorm2d(num_features, affine=False, track_running_stats=True)
def set_condition(self, gamma, beta):
self.gamma, self.beta = gamma, beta
self.have_set_condition = True
def forward(self, x):
assert self.have_set_condition
x = self.norm(x)
x = (self.gamma + self.gamma_bias) * x + self.beta
self.have_set_condition = False
return x
#
# def __repr__(self):
# return f"{self.__class__.__name__}(num_features={self.num_features}, " \
# f"base_norm_type={self.base_norm_type})"
@NORMALIZATION.register_module("AdaIN")
class AdaptiveInstanceNorm2d(AdaptiveDenormalization):
def __init__(self, num_features: int):
super().__init__(num_features, "IN")
self.num_features = num_features
def set_style(self, style):
style = style.view(*style.size(), 1, 1)
gamma, beta = style.chunk(2, 1)
super().set_condition(gamma, beta)
@NORMALIZATION.register_module("FADE")
class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
def __init__(self, num_features: int, condition_in_channels,
base_norm_type="BN", padding_mode="zeros", gamma_bias=0.0):
super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias)
self.beta_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
padding_mode=padding_mode)
self.gamma_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
padding_mode=padding_mode)
def set_feature(self, feature):
gamma = self.gamma_conv(feature)
beta = self.beta_conv(feature)
super().set_condition(gamma, beta)
@NORMALIZATION.register_module("SPADE")
class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
def __init__(self, num_features: int, condition_in_channels, base_channels=128, base_norm_type="BN",
activation_type="ReLU", padding_mode="zeros", gamma_bias=0.0):
super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias)
self.base_conv_block = Conv2dBlock(condition_in_channels, base_channels, activation_type=activation_type,
kernel_size=3, padding=1, padding_mode=padding_mode, norm_type="NONE")
self.beta_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
self.gamma_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
def set_condition_image(self, condition_image):
feature = self.base_conv_block(condition_image)
gamma = self.gamma_conv(feature)
beta = self.beta_conv(feature)
super().set_condition(gamma, beta)
def _instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
out = rho * F.instance_norm(x, eps=eps) + (1 - rho) * F.layer_norm(x, x.size()[1:], eps=eps)
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
return out
@NORMALIZATION.register_module("ILN")
class ILN(nn.Module):
def __init__(self, num_features, eps=1e-5):
super(ILN, self).__init__()
self.eps = eps
self.rho = nn.Parameter(torch.Tensor(num_features))
self.gamma = nn.Parameter(torch.Tensor(num_features))
self.beta = nn.Parameter(torch.Tensor(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.zeros_(self.rho)
nn.init.ones_(self.gamma)
nn.init.zeros_(self.beta)
def forward(self, x):
return _instance_layer_normalization(
x, self.gamma.view(1, -1), self.beta.view(1, -1), self.rho.view(1, -1, 1, 1), self.eps)
@NORMALIZATION.register_module("AdaILN")
class AdaILN(nn.Module):
def __init__(self, num_features, eps=1e-5, default_rho=0.9):
super(AdaILN, self).__init__()
self.eps = eps
self.rho = nn.Parameter(torch.Tensor(num_features))
self.rho.data.fill_(default_rho)
self.gamma = None
self.beta = None
self.have_set_condition = False
def set_condition(self, gamma, beta):
self.gamma, self.beta = gamma, beta
self.have_set_condition = True
def forward(self, x):
assert self.have_set_condition
out = _instance_layer_normalization(x, self.gamma, self.beta, self.rho.view(1, -1, 1, 1), self.eps)
self.have_set_condition = False
return out

View File

@@ -1,105 +0,0 @@
import math
import torch.nn as nn
from .registry import MODEL
# --- gaussian initialize ---
def init_layer(l):
# Initialization using fan-in
if isinstance(l, nn.Conv2d):
n = l.kernel_size[0] * l.kernel_size[1] * l.out_channels
l.weight.data.normal_(0, math.sqrt(2.0 / float(n)))
elif isinstance(l, nn.BatchNorm2d):
l.weight.data.fill_(1)
l.bias.data.fill_(0)
elif isinstance(l, nn.Linear):
l.bias.data.fill_(0)
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
return x.view(x.size(0), -1)
class SimpleBlock(nn.Module):
def __init__(self, in_channels, out_channels, half_res, leakyrelu=False):
super(SimpleBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
)
self.relu = nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True)
if in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, 2 if half_res else 1, bias=False),
nn.BatchNorm2d(out_channels)
)
else:
self.shortcut = nn.Identity()
def forward(self, x):
o = self.block(x)
return self.relu(o + self.shortcut(x))
class ResNet(nn.Module):
def __init__(self, block, layers, dims, num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
super().__init__()
assert len(layers) == 4, 'Can have only four stages'
self.inplanes = 64
self.start = nn.Sequential(
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(self.inplanes),
nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
trunk = []
in_channels = self.inplanes
for i in range(4):
for j in range(layers[i]):
half_res = i >= 1 and j == 0
trunk.append(block(in_channels, dims[i], half_res, leakyrelu))
in_channels = dims[i]
if flatten:
trunk.append(nn.AvgPool2d(7))
trunk.append(Flatten())
if num_classes is not None:
if classifier_type == "linear":
trunk.append(nn.Linear(in_channels, num_classes))
elif classifier_type == "distlinear":
pass
else:
raise ValueError(f"invalid classifier_type:{classifier_type}")
self.trunk = nn.Sequential(*trunk)
self.apply(init_layer)
def forward(self, x):
return self.trunk(self.start(x))
@MODEL.register_module()
def resnet10(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)
@MODEL.register_module()
def resnet18(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
return ResNet(SimpleBlock, [2, 2, 2, 2], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)
@MODEL.register_module()
def resnet34(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
return ResNet(SimpleBlock, [3, 4, 6, 3], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)

View File

@@ -0,0 +1,151 @@
import torch.nn as nn
from model import MODEL
from model.base.module import Conv2dBlock, ResidualBlock
class Encoder(nn.Module):
def __init__(self, in_channels, base_channels, num_conv, num_res, max_down_sampling_multiple=2,
padding_mode='reflect', activation_type="ReLU",
down_conv_norm_type="IN", down_conv_kernel_size=3,
res_norm_type="IN", pre_activation=False):
super().__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
activation_type=activation_type, norm_type=down_conv_norm_type
)]
multiple_now = 1
for i in range(1, num_conv + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode="zeros",
activation_type=activation_type, norm_type=down_conv_norm_type
))
self.out_channels = multiple_now * base_channels
sequence += [
ResidualBlock(
self.out_channels,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type=res_norm_type,
pre_activation=pre_activation
) for _ in range(num_res)
]
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
return self.sequence(x)
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
activation_type="ReLU", padding_mode='reflect',
up_conv_kernel_size=5, up_conv_norm_type="LN",
res_norm_type="AdaIN", pre_activation=False, use_transpose_conv=False):
super().__init__()
self.residual_blocks = nn.ModuleList([
ResidualBlock(
in_channels,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type=res_norm_type,
pre_activation=pre_activation
) for _ in range(num_residual_blocks)
])
sequence = list()
channels = in_channels
padding = (up_conv_kernel_size - 1) // 2
for i in range(num_up_sampling):
if use_transpose_conv:
sequence.append(Conv2dBlock(
channels, channels // 2, kernel_size=up_conv_kernel_size, stride=2,
padding=padding, output_padding=padding,
padding_mode=padding_mode,
activation_type=activation_type, norm_type=up_conv_norm_type,
use_transpose_conv=True
))
else:
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1,
padding=padding, padding_mode=padding_mode,
activation_type=activation_type, norm_type=up_conv_norm_type),
))
channels = channels // 2
sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3,
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"))
self.up_sequence = nn.Sequential(*sequence)
def forward(self, x):
for i, blk in enumerate(self.residual_blocks):
x = blk(x)
return self.up_sequence(x)
@MODEL.register_module("CycleGAN-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, activation_type="ReLU",
padding_mode='reflect', norm_type="IN", pre_activation=False, use_transpose_conv=True):
super().__init__()
self.encoder = Encoder(in_channels, base_channels, num_conv=2, num_res=num_blocks,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type=norm_type, res_norm_type=norm_type, pre_activation=pre_activation)
self.decoder = Decoder(self.encoder.out_channels, out_channels, num_up_sampling=2, num_residual_blocks=0,
padding_mode=padding_mode, activation_type=activation_type,
up_conv_kernel_size=3, up_conv_norm_type=norm_type,
pre_activation=pre_activation, use_transpose_conv=use_transpose_conv)
def forward(self, x):
return self.decoder(self.encoder(x))
@MODEL.register_module("PatchDiscriminator")
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_conv=4, need_intermediate_feature=False,
norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
self.need_intermediate_feature = need_intermediate_feature
kernel_size = 4
padding = (kernel_size - 1) // 2
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=kernel_size, stride=2, padding=padding, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
multiple_now = 1
for i in range(1, num_conv):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 3)
stride = 1 if i == num_conv - 1 else 2
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
))
sequence.append(nn.Conv2d(
base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding, padding_mode=padding_mode))
if self.need_intermediate_feature:
self.sequence = nn.ModuleList(sequence)
else:
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
if self.need_intermediate_feature:
intermediate_feature = []
for layer in self.sequence:
x = layer(x)
intermediate_feature.append(x)
return tuple(intermediate_feature)
else:
return self.sequence(x)
if __name__ == '__main__':
g = Generator(**dict(in_channels=3, out_channels=3))
print(g)
pd = PatchDiscriminator(**dict(in_channels=3, base_channels=64, num_conv=4))
print(pd)

View File

@@ -0,0 +1,183 @@
from collections import OrderedDict
from functools import partial
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.base.module import ResidualBlock, Conv2dBlock, LinearBlock
from model import MODEL
class StyleEncoder(nn.Module):
def __init__(self, in_channels, style_dim, num_conv, end_size=(4, 4), base_channels=64,
norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
multiple_now = 0
max_multiple = 3
for i in range(1, num_conv + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** max_multiple)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=3, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
))
self.sequence = nn.Sequential(*sequence)
self.fc_avg = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim)
self.fc_var = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim)
def forward(self, x):
x = self.sequence(x)
x = x.view(x.size(0), -1)
return self.fc_avg(x), self.fc_var(x)
class ImprovedSPADEGenerator(nn.Module):
def __init__(self, in_channels, out_channels, output_size, have_style_input, style_dim, start_size=(4, 4),
base_channels=64, padding_mode='reflect', activation_type="LeakyReLU", pre_activation=False):
super().__init__()
assert output_size in (128, 256, 512, 1024)
self.output_size = output_size
kernel_size = 3
if have_style_input:
self.style_converter = nn.Sequential(
LinearBlock(style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"),
LinearBlock(2 * style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"),
)
base_conv = partial(
Conv2dBlock,
pre_activation=pre_activation, activation_type=activation_type,
norm_type="AdaIN" if have_style_input else "NONE",
kernel_size=kernel_size, padding=(kernel_size - 1) // 2, padding_mode=padding_mode
)
base_residual_block = partial(
ResidualBlock,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type="SPADE",
pre_activation=True,
additional_norm_kwargs=dict(
condition_in_channels=in_channels, base_channels=128, base_norm_type="BN",
activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0
)
)
sequence = OrderedDict()
channels = (2 ** 4) * base_channels
sequence["block_head"] = nn.Sequential(OrderedDict([
("conv_input", base_conv(in_channels=in_channels, out_channels=channels)),
("conv_style", base_conv(in_channels=channels, out_channels=channels)),
("res_a", base_residual_block(in_channels=channels, out_channels=channels)),
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
("up", nn.Upsample(scale_factor=2, mode='nearest'))
]))
for i in range(4, 9 - min(int(math.log(self.output_size, 2)), 8), -1):
channels = (2 ** (i - 1)) * base_channels
sequence[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([
("res_a", base_residual_block(in_channels=channels * 2, out_channels=channels)),
("conv_style", base_conv(in_channels=channels, out_channels=channels)),
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
("up", nn.Upsample(scale_factor=2, mode='nearest'))
]))
self.sequence = nn.Sequential(sequence)
# channels = 2*base_channels when output size is 256, 512, 1024
# channels = 5*base_channels when output size is 128
out_modules = OrderedDict()
out_modules["out_1"] = nn.Sequential(
Conv2dBlock(
channels, out_channels, kernel_size=5, stride=1, padding=2,
pre_activation=pre_activation,
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"
),
nn.Tanh()
)
for i in range(int(math.log(self.output_size, 2)) - 8):
channels = channels // 2
out_modules[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([
("res_a", base_residual_block(in_channels=2 * channels, out_channels=channels)),
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
("up", nn.Upsample(scale_factor=2, mode='nearest'))
]))
out_modules[f"out_{i + 2}"] = nn.Sequential(
Conv2dBlock(
channels, out_channels, kernel_size=5, stride=1, padding=2,
pre_activation=pre_activation,
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"
),
nn.Tanh()
)
self.out_modules = nn.ModuleDict(out_modules)
def forward(self, seg, style=None):
pass
@MODEL.register_module()
class SPADEGenerator(nn.Module):
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
self.sx, self.sy = start_size
self.use_vae = use_vae
self.num_z_dim = num_z_dim
if use_vae:
self.input_converter = nn.Linear(num_z_dim, 16 * base_channels * self.sx * self.sy)
else:
self.input_converter = nn.Conv2d(in_channels, 16 * base_channels, kernel_size=3, padding=1)
sequence = []
multiple_now = 16
for i in range(num_blocks - 1, -1, -1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
if i != num_blocks - 1:
sequence.append(nn.Upsample(scale_factor=2))
sequence.append(ResidualBlock(
base_channels * multiple_prev,
out_channels=base_channels * multiple_now,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type="SPADE",
pre_activation=True,
additional_norm_kwargs=dict(
condition_in_channels=in_channels, base_channels=128, base_norm_type="BN",
activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0
)
))
self.sequence = nn.Sequential(*sequence)
self.output_converter = Conv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE")
def forward(self, seg, z=None):
if self.use_vae:
if z is None:
z = torch.randn(seg.size(0), self.num_z_dim, device=seg.device)
x = self.input_converter(z).view(seg.size(0), -1, self.sx, self.sy)
else:
x = self.input_converter(F.interpolate(seg, size=(self.sx, self.sy)))
for blk in self.sequence:
if isinstance(blk, ResidualBlock):
downsampling_seg = F.interpolate(seg, size=x.size()[2:], mode='nearest')
blk.conv1.normalization.set_condition_image(downsampling_seg)
blk.conv2.normalization.set_condition_image(downsampling_seg)
if blk.learn_skip_connection:
blk.res_conv.normalization.set_condition_image(downsampling_seg)
x = blk(x)
return self.output_converter(x)
if __name__ == '__main__':
g = SPADEGenerator(3, 3, 7, False, 256)
print(g)
print(g(torch.randn(2, 3, 256, 256)).size())

View File

@@ -0,0 +1,89 @@
import torch
import torch.nn as nn
from model import MODEL
from model.base.module import LinearBlock
from model.image_translation.CycleGAN import Encoder, Decoder
class StyleEncoder(nn.Module):
def __init__(self, in_channels, out_dim, num_conv, base_channels=64,
max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE",
pre_activation=False):
super().__init__()
self.down_encoder = Encoder(
in_channels, base_channels, num_conv, num_res=0, max_down_sampling_multiple=max_down_sampling_multiple,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type=norm_type, down_conv_kernel_size=4, pre_activation=pre_activation,
)
sequence = list()
sequence.append(nn.AdaptiveAvgPool2d(1))
# conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code
sequence.append(nn.Conv2d(self.down_encoder.out_channels, out_dim, kernel_size=1, stride=1, padding=0))
self.sequence = nn.Sequential(*sequence)
def forward(self, image):
return self.sequence(image).view(image.size(0), -1)
class MLPFusion(nn.Module):
def __init__(self, in_features, out_features, base_features, n_blocks, activation_type="ReLU", norm_type="NONE"):
super().__init__()
sequence = [LinearBlock(in_features, base_features, activation_type=activation_type, norm_type=norm_type)]
sequence += [
LinearBlock(base_features, base_features, activation_type=activation_type, norm_type=norm_type)
for _ in range(n_blocks - 2)
]
sequence.append(LinearBlock(base_features, out_features, activation_type=activation_type, norm_type=norm_type))
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
return self.sequence(x)
@MODEL.register_module("MUNIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, style_dim=8,
num_mlp_base_feature=256, num_mlp_blocks=3,
max_down_sampling_multiple=2, num_content_down_sampling=2, num_style_down_sampling=2,
encoder_num_residual_blocks=4, decoder_num_residual_blocks=4,
padding_mode='reflect', activation_type="ReLU", pre_activation=False):
super().__init__()
self.content_encoder = Encoder(
in_channels, base_channels, num_content_down_sampling, encoder_num_residual_blocks,
max_down_sampling_multiple=num_content_down_sampling,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type="IN", down_conv_kernel_size=4,
res_norm_type="IN", pre_activation=pre_activation
)
self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels,
max_down_sampling_multiple, padding_mode, activation_type,
norm_type="NONE", pre_activation=pre_activation)
content_channels = base_channels * (2 ** max_down_sampling_multiple)
self.fusion = MLPFusion(style_dim, decoder_num_residual_blocks * 2 * content_channels * 2,
num_mlp_base_feature, num_mlp_blocks, activation_type,
norm_type="NONE")
self.decoder = Decoder(in_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks,
activation_type=activation_type, padding_mode=padding_mode,
up_conv_kernel_size=5, up_conv_norm_type="LN",
res_norm_type="AdaIN", pre_activation=pre_activation)
def encode(self, x):
return self.content_encoder(x), self.style_encoder(x)
def decode(self, content, style):
as_param_style = torch.chunk(self.fusion(style), 2 * len(self.decoder.residual_blocks), dim=1)
# set style for decoder
for i, blk in enumerate(self.decoder.residual_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
self.decoder(content)
def forward(self, x):
content, style = self.encode(x)
return self.decode(content, style)

View File

@@ -0,0 +1,98 @@
import torch.nn as nn
import torch.nn.functional as F
import torch
from model import MODEL
from model.base.module import ResidualBlock, Conv2dBlock
class Interpolation(nn.Module):
def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
super(Interpolation, self).__init__()
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
recompute_scale_factor=False)
def __repr__(self):
return f"Interpolation(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
@MODEL.register_module("TSIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=7,
padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
self.input_layer = Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
activation_type=activation_type, norm_type="IN",
)
multiple_now = 1
stream_sequence = []
for i in range(1, num_blocks + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
stream_sequence.append(nn.Sequential(
Interpolation(scale_factor=0.5, mode="nearest"),
ResidualBlock(
multiple_prev * base_channels, out_channels=multiple_now * base_channels,
padding_mode=padding_mode, activation_type=activation_type, norm_type="IN")
))
self.down_sequence = nn.ModuleList(stream_sequence)
sequence = []
multiple_now = 16
for i in range(num_blocks - 1, -1, -1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
sequence.append(nn.Sequential(
ResidualBlock(
base_channels * multiple_prev,
out_channels=base_channels * multiple_now,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type="FADE",
pre_activation=True,
additional_norm_kwargs=dict(
condition_in_channels=base_channels * multiple_prev, base_norm_type="BN",
padding_mode="zeros", gamma_bias=0.0
)
),
Interpolation(scale_factor=2, mode="nearest")
))
self.up_sequence = nn.Sequential(*sequence)
self.output_layer = Conv2dBlock(
base_channels, out_channels, kernel_size=3, stride=1, padding=1,
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"
)
def forward(self, x, z=None):
c = self.input_layer(x)
contents = []
for blk in self.down_sequence:
c = blk(c)
contents.append(c)
if z is None:
# for image 256x256, z size: [batch_size, 1024, 2, 2]
z = torch.randn(size=contents[-1].size(), device=contents[-1].device)
for blk in self.up_sequence:
res = blk[0]
c = contents.pop()
res.conv1.normalization.set_feature(c)
res.conv2.normalization.set_feature(c)
if res.learn_skip_connection:
res.res_conv.normalization.set_feature(c)
return self.output_layer(self.up_sequence(z))
if __name__ == '__main__':
g = Generator(3, 3).cuda()
img = torch.randn(2, 3, 256, 256).cuda()
print(g(img).size())

View File

@@ -0,0 +1,125 @@
import torch
import torch.nn as nn
from model import MODEL
from model.base.module import Conv2dBlock, LinearBlock
from model.image_translation.CycleGAN import Encoder, Decoder
class CAMClassifier(nn.Module):
def __init__(self, in_channels, activation_type="ReLU"):
super(CAMClassifier, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.avg_fc = nn.Linear(in_channels, 1, bias=False)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.max_fc = nn.Linear(in_channels, 1, bias=False)
self.fusion_conv = Conv2dBlock(in_channels * 2, in_channels, kernel_size=1, stride=1, bias=True,
activation_type=activation_type, norm_type="NONE")
def forward(self, x):
avg_logit = self.avg_fc(self.avg_pool(x).view(x.size(0), -1))
max_logit = self.max_fc(self.max_pool(x).view(x.size(0), -1))
return self.fusion_conv(torch.cat(
[x * self.avg_fc.weight.unsqueeze(2).unsqueeze(3), x * self.max_fc.weight.unsqueeze(2).unsqueeze(3)],
dim=1
)), torch.cat([avg_logit, max_logit], 1)
@MODEL.register_module("UGATIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False,
activation_type="ReLU", norm_type="IN", padding_mode='reflect', pre_activation=False):
super(Generator, self).__init__()
self.light = light
n_down_sampling = 2
self.encoder = Encoder(in_channels, base_channels, n_down_sampling, num_blocks,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type,
pre_activation=pre_activation)
mult = 2 ** n_down_sampling
self.cam = CAMClassifier(base_channels * mult, activation_type)
# Gamma, Beta block
if self.light:
self.fc = nn.Sequential(
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE"),
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE")
)
else:
self.fc = nn.Sequential(
LinearBlock(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, False,
"ReLU", "NONE"),
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE")
)
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
self.decoder = Decoder(
base_channels * mult, out_channels, n_down_sampling, num_blocks,
activation_type=activation_type, padding_mode=padding_mode,
up_conv_kernel_size=3, up_conv_norm_type="ILN",
res_norm_type="AdaILN", pre_activation=pre_activation
)
def forward(self, x):
x = self.encoder(x)
x, cam_logit = self.cam(x)
heatmap = torch.sum(x, dim=1, keepdim=True)
if self.light:
x_ = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
x_ = self.fc(x_.view(x_.shape[0], -1))
else:
x_ = self.fc(x.view(x.shape[0], -1))
gamma, beta = self.gamma(x_), self.beta(x_)
for blk in self.decoder.residual_blocks:
blk.conv1.normalization.set_condition(gamma, beta)
blk.conv2.normalization.set_condition(gamma, beta)
return self.decoder(x), cam_logit, heatmap
@MODEL.register_module("UGATIT-Discriminator")
class Discriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_blocks=5,
activation_type="LeakyReLU", norm_type="NONE", padding_mode='reflect'):
super().__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
sequence += [Conv2dBlock(
base_channels * (2 ** i), base_channels * (2 ** i) * 2,
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type) for i in range(num_blocks - 3)]
sequence.append(
Conv2dBlock(base_channels * (2 ** (num_blocks - 3)), base_channels * (2 ** (num_blocks - 2)),
kernel_size=4, stride=1, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type)
)
self.sequence = nn.Sequential(*sequence)
mult = 2 ** (num_blocks - 2)
self.cam = CAMClassifier(base_channels * mult, activation_type)
self.conv = nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False,
padding_mode="reflect")
def forward(self, x, return_heatmap=False):
x = self.sequence(x)
x, cam_logit = self.cam(x)
if return_heatmap:
heatmap = torch.sum(x, dim=1, keepdim=True)
return self.conv(x), cam_logit, heatmap
else:
return self.conv(x), cam_logit

View File

View File

@@ -0,0 +1,29 @@
import torch.nn as nn
import torch.nn.functional as F
from model import MODEL
@MODEL.register_module()
class MultiScaleDiscriminator(nn.Module):
def __init__(self, num_scale, discriminator_cfg, down_sample_method="avg"):
super().__init__()
assert down_sample_method in ["avg", "bilinear"]
self.down_sample_method = down_sample_method
self.discriminator_list = nn.ModuleList([
MODEL.build_with(discriminator_cfg) for _ in range(num_scale)
])
def down_sample(self, x):
if self.down_sample_method == "avg":
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
if self.down_sample_method == "bilinear":
return F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True)
def forward(self, x):
results = []
for discriminator in self.discriminator_list:
results.append(discriminator(x))
x = self.down_sample(x)
return results

View File

@@ -1,75 +0,0 @@
import torch.nn as nn
import functools
import torch
def select_norm_layer(norm_type):
if norm_type == "BN":
return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == "IN":
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == "LN":
return functools.partial(LayerNorm2d, affine=True)
elif norm_type == "NONE":
return lambda num_features: nn.Identity()
elif norm_type == "AdaIN":
return functools.partial(AdaptiveInstanceNorm2d, affine=False, track_running_stats=False)
else:
raise NotImplemented(f'normalization layer {norm_type} is not found')
class LayerNorm2d(nn.Module):
def __init__(self, num_features, eps: float = 1e-5, affine: bool = True):
super().__init__()
self.num_features = num_features
self.eps = eps
self.affine = affine
if self.affine:
self.channel_gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.channel_beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.reset_parameters()
def reset_parameters(self):
if self.affine:
nn.init.uniform_(self.channel_gamma)
nn.init.zeros_(self.channel_beta)
def forward(self, x):
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
x = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
if self.affine:
return self.channel_gamma * x + self.channel_beta
return x
def __repr__(self):
return f"{self.__class__.__name__}(num_features={self.num_features}, affine={self.affine})"
class AdaptiveInstanceNorm2d(nn.Module):
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1,
affine: bool = False, track_running_stats: bool = False):
super().__init__()
self.num_features = num_features
self.affine = affine
self.track_running_stats = track_running_stats
self.norm = nn.InstanceNorm2d(num_features, eps, momentum, affine, track_running_stats)
self.gamma = None
self.beta = None
self.have_set_style = False
def set_style(self, style):
style = style.view(*style.size(), 1, 1)
self.gamma, self.beta = style.chunk(2, 1)
self.have_set_style = True
def forward(self, x):
assert self.have_set_style
x = self.norm(x)
x = self.gamma * x + self.beta
self.have_set_style = False
return x
def __repr__(self):
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
f"affine={self.affine}, track_running_stats={self.track_running_stats})"

View File

@@ -1,3 +1,4 @@
from util.registry import Registry
MODEL = Registry("model")
NORMALIZATION = Registry("normalization")

View File

@@ -65,7 +65,8 @@ def generation_init_weights(module, init_type='normal', init_gain=0.02):
elif classname.find('BatchNorm2d') != -1:
# BatchNorm Layer's weight is not a matrix;
# only normal distribution applies.
normal_init(m, 1.0, init_gain)
if m.weight is not None:
normal_init(m, 1.0, init_gain)
assert isinstance(module, nn.Module)
module.apply(init_func)

8
run.sh
View File

@@ -5,16 +5,18 @@ TASK=$2
GPUS=$3
MORE_ARG=${*:4}
RANDOM_MASTER=$(shuf -i 2000-65000 -n 1)
_command="print(len('${GPUS}'.split(',')))"
GPU_COUNT=$(python3 -c "${_command}")
echo "GPU_COUNT:${GPU_COUNT}"
echo CUDA_VISIBLE_DEVICES=$GPUS \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed "$MORE_ARG"
CUDA_VISIBLE_DEVICES=$GPUS \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
--master_port=${RANDOM_MASTER} \
main.py "$TASK" "$CONFIG" $MORE_ARG --backup_config --setup_output_dir --setup_random_seed

46
tool/dump_tensorboard.py Normal file
View File

@@ -0,0 +1,46 @@
#!/usr/bin/env python
# edit from https://gist.github.com/hysts/81a0d30ac4f33dfa0c8859383aec42c2
import argparse
import pathlib
import cv2
import numpy as np
from tensorboard.backend.event_processing import event_accumulator
def save(outdir: pathlib.Path, tag, event_acc):
events = event_acc.Images(tag)
for index, event in enumerate(events):
s = np.frombuffer(event.encoded_image_string, dtype=np.uint8)
image = cv2.imdecode(s, cv2.IMREAD_COLOR)
outpath = outdir / f"{tag.replace('/', '_')}@{index}.png"
cv2.imwrite(outpath.as_posix(), image)
# ffmpeg -framerate 1 -i ./tmp/test_b/%04d.jpg -vcodec mpeg4 test_b.mp4
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, required=True)
parser.add_argument('--outdir', type=str, required=True)
parser.add_argument("--tag", type=str, required=False)
args = parser.parse_args()
event_acc = event_accumulator.EventAccumulator(args.path, size_guidance={'images': 0})
event_acc.Reload()
outdir = pathlib.Path(args.outdir)
outdir.mkdir(exist_ok=True, parents=True)
if args.tag is None:
for tag in event_acc.Tags()['images']:
save(outdir, tag, event_acc)
else:
assert args.tag in event_acc.Tags()['images'], f"{args.tag} not in {event_acc.Tags()['images']}"
save(outdir, args.tag, event_acc)
if __name__ == '__main__':
main()

32
tool/encoder_distance.py Normal file
View File

@@ -0,0 +1,32 @@
from pathlib import Path
import torch
#
# data = {}
#
# for i in range(1, 422 + 1):
# _, names = torch.load(f"/tmp/pt/batch{i}.pt")
# generated = torch.load(f"/tmp/pt/generated{i}.pt")
# print(len(names))
# for j, n in enumerate(names):
# data[Path(names[j]).stem] = generated[j]
#
# torch.save(data, "/tmp/data.pt")
data = torch.load("/tmp/data.pt")
videos = sorted(list(set([k.split("@")[0] for k in data.keys()])))
for idx in range(len(videos)):
print(videos[idx])
videos_data = {}
for k in data:
if k.startswith(videos[idx]):
videos_data[int(k.split("@")[-1])] = data[k]
to_save = []
for i in range(2, len(videos_data) + 1):
to_save.append(torch.mean(torch.abs(videos_data[i] - videos_data[1])).cpu())
torch.save(to_save, f"{videos[idx]}.pt")
print(f"{videos[idx]}.pt")

14
tool/inspect_model.py Normal file
View File

@@ -0,0 +1,14 @@
import sys
import torch
from omegaconf import OmegaConf
from engine.util.build import build_model
config = OmegaConf.load(sys.argv[1])
generator = build_model(config.model.generator)
ckp = torch.load(sys.argv[2], map_location="cpu")
generator.module.load_state_dict(ckp["generator_main"])

View File

@@ -0,0 +1,13 @@
from pathlib import Path
import sys
from collections import defaultdict
from itertools import permutations
pids = defaultdict(list)
for p in Path(sys.argv[1]).glob("*.jpg"):
pids[p.stem[:7]].append(p.stem)
data = []
for p in pids:
data.extend(list(permutations(pids[p], 2)))

69
tool/verify_loss.py Normal file
View File

@@ -0,0 +1,69 @@
import torch
from torch.utils.data import DataLoader
from ignite.utils import convert_tensor
from omegaconf import OmegaConf
from data.dataset import SingleFolderDataset
from loss.I2I.perceptual_loss import PerceptualLoss
import ignite.distributed as idist
CONFIG = """
loss:
perceptual:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'NL2'
style_loss: False
perceptual_loss: True
match_data:
root: "/tmp/generated/"
pipeline:
- Load
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
not_match_data:
root: "/data/i2i/selfie2anime/trainB/"
pipeline:
- Load
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
"""
config = OmegaConf.create(CONFIG)
dataset = SingleFolderDataset(**config.match_data)
data_loader = DataLoader(dataset, 1, False, num_workers=1)
perceptual_loss = PerceptualLoss(**config.loss.perceptual).to("cuda:0")
pls = []
for batch in data_loader:
with torch.no_grad():
batch = convert_tensor(batch, "cuda:0")
x, t = torch.chunk(batch, 2, -1)
pl, _ = perceptual_loss(x, t)
print(pl)
pls.append(pl)
torch.save(torch.stack(pls).cpu(), "verify_loss.match.pt")
dataset = SingleFolderDataset(**config.not_match_data)
data_loader = DataLoader(dataset, 4, False, num_workers=1)
pls = []
for batch in data_loader:
with torch.no_grad():
batch = convert_tensor(batch, "cuda:0")
for i, j in [(0, 1), (1, 2), (2, 3), (3, 0)]:
x, t = batch[i].unsqueeze(dim=0), batch[j].unsqueeze(dim=0)
pl, _ = perceptual_loss(x, t)
print(pl)
pls.append(pl)
torch.save(torch.stack(pls).cpu(), "verify_loss.not_match.pt")

View File

@@ -1,27 +0,0 @@
import torch
import torch.optim as optim
import ignite.distributed as idist
from omegaconf import OmegaConf
from model import MODEL
from util.distributed import auto_model
def build_model(cfg, distributed_args=None):
cfg = OmegaConf.to_container(cfg)
model_distributed_config = cfg.pop("_distributed", dict())
model = MODEL.build_with(cfg)
if model_distributed_config.get("bn_to_syncbn"):
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args
return auto_model(model, **distributed_args)
def build_optimizer(params, cfg):
assert "_type" in cfg
cfg = OmegaConf.to_container(cfg)
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
return idist.auto_optim(optimizer)

View File

@@ -1,66 +0,0 @@
import torch
import torch.nn as nn
from ignite.distributed import utils as idist
from ignite.distributed.comp_models import native as idist_native
from ignite.utils import setup_logger
def auto_model(model: nn.Module, **additional_kwargs) -> nn.Module:
"""Helper method to adapt provided model for non-distributed and distributed configurations (supporting
all available backends from :meth:`~ignite.distributed.utils.available_backends()`).
Internally, we perform to following:
- send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device.
- wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1.
- wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available.
Examples:
.. code-block:: python
import ignite.distribted as idist
model = idist.auto_model(model)
In addition with NVidia/Apex, it can be used in the following way:
.. code-block:: python
import ignite.distribted as idist
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model = idist.auto_model(model)
Args:
model (torch.nn.Module): model to adapt.
Returns:
torch.nn.Module
.. _torch DistributedDataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel
.. _torch DataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
"""
logger = setup_logger(__name__ + ".auto_model")
# Put model's parameters to device if its parameters are not on the device
device = idist.device()
if not all([p.device == device for p in model.parameters()]):
model.to(device)
# distributed data parallel model
if idist.get_world_size() > 1:
if idist.backend() == idist_native.NCCL:
lrank = idist.get_local_rank()
logger.info("Apply torch DistributedDataParallel on model, device id: {}".format(lrank))
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[lrank, ], **additional_kwargs)
elif idist.backend() == idist_native.GLOO:
logger.info("Apply torch DistributedDataParallel on model")
model = torch.nn.parallel.DistributedDataParallel(model, **additional_kwargs)
# not distributed but multiple GPUs reachable so data parallel model
elif torch.cuda.device_count() > 1 and "cuda" in idist.device().type:
logger.info("Apply torch DataParallel on model")
model = torch.nn.parallel.DataParallel(model, **additional_kwargs)
return model

View File

@@ -1,26 +1,34 @@
import torchvision.utils
from matplotlib.pyplot import get_cmap
import torch
import warnings
from torch.nn.functional import interpolate
import numpy as np
import cv2
def attention_colored_map(attentions, size=None, cmap_name="jet"):
def attention_colored_map(attentions, size=None):
assert attentions.dim() == 4 and attentions.size(1) == 1
device = attentions.device
min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
attentions -= min_attentions
attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
if size is not None and attentions.size()[-2:] != size:
attentions = attentions.detach().cpu().numpy()
attentions = (attentions * 255).astype(np.uint8)
need_resize = False
if size is not None and attentions.shape[-2:] != size:
assert len(size) == 2, "for interpolate, size must be (x, y), have two dim"
attentions = interpolate(attentions, size, mode="bilinear", align_corners=False)
cmap = get_cmap(cmap_name)
ca = cmap(attentions.squeeze(1).cpu())[:, :, :, :3]
return torch.from_numpy(ca).permute(0, 3, 1, 2).contiguous()
need_resize = True
subs = []
for sub in attentions:
sub = cv2.resize(sub[0], size) if need_resize else sub[0] # numpy.array shape=size
subs.append(cv2.applyColorMap(sub, cv2.COLORMAP_JET)) # append a (size[0], size[1], 3) numpy array
subs = np.stack(subs) # (batch_size, size[0], size[1], 3)
return torch.from_numpy(subs).permute(0, 3, 1, 2).contiguous().to(device).float() / 255
def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
def fuse_attention_map(images, attentions, alpha=0.5):
"""
:param images: B x H x W
@@ -35,7 +43,7 @@ def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
if attentions.size(1) != 1:
warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}")
return images
colored_attentions = attention_colored_map(attentions, images.size()[-2:], cmap_name).to(images.device)
colored_attentions = attention_colored_map(attentions, images.size()[-2:])
return images * alpha + colored_attentions * (1 - alpha)

View File

@@ -1,6 +1,25 @@
import importlib
import logging
from typing import Optional
import pkgutil
from pathlib import Path
from typing import Optional
def import_submodules(package, recursive=True):
""" Import all submodules of a module, recursively, including subpackages
:param package: package (name or actual module)
:type package: str | module
:rtype: dict[str, types.ModuleType]
"""
if isinstance(package, str):
package = importlib.import_module(package)
results = {}
for loader, name, is_pkg in pkgutil.walk_packages(package.__path__):
full_name = package.__name__ + '.' + name
results[name] = importlib.import_module(full_name)
if recursive and is_pkg:
results.update(import_submodules(full_name))
return results
def setup_logger(
@@ -8,7 +27,6 @@ def setup_logger(
level: int = logging.INFO,
logger_format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s",
filepath: Optional[str] = None,
file_level: int = logging.DEBUG,
distributed_rank: Optional[int] = None,
) -> logging.Logger:
"""Setups logger: name, level, format etc.
@@ -18,7 +36,6 @@ def setup_logger(
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG
logger_format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`
filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
file_level (int): Optional logging level for logging file.
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
If None, distributed_rank is initialized to the rank of process.

View File

@@ -1,8 +1,10 @@
import inspect
from omegaconf.dictconfig import DictConfig
from omegaconf import OmegaConf
import warnings
from types import ModuleType
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
class _Registry:
def __init__(self, name):
@@ -51,6 +53,10 @@ class _Registry:
else:
raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}')
for invalid_key in [k for k in args.keys() if k.startswith("_")]:
warnings.warn(f"got param start with `_`: {invalid_key}, will remove it")
args.pop(invalid_key)
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError('default_args must be a dict or None, '
f'but got {type(default_args)}')
@@ -130,8 +136,11 @@ class Registry(_Registry):
if module_name is None:
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {self.name}')
if self._module_dict[module_name] == module_class:
warnings.warn(f'{module_name} is already registered in {self.name}, but is the same class')
return
raise KeyError(f'{module_name}:{self._module_dict[module_name]} is already registered in {self.name}'
f'so {module_class} can not be registered')
self._module_dict[module_name] = module_class
def register_module(self, name=None, force=False, module=None):