Compare commits
39 Commits
e71e8d95d0
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 8998c30c23 | |||
| 0bec02bf6d | |||
| f7b7b78669 | |||
| 376f5caeb7 | |||
| 0019d4034c | |||
| 0927fa3de5 | |||
| 611901cbdf | |||
| a6ffab1445 | |||
| 7b05b45156 | |||
| 2de00d0245 | |||
| 74a7cfb2d8 | |||
| 436bca88b4 | |||
| 6070f08835 | |||
| 06b2abd19a | |||
| 9c08b4cd09 | |||
| 04c6366c07 | |||
| 6ea13df465 | |||
| 776fe40199 | |||
| f67bcdf161 | |||
| 16f18ab2e2 | |||
| 0f2b67e215 | |||
| acf243cb12 | |||
| fbea96f6d7 | |||
| ca55318253 | |||
| b01016edb5 | |||
| 61e04de8a5 | |||
| 2ff4a91057 | |||
| f70658eaed | |||
| 340a344e91 | |||
| 85b5c3f589 | |||
| 72d09aa483 | |||
| 7ea9c6d0d5 | |||
| 87cbcf34d3 | |||
| 97ded53b30 | |||
| ab545843bf | |||
| e3c760d0c5 | |||
| 39c754374c | |||
| 2469bf15fe | |||
| 14d4247112 |
11
.idea/deployment.xml
generated
11
.idea/deployment.xml
generated
@@ -1,11 +1,11 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" autoUpload="Always" serverName="22d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<component name="PublishConfigData" autoUpload="Always" serverName="14d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<serverData>
|
||||
<paths name="14d">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
<mapping deploy="raycv" local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
@@ -16,6 +16,13 @@
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="21d">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping deploy="/raycv" local="$PROJECT_DIR$" web="" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="22d">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
|
||||
2
.idea/misc.xml
generated
2
.idea/misc.xml
generated
@@ -1,4 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="22d-base" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="14d-python" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
2
.idea/raycv.iml
generated
2
.idea/raycv.iml
generated
@@ -2,7 +2,7 @@
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="22d-base" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="14d-python" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="TestRunnerService">
|
||||
|
||||
8
.idea/sshConfigs.xml
generated
Normal file
8
.idea/sshConfigs.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="SshConfigs">
|
||||
<configs>
|
||||
<sshConfig host="222.20.77.126" id="38d32db7-46b2-4b95-a40c-d17e8eeca6c1" keyPath="C:\Users\wr\.ssh\sg_id_rsa" port="50001" nameFormat="DESCRIPTIVE" username="dancer" />
|
||||
</configs>
|
||||
</component>
|
||||
</project>
|
||||
@@ -1,51 +0,0 @@
|
||||
name: cross-domain-1
|
||||
engine: crossdomain
|
||||
result_dir: ./result
|
||||
|
||||
distributed:
|
||||
model:
|
||||
# broadcast_buffers: False
|
||||
|
||||
misc:
|
||||
random_seed: 1004
|
||||
|
||||
checkpoints:
|
||||
interval: 2000
|
||||
|
||||
log:
|
||||
logger:
|
||||
level: 20 # DEBUG(10) INFO(20)
|
||||
|
||||
model:
|
||||
_type: resnet10
|
||||
|
||||
baseline:
|
||||
plusplus: False
|
||||
optimizers:
|
||||
_type: Adam
|
||||
data:
|
||||
dataloader:
|
||||
batch_size: 1200
|
||||
shuffle: True
|
||||
num_workers: 16
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
train:
|
||||
path: /data/few-shot/mini_imagenet_full_size/train
|
||||
lmdb_path: /data/few-shot/lmdb/mini-ImageNet/train.lmdb
|
||||
pipeline:
|
||||
- Load
|
||||
- RandomResizedCrop:
|
||||
size: [224, 224]
|
||||
- ColorJitter:
|
||||
brightness: 0.4
|
||||
contrast: 0.4
|
||||
saturation: 0.4
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
|
||||
|
||||
163
configs/synthesizers/CyCleGAN.yml
Normal file
163
configs/synthesizers/CyCleGAN.yml
Normal file
@@ -0,0 +1,163 @@
|
||||
name: huawei-cycylegan-7
|
||||
engine: CycleGAN
|
||||
result_dir: ./result
|
||||
max_pairs: 1000000
|
||||
|
||||
misc:
|
||||
random_seed: 324
|
||||
|
||||
handler:
|
||||
clear_cuda_cache: True
|
||||
set_epoch_for_dist_sampler: True
|
||||
checkpoint:
|
||||
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||
n_saved: 2
|
||||
tensorboard:
|
||||
scalar: 100 # log scalar `scalar` times per epoch
|
||||
image: 4 # log image `image` times per epoch
|
||||
test:
|
||||
random: True
|
||||
images: 10
|
||||
|
||||
model:
|
||||
generator:
|
||||
_type: CycleGAN-Generator
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
out_channels: 3
|
||||
base_channels: 64
|
||||
num_blocks: 9
|
||||
use_transpose_conv: False
|
||||
pre_activation: True
|
||||
# discriminator:
|
||||
# _type: MultiScaleDiscriminator
|
||||
# _add_spectral_norm: True
|
||||
# num_scale: 2
|
||||
# down_sample_method: "bilinear"
|
||||
# discriminator_cfg:
|
||||
# _type: PatchDiscriminator
|
||||
# in_channels: 3
|
||||
# base_channels: 64
|
||||
# num_conv: 4
|
||||
# need_intermediate_feature: True
|
||||
discriminator:
|
||||
_type: PatchDiscriminator
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
num_conv: 4
|
||||
need_intermediate_feature: False
|
||||
|
||||
|
||||
loss:
|
||||
gan:
|
||||
loss_type: hinge
|
||||
weight: 1.0
|
||||
real_label_val: 1
|
||||
fake_label_val: 0.0
|
||||
cycle:
|
||||
level: 1
|
||||
weight: 10.0
|
||||
id:
|
||||
level: 1
|
||||
weight: 10.0
|
||||
mgc:
|
||||
weight: 1
|
||||
fm:
|
||||
weight: 0
|
||||
edge:
|
||||
weight: 0
|
||||
hed_pretrained_model_path: ./network-bsds500.pytorch
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
_type: Adam
|
||||
lr: 1e-4
|
||||
betas: [ 0.5, 0.999 ]
|
||||
weight_decay: 0.0001
|
||||
discriminator:
|
||||
_type: Adam
|
||||
lr: 4e-4
|
||||
betas: [ 0.5, 0.999 ]
|
||||
weight_decay: 0.0001
|
||||
|
||||
data:
|
||||
train:
|
||||
scheduler:
|
||||
start_proportion: 0.5
|
||||
target_lr: 0
|
||||
buffer_size: 50
|
||||
dataloader:
|
||||
batch_size: 1
|
||||
shuffle: True
|
||||
num_workers: 2
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/face2cartoon/all_face"
|
||||
root_b: "/data/selfie2anime/trainB/"
|
||||
random_pair: True
|
||||
pipeline_a:
|
||||
- Load
|
||||
- RandomCrop:
|
||||
size: [ 178, 178 ]
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
pipeline_b:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 286, 286 ]
|
||||
- RandomCrop:
|
||||
size: [ 256, 256 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
test:
|
||||
which: video_dataset
|
||||
dataloader:
|
||||
batch_size: 1
|
||||
shuffle: False
|
||||
num_workers: 1
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/face2cartoon/test/human"
|
||||
root_b: "/data/face2cartoon/test/anime"
|
||||
random_pair: True
|
||||
pipeline_a:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
pipeline_b:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
video_dataset:
|
||||
_type: SingleFolderDataset
|
||||
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
|
||||
with_path: True
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
167
configs/synthesizers/GauGAN.yml
Normal file
167
configs/synthesizers/GauGAN.yml
Normal file
@@ -0,0 +1,167 @@
|
||||
name: huawei-GauGAN-3
|
||||
engine: GauGAN
|
||||
result_dir: ./result
|
||||
max_pairs: 1000000
|
||||
|
||||
misc:
|
||||
random_seed: 324
|
||||
|
||||
handler:
|
||||
clear_cuda_cache: True
|
||||
set_epoch_for_dist_sampler: True
|
||||
checkpoint:
|
||||
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||
n_saved: 2
|
||||
tensorboard:
|
||||
scalar: 100 # log scalar `scalar` times per epoch
|
||||
image: 4 # log image `image` times per epoch
|
||||
test:
|
||||
random: True
|
||||
images: 10
|
||||
|
||||
model:
|
||||
generator:
|
||||
_type: SPADEGenerator
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
out_channels: 3
|
||||
num_blocks: 7
|
||||
use_vae: False
|
||||
num_z_dim: 256
|
||||
# discriminator:
|
||||
# _type: MultiScaleDiscriminator
|
||||
# _add_spectral_norm: True
|
||||
# num_scale: 2
|
||||
# down_sample_method: "bilinear"
|
||||
# discriminator_cfg:
|
||||
# _type: PatchDiscriminator
|
||||
# in_channels: 3
|
||||
# base_channels: 64
|
||||
# num_conv: 4
|
||||
# need_intermediate_feature: True
|
||||
discriminator:
|
||||
_type: PatchDiscriminator
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
num_conv: 4
|
||||
need_intermediate_feature: True
|
||||
|
||||
|
||||
loss:
|
||||
gan:
|
||||
loss_type: hinge
|
||||
weight: 1.0
|
||||
real_label_val: 1
|
||||
fake_label_val: 0.0
|
||||
perceptual:
|
||||
layer_weights:
|
||||
"1": 0.03125
|
||||
"6": 0.0625
|
||||
"11": 0.125
|
||||
"20": 0.25
|
||||
"29": 1
|
||||
criterion: 'L1'
|
||||
style_loss: False
|
||||
perceptual_loss: True
|
||||
weight: 2
|
||||
mgc:
|
||||
weight: 5
|
||||
fm:
|
||||
weight: 5
|
||||
edge:
|
||||
weight: 0
|
||||
hed_pretrained_model_path: ./network-bsds500.pytorch
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
_type: Adam
|
||||
lr: 1e-4
|
||||
betas: [ 0, 0.9 ]
|
||||
weight_decay: 0.0001
|
||||
discriminator:
|
||||
_type: Adam
|
||||
lr: 4e-4
|
||||
betas: [ 0, 0.9 ]
|
||||
weight_decay: 0.0001
|
||||
|
||||
data:
|
||||
train:
|
||||
scheduler:
|
||||
start_proportion: 0.5
|
||||
target_lr: 0
|
||||
buffer_size: 50
|
||||
dataloader:
|
||||
batch_size: 1
|
||||
shuffle: True
|
||||
num_workers: 2
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/face2cartoon/all_face"
|
||||
root_b: "/data/selfie2anime/trainB/"
|
||||
random_pair: True
|
||||
pipeline_a:
|
||||
- Load
|
||||
- RandomCrop:
|
||||
size: [ 178, 178 ]
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
pipeline_b:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 286, 286 ]
|
||||
- RandomCrop:
|
||||
size: [ 256, 256 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
test:
|
||||
which: video_dataset
|
||||
dataloader:
|
||||
batch_size: 1
|
||||
shuffle: False
|
||||
num_workers: 1
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/face2cartoon/test/human"
|
||||
root_b: "/data/face2cartoon/test/anime"
|
||||
random_pair: True
|
||||
pipeline_a:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
pipeline_b:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
video_dataset:
|
||||
_type: SingleFolderDataset
|
||||
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
|
||||
with_path: True
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
132
configs/synthesizers/MUNIT.yml
Normal file
132
configs/synthesizers/MUNIT.yml
Normal file
@@ -0,0 +1,132 @@
|
||||
name: MUNIT-edges2shoes
|
||||
engine: MUNIT
|
||||
result_dir: ./result
|
||||
max_pairs: 1000000
|
||||
|
||||
handler:
|
||||
clear_cuda_cache: True
|
||||
set_epoch_for_dist_sampler: True
|
||||
checkpoint:
|
||||
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||
n_saved: 2
|
||||
tensorboard:
|
||||
scalar: 100 # log scalar `scalar` times per epoch
|
||||
image: 2 # log image `image` times per epoch
|
||||
|
||||
|
||||
misc:
|
||||
random_seed: 324
|
||||
|
||||
model:
|
||||
generator:
|
||||
_type: MUNIT-Generator
|
||||
in_channels: 3
|
||||
out_channels: 3
|
||||
base_channels: 64
|
||||
num_sampling: 2
|
||||
num_style_dim: 8
|
||||
num_style_conv: 4
|
||||
num_content_res_blocks: 4
|
||||
num_decoder_res_blocks: 4
|
||||
num_fusion_dim: 256
|
||||
num_fusion_blocks: 3
|
||||
|
||||
discriminator:
|
||||
_type: MultiScaleDiscriminator
|
||||
num_scale: 2
|
||||
discriminator_cfg:
|
||||
_type: PatchDiscriminator
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
use_spectral: True
|
||||
need_intermediate_feature: True
|
||||
|
||||
loss:
|
||||
gan:
|
||||
loss_type: lsgan
|
||||
real_label_val: 1.0
|
||||
fake_label_val: 0.0
|
||||
weight: 1.0
|
||||
perceptual:
|
||||
layer_weights:
|
||||
"1": 0.03125
|
||||
"6": 0.0625
|
||||
"11": 0.125
|
||||
"20": 0.25
|
||||
"29": 1
|
||||
criterion: 'L1'
|
||||
style_loss: False
|
||||
perceptual_loss: True
|
||||
weight: 0
|
||||
recon:
|
||||
level: 1
|
||||
style:
|
||||
weight: 1
|
||||
content:
|
||||
weight: 1
|
||||
image:
|
||||
weight: 10
|
||||
cycle:
|
||||
weight: 0
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
_type: Adam
|
||||
lr: 0.0001
|
||||
betas: [ 0.5, 0.999 ]
|
||||
weight_decay: 0.0001
|
||||
discriminator:
|
||||
_type: Adam
|
||||
lr: 4e-4
|
||||
betas: [ 0.5, 0.999 ]
|
||||
weight_decay: 0.0001
|
||||
|
||||
data:
|
||||
train:
|
||||
scheduler:
|
||||
start_proportion: 0.5
|
||||
target_lr: 0
|
||||
buffer_size: 50
|
||||
dataloader:
|
||||
batch_size: 1
|
||||
shuffle: True
|
||||
num_workers: 1
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/edges2shoes/trainA"
|
||||
root_b: "/data/i2i/edges2shoes/trainB"
|
||||
random_pair: True
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 286, 286 ]
|
||||
- RandomCrop:
|
||||
size: [ 256, 256 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
test:
|
||||
which: dataset
|
||||
dataloader:
|
||||
batch_size: 8
|
||||
shuffle: False
|
||||
num_workers: 1
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/edges2shoes/testA"
|
||||
root_b: "/data/i2i/edges2shoes/testB"
|
||||
random_pair: False
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
@@ -1,70 +1,99 @@
|
||||
name: TAHG
|
||||
engine: TAHG
|
||||
name: TAFG-vox2
|
||||
engine: TAFG
|
||||
result_dir: ./result
|
||||
max_pairs: 1000000
|
||||
|
||||
distributed:
|
||||
model:
|
||||
# broadcast_buffers: False
|
||||
handler:
|
||||
clear_cuda_cache: True
|
||||
set_epoch_for_dist_sampler: True
|
||||
checkpoint:
|
||||
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||
n_saved: 2
|
||||
tensorboard:
|
||||
scalar: 100 # log scalar `scalar` times per epoch
|
||||
image: 4 # log image `image` times per epoch
|
||||
test:
|
||||
random: True
|
||||
images: 10
|
||||
|
||||
|
||||
misc:
|
||||
random_seed: 324
|
||||
|
||||
checkpoint:
|
||||
epoch_interval: 1 # one checkpoint every 1 epoch
|
||||
n_saved: 2
|
||||
|
||||
interval:
|
||||
print_per_iteration: 10 # print once per 10 iteration
|
||||
tensorboard:
|
||||
scalar: 100
|
||||
image: 2
|
||||
random_seed: 1004
|
||||
add_new_loss_epoch: -1
|
||||
|
||||
model:
|
||||
generator:
|
||||
_type: TAHG-Generator
|
||||
_type: TAFG-Generator
|
||||
_bn_to_sync_bn: False
|
||||
style_in_channels: 3
|
||||
content_in_channels: 1
|
||||
num_blocks: 4
|
||||
content_in_channels: 24
|
||||
use_spectral_norm: False
|
||||
style_encoder_type: StyleEncoder
|
||||
num_style_conv: 4
|
||||
style_dim: 8
|
||||
num_adain_blocks: 4
|
||||
num_res_blocks: 4
|
||||
|
||||
discriminator:
|
||||
_type: TAHG-Discriminator
|
||||
in_channels: 3
|
||||
_type: MultiScaleDiscriminator
|
||||
num_scale: 2
|
||||
discriminator_cfg:
|
||||
_type: PatchDiscriminator
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
use_spectral: True
|
||||
need_intermediate_feature: True
|
||||
|
||||
loss:
|
||||
gan:
|
||||
loss_type: lsgan
|
||||
loss_type: hinge
|
||||
real_label_val: 1.0
|
||||
fake_label_val: 0.0
|
||||
weight: 1.0
|
||||
edge:
|
||||
criterion: 'L1'
|
||||
hed_pretrained_model_path: "./network-bsds500.pytorch"
|
||||
weight: 1
|
||||
perceptual:
|
||||
layer_weights:
|
||||
"3": 1.0
|
||||
# "0": 1.0
|
||||
# "5": 1.0
|
||||
# "10": 1.0
|
||||
# "19": 1.0
|
||||
criterion: 'L2'
|
||||
"1": 0.03125
|
||||
"6": 0.0625
|
||||
"11": 0.125
|
||||
"20": 0.25
|
||||
"29": 1
|
||||
criterion: 'L1'
|
||||
style_loss: False
|
||||
perceptual_loss: True
|
||||
weight: 0
|
||||
style:
|
||||
layer_weights:
|
||||
"3": 1
|
||||
criterion: 'L1'
|
||||
style_loss: True
|
||||
perceptual_loss: False
|
||||
weight: 20
|
||||
weight: 10
|
||||
recon:
|
||||
level: 1
|
||||
weight: 10
|
||||
style_recon:
|
||||
level: 1
|
||||
weight: 1
|
||||
content_recon:
|
||||
level: 1
|
||||
weight: 1
|
||||
edge:
|
||||
weight: 5
|
||||
hed_pretrained_model_path: ./network-bsds500.pytorch
|
||||
cycle:
|
||||
level: 1
|
||||
weight: 10
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
_type: Adam
|
||||
lr: 0.0001
|
||||
betas: [ 0.5, 0.999 ]
|
||||
betas: [ 0, 0.9 ]
|
||||
weight_decay: 0.0001
|
||||
discriminator:
|
||||
_type: Adam
|
||||
lr: 1e-4
|
||||
betas: [ 0.5, 0.999 ]
|
||||
lr: 4e-4
|
||||
betas: [ 0, 0.9 ]
|
||||
weight_decay: 0.0001
|
||||
|
||||
data:
|
||||
@@ -74,9 +103,9 @@ data:
|
||||
target_lr: 0
|
||||
buffer_size: 50
|
||||
dataloader:
|
||||
batch_size: 160
|
||||
batch_size: 8
|
||||
shuffle: True
|
||||
num_workers: 2
|
||||
num_workers: 1
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
@@ -84,20 +113,22 @@ data:
|
||||
root_a: "/data/i2i/VoxCeleb2Anime/trainA"
|
||||
root_b: "/data/i2i/VoxCeleb2Anime/trainB"
|
||||
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
||||
edge_type: "hed"
|
||||
size: [128, 128]
|
||||
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
||||
edge_type: "landmark_hed"
|
||||
size: [ 128, 128 ]
|
||||
random_pair: True
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [128, 128]
|
||||
size: [ 128, 128 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
test:
|
||||
which: video_dataset
|
||||
dataloader:
|
||||
batch_size: 8
|
||||
batch_size: 1
|
||||
shuffle: False
|
||||
num_workers: 1
|
||||
pin_memory: False
|
||||
@@ -107,13 +138,14 @@ data:
|
||||
root_a: "/data/i2i/VoxCeleb2Anime/testA"
|
||||
root_b: "/data/i2i/VoxCeleb2Anime/testB"
|
||||
edges_path: "/data/i2i/VoxCeleb2Anime/edges"
|
||||
edge_type: "hed"
|
||||
landmarks_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
||||
edge_type: "landmark_hed"
|
||||
random_pair: False
|
||||
size: [128, 128]
|
||||
size: [ 128, 128 ]
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [128, 128]
|
||||
size: [ 128, 128 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
@@ -125,7 +157,7 @@ data:
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
size: [ 128, 128 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
165
configs/synthesizers/TSIT.yml
Normal file
165
configs/synthesizers/TSIT.yml
Normal file
@@ -0,0 +1,165 @@
|
||||
name: huawei-TSIT-1
|
||||
engine: GauGAN
|
||||
result_dir: ./result
|
||||
max_pairs: 1000000
|
||||
|
||||
misc:
|
||||
random_seed: 324
|
||||
|
||||
handler:
|
||||
clear_cuda_cache: True
|
||||
set_epoch_for_dist_sampler: True
|
||||
checkpoint:
|
||||
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||
n_saved: 2
|
||||
tensorboard:
|
||||
scalar: 100 # log scalar `scalar` times per epoch
|
||||
image: 4 # log image `image` times per epoch
|
||||
test:
|
||||
random: True
|
||||
images: 10
|
||||
|
||||
model:
|
||||
generator:
|
||||
_type: TSIT-Generator
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
out_channels: 3
|
||||
num_blocks: 7
|
||||
# discriminator:
|
||||
# _type: MultiScaleDiscriminator
|
||||
# _add_spectral_norm: True
|
||||
# num_scale: 2
|
||||
# down_sample_method: "bilinear"
|
||||
# discriminator_cfg:
|
||||
# _type: PatchDiscriminator
|
||||
# in_channels: 3
|
||||
# base_channels: 64
|
||||
# num_conv: 4
|
||||
# need_intermediate_feature: True
|
||||
discriminator:
|
||||
_type: PatchDiscriminator
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
num_conv: 4
|
||||
need_intermediate_feature: True
|
||||
|
||||
|
||||
loss:
|
||||
gan:
|
||||
loss_type: hinge
|
||||
weight: 1.0
|
||||
real_label_val: 1
|
||||
fake_label_val: 0.0
|
||||
perceptual:
|
||||
layer_weights:
|
||||
"1": 0.03125
|
||||
"6": 0.0625
|
||||
"11": 0.125
|
||||
"20": 0.25
|
||||
"29": 1
|
||||
criterion: 'L1'
|
||||
style_loss: False
|
||||
perceptual_loss: True
|
||||
weight: 1
|
||||
mgc:
|
||||
weight: 5
|
||||
fm:
|
||||
weight: 1
|
||||
edge:
|
||||
weight: 0
|
||||
hed_pretrained_model_path: ./network-bsds500.pytorch
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
_type: Adam
|
||||
lr: 1e-4
|
||||
betas: [ 0, 0.9 ]
|
||||
weight_decay: 0.0001
|
||||
discriminator:
|
||||
_type: Adam
|
||||
lr: 4e-4
|
||||
betas: [ 0, 0.9 ]
|
||||
weight_decay: 0.0001
|
||||
|
||||
data:
|
||||
train:
|
||||
scheduler:
|
||||
start_proportion: 0.5
|
||||
target_lr: 0
|
||||
buffer_size: 0
|
||||
dataloader:
|
||||
batch_size: 1
|
||||
shuffle: True
|
||||
num_workers: 2
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/face2cartoon/all_face"
|
||||
root_b: "/data/selfie2anime/trainB/"
|
||||
random_pair: True
|
||||
pipeline_a:
|
||||
- Load
|
||||
- RandomCrop:
|
||||
size: [ 178, 178 ]
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
pipeline_b:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 286, 286 ]
|
||||
- RandomCrop:
|
||||
size: [ 256, 256 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
test:
|
||||
which: video_dataset
|
||||
dataloader:
|
||||
batch_size: 1
|
||||
shuffle: False
|
||||
num_workers: 1
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/face2cartoon/test/human"
|
||||
root_b: "/data/face2cartoon/test/anime"
|
||||
random_pair: True
|
||||
pipeline_a:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
pipeline_b:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
video_dataset:
|
||||
_type: SingleFolderDataset
|
||||
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
|
||||
with_path: True
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
@@ -1,24 +1,20 @@
|
||||
name: VoxCeleb2Anime
|
||||
engine: UGATIT
|
||||
name: selfie2anime-vox2
|
||||
engine: U-GAT-IT
|
||||
result_dir: ./result
|
||||
max_pairs: 1000000
|
||||
|
||||
distributed:
|
||||
model:
|
||||
# broadcast_buffers: False
|
||||
|
||||
misc:
|
||||
random_seed: 324
|
||||
|
||||
checkpoint:
|
||||
epoch_interval: 1 # one checkpoint every 1 epoch
|
||||
n_saved: 2
|
||||
|
||||
interval:
|
||||
print_per_iteration: 10 # print once per 10 iteration
|
||||
handler:
|
||||
clear_cuda_cache: True
|
||||
set_epoch_for_dist_sampler: True
|
||||
checkpoint:
|
||||
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||
n_saved: 2
|
||||
tensorboard:
|
||||
scalar: 10
|
||||
image: 500
|
||||
scalar: 100 # log scalar `scalar` times per epoch
|
||||
image: 2 # log image `image` times per epoch
|
||||
|
||||
model:
|
||||
generator:
|
||||
@@ -27,7 +23,7 @@ model:
|
||||
out_channels: 3
|
||||
base_channels: 64
|
||||
num_blocks: 4
|
||||
img_size: 128
|
||||
img_size: 256
|
||||
light: True
|
||||
local_discriminator:
|
||||
_type: UGATIT-Discriminator
|
||||
@@ -74,7 +70,7 @@ data:
|
||||
target_lr: 0
|
||||
buffer_size: 50
|
||||
dataloader:
|
||||
batch_size: 20
|
||||
batch_size: 6
|
||||
shuffle: True
|
||||
num_workers: 2
|
||||
pin_memory: True
|
||||
@@ -87,9 +83,9 @@ data:
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 135, 135 ]
|
||||
size: [ 286, 286 ]
|
||||
- RandomCrop:
|
||||
size: [ 128, 128 ]
|
||||
size: [ 256, 256 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
@@ -97,7 +93,7 @@ data:
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
test:
|
||||
dataloader:
|
||||
batch_size: 8
|
||||
batch_size: 4
|
||||
shuffle: False
|
||||
num_workers: 1
|
||||
pin_memory: False
|
||||
@@ -110,7 +106,7 @@ data:
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 128, 128 ]
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
@@ -122,7 +118,7 @@ data:
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 128, 128 ]
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
|
||||
@@ -1,28 +1,28 @@
|
||||
name: selfie2anime
|
||||
engine: UGATIT
|
||||
engine: U-GAT-IT
|
||||
result_dir: ./result
|
||||
max_pairs: 1000000
|
||||
|
||||
distributed:
|
||||
model:
|
||||
# broadcast_buffers: False
|
||||
|
||||
misc:
|
||||
random_seed: 324
|
||||
|
||||
checkpoint:
|
||||
epoch_interval: 1 # one checkpoint every 1 epoch
|
||||
n_saved: 2
|
||||
|
||||
interval:
|
||||
print_per_iteration: 10 # print once per 10 iteration
|
||||
handler:
|
||||
clear_cuda_cache: True
|
||||
set_epoch_for_dist_sampler: True
|
||||
checkpoint:
|
||||
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||
n_saved: 2
|
||||
tensorboard:
|
||||
scalar: 10
|
||||
image: 500
|
||||
scalar: 100 # log scalar `scalar` times per epoch
|
||||
image: 4 # log image `image` times per epoch
|
||||
test:
|
||||
random: True
|
||||
images: 10
|
||||
|
||||
model:
|
||||
generator:
|
||||
_type: UGATIT-Generator
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
out_channels: 3
|
||||
base_channels: 64
|
||||
@@ -31,11 +31,13 @@ model:
|
||||
light: True
|
||||
local_discriminator:
|
||||
_type: UGATIT-Discriminator
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
num_blocks: 5
|
||||
global_discriminator:
|
||||
_type: UGATIT-Discriminator
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
num_blocks: 7
|
||||
@@ -54,17 +56,19 @@ loss:
|
||||
weight: 10.0
|
||||
cam:
|
||||
weight: 1000
|
||||
mgc:
|
||||
weight: 0
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
_type: Adam
|
||||
lr: 0.0001
|
||||
betas: [0.5, 0.999]
|
||||
betas: [ 0.5, 0.999 ]
|
||||
weight_decay: 0.0001
|
||||
discriminator:
|
||||
_type: Adam
|
||||
lr: 1e-4
|
||||
betas: [0.5, 0.999]
|
||||
betas: [ 0.5, 0.999 ]
|
||||
weight_decay: 0.0001
|
||||
|
||||
data:
|
||||
@@ -74,7 +78,7 @@ data:
|
||||
target_lr: 0
|
||||
buffer_size: 50
|
||||
dataloader:
|
||||
batch_size: 4
|
||||
batch_size: 1
|
||||
shuffle: True
|
||||
num_workers: 2
|
||||
pin_memory: True
|
||||
@@ -87,17 +91,18 @@ data:
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [286, 286]
|
||||
size: [ 286, 286 ]
|
||||
- RandomCrop:
|
||||
size: [256, 256]
|
||||
size: [ 256, 256 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
std: [0.5, 0.5, 0.5]
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
test:
|
||||
which: video_dataset
|
||||
dataloader:
|
||||
batch_size: 8
|
||||
batch_size: 1
|
||||
shuffle: False
|
||||
num_workers: 1
|
||||
pin_memory: False
|
||||
@@ -110,11 +115,11 @@ data:
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [256, 256]
|
||||
size: [ 256, 256 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [0.5, 0.5, 0.5]
|
||||
std: [0.5, 0.5, 0.5]
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
video_dataset:
|
||||
_type: SingleFolderDataset
|
||||
root: "/data/i2i/VoxCeleb2Anime/test_video_frames/"
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
name: horse2zebra
|
||||
engine: cyclegan
|
||||
result_dir: ./result
|
||||
max_iteration: 16600
|
||||
|
||||
distributed:
|
||||
model:
|
||||
# broadcast_buffers: False
|
||||
|
||||
misc:
|
||||
random_seed: 324
|
||||
|
||||
checkpoints:
|
||||
interval: 2000
|
||||
|
||||
log:
|
||||
logger:
|
||||
level: 20 # DEBUG(10) INFO(20)
|
||||
|
||||
model:
|
||||
generator:
|
||||
_type: ResGenerator
|
||||
in_channels: 3
|
||||
out_channels: 3
|
||||
base_channels: 64
|
||||
num_blocks: 9
|
||||
padding_mode: reflect
|
||||
norm_type: IN
|
||||
use_dropout: False
|
||||
discriminator:
|
||||
_type: PatchDiscriminator
|
||||
# _distributed:
|
||||
# bn_to_syncbn: False
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
num_conv: 3
|
||||
norm_type: IN
|
||||
|
||||
loss:
|
||||
gan:
|
||||
loss_type: lsgan
|
||||
weight: 1.0
|
||||
real_label_val: 1.0
|
||||
fake_label_val: 0.0
|
||||
cycle:
|
||||
level: 1
|
||||
weight: 10.0
|
||||
id:
|
||||
level: 1
|
||||
weight: 0
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
_type: Adam
|
||||
lr: 2e-4
|
||||
betas: [0.5, 0.999]
|
||||
discriminator:
|
||||
_type: Adam
|
||||
lr: 2e-4
|
||||
betas: [0.5, 0.999]
|
||||
|
||||
data:
|
||||
train:
|
||||
buffer_size: 50
|
||||
dataloader:
|
||||
batch_size: 16
|
||||
shuffle: True
|
||||
num_workers: 4
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/horse2zebra/trainA"
|
||||
root_b: "/data/i2i/horse2zebra/trainB"
|
||||
random_pair: True
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [286, 286]
|
||||
- RandomCrop:
|
||||
size: [256, 256]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
scheduler:
|
||||
start: 8300
|
||||
target_lr: 0
|
||||
test:
|
||||
dataloader:
|
||||
batch_size: 4
|
||||
shuffle: False
|
||||
num_workers: 1
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
dataset:
|
||||
_type: GenerationUnpairedDataset
|
||||
root_a: "/data/i2i/horse2zebra/testA"
|
||||
root_b: "/data/i2i/horse2zebra/testB"
|
||||
random_pair: False
|
||||
pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [256, 256]
|
||||
- ToTensor
|
||||
171
configs/synthesizers/talking_anime.yml
Normal file
171
configs/synthesizers/talking_anime.yml
Normal file
@@ -0,0 +1,171 @@
|
||||
name: talking_anime
|
||||
engine: talking_anime
|
||||
result_dir: ./result
|
||||
max_pairs: 1000000
|
||||
|
||||
handler:
|
||||
clear_cuda_cache: True
|
||||
set_epoch_for_dist_sampler: True
|
||||
checkpoint:
|
||||
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
|
||||
n_saved: 2
|
||||
tensorboard:
|
||||
scalar: 100 # log scalar `scalar` times per epoch
|
||||
image: 100 # log image `image` times per epoch
|
||||
test:
|
||||
random: True
|
||||
images: 10
|
||||
|
||||
misc:
|
||||
random_seed: 1004
|
||||
|
||||
loss:
|
||||
gan:
|
||||
loss_type: hinge
|
||||
real_label_val: 1.0
|
||||
fake_label_val: 0.0
|
||||
weight: 1.0
|
||||
fm:
|
||||
level: 1
|
||||
weight: 1
|
||||
style:
|
||||
layer_weights:
|
||||
"3": 1
|
||||
criterion: 'L1'
|
||||
style_loss: True
|
||||
perceptual_loss: False
|
||||
weight: 10
|
||||
perceptual:
|
||||
layer_weights:
|
||||
"1": 0.03125
|
||||
"6": 0.0625
|
||||
"11": 0.125
|
||||
"20": 0.25
|
||||
"29": 1
|
||||
criterion: 'L1'
|
||||
style_loss: False
|
||||
perceptual_loss: True
|
||||
weight: 0
|
||||
context:
|
||||
layer_weights:
|
||||
#"13": 1
|
||||
"22": 1
|
||||
weight: 5
|
||||
recon:
|
||||
level: 1
|
||||
weight: 10
|
||||
edge:
|
||||
weight: 5
|
||||
hed_pretrained_model_path: ./network-bsds500.pytorch
|
||||
|
||||
model:
|
||||
face_generator:
|
||||
_type: TAFG-SingleGenerator
|
||||
_bn_to_sync_bn: False
|
||||
style_in_channels: 3
|
||||
content_in_channels: 1
|
||||
use_spectral_norm: True
|
||||
style_encoder_type: VGG19StyleEncoder
|
||||
num_style_conv: 4
|
||||
style_dim: 512
|
||||
num_adain_blocks: 8
|
||||
num_res_blocks: 8
|
||||
anime_generator:
|
||||
_type: TAFG-ResGenerator
|
||||
_bn_to_sync_bn: False
|
||||
in_channels: 6
|
||||
use_spectral_norm: True
|
||||
num_res_blocks: 8
|
||||
|
||||
discriminator:
|
||||
_type: MultiScaleDiscriminator
|
||||
num_scale: 2
|
||||
discriminator_cfg:
|
||||
_type: PatchDiscriminator
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
use_spectral: True
|
||||
need_intermediate_feature: True
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
_type: Adam
|
||||
lr: 0.0001
|
||||
betas: [ 0, 0.9 ]
|
||||
weight_decay: 0.0001
|
||||
discriminator:
|
||||
_type: Adam
|
||||
lr: 4e-4
|
||||
betas: [ 0, 0.9 ]
|
||||
weight_decay: 0.0001
|
||||
|
||||
data:
|
||||
train:
|
||||
scheduler:
|
||||
start_proportion: 0.5
|
||||
target_lr: 0
|
||||
dataloader:
|
||||
batch_size: 8
|
||||
shuffle: True
|
||||
num_workers: 1
|
||||
pin_memory: True
|
||||
drop_last: True
|
||||
dataset:
|
||||
_type: PoseFacesWithSingleAnime
|
||||
root_face: "/data/i2i/VoxCeleb2Anime/trainA"
|
||||
root_anime: "/data/i2i/VoxCeleb2Anime/trainB"
|
||||
landmark_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
||||
num_face: 2
|
||||
img_size: [ 128, 128 ]
|
||||
with_order: False
|
||||
face_pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 128, 128 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
anime_pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 144, 144 ]
|
||||
- RandomCrop:
|
||||
size: [ 128, 128 ]
|
||||
- RandomHorizontalFlip
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
test:
|
||||
which: dataset
|
||||
dataloader:
|
||||
batch_size: 1
|
||||
shuffle: False
|
||||
num_workers: 1
|
||||
pin_memory: False
|
||||
drop_last: False
|
||||
dataset:
|
||||
_type: PoseFacesWithSingleAnime
|
||||
root_face: "/data/i2i/VoxCeleb2Anime/testA"
|
||||
root_anime: "/data/i2i/VoxCeleb2Anime/testB"
|
||||
landmark_path: "/data/i2i/VoxCeleb2Anime/landmarks"
|
||||
num_face: 2
|
||||
img_size: [ 128, 128 ]
|
||||
with_order: False
|
||||
face_pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 128, 128 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
anime_pipeline:
|
||||
- Load
|
||||
- Resize:
|
||||
size: [ 128, 128 ]
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
210
data/dataset.py
210
data/dataset.py
@@ -1,210 +0,0 @@
|
||||
import os
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.transforms import functional as F
|
||||
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader
|
||||
|
||||
import lmdb
|
||||
from tqdm import tqdm
|
||||
|
||||
from .transform import transform_pipeline
|
||||
from .registry import DATASET
|
||||
|
||||
|
||||
def default_transform_way(transform, sample):
|
||||
return [transform(sample[0]), *sample[1:]]
|
||||
|
||||
|
||||
class LMDBDataset(Dataset):
|
||||
def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True,
|
||||
**lmdb_kwargs):
|
||||
self.path = lmdb_path
|
||||
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
|
||||
lock=False, **lmdb_kwargs)
|
||||
|
||||
with self.db.begin(write=False) as txn:
|
||||
self._len = pickle.loads(txn.get(b"$$len$$"))
|
||||
self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$"))
|
||||
if pipeline is None:
|
||||
self.not_done_pipeline = []
|
||||
else:
|
||||
self.not_done_pipeline = self._remain_pipeline(pipeline)
|
||||
self.transform = transform_pipeline(self.not_done_pipeline)
|
||||
self.transform_way = transform_way
|
||||
essential_attr = pickle.loads(txn.get(b"$$essential_attr$$"))
|
||||
for ea in essential_attr:
|
||||
setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8"))))
|
||||
|
||||
def _remain_pipeline(self, pipeline):
|
||||
for i, dp in enumerate(self.done_pipeline):
|
||||
if pipeline[i] != dp:
|
||||
raise ValueError(
|
||||
f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}")
|
||||
return pipeline[len(self.done_pipeline):]
|
||||
|
||||
def __repr__(self):
|
||||
return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}"
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def __getitem__(self, idx):
|
||||
with self.db.begin(write=False) as txn:
|
||||
sample = pickle.loads(txn.get("{}".format(idx).encode()))
|
||||
sample = self.transform_way(self.transform, sample)
|
||||
return sample
|
||||
|
||||
@staticmethod
|
||||
def lmdbify(dataset, done_pipeline, lmdb_path):
|
||||
env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path))
|
||||
with env.begin(write=True) as txn:
|
||||
for i in tqdm(range(len(dataset)), ncols=0):
|
||||
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
|
||||
txn.put(b"$$len$$", pickle.dumps(len(dataset)))
|
||||
txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline))
|
||||
essential_attr = getattr(dataset, "essential_attr", list())
|
||||
txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr))
|
||||
for ea in essential_attr:
|
||||
txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea)))
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class ImprovedImageFolder(ImageFolder):
|
||||
def __init__(self, root, pipeline):
|
||||
super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x)
|
||||
self.classes_list = defaultdict(list)
|
||||
self.essential_attr = ["classes_list"]
|
||||
for i in range(len(self)):
|
||||
self.classes_list[self.samples[i][-1]].append(i)
|
||||
assert len(self.classes_list) == len(self.classes)
|
||||
|
||||
|
||||
class EpisodicDataset(Dataset):
|
||||
def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes):
|
||||
self.origin = origin_dataset
|
||||
self.num_class = num_class
|
||||
assert self.num_class < len(self.origin.classes_list)
|
||||
self.num_query = num_query # K
|
||||
self.num_support = num_support # K
|
||||
self.num_episodes = num_episodes
|
||||
|
||||
def _fetch_list_data(self, id_list):
|
||||
return [self.origin[i][0] for i in id_list]
|
||||
|
||||
def __len__(self):
|
||||
return self.num_episodes
|
||||
|
||||
def __getitem__(self, _):
|
||||
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
|
||||
support_set = []
|
||||
query_set = []
|
||||
target_set = []
|
||||
for tag, c in enumerate(random_classes):
|
||||
image_list = self.origin.classes_list[c]
|
||||
|
||||
if len(image_list) >= self.num_query + self.num_support:
|
||||
# have enough images belong to this class
|
||||
idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist()
|
||||
else:
|
||||
idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist()
|
||||
|
||||
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
|
||||
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
|
||||
support_set.extend(support)
|
||||
query_set.extend(query)
|
||||
target_set.extend([tag] * self.num_query)
|
||||
return {
|
||||
"support": torch.stack(support_set),
|
||||
"query": torch.stack(query_set),
|
||||
"target": torch.tensor(target_set)
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
return f"<EpisodicDataset NE{self.num_episodes} NC{self.num_class} NS{self.num_support} NQ{self.num_query}>"
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class SingleFolderDataset(Dataset):
|
||||
def __init__(self, root, pipeline, with_path=False):
|
||||
assert os.path.isdir(root)
|
||||
self.root = root
|
||||
samples = []
|
||||
for r, _, fns in sorted(os.walk(self.root, followlinks=True)):
|
||||
for fn in sorted(fns):
|
||||
path = os.path.join(r, fn)
|
||||
if has_file_allowed_extension(path, IMG_EXTENSIONS):
|
||||
samples.append(path)
|
||||
self.samples = samples
|
||||
self.pipeline = transform_pipeline(pipeline)
|
||||
self.with_path = with_path
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if not self.with_path:
|
||||
return self.pipeline(self.samples[idx])
|
||||
else:
|
||||
return self.pipeline(self.samples[idx]), self.samples[idx]
|
||||
|
||||
def __repr__(self):
|
||||
return f"<SingleFolderDataset root={self.root} len={len(self)}>"
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDataset(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline):
|
||||
self.A = SingleFolderDataset(root_a, pipeline)
|
||||
self.B = SingleFolderDataset(root_b, pipeline)
|
||||
self.random_pair = random_pair
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||
return dict(a=self.A[a_idx], b=self.B[b_idx])
|
||||
|
||||
def __len__(self):
|
||||
return max(len(self.A), len(self.B))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, size=(256, 256)):
|
||||
self.edge_type = edge_type
|
||||
self.size = size
|
||||
self.edges_path = Path(edges_path)
|
||||
assert self.edges_path.exists()
|
||||
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
||||
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
|
||||
self.random_pair = random_pair
|
||||
|
||||
def get_edge(self, origin_path):
|
||||
op = Path(origin_path)
|
||||
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{self.edge_type}.png"
|
||||
img = Image.open(edge_path).resize(self.size)
|
||||
return F.to_tensor(img)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||
output = dict()
|
||||
output["a"], path_a = self.A[a_idx]
|
||||
output["b"], path_b = self.B[b_idx]
|
||||
output["edge_a"] = self.get_edge(path_a)
|
||||
output["edge_b"] = self.get_edge(path_b)
|
||||
return output
|
||||
|
||||
def __len__(self):
|
||||
return max(len(self.A), len(self.B))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
||||
3
data/dataset/__init__.py
Normal file
3
data/dataset/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from util.misc import import_submodules
|
||||
|
||||
__all__ = import_submodules(__name__).keys()
|
||||
63
data/dataset/few-shot.py
Normal file
63
data/dataset/few-shot.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import ImageFolder
|
||||
|
||||
from data.registry import DATASET
|
||||
from data.transform import transform_pipeline
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class ImprovedImageFolder(ImageFolder):
|
||||
def __init__(self, root, pipeline):
|
||||
super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x)
|
||||
self.classes_list = defaultdict(list)
|
||||
self.essential_attr = ["classes_list"]
|
||||
for i in range(len(self)):
|
||||
self.classes_list[self.samples[i][-1]].append(i)
|
||||
assert len(self.classes_list) == len(self.classes)
|
||||
|
||||
|
||||
class EpisodicDataset(Dataset):
|
||||
def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes):
|
||||
self.origin = origin_dataset
|
||||
self.num_class = num_class
|
||||
assert self.num_class < len(self.origin.classes_list)
|
||||
self.num_query = num_query # K
|
||||
self.num_support = num_support # K
|
||||
self.num_episodes = num_episodes
|
||||
|
||||
def _fetch_list_data(self, id_list):
|
||||
return [self.origin[i][0] for i in id_list]
|
||||
|
||||
def __len__(self):
|
||||
return self.num_episodes
|
||||
|
||||
def __getitem__(self, _):
|
||||
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
|
||||
support_set = []
|
||||
query_set = []
|
||||
target_set = []
|
||||
for tag, c in enumerate(random_classes):
|
||||
image_list = self.origin.classes_list[c]
|
||||
|
||||
if len(image_list) >= self.num_query + self.num_support:
|
||||
# have enough images belong to this class
|
||||
idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist()
|
||||
else:
|
||||
idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist()
|
||||
|
||||
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
|
||||
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
|
||||
support_set.extend(support)
|
||||
query_set.extend(query)
|
||||
target_set.extend([tag] * self.num_query)
|
||||
return {
|
||||
"support": torch.stack(support_set),
|
||||
"query": torch.stack(query_set),
|
||||
"target": torch.tensor(target_set)
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
return f"<EpisodicDataset NE{self.num_episodes} NC{self.num_class} NS{self.num_support} NQ{self.num_query}>"
|
||||
62
data/dataset/image_translation.py
Normal file
62
data/dataset/image_translation.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
|
||||
|
||||
from data.registry import DATASET
|
||||
from data.transform import transform_pipeline
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class SingleFolderDataset(Dataset):
|
||||
def __init__(self, root, pipeline, with_path=False):
|
||||
assert os.path.isdir(root)
|
||||
self.root = root
|
||||
samples = []
|
||||
for r, _, fns in sorted(os.walk(self.root, followlinks=True)):
|
||||
for fn in sorted(fns):
|
||||
path = os.path.join(r, fn)
|
||||
if has_file_allowed_extension(path, IMG_EXTENSIONS):
|
||||
samples.append(path)
|
||||
self.samples = samples
|
||||
self.pipeline = transform_pipeline(pipeline)
|
||||
self.with_path = with_path
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
output = dict(img=self.pipeline(self.samples[idx]))
|
||||
if self.with_path:
|
||||
output["path"] = self.samples[idx]
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f"<SingleFolderDataset root={self.root} len={len(self)} with_path={self.with_path}>"
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDataset(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline_a, pipeline_b, with_path=False):
|
||||
self.A = SingleFolderDataset(root_a, pipeline_a, with_path)
|
||||
self.B = SingleFolderDataset(root_b, pipeline_b, with_path)
|
||||
self.with_path = with_path
|
||||
self.random_pair = random_pair
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||
output_a = self.A[a_idx]
|
||||
output_b = self.B[b_idx]
|
||||
output = dict(a=output_a["img"], b=output_b["img"])
|
||||
if self.with_path:
|
||||
output["a_path"] = output_a["path"]
|
||||
output["b_path"] = output_b["path"]
|
||||
return output
|
||||
|
||||
def __len__(self):
|
||||
return max(len(self.A), len(self.B))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
||||
65
data/dataset/lmdb.py
Normal file
65
data/dataset/lmdb.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import lmdb
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from data.transform import transform_pipeline
|
||||
|
||||
|
||||
def default_transform_way(transform, sample):
|
||||
return [transform(sample[0]), *sample[1:]]
|
||||
|
||||
|
||||
class LMDBDataset(Dataset):
|
||||
def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True,
|
||||
**lmdb_kwargs):
|
||||
self.path = lmdb_path
|
||||
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
|
||||
lock=False, **lmdb_kwargs)
|
||||
|
||||
with self.db.begin(write=False) as txn:
|
||||
self._len = pickle.loads(txn.get(b"$$len$$"))
|
||||
self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$"))
|
||||
if pipeline is None:
|
||||
self.not_done_pipeline = []
|
||||
else:
|
||||
self.not_done_pipeline = self._remain_pipeline(pipeline)
|
||||
self.transform = transform_pipeline(self.not_done_pipeline)
|
||||
self.transform_way = transform_way
|
||||
essential_attr = pickle.loads(txn.get(b"$$essential_attr$$"))
|
||||
for ea in essential_attr:
|
||||
setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8"))))
|
||||
|
||||
def _remain_pipeline(self, pipeline):
|
||||
for i, dp in enumerate(self.done_pipeline):
|
||||
if pipeline[i] != dp:
|
||||
raise ValueError(
|
||||
f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}")
|
||||
return pipeline[len(self.done_pipeline):]
|
||||
|
||||
def __repr__(self):
|
||||
return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}"
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def __getitem__(self, idx):
|
||||
with self.db.begin(write=False) as txn:
|
||||
sample = pickle.loads(txn.get("{}".format(idx).encode()))
|
||||
sample = self.transform_way(self.transform, sample)
|
||||
return sample
|
||||
|
||||
@staticmethod
|
||||
def lmdbify(dataset, done_pipeline, lmdb_path):
|
||||
env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path))
|
||||
with env.begin(write=True) as txn:
|
||||
for i in tqdm(range(len(dataset)), ncols=0):
|
||||
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
|
||||
txn.put(b"$$len$$", pickle.dumps(len(dataset)))
|
||||
txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline))
|
||||
essential_attr = getattr(dataset, "essential_attr", list())
|
||||
txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr))
|
||||
for ea in essential_attr:
|
||||
txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea)))
|
||||
122
data/dataset/pose_transfer.py
Normal file
122
data/dataset/pose_transfer.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from collections import defaultdict
|
||||
from itertools import permutations, combinations
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
from data.registry import DATASET
|
||||
from data.transform import transform_pipeline
|
||||
from data.util import dlib_landmark
|
||||
|
||||
|
||||
def normalize_tensor(tensor):
|
||||
tensor = tensor.float()
|
||||
tensor -= tensor.min()
|
||||
tensor /= tensor.max()
|
||||
return tensor
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256),
|
||||
with_path=False):
|
||||
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
|
||||
self.edge_type = edge_type
|
||||
self.size = size
|
||||
self.edges_path = Path(edges_path)
|
||||
self.landmarks_path = Path(landmarks_path)
|
||||
assert self.edges_path.exists()
|
||||
assert self.landmarks_path.exists()
|
||||
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
||||
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
|
||||
self.random_pair = random_pair
|
||||
self.with_path = with_path
|
||||
|
||||
def get_edge(self, origin_path):
|
||||
op = Path(origin_path)
|
||||
if self.edge_type.startswith("landmark_"):
|
||||
edge_type = self.edge_type.lstrip("landmark_")
|
||||
use_landmark = op.parent.name.endswith("A")
|
||||
else:
|
||||
edge_type = self.edge_type
|
||||
use_landmark = False
|
||||
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
|
||||
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR))
|
||||
if not use_landmark:
|
||||
return origin_edge
|
||||
else:
|
||||
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt"
|
||||
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
|
||||
|
||||
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size)))
|
||||
part_labels = normalize_tensor(torch.from_numpy(part_labels))
|
||||
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
|
||||
# edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
|
||||
# edges = part_edge + edges
|
||||
return torch.cat([origin_edge, part_edge, dist_tensor, part_labels])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||
output = dict(a={}, b={})
|
||||
output["a"]["img"], output["a"]["path"] = self.A[a_idx]
|
||||
output["b"]["img"], output["b"]["path"] = self.B[b_idx]
|
||||
for p in "ab":
|
||||
output[p]["edge"] = self.get_edge(output[p]["path"])
|
||||
return output
|
||||
|
||||
def __len__(self):
|
||||
return max(len(self.A), len(self.B))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class PoseFacesWithSingleAnime(Dataset):
|
||||
def __init__(self, root_face, root_anime, landmark_path, num_face, face_pipeline, anime_pipeline, img_size,
|
||||
with_order=True):
|
||||
self.num_face = num_face
|
||||
self.landmark_path = Path(landmark_path)
|
||||
self.with_order = with_order
|
||||
self.root_face = Path(root_face)
|
||||
self.root_anime = Path(root_anime)
|
||||
self.img_size = img_size
|
||||
self.face_samples = self.iter_folders()
|
||||
self.face_pipeline = transform_pipeline(face_pipeline)
|
||||
self.B = SingleFolderDataset(root_anime, anime_pipeline, with_path=True)
|
||||
|
||||
def iter_folders(self):
|
||||
pics_per_person = defaultdict(list)
|
||||
for p in self.root_face.glob("*.jpg"):
|
||||
pics_per_person[p.stem[:7]].append(p.stem)
|
||||
data = []
|
||||
for p in pics_per_person:
|
||||
if len(pics_per_person[p]) >= self.num_face:
|
||||
if self.with_order:
|
||||
data.extend(list(combinations(pics_per_person[p], self.num_face)))
|
||||
else:
|
||||
data.extend(list(permutations(pics_per_person[p], self.num_face)))
|
||||
return data
|
||||
|
||||
def read_pose(self, pose_txt):
|
||||
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(pose_txt, size=self.img_size)
|
||||
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.img_size)))
|
||||
part_labels = normalize_tensor(torch.from_numpy(part_labels))
|
||||
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
|
||||
return torch.cat([part_labels, part_edge, dist_tensor])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.face_samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
output = dict()
|
||||
output["anime_img"], output["anime_path"] = self.B[torch.randint(len(self.B), (1,)).item()]
|
||||
for i, f in enumerate(self.face_samples[idx]):
|
||||
output[f"face_{i}"] = self.face_pipeline(self.root_face / f"{f}.jpg")
|
||||
output[f"pose_{i}"] = self.read_pose(self.landmark_path / self.root_face.name / f"{f}.txt")
|
||||
output[f"stem_{i}"] = f
|
||||
return output
|
||||
@@ -28,7 +28,7 @@ class Load:
|
||||
|
||||
|
||||
def transform_pipeline(pipeline_description):
|
||||
if len(pipeline_description) == 0:
|
||||
if pipeline_description is None or len(pipeline_description) == 0:
|
||||
return lambda x: x
|
||||
transform_list = [TRANSFORM.build_with(pd) for pd in pipeline_description]
|
||||
return transforms.Compose(transform_list)
|
||||
|
||||
105
engine/CycleGAN.py
Normal file
105
engine/CycleGAN.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from itertools import chain
|
||||
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel
|
||||
from engine.util.build import build_model
|
||||
from engine.util.container import GANImageBuffer, LossContainer
|
||||
from engine.util.loss import pixel_loss, gan_loss, feature_match_loss
|
||||
from loss.I2I.edge_loss import EdgeLoss
|
||||
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
|
||||
|
||||
class CycleGANEngineKernel(EngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.gan_loss = gan_loss(config.loss.gan)
|
||||
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
|
||||
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
|
||||
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
|
||||
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same"))
|
||||
self.edge_loss = LossContainer(config.loss.edge.weight, EdgeLoss(
|
||||
"HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(idist.device()))
|
||||
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
||||
self.discriminators.keys()}
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
generators = dict(
|
||||
a2b=build_model(self.config.model.generator),
|
||||
b2a=build_model(self.config.model.generator)
|
||||
)
|
||||
discriminators = dict(
|
||||
a=build_model(self.config.model.discriminator),
|
||||
b=build_model(self.config.model.discriminator)
|
||||
)
|
||||
self.logger.debug(discriminators["a"])
|
||||
self.logger.debug(generators["a2b"])
|
||||
|
||||
for m in chain(generators.values(), discriminators.values()):
|
||||
generation_init_weights(m)
|
||||
|
||||
return generators, discriminators
|
||||
|
||||
def setup_after_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
def setup_before_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(False)
|
||||
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
images = dict()
|
||||
with torch.set_grad_enabled(not inference):
|
||||
images["a2b"] = self.generators["a2b"](batch["a"])
|
||||
images["b2a"] = self.generators["b2a"](batch["b"])
|
||||
images["a2b2a"] = self.generators["b2a"](images["a2b"])
|
||||
images["b2a2b"] = self.generators["a2b"](images["b2a"])
|
||||
if self.id_loss.weight > 0:
|
||||
images["a2a"] = self.generators["b2a"](batch["a"])
|
||||
images["b2b"] = self.generators["a2b"](batch["b"])
|
||||
return images
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
for ph in "ab":
|
||||
loss[f"cycle_{ph}"] = self.cycle_loss(generated["a2b2a" if ph == "a" else "b2a2b"], batch[ph])
|
||||
loss[f"id_{ph}"] = self.id_loss(generated[f"{ph}2{ph}"], batch[ph])
|
||||
loss[f"mgc_{ph}"] = self.mgc_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph])
|
||||
prediction_fake = self.discriminators[ph](generated["a2b" if ph == "b" else "b2a"])
|
||||
loss[f"gan_{ph}"] = self.config.loss.gan.weight * self.gan_loss(prediction_fake, True)
|
||||
if self.fm_loss.weight > 0:
|
||||
prediction_real = self.discriminators[ph](batch[ph])
|
||||
loss[f"feature_match_{ph}"] = self.fm_loss(prediction_fake, prediction_real)
|
||||
loss[f"edge_{ph}"] = self.edge_loss(generated["a2b" if ph == "a" else "b2a"], batch[ph], gt_is_edge=False)
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
for phase in "ab":
|
||||
generated_image = self.image_buffers[phase].query(generated["b2a" if phase == "a" else "a2b"].detach())
|
||||
loss[f"gan_{phase}"] = (self.gan_loss(self.discriminators[phase](generated_image), False,
|
||||
is_discriminator=True) +
|
||||
self.gan_loss(self.discriminators[phase](batch[phase]), True,
|
||||
is_discriminator=True)) / 2
|
||||
return loss
|
||||
|
||||
def intermediate_images(self, batch, generated) -> dict:
|
||||
"""
|
||||
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
:param batch:
|
||||
:param generated: dict of images
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
return dict(
|
||||
a=[batch["a"].detach(), generated["a2b"].detach(), generated["a2b2a"].detach()],
|
||||
b=[batch["b"].detach(), generated["b2a"].detach(), generated["b2a2b"].detach()],
|
||||
)
|
||||
|
||||
|
||||
def run(task, config, _):
|
||||
kernel = CycleGANEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
86
engine/GauGAN.py
Normal file
86
engine/GauGAN.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from itertools import chain
|
||||
|
||||
import torch
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel
|
||||
from engine.util.build import build_model
|
||||
from engine.util.container import GANImageBuffer, LossContainer
|
||||
from engine.util.loss import gan_loss, feature_match_loss, perceptual_loss
|
||||
from loss.I2I.minimal_geometry_distortion_constraint_loss import MGCLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
|
||||
|
||||
class GauGANEngineKernel(EngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.gan_loss = gan_loss(config.loss.gan)
|
||||
self.mgc_loss = LossContainer(config.loss.mgc.weight, MGCLoss("opposite"))
|
||||
self.fm_loss = LossContainer(config.loss.fm.weight, feature_match_loss(1, "same"))
|
||||
self.perceptual_loss = LossContainer(config.loss.perceptual.weight, perceptual_loss(config.loss.perceptual))
|
||||
|
||||
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
||||
self.discriminators.keys()}
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
generators = dict(
|
||||
main=build_model(self.config.model.generator)
|
||||
)
|
||||
discriminators = dict(
|
||||
b=build_model(self.config.model.discriminator)
|
||||
)
|
||||
self.logger.debug(discriminators["b"])
|
||||
self.logger.debug(generators["main"])
|
||||
|
||||
for m in chain(generators.values(), discriminators.values()):
|
||||
generation_init_weights(m)
|
||||
|
||||
return generators, discriminators
|
||||
|
||||
def setup_after_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
def setup_before_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(False)
|
||||
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
images = dict()
|
||||
with torch.set_grad_enabled(not inference):
|
||||
images["a2b"] = self.generators["main"](batch["a"])
|
||||
return images
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
prediction_fake = self.discriminators["b"](generated["a2b"])
|
||||
loss["gan"] = self.config.loss.gan.weight * self.gan_loss(prediction_fake, True)
|
||||
loss["mgc"] = self.mgc_loss(generated["a2b"], batch["a"])
|
||||
loss["perceptual"] = self.perceptual_loss(generated["a2b"], batch["a"])
|
||||
if self.fm_loss.weight > 0:
|
||||
prediction_real = self.discriminators["b"](batch["b"])
|
||||
loss["feature_match"] = self.fm_loss(prediction_fake, prediction_real)
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
generated_image = self.image_buffers["b"].query(generated["a2b"].detach())
|
||||
loss["b"] = (self.gan_loss(self.discriminators["b"](generated_image), False, is_discriminator=True) +
|
||||
self.gan_loss(self.discriminators["b"](batch["b"]), True, is_discriminator=True)) / 2
|
||||
return loss
|
||||
|
||||
def intermediate_images(self, batch, generated) -> dict:
|
||||
"""
|
||||
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
:param batch:
|
||||
:param generated: dict of images
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
return dict(
|
||||
a=[batch["a"].detach(), generated["a2b"].detach()],
|
||||
)
|
||||
|
||||
|
||||
def run(task, config, _):
|
||||
kernel = GauGANEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
154
engine/MUNIT.py
Normal file
154
engine/MUNIT.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
||||
from engine.util.build import build_model
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
from loss.gan import GANLoss
|
||||
|
||||
|
||||
def mse_loss(x, target_flag):
|
||||
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||
|
||||
|
||||
def bce_loss(x, target_flag):
|
||||
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||
|
||||
|
||||
class MUNITEngineKernel(EngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
||||
perceptual_loss_cfg.pop("weight")
|
||||
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
||||
self.train_generator_first = False
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
generators = dict(
|
||||
a=build_model(self.config.model.generator),
|
||||
b=build_model(self.config.model.generator)
|
||||
)
|
||||
discriminators = dict(
|
||||
a=build_model(self.config.model.discriminator),
|
||||
b=build_model(self.config.model.discriminator)
|
||||
)
|
||||
self.logger.debug(discriminators["a"])
|
||||
self.logger.debug(generators["a"])
|
||||
|
||||
return generators, discriminators
|
||||
|
||||
def setup_after_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
def setup_before_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(False)
|
||||
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
styles = dict()
|
||||
contents = dict()
|
||||
images = dict()
|
||||
with torch.set_grad_enabled(not inference):
|
||||
for phase in "ab":
|
||||
contents[phase], styles[phase] = self.generators[phase].encode(batch[phase])
|
||||
images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase])
|
||||
styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device())
|
||||
|
||||
for phase in ("a2b", "b2a"):
|
||||
# images["a2b"] = Gb.decode(content_a, random_style_b)
|
||||
images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"])
|
||||
# contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"])
|
||||
contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase])
|
||||
if self.config.loss.recon.cycle.weight > 0:
|
||||
images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]])
|
||||
return dict(styles=styles, contents=contents, images=images)
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
for phase in "ab":
|
||||
loss[f"recon_image_{phase}"] = self.config.loss.recon.image.weight * self.recon_loss(
|
||||
batch[phase], generated["images"]["{0}2{0}".format(phase)])
|
||||
loss[f"recon_content_{phase}"] = self.config.loss.recon.content.weight * self.recon_loss(
|
||||
generated["contents"][phase], generated["contents"]["a2b" if phase == "a" else "b2a"])
|
||||
loss[f"recon_style_{phase}"] = self.config.loss.recon.style.weight * self.recon_loss(
|
||||
generated["styles"][f"random_{phase}"], generated["styles"]["b2a" if phase == "a" else "a2b"])
|
||||
pred_fake = self.discriminators[phase](generated["images"]["b2a" if phase == "a" else "a2b"])
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for sub_pred_fake in pred_fake:
|
||||
# last output is actual prediction
|
||||
loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True)
|
||||
if self.config.loss.recon.cycle.weight > 0:
|
||||
loss[f"recon_cycle_{phase}"] = self.config.loss.recon.cycle.weight * self.recon_loss(
|
||||
batch[phase], generated["images"]["a2b2a" if phase == "a" else "b2a2b"])
|
||||
if self.config.loss.perceptual.weight > 0:
|
||||
loss[f"perceptual_{phase}"] = self.config.loss.perceptual.weight * self.perceptual_loss(
|
||||
batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
for phase in ("a2b", "b2a"):
|
||||
pred_real = self.discriminators[phase[-1]](batch[phase[-1]])
|
||||
pred_fake = self.discriminators[phase[-1]](generated["images"][phase].detach())
|
||||
loss[f"gan_{phase[-1]}"] = 0
|
||||
for i in range(len(pred_fake)):
|
||||
loss[f"gan_{phase[-1]}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
||||
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
|
||||
return loss
|
||||
|
||||
def intermediate_images(self, batch, generated) -> dict:
|
||||
"""
|
||||
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
:param batch:
|
||||
:param generated: dict of images
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
generated = {img: generated["images"][img].detach() for img in generated["images"]}
|
||||
images = dict()
|
||||
for phase in "ab":
|
||||
images[phase] = [batch[phase].detach(), generated["{0}2{0}".format(phase)],
|
||||
generated["a2b" if phase == "a" else "b2a"]]
|
||||
if self.config.loss.recon.cycle.weight > 0:
|
||||
images[phase].append(generated["a2b2a" if phase == "a" else "b2a2b"])
|
||||
return images
|
||||
|
||||
|
||||
class MUNITTestEngineKernel(TestEngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
def build_generators(self) -> dict:
|
||||
generators = dict(
|
||||
a=build_model(self.config.model.generator),
|
||||
b=build_model(self.config.model.generator)
|
||||
)
|
||||
return generators
|
||||
|
||||
def to_load(self):
|
||||
return {f"generator_{k}": self.generators[k] for k in self.generators}
|
||||
|
||||
def inference(self, batch):
|
||||
with torch.no_grad():
|
||||
fake, _, _ = self.generators["a2b"](batch[0])
|
||||
return fake.detach()
|
||||
|
||||
|
||||
def run(task, config, _):
|
||||
if task == "train":
|
||||
kernel = MUNITEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
elif task == "test":
|
||||
kernel = MUNITTestEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
else:
|
||||
raise NotImplemented
|
||||
212
engine/TAFG.py
Normal file
212
engine/TAFG.py
Normal file
@@ -0,0 +1,212 @@
|
||||
from itertools import chain
|
||||
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel
|
||||
from engine.util.build import build_model
|
||||
from loss.I2I.edge_loss import EdgeLoss
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
from loss.gan import GANLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
|
||||
|
||||
class TAFGEngineKernel(EngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
||||
perceptual_loss_cfg.pop("weight")
|
||||
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||
|
||||
style_loss_cfg = OmegaConf.to_container(config.loss.style)
|
||||
style_loss_cfg.pop("weight")
|
||||
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
|
||||
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
||||
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
|
||||
self.content_recon_loss = nn.L1Loss() if config.loss.content_recon.level == 1 else nn.MSELoss()
|
||||
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
|
||||
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
|
||||
idist.device())
|
||||
|
||||
def _process_batch(self, batch, inference=False):
|
||||
# batch["b"] = batch["b"] if inference else batch["b"][0].expand(batch["a"].size())
|
||||
return batch
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
generators = dict(
|
||||
main=build_model(self.config.model.generator)
|
||||
)
|
||||
discriminators = dict(
|
||||
a=build_model(self.config.model.discriminator),
|
||||
b=build_model(self.config.model.discriminator)
|
||||
)
|
||||
self.logger.debug(discriminators["a"])
|
||||
self.logger.debug(generators["main"])
|
||||
|
||||
for m in chain(generators.values(), discriminators.values()):
|
||||
generation_init_weights(m)
|
||||
|
||||
return generators, discriminators
|
||||
|
||||
def setup_after_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
def setup_before_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(False)
|
||||
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
generator = self.generators["main"]
|
||||
batch = self._process_batch(batch, inference)
|
||||
|
||||
styles = dict()
|
||||
contents = dict()
|
||||
images = dict()
|
||||
with torch.set_grad_enabled(not inference):
|
||||
contents["a"], styles["a"] = generator.encode(batch["a"]["edge"], batch["a"]["img"], "a", "a")
|
||||
contents["b"], styles["b"] = generator.encode(batch["b"]["edge"], batch["b"]["img"], "b", "b")
|
||||
for ph in "ab":
|
||||
images[f"{ph}2{ph}"] = generator.decode(contents[ph], styles[ph], ph)
|
||||
|
||||
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
|
||||
styles[f"random_b"] = torch.randn_like(styles["b"]).to(idist.device())
|
||||
images["a2b"] = generator.decode(contents["a"], styles["random_b"], "b")
|
||||
contents["recon_a"], styles["recon_b"] = generator.encode(self.edge_loss.edge_extractor(images["a2b"]),
|
||||
images["a2b"], "b", "b")
|
||||
images["cycle_b"] = generator.decode(contents["b"], styles["recon_b"], "b")
|
||||
images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a")
|
||||
return dict(styles=styles, contents=contents, images=images)
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
batch = self._process_batch(batch)
|
||||
loss = dict()
|
||||
|
||||
for ph in "ab":
|
||||
loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss(
|
||||
generated["images"][f"{ph}2{ph}"], batch[ph]["img"])
|
||||
|
||||
pred_fake = self.discriminators[ph](generated["images"][f"{ph}2{ph}"])
|
||||
loss[f"gan_{ph}"] = 0
|
||||
for sub_pred_fake in pred_fake:
|
||||
# last output is actual prediction
|
||||
loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
|
||||
|
||||
if self.engine.state.epoch == self.config.misc.add_new_loss_epoch:
|
||||
self.generators["main"].style_converters.requires_grad_(False)
|
||||
self.generators["main"].style_encoders.requires_grad_(False)
|
||||
|
||||
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
|
||||
|
||||
pred_fake = self.discriminators[ph](generated["images"]["a2b"])
|
||||
loss["gan_a2b"] = 0
|
||||
for sub_pred_fake in pred_fake:
|
||||
# last output is actual prediction
|
||||
loss["gan_a2b"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
|
||||
|
||||
loss["recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
|
||||
generated["contents"]["a"], generated["contents"]["recon_a"]
|
||||
)
|
||||
loss["recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
|
||||
generated["styles"]["random_b"], generated["styles"]["recon_b"]
|
||||
)
|
||||
|
||||
if self.config.loss.perceptual.weight > 0:
|
||||
loss["perceptual_a"] = self.config.loss.perceptual.weight * self.perceptual_loss(
|
||||
batch["a"]["img"], generated["images"]["a2b"]
|
||||
)
|
||||
|
||||
if self.config.loss.cycle.weight > 0:
|
||||
loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss(
|
||||
batch["a"]["img"], generated["images"][f"cycle_a"]
|
||||
)
|
||||
|
||||
# for ph in "ab":
|
||||
#
|
||||
# if self.config.loss.style.weight > 0:
|
||||
# loss[f"style_{ph}"] = self.config.loss.style.weight * self.style_loss(
|
||||
# batch[ph]["img"], generated["images"][f"a2{ph}"]
|
||||
# )
|
||||
|
||||
if self.config.loss.edge.weight > 0:
|
||||
loss["edge_a"] = self.config.loss.edge.weight * self.edge_loss(
|
||||
generated["images"]["a2b"], batch["a"]["edge"][:, 0:1, :, :]
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
|
||||
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
|
||||
for phase in self.discriminators.keys():
|
||||
pred_real = self.discriminators[phase](batch[phase]["img"])
|
||||
pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach())
|
||||
pred_fake_2 = self.discriminators[phase](generated["images"]["a2b"].detach())
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for i in range(len(pred_fake)):
|
||||
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) +
|
||||
self.gan_loss(pred_fake_2[i][-1], False, is_discriminator=True)
|
||||
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 3
|
||||
else:
|
||||
for phase in self.discriminators.keys():
|
||||
pred_real = self.discriminators[phase](batch[phase]["img"])
|
||||
pred_fake = self.discriminators[phase](generated["images"][f"{phase}2{phase}"].detach())
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for i in range(len(pred_fake)):
|
||||
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
||||
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
|
||||
return loss
|
||||
|
||||
def intermediate_images(self, batch, generated) -> dict:
|
||||
"""
|
||||
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
:param batch:
|
||||
:param generated: dict of images
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
batch = self._process_batch(batch)
|
||||
if self.engine.state.epoch > self.config.misc.add_new_loss_epoch:
|
||||
return dict(
|
||||
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
|
||||
batch["a"]["img"].detach(),
|
||||
generated["images"]["a2a"].detach(),
|
||||
generated["images"]["a2b"].detach(),
|
||||
generated["images"]["cycle_a"].detach(),
|
||||
],
|
||||
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
|
||||
batch["b"]["img"].detach(),
|
||||
generated["images"]["b2b"].detach(),
|
||||
generated["images"]["cycle_b"].detach()]
|
||||
)
|
||||
else:
|
||||
return dict(
|
||||
a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
|
||||
batch["a"]["img"].detach(),
|
||||
generated["images"]["a2a"].detach(),
|
||||
],
|
||||
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
|
||||
batch["b"]["img"].detach(),
|
||||
generated["images"]["b2b"].detach(),
|
||||
]
|
||||
)
|
||||
|
||||
def change_engine(self, config, trainer):
|
||||
pass
|
||||
# @trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3)))
|
||||
# def change_config(engine):
|
||||
# with read_write(config):
|
||||
# config.loss.perceptual.weight = 5
|
||||
|
||||
|
||||
def run(task, config, _):
|
||||
kernel = TAFGEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
245
engine/TAHG.py
245
engine/TAHG.py
@@ -1,245 +0,0 @@
|
||||
from itertools import chain
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.metrics import RunningAverage
|
||||
from ignite.utils import convert_tensor
|
||||
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
|
||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||
|
||||
from omegaconf import OmegaConf, read_write
|
||||
|
||||
import data
|
||||
from loss.gan import GANLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
from model.GAN.residual_generator import GANImageBuffer
|
||||
from loss.I2I.edge_loss import EdgeLoss
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
from util.image import make_2d_grid
|
||||
from util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||
from util.build import build_model, build_optimizer
|
||||
|
||||
|
||||
def build_lr_schedulers(optimizers, config):
|
||||
g_milestones_values = [
|
||||
(0, config.optimizers.generator.lr),
|
||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
d_milestones_values = [
|
||||
(0, config.optimizers.discriminator.lr),
|
||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
return dict(
|
||||
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
|
||||
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
|
||||
)
|
||||
|
||||
|
||||
def get_trainer(config, logger, train_data_loader):
|
||||
generator = build_model(config.model.generator, config.distributed.model)
|
||||
discriminators = dict(
|
||||
a=build_model(config.model.discriminator, config.distributed.model),
|
||||
b=build_model(config.model.discriminator, config.distributed.model),
|
||||
)
|
||||
generation_init_weights(generator)
|
||||
for m in discriminators.values():
|
||||
generation_init_weights(m)
|
||||
|
||||
logger.debug(discriminators["a"])
|
||||
logger.debug(generator)
|
||||
|
||||
optimizers = dict(
|
||||
g=build_optimizer(generator.parameters(), config.optimizers.generator),
|
||||
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
|
||||
)
|
||||
logger.info(f"build optimizers:\n{optimizers}")
|
||||
|
||||
lr_schedulers = build_lr_schedulers(optimizers, config)
|
||||
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
|
||||
edge_loss_cfg = OmegaConf.to_container(config.loss.edge)
|
||||
edge_loss_cfg.pop("weight")
|
||||
edge_loss = EdgeLoss(**edge_loss_cfg).to(idist.device())
|
||||
|
||||
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
||||
perceptual_loss_cfg.pop("weight")
|
||||
perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||
|
||||
recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
||||
|
||||
image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()}
|
||||
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
real = dict(a=batch["a"], b=batch["b"])
|
||||
fake = dict(
|
||||
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
|
||||
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
|
||||
)
|
||||
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
|
||||
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
|
||||
|
||||
optimizers["g"].zero_grad()
|
||||
loss_g = dict()
|
||||
for d in "ab":
|
||||
discriminators[d].requires_grad_(False)
|
||||
pred_fake = discriminators[d](fake[d])
|
||||
loss_g[f"gan_{d}"] = config.loss.gan.weight * gan_loss(pred_fake, True)
|
||||
_, t = perceptual_loss(fake[d], real[d])
|
||||
loss_g[f"perceptual_{d}"] = config.loss.perceptual.weight * t
|
||||
loss_g[f"edge_{d}"] = config.loss.edge.weight * edge_loss(fake[d], batch["edge_a"])
|
||||
loss_g["recon_a"] = config.loss.recon.weight * recon_loss(fake["a"], real["a"])
|
||||
loss_g["recon_b"] = config.loss.recon.weight * recon_loss(rec_b, real["b"])
|
||||
loss_g["recon_bb"] = config.loss.recon.weight * recon_loss(rec_bb, real["b"])
|
||||
sum(loss_g.values()).backward()
|
||||
optimizers["g"].step()
|
||||
|
||||
for discriminator in discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
optimizers["d"].zero_grad()
|
||||
loss_d = dict()
|
||||
for k in discriminators.keys():
|
||||
pred_real = discriminators[k](real[k])
|
||||
pred_fake = discriminators[k](image_buffers[k].query(fake[k].detach()))
|
||||
loss_d[f"gan_{k}"] = (gan_loss(pred_real, True, is_discriminator=True) +
|
||||
gan_loss(pred_fake, False, is_discriminator=True)) / 2
|
||||
sum(loss_d.values()).backward()
|
||||
optimizers["d"].step()
|
||||
|
||||
generated_img = {f"real_{k}": real[k].detach() for k in real}
|
||||
generated_img["rec_b"] = rec_b.detach()
|
||||
generated_img["rec_bb"] = rec_b.detach()
|
||||
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
|
||||
generated_img.update({f"edge_{k}": batch[f"edge_{k}"].expand(-1, 3, -1, -1).detach() for k in "ab"})
|
||||
return {
|
||||
"loss": {
|
||||
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
|
||||
"d": {ln: loss_d[ln].mean().item() for ln in loss_d},
|
||||
},
|
||||
"img": generated_img
|
||||
}
|
||||
|
||||
trainer = Engine(_step)
|
||||
trainer.logger = logger
|
||||
for lr_shd in lr_schedulers.values():
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
|
||||
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
|
||||
|
||||
to_save = dict(trainer=trainer)
|
||||
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
|
||||
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
|
||||
to_save.update({"generator": generator})
|
||||
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
||||
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
|
||||
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
|
||||
def output_transform(output):
|
||||
loss = dict()
|
||||
for tl in output["loss"]:
|
||||
if isinstance(output["loss"][tl], dict):
|
||||
for l in output["loss"][tl]:
|
||||
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
||||
else:
|
||||
loss[tl] = output["loss"][tl]
|
||||
return loss
|
||||
|
||||
iter_per_epoch = len(train_data_loader)
|
||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform, iter_per_epoch)
|
||||
if tensorboard_handler is not None:
|
||||
tensorboard_handler.attach(
|
||||
trainer,
|
||||
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
|
||||
event_name=Events.ITERATION_STARTED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
|
||||
)
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.image, 1)))
|
||||
def show_images(engine):
|
||||
output = engine.state.output
|
||||
image_order = dict(
|
||||
a=["edge_a", "real_a", "fake_a", "fake_b"],
|
||||
b=["edge_b", "real_b", "rec_b", "rec_bb"]
|
||||
)
|
||||
for k in "ab":
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"train/{k}",
|
||||
make_2d_grid([output["img"][o] for o in image_order[k]]),
|
||||
engine.state.iteration
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(config.misc.random_seed)
|
||||
random_start = torch.randperm(len(engine.state.test_dataset) - 11, generator=g).tolist()[0]
|
||||
test_images = dict(
|
||||
a=[[], [], [], []],
|
||||
b=[[], [], [], []]
|
||||
)
|
||||
for i in range(random_start, random_start + 10):
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
for k in batch:
|
||||
batch[k] = batch[k].view(1, *batch[k].size())
|
||||
|
||||
real = dict(a=batch["a"], b=batch["b"])
|
||||
fake = dict(
|
||||
a=generator(content_img=batch["edge_a"], style_img=real["a"], which_decoder="a"),
|
||||
b=generator(content_img=batch["edge_a"], style_img=real["b"], which_decoder="b"),
|
||||
)
|
||||
rec_b = generator(content_img=batch["edge_b"], style_img=real["b"], which_decoder="b")
|
||||
rec_bb = generator(content_img=batch["edge_b"], style_img=fake["b"], which_decoder="b")
|
||||
|
||||
test_images["a"][0].append(batch["edge_a"])
|
||||
test_images["a"][1].append(batch["a"])
|
||||
test_images["a"][2].append(fake["a"])
|
||||
test_images["a"][3].append(fake["b"])
|
||||
test_images["b"][0].append(batch["edge_b"])
|
||||
test_images["b"][1].append(batch["b"])
|
||||
test_images["b"][2].append(rec_b)
|
||||
test_images["b"][3].append(rec_bb)
|
||||
for n in "ab":
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"test/{n}",
|
||||
make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
|
||||
engine.state.iteration
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
def run(task, config, logger):
|
||||
assert torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.benchmark = True
|
||||
logger.info(f"start task {task}")
|
||||
with read_write(config):
|
||||
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
|
||||
|
||||
if task == "train":
|
||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||
trainer = get_trainer(config, logger, train_data_loader)
|
||||
if idist.get_rank() == 0:
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
trainer.state.test_dataset = test_dataset
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
else:
|
||||
return NotImplemented(f"invalid task: {task}")
|
||||
119
engine/TSIT.py
Normal file
119
engine/TSIT.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from itertools import chain
|
||||
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
||||
from engine.util.build import build_model
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
from loss.gan import GANLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
|
||||
|
||||
class TSITEngineKernel(EngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
||||
perceptual_loss_cfg.pop("weight")
|
||||
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
|
||||
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
generators = dict(
|
||||
main=build_model(self.config.model.generator)
|
||||
)
|
||||
discriminators = dict(
|
||||
b=build_model(self.config.model.discriminator)
|
||||
)
|
||||
self.logger.debug(discriminators["b"])
|
||||
self.logger.debug(generators["main"])
|
||||
|
||||
for m in chain(generators.values(), discriminators.values()):
|
||||
generation_init_weights(m)
|
||||
|
||||
return generators, discriminators
|
||||
|
||||
def setup_after_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
def setup_before_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(False)
|
||||
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
with torch.set_grad_enabled(not inference):
|
||||
fake = dict(
|
||||
b=self.generators["main"](content_img=batch["a"])
|
||||
)
|
||||
return fake
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
loss["perceptual"] = self.perceptual_loss(generated["b"], batch["a"]) * self.config.loss.perceptual.weight
|
||||
for phase in "b":
|
||||
pred_fake = self.discriminators[phase](generated[phase])
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for sub_pred_fake in pred_fake:
|
||||
# last output is actual prediction
|
||||
loss[f"gan_{phase}"] += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
for phase in self.discriminators.keys():
|
||||
pred_real = self.discriminators[phase](batch[phase])
|
||||
pred_fake = self.discriminators[phase](generated[phase].detach())
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for i in range(len(pred_fake)):
|
||||
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
||||
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
|
||||
return loss
|
||||
|
||||
def intermediate_images(self, batch, generated) -> dict:
|
||||
"""
|
||||
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
:param batch:
|
||||
:param generated: dict of images
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
return dict(
|
||||
b=[batch["a"].detach(), batch["b"].detach(), generated["b"].detach()]
|
||||
)
|
||||
|
||||
|
||||
class TSITTestEngineKernel(TestEngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
def build_generators(self) -> dict:
|
||||
generators = dict(
|
||||
main=build_model(self.config.model.generator)
|
||||
)
|
||||
return generators
|
||||
|
||||
def to_load(self):
|
||||
return {f"generator_{k}": self.generators[k] for k in self.generators}
|
||||
|
||||
def inference(self, batch):
|
||||
with torch.no_grad():
|
||||
fake = self.generators["main"](content_img=batch["a"][0], style_img=batch["b"][0])
|
||||
return {"a": fake.detach()}
|
||||
|
||||
|
||||
def run(task, config, _):
|
||||
if task == "train":
|
||||
kernel = TSITEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
elif task == "test":
|
||||
kernel = TSITTestEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
else:
|
||||
raise NotImplemented
|
||||
150
engine/U-GAT-IT.py
Normal file
150
engine/U-GAT-IT.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import torch
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
||||
from engine.util.build import build_model
|
||||
from engine.util.container import LossContainer
|
||||
from engine.util.loss import bce_loss, mse_loss, pixel_loss, gan_loss
|
||||
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
|
||||
from util.image import attention_colored_map
|
||||
|
||||
|
||||
class RhoClipper(object):
|
||||
def __init__(self, clip_min, clip_max):
|
||||
self.clip_min = clip_min
|
||||
self.clip_max = clip_max
|
||||
assert clip_min < clip_max
|
||||
|
||||
def __call__(self, module):
|
||||
if hasattr(module, 'rho'):
|
||||
w = module.rho.data
|
||||
w = w.clamp(self.clip_min, self.clip_max)
|
||||
module.rho.data = w
|
||||
|
||||
|
||||
class UGATITEngineKernel(EngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.gan_loss = gan_loss(config.loss.gan)
|
||||
self.cycle_loss = LossContainer(config.loss.cycle.weight, pixel_loss(config.loss.cycle.level))
|
||||
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
|
||||
self.id_loss = LossContainer(config.loss.id.weight, pixel_loss(config.loss.id.level))
|
||||
self.bce_loss = LossContainer(self.config.loss.cam.weight, bce_loss)
|
||||
self.mse_loss = LossContainer(self.config.loss.gan.weight, mse_loss)
|
||||
|
||||
self.rho_clipper = RhoClipper(0, 1)
|
||||
self.train_generator_first = False
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
generators = dict(
|
||||
a2b=build_model(self.config.model.generator),
|
||||
b2a=build_model(self.config.model.generator)
|
||||
)
|
||||
discriminators = dict(
|
||||
la=build_model(self.config.model.local_discriminator),
|
||||
lb=build_model(self.config.model.local_discriminator),
|
||||
ga=build_model(self.config.model.global_discriminator),
|
||||
gb=build_model(self.config.model.global_discriminator),
|
||||
)
|
||||
self.logger.debug(discriminators["ga"])
|
||||
self.logger.debug(generators["a2b"])
|
||||
|
||||
return generators, discriminators
|
||||
|
||||
def setup_after_g(self):
|
||||
for generator in self.generators.values():
|
||||
generator.apply(self.rho_clipper)
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
def setup_before_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(False)
|
||||
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
images = dict()
|
||||
heatmap = dict()
|
||||
cam_pred = dict()
|
||||
|
||||
with torch.set_grad_enabled(not inference):
|
||||
images["a2b"], cam_pred["a2b"], heatmap["a2b"] = self.generators["a2b"](batch["a"])
|
||||
images["b2a"], cam_pred["b2a"], heatmap["b2a"] = self.generators["b2a"](batch["b"])
|
||||
images["a2b2a"], _, heatmap["a2b2a"] = self.generators["b2a"](images["a2b"])
|
||||
images["b2a2b"], _, heatmap["b2a2b"] = self.generators["a2b"](images["b2a"])
|
||||
images["a2a"], cam_pred["a2a"], heatmap["a2a"] = self.generators["b2a"](batch["a"])
|
||||
images["b2b"], cam_pred["b2b"], heatmap["b2b"] = self.generators["a2b"](batch["b"])
|
||||
return dict(images=images, heatmap=heatmap, cam_pred=cam_pred)
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
for phase in "ab":
|
||||
cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"]
|
||||
loss[f"cycle_{phase}"] = self.cycle_loss(cycle_image, batch[phase])
|
||||
loss[f"id_{phase}"] = self.id_loss(batch[phase], generated["images"][f"{phase}2{phase}"])
|
||||
loss[f"mgc_{phase}"] = self.mgc_loss(batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
|
||||
for dk in "lg":
|
||||
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
|
||||
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)
|
||||
loss[f"gan_{phase}_{dk}"] = self.config.loss.gan.weight * self.gan_loss(pred_fake, True)
|
||||
loss[f"gan_cam_{phase}_{dk}"] = self.mse_loss(cam_pred, True)
|
||||
for t, f in [("a2b", "b2b"), ("b2a", "a2a")]:
|
||||
loss[f"cam_{t[-1]}"] = self.bce_loss(generated["cam_pred"][t], True) + \
|
||||
self.bce_loss(generated["cam_pred"][f], False)
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
for phase in "ab":
|
||||
for level in "gl":
|
||||
generated_image = generated["images"]["b2a" if phase == "a" else "a2b"].detach()
|
||||
pred_fake, cam_fake_pred = self.discriminators[level + phase](generated_image)
|
||||
pred_real, cam_real_pred = self.discriminators[level + phase](batch[phase])
|
||||
loss[f"gan_{phase}_{level}"] = self.gan_loss(pred_real, True, is_discriminator=True) + self.gan_loss(
|
||||
pred_fake, False, is_discriminator=True)
|
||||
loss[f"cam_{phase}_{level}"] = mse_loss(cam_fake_pred, False) + mse_loss(cam_real_pred, True)
|
||||
return loss
|
||||
|
||||
def intermediate_images(self, batch, generated) -> dict:
|
||||
"""
|
||||
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
:param batch:
|
||||
:param generated: dict of images
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
attention_a = attention_colored_map(generated["heatmap"]["a2b"].detach(), batch["a"].size()[-2:])
|
||||
attention_b = attention_colored_map(generated["heatmap"]["b2a"].detach(), batch["b"].size()[-2:])
|
||||
generated = {img: generated["images"][img].detach() for img in generated["images"]}
|
||||
return {
|
||||
"a": [batch["a"], attention_a, generated["a2b"], generated["a2a"], generated["a2b2a"]],
|
||||
"b": [batch["b"], attention_b, generated["b2a"], generated["b2b"], generated["b2a2b"]],
|
||||
}
|
||||
|
||||
|
||||
class UGATITTestEngineKernel(TestEngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
def build_generators(self) -> dict:
|
||||
generators = dict(
|
||||
a2b=build_model(self.config.model.generator),
|
||||
)
|
||||
return generators
|
||||
|
||||
def to_load(self):
|
||||
return {f"generator_{k}": self.generators[k] for k in self.generators}
|
||||
|
||||
def inference(self, batch):
|
||||
with torch.no_grad():
|
||||
fake, _, _ = self.generators["a2b"](batch[0])
|
||||
return fake.detach()
|
||||
|
||||
|
||||
def run(task, config, _):
|
||||
if task == "train":
|
||||
kernel = UGATITEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
elif task == "test":
|
||||
kernel = UGATITTestEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
else:
|
||||
raise NotImplemented
|
||||
320
engine/UGATIT.py
320
engine/UGATIT.py
@@ -1,320 +0,0 @@
|
||||
from itertools import chain
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.metrics import RunningAverage
|
||||
from ignite.utils import convert_tensor
|
||||
from ignite.contrib.handlers.tensorboard_logger import OptimizerParamsHandler
|
||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||
|
||||
from omegaconf import OmegaConf, read_write
|
||||
|
||||
import data
|
||||
from loss.gan import GANLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
from model.GAN.residual_generator import GANImageBuffer
|
||||
from model.GAN.UGATIT import RhoClipper
|
||||
from util.image import make_2d_grid, fuse_attention_map, attention_colored_map
|
||||
from util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||
from util.build import build_model, build_optimizer
|
||||
|
||||
|
||||
def build_lr_schedulers(optimizers, config):
|
||||
g_milestones_values = [
|
||||
(0, config.optimizers.generator.lr),
|
||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
d_milestones_values = [
|
||||
(0, config.optimizers.discriminator.lr),
|
||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
return dict(
|
||||
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
|
||||
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
|
||||
)
|
||||
|
||||
|
||||
def get_trainer(config, logger):
|
||||
generators = dict(
|
||||
a2b=build_model(config.model.generator, config.distributed.model),
|
||||
b2a=build_model(config.model.generator, config.distributed.model),
|
||||
)
|
||||
discriminators = dict(
|
||||
la=build_model(config.model.local_discriminator, config.distributed.model),
|
||||
lb=build_model(config.model.local_discriminator, config.distributed.model),
|
||||
ga=build_model(config.model.global_discriminator, config.distributed.model),
|
||||
gb=build_model(config.model.global_discriminator, config.distributed.model),
|
||||
)
|
||||
for m in chain(generators.values(), discriminators.values()):
|
||||
generation_init_weights(m)
|
||||
|
||||
logger.debug(discriminators["ga"])
|
||||
logger.debug(generators["a2b"])
|
||||
|
||||
optimizers = dict(
|
||||
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
|
||||
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
|
||||
)
|
||||
logger.info(f"build optimizers:\n{optimizers}")
|
||||
|
||||
lr_schedulers = build_lr_schedulers(optimizers, config)
|
||||
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
|
||||
def mse_loss(x, target_flag):
|
||||
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||
|
||||
def bce_loss(x, target_flag):
|
||||
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||
|
||||
image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in discriminators.keys()}
|
||||
rho_clipper = RhoClipper(0, 1)
|
||||
|
||||
def criterion_generator(name, real, fake, rec, identity, cam_g_pred, cam_g_id_pred, discriminator_l,
|
||||
discriminator_g):
|
||||
discriminator_g.requires_grad_(False)
|
||||
discriminator_l.requires_grad_(False)
|
||||
pred_fake_g, cam_gd_pred = discriminator_g(fake)
|
||||
pred_fake_l, cam_ld_pred = discriminator_l(fake)
|
||||
return {
|
||||
f"cycle_{name}": config.loss.cycle.weight * cycle_loss(real, rec),
|
||||
f"id_{name}": config.loss.id.weight * id_loss(real, identity),
|
||||
f"cam_{name}": config.loss.cam.weight * (bce_loss(cam_g_pred, True) + bce_loss(cam_g_id_pred, False)),
|
||||
f"gan_l_{name}": config.loss.gan.weight * gan_loss(pred_fake_l, True),
|
||||
f"gan_g_{name}": config.loss.gan.weight * gan_loss(pred_fake_g, True),
|
||||
f"gan_cam_g_{name}": config.loss.gan.weight * mse_loss(cam_gd_pred, True),
|
||||
f"gan_cam_l_{name}": config.loss.gan.weight * mse_loss(cam_ld_pred, True),
|
||||
}
|
||||
|
||||
def criterion_discriminator(name, discriminator, real, fake):
|
||||
pred_real, cam_real = discriminator(real)
|
||||
pred_fake, cam_fake = discriminator(fake)
|
||||
# TODO: origin do not divide 2, but I think it better to divide 2.
|
||||
loss_gan = gan_loss(pred_real, True, is_discriminator=True) + gan_loss(pred_fake, False, is_discriminator=True)
|
||||
loss_cam = mse_loss(cam_real, True) + mse_loss(cam_fake, False)
|
||||
return {f"gan_{name}": loss_gan, f"cam_{name}": loss_cam}
|
||||
|
||||
def _step(engine, real):
|
||||
real = convert_tensor(real, idist.device())
|
||||
|
||||
fake = dict()
|
||||
cam_generator_pred = dict()
|
||||
rec = dict()
|
||||
identity = dict()
|
||||
cam_identity_pred = dict()
|
||||
heatmap = dict()
|
||||
|
||||
fake["b"], cam_generator_pred["a"], heatmap["a2b"] = generators["a2b"](real["a"])
|
||||
fake["a"], cam_generator_pred["b"], heatmap["b2a"] = generators["b2a"](real["b"])
|
||||
rec["a"], _, heatmap["a2b2a"] = generators["b2a"](fake["b"])
|
||||
rec["b"], _, heatmap["b2a2b"] = generators["a2b"](fake["a"])
|
||||
identity["a"], cam_identity_pred["a"], heatmap["a2a"] = generators["b2a"](real["a"])
|
||||
identity["b"], cam_identity_pred["b"], heatmap["b2b"] = generators["a2b"](real["b"])
|
||||
|
||||
optimizers["g"].zero_grad()
|
||||
loss_g = dict()
|
||||
for n in ["a", "b"]:
|
||||
loss_g.update(criterion_generator(n, real[n], fake[n], rec[n], identity[n], cam_generator_pred[n],
|
||||
cam_identity_pred[n], discriminators["l" + n], discriminators["g" + n]))
|
||||
sum(loss_g.values()).backward()
|
||||
optimizers["g"].step()
|
||||
for generator in generators.values():
|
||||
generator.apply(rho_clipper)
|
||||
for discriminator in discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
optimizers["d"].zero_grad()
|
||||
loss_d = dict()
|
||||
for k in discriminators.keys():
|
||||
n = k[-1] # "a" or "b"
|
||||
loss_d.update(
|
||||
criterion_discriminator(k, discriminators[k], real[n], image_buffers[k].query(fake[n].detach())))
|
||||
sum(loss_d.values()).backward()
|
||||
optimizers["d"].step()
|
||||
|
||||
for h in heatmap:
|
||||
heatmap[h] = heatmap[h].detach()
|
||||
generated_img = {f"real_{k}": real[k].detach() for k in real}
|
||||
generated_img.update({f"fake_{k}": fake[k].detach() for k in fake})
|
||||
generated_img.update({f"id_{k}": identity[k].detach() for k in identity})
|
||||
generated_img.update({f"rec_{k}": rec[k].detach() for k in rec})
|
||||
|
||||
return {
|
||||
"loss": {
|
||||
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
|
||||
"d": {ln: loss_d[ln].mean().item() for ln in loss_d},
|
||||
},
|
||||
"img": {
|
||||
"heatmap": heatmap,
|
||||
"generated": generated_img
|
||||
}
|
||||
}
|
||||
|
||||
trainer = Engine(_step)
|
||||
trainer.logger = logger
|
||||
for lr_shd in lr_schedulers.values():
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
|
||||
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d")
|
||||
|
||||
to_save = dict(trainer=trainer)
|
||||
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
|
||||
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
|
||||
to_save.update({f"generator_{k}": generators[k] for k in generators})
|
||||
to_save.update({f"discriminator_{k}": discriminators[k] for k in discriminators})
|
||||
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=True, set_epoch_for_dist_sampler=True,
|
||||
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
|
||||
def output_transform(output):
|
||||
loss = dict()
|
||||
for tl in output["loss"]:
|
||||
if isinstance(output["loss"][tl], dict):
|
||||
for l in output["loss"][tl]:
|
||||
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
||||
else:
|
||||
loss[tl] = output["loss"][tl]
|
||||
return loss
|
||||
|
||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, output_transform)
|
||||
if tensorboard_handler is not None:
|
||||
tensorboard_handler.attach(
|
||||
trainer,
|
||||
log_handler=OptimizerParamsHandler(optimizers["g"], tag="optimizer_g"),
|
||||
event_name=Events.ITERATION_STARTED(every=config.interval.tensorboard.scalar)
|
||||
)
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=config.interval.tensorboard.image))
|
||||
def show_images(engine):
|
||||
output = engine.state.output
|
||||
image_order = dict(
|
||||
a=["real_a", "fake_b", "rec_a", "id_a"],
|
||||
b=["real_b", "fake_a", "rec_b", "id_b"]
|
||||
)
|
||||
output["img"]["generated"]["real_a"] = fuse_attention_map(
|
||||
output["img"]["generated"]["real_a"], output["img"]["heatmap"]["a2b"])
|
||||
output["img"]["generated"]["real_b"] = fuse_attention_map(
|
||||
output["img"]["generated"]["real_b"], output["img"]["heatmap"]["b2a"])
|
||||
for k in "ab":
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"train/{k}",
|
||||
make_2d_grid([output["img"]["generated"][o] for o in image_order[k]]),
|
||||
engine.state.iteration
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(config.misc.random_seed)
|
||||
random_start = torch.randperm(len(engine.state.test_dataset)-11, generator=g).tolist()[0]
|
||||
test_images = dict(
|
||||
a=[[], [], [], []],
|
||||
b=[[], [], [], []]
|
||||
)
|
||||
for i in range(random_start, random_start+10):
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
|
||||
real_a, real_b = batch["a"].view(1, *batch["a"].size()), batch["b"].view(1, *batch["a"].size())
|
||||
fake_b, _, heatmap_a2b = generators["a2b"](real_a)
|
||||
fake_a, _, heatmap_b2a = generators["b2a"](real_b)
|
||||
rec_a = generators["b2a"](fake_b)[0]
|
||||
rec_b = generators["a2b"](fake_a)[0]
|
||||
|
||||
for idx, im in enumerate(
|
||||
[attention_colored_map(heatmap_a2b, real_a.size()[-2:]), real_a, fake_b, rec_a]):
|
||||
test_images["a"][idx].append(im.cpu())
|
||||
for idx, im in enumerate(
|
||||
[attention_colored_map(heatmap_b2a, real_b.size()[-2:]), real_b, fake_a, rec_b]):
|
||||
test_images["b"][idx].append(im.cpu())
|
||||
for n in "ab":
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"test/{n}",
|
||||
make_2d_grid([torch.cat(ti) for ti in test_images[n]]),
|
||||
engine.state.iteration
|
||||
)
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
def get_tester(config, logger):
|
||||
generator_a2b = build_model(config.model.generator, config.distributed.model)
|
||||
|
||||
def _step(engine, batch):
|
||||
real_a, path = convert_tensor(batch, idist.device())
|
||||
with torch.no_grad():
|
||||
fake_b = generator_a2b(real_a)[0]
|
||||
return {"path": path, "img": [real_a.detach(), fake_b.detach()]}
|
||||
|
||||
tester = Engine(_step)
|
||||
tester.logger = logger
|
||||
|
||||
to_load = dict(generator_a2b=generator_a2b)
|
||||
setup_common_handlers(tester, config, use_profiler=False, to_save=to_load)
|
||||
|
||||
@tester.on(Events.STARTED)
|
||||
def mkdir(engine):
|
||||
img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images"
|
||||
engine.state.img_output_dir = Path(img_output_dir)
|
||||
if idist.get_rank() == 0 and not engine.state.img_output_dir.exists():
|
||||
engine.logger.info(f"mkdir {engine.state.img_output_dir}")
|
||||
engine.state.img_output_dir.mkdir()
|
||||
|
||||
@tester.on(Events.ITERATION_COMPLETED)
|
||||
def save_images(engine):
|
||||
img_tensors = engine.state.output["img"]
|
||||
paths = engine.state.output["path"]
|
||||
batch_size = img_tensors[0].size(0)
|
||||
for i in range(batch_size):
|
||||
image_name = Path(paths[i]).name
|
||||
torchvision.utils.save_image([img[i] for img in img_tensors], engine.state.img_output_dir / image_name,
|
||||
nrow=len(img_tensors))
|
||||
|
||||
return tester
|
||||
|
||||
|
||||
def run(task, config, logger):
|
||||
assert torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.benchmark = True
|
||||
logger.info(f"start task {task}")
|
||||
with read_write(config):
|
||||
config.max_iteration = ceil(config.max_pairs / config.data.train.dataloader.batch_size)
|
||||
|
||||
if task == "train":
|
||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||
trainer = get_trainer(config, logger)
|
||||
if idist.get_rank() == 0:
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
trainer.state.test_dataset = test_dataset
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
elif task == "test":
|
||||
assert config.resume_from is not None
|
||||
test_dataset = data.DATASET.build_with(config.data.test.video_dataset)
|
||||
logger.info(f"test with dataset:\n{test_dataset}")
|
||||
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
||||
tester = get_tester(config, logger)
|
||||
try:
|
||||
tester.run(test_data_loader, max_epochs=1)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
else:
|
||||
return NotImplemented(f"invalid task: {task}")
|
||||
311
engine/base/i2i.py
Normal file
311
engine/base/i2i.py
Normal file
@@ -0,0 +1,311 @@
|
||||
import logging
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torchvision
|
||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.metrics import RunningAverage
|
||||
from ignite.utils import convert_tensor
|
||||
from math import ceil
|
||||
from omegaconf import read_write, OmegaConf
|
||||
|
||||
import data
|
||||
from engine.util.build import build_optimizer
|
||||
from engine.util.handler import setup_common_handlers, setup_tensorboard_handler
|
||||
from util.image import make_2d_grid
|
||||
|
||||
|
||||
def build_lr_schedulers(optimizers, config):
|
||||
# TODO: support more scheduler type
|
||||
g_milestones_values = [
|
||||
(0, config.optimizers.generator.lr),
|
||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.generator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
d_milestones_values = [
|
||||
(0, config.optimizers.discriminator.lr),
|
||||
(int(config.data.train.scheduler.start_proportion * config.max_iteration), config.optimizers.discriminator.lr),
|
||||
(config.max_iteration, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
return dict(
|
||||
g=PiecewiseLinear(optimizers["g"], param_name="lr", milestones_values=g_milestones_values),
|
||||
d=PiecewiseLinear(optimizers["d"], param_name="lr", milestones_values=d_milestones_values)
|
||||
)
|
||||
|
||||
|
||||
class TestEngineKernel(object):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(config.name)
|
||||
self.generators = self.build_generators()
|
||||
|
||||
def build_generators(self) -> dict:
|
||||
raise NotImplemented
|
||||
|
||||
def to_load(self):
|
||||
return {f"generator_{k}": self.generators[k] for k in self.generators}
|
||||
|
||||
def inference(self, batch):
|
||||
raise NotImplemented
|
||||
|
||||
|
||||
class EngineKernel(object):
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(config.name)
|
||||
self.generators, self.discriminators = self.build_models()
|
||||
self.train_generator_first = True
|
||||
self.engine = None
|
||||
|
||||
def bind_engine(self, engine):
|
||||
self.engine = engine
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
raise NotImplementedError
|
||||
|
||||
def to_save(self):
|
||||
to_save = {}
|
||||
to_save.update({f"generator_{k}": self.generators[k] for k in self.generators})
|
||||
to_save.update({f"discriminator_{k}": self.discriminators[k] for k in self.discriminators})
|
||||
return to_save
|
||||
|
||||
def setup_after_g(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def setup_before_g(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
def intermediate_images(self, batch, generated) -> dict:
|
||||
"""
|
||||
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
:param batch:
|
||||
:param generated: dict of images
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def change_engine(self, config, engine: Engine):
|
||||
pass
|
||||
|
||||
|
||||
def _remove_no_grad_loss(loss_dict):
|
||||
need_to_pop = []
|
||||
for k in loss_dict:
|
||||
if not isinstance(loss_dict[k], torch.Tensor):
|
||||
need_to_pop.append(k)
|
||||
for k in need_to_pop:
|
||||
loss_dict.pop(k)
|
||||
return loss_dict
|
||||
|
||||
|
||||
def get_trainer(config, kernel: EngineKernel):
|
||||
logger = logging.getLogger(config.name)
|
||||
generators, discriminators = kernel.generators, kernel.discriminators
|
||||
optimizers = dict(
|
||||
g=build_optimizer(chain(*[m.parameters() for m in generators.values()]), config.optimizers.generator),
|
||||
d=build_optimizer(chain(*[m.parameters() for m in discriminators.values()]), config.optimizers.discriminator),
|
||||
)
|
||||
logger.info(f"build optimizers:\n{optimizers}")
|
||||
|
||||
lr_schedulers = build_lr_schedulers(optimizers, config)
|
||||
logger.info(f"build lr_schedulers:\n{lr_schedulers}")
|
||||
|
||||
iteration_per_image = max(config.iterations_per_epoch // config.handler.tensorboard.image, 1)
|
||||
|
||||
def train_generators(batch, generated):
|
||||
kernel.setup_before_g()
|
||||
optimizers["g"].zero_grad()
|
||||
loss_g = kernel.criterion_generators(batch, generated)
|
||||
sum(loss_g.values()).backward()
|
||||
optimizers["g"].step()
|
||||
kernel.setup_after_g()
|
||||
return loss_g
|
||||
|
||||
def train_discriminators(batch, generated):
|
||||
optimizers["d"].zero_grad()
|
||||
loss_d = kernel.criterion_discriminators(batch, generated)
|
||||
sum(loss_d.values()).backward()
|
||||
optimizers["d"].step()
|
||||
return loss_d
|
||||
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
|
||||
generated = kernel.forward(batch)
|
||||
|
||||
if kernel.train_generator_first:
|
||||
# simultaneous, train G with simultaneous D
|
||||
loss_g = train_generators(batch, generated)
|
||||
loss_d = train_discriminators(batch, generated)
|
||||
else:
|
||||
# update discriminators first, not simultaneous.
|
||||
# train G with updated discriminators
|
||||
loss_d = train_discriminators(batch, generated)
|
||||
loss_g = train_generators(batch, generated)
|
||||
|
||||
if engine.state.iteration % iteration_per_image == 0:
|
||||
return {
|
||||
"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d)),
|
||||
"img": kernel.intermediate_images(batch, generated)
|
||||
}
|
||||
return {"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d))}
|
||||
|
||||
trainer = Engine(_step)
|
||||
trainer.logger = logger
|
||||
for lr_shd in lr_schedulers.values():
|
||||
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_shd)
|
||||
|
||||
kernel.change_engine(config, trainer)
|
||||
kernel.bind_engine(trainer)
|
||||
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values()), epoch_bound=False).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values()), epoch_bound=False).attach(trainer, "loss_d")
|
||||
to_save = dict(trainer=trainer)
|
||||
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
|
||||
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
|
||||
to_save.update(kernel.to_save())
|
||||
setup_common_handlers(trainer, config, to_save=to_save, clear_cuda_cache=config.handler.clear_cuda_cache,
|
||||
set_epoch_for_dist_sampler=config.handler.set_epoch_for_dist_sampler,
|
||||
end_event=Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
|
||||
tensorboard_handler = setup_tensorboard_handler(trainer, config, optimizers, step_type="item")
|
||||
if tensorboard_handler is not None:
|
||||
basic_image_event = Events.ITERATION_COMPLETED(
|
||||
every=iteration_per_image)
|
||||
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
|
||||
|
||||
@trainer.on(basic_image_event)
|
||||
def show_images(engine):
|
||||
output = engine.state.output
|
||||
test_images = {}
|
||||
|
||||
for k in output["img"]:
|
||||
image_list = output["img"][k]
|
||||
tensorboard_handler.writer.add_image(f"train/{k}", make_2d_grid(image_list, range=(-1, 1)),
|
||||
engine.state.iteration * pairs_per_iteration)
|
||||
test_images[k] = []
|
||||
for i in range(len(image_list)):
|
||||
test_images[k].append([])
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(config.misc.random_seed + engine.state.epoch
|
||||
if config.handler.test.random else config.misc.random_seed)
|
||||
random_start = \
|
||||
torch.randperm(len(engine.state.test_dataset) - config.handler.test.images, generator=g).tolist()[0]
|
||||
for i in range(random_start, random_start + config.handler.test.images):
|
||||
batch = convert_tensor(engine.state.test_dataset[i], idist.device())
|
||||
for k in batch:
|
||||
if isinstance(batch[k], torch.Tensor):
|
||||
batch[k] = batch[k].unsqueeze(0)
|
||||
elif isinstance(batch[k], dict):
|
||||
for kk in batch[k]:
|
||||
if isinstance(batch[k][kk], torch.Tensor):
|
||||
batch[k][kk] = batch[k][kk].unsqueeze(0)
|
||||
|
||||
generated = kernel.forward(batch, inference=True)
|
||||
images = kernel.intermediate_images(batch, generated)
|
||||
|
||||
for k in test_images:
|
||||
for j in range(len(images[k])):
|
||||
test_images[k][j].append(images[k][j])
|
||||
for k in test_images:
|
||||
tensorboard_handler.writer.add_image(
|
||||
f"test/{k}",
|
||||
make_2d_grid([torch.cat(ti) for ti in test_images[k]], range=(-1, 1)),
|
||||
engine.state.iteration * pairs_per_iteration
|
||||
)
|
||||
return trainer
|
||||
|
||||
|
||||
def save_images_helper(output_dir, paths, images_list):
|
||||
batch_size = len(paths)
|
||||
for i in range(batch_size):
|
||||
image_name = Path(paths[i]).name
|
||||
img_list = [img[i] for img in images_list]
|
||||
torchvision.utils.save_image(img_list, Path(output_dir) / image_name, nrow=len(img_list), padding=0,
|
||||
normalize=True, range=(-1, 1))
|
||||
|
||||
|
||||
def get_tester(config, kernel: TestEngineKernel):
|
||||
logger = logging.getLogger(config.name)
|
||||
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
return {"batch": batch, "generated": kernel.inference(batch)}
|
||||
|
||||
tester = Engine(_step)
|
||||
tester.logger = logger
|
||||
|
||||
setup_common_handlers(tester, config, use_profiler=True, to_save=kernel.to_load())
|
||||
|
||||
@tester.on(Events.STARTED)
|
||||
def mkdir(engine):
|
||||
img_output_dir = config.img_output_dir or f"{config.output_dir}/test_images"
|
||||
engine.state.img_output_dir = Path(img_output_dir)
|
||||
if idist.get_rank() == 0 and not engine.state.img_output_dir.exists():
|
||||
engine.logger.info(f"mkdir {engine.state.img_output_dir}")
|
||||
engine.state.img_output_dir.mkdir()
|
||||
|
||||
@tester.on(Events.ITERATION_COMPLETED)
|
||||
def save_images(engine):
|
||||
if engine.state.dataloader.dataset.__class__.__name__ == "SingleFolderDataset":
|
||||
images, paths = engine.state.output["batch"]
|
||||
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"]])
|
||||
|
||||
else:
|
||||
for k in engine.state.output['generated']:
|
||||
images, paths = engine.state.output["batch"][k]
|
||||
save_images_helper(config.img_output_dir, paths, [images, engine.state.output["generated"][k]])
|
||||
|
||||
return tester
|
||||
|
||||
|
||||
def run_kernel(task, config, kernel):
|
||||
logger = logging.getLogger(config.name)
|
||||
with read_write(config):
|
||||
real_batch_size = config.data.train.dataloader.batch_size * idist.get_world_size()
|
||||
config.max_iteration = ceil(config.max_pairs / real_batch_size)
|
||||
|
||||
if task == "train":
|
||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
dataloader_kwargs = OmegaConf.to_container(config.data.train.dataloader)
|
||||
dataloader_kwargs["batch_size"] = dataloader_kwargs["batch_size"] * idist.get_world_size()
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **dataloader_kwargs)
|
||||
with read_write(config):
|
||||
config.iterations_per_epoch = len(train_data_loader)
|
||||
|
||||
trainer = get_trainer(config, kernel)
|
||||
if idist.get_rank() == 0:
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
trainer.state.test_dataset = test_dataset
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=ceil(config.max_iteration / len(train_data_loader)))
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
elif task == "test":
|
||||
assert config.resume_from is not None
|
||||
test_dataset = data.DATASET.build_with(config.data.test[config.data.test.which])
|
||||
logger.info(f"test with dataset:\n{test_dataset}")
|
||||
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
||||
tester = get_tester(config, kernel)
|
||||
try:
|
||||
tester.run(test_data_loader, max_epochs=1)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
else:
|
||||
return NotImplemented(f"invalid task: {task}")
|
||||
@@ -1,85 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.datasets import ImageFolder
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.contrib.metrics.gpu_info import GpuInfo
|
||||
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, global_step_from_engine, OutputHandler, \
|
||||
WeightsScalarHandler, GradsHistHandler, WeightsHistHandler, GradsScalarHandler
|
||||
from ignite.engine import create_supervised_evaluator, create_supervised_trainer, Events
|
||||
from ignite.metrics import Accuracy, Loss, RunningAverage
|
||||
from ignite.contrib.engines.common import save_best_model_by_val_score
|
||||
from ignite.contrib.handlers import ProgressBar
|
||||
|
||||
from util.build import build_model, build_optimizer
|
||||
from util.handler import setup_common_handlers
|
||||
from data.transform import transform_pipeline
|
||||
from data.dataset import LMDBDataset
|
||||
|
||||
|
||||
def warmup_trainer(config, logger):
|
||||
model = build_model(config.model, config.distributed.model)
|
||||
optimizer = build_optimizer(model.parameters(), config.baseline.optimizers)
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
trainer = create_supervised_trainer(model, optimizer, loss_fn, idist.device(), non_blocking=True,
|
||||
output_transform=lambda x, y, y_pred, loss: (loss.item(), y_pred, y))
|
||||
trainer.logger = logger
|
||||
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss")
|
||||
Accuracy(output_transform=lambda x: (x[1], x[2])).attach(trainer, "acc")
|
||||
ProgressBar(ncols=0).attach(trainer)
|
||||
|
||||
if idist.get_rank() == 0:
|
||||
GpuInfo().attach(trainer, name='gpu')
|
||||
|
||||
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OutputHandler(
|
||||
tag="train",
|
||||
metric_names='all',
|
||||
global_step_transform=global_step_from_engine(trainer),
|
||||
),
|
||||
event_name=Events.EPOCH_COMPLETED
|
||||
)
|
||||
tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model),
|
||||
event_name=Events.EPOCH_COMPLETED(every=10))
|
||||
|
||||
tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25))
|
||||
|
||||
tb_logger.attach(trainer, log_handler=GradsScalarHandler(model),
|
||||
event_name=Events.EPOCH_COMPLETED(every=10))
|
||||
|
||||
tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=25))
|
||||
|
||||
@trainer.on(Events.COMPLETED)
|
||||
def _():
|
||||
tb_logger.close()
|
||||
|
||||
to_save = dict(model=model, optimizer=optimizer, trainer=trainer)
|
||||
setup_common_handlers(trainer, config.output_dir, print_interval_event=Events.EPOCH_COMPLETED, to_save=to_save,
|
||||
save_interval_event=Events.EPOCH_COMPLETED(every=25), n_saved=5,
|
||||
metrics_to_print=["loss", "acc"])
|
||||
return trainer
|
||||
|
||||
|
||||
def run(task, config, logger):
|
||||
assert torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.benchmark = True
|
||||
logger.info(f"start task {task}")
|
||||
if task == "warmup":
|
||||
train_dataset = LMDBDataset(config.baseline.data.dataset.train.lmdb_path,
|
||||
pipeline=config.baseline.data.dataset.train.pipeline)
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.baseline.data.dataloader)
|
||||
trainer = warmup_trainer(config, logger)
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=400)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
elif task == "protonet-wo":
|
||||
pass
|
||||
elif task == "protonet-w":
|
||||
pass
|
||||
else:
|
||||
return ValueError(f"invalid task: {task}")
|
||||
@@ -1,268 +0,0 @@
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.utils
|
||||
|
||||
import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.contrib.handlers.param_scheduler import PiecewiseLinear
|
||||
from ignite.metrics import RunningAverage
|
||||
from ignite.contrib.handlers import ProgressBar
|
||||
from ignite.utils import convert_tensor
|
||||
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OptimizerParamsHandler, OutputHandler
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import data
|
||||
from loss.gan import GANLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
from model.GAN.residual_generator import GANImageBuffer
|
||||
from util.image import make_2d_grid
|
||||
from util.handler import setup_common_handlers
|
||||
from util.build import build_model, build_optimizer
|
||||
|
||||
|
||||
def get_trainer(config, logger):
|
||||
generator_a = build_model(config.model.generator, config.distributed.model)
|
||||
generator_b = build_model(config.model.generator, config.distributed.model)
|
||||
discriminator_a = build_model(config.model.discriminator, config.distributed.model)
|
||||
discriminator_b = build_model(config.model.discriminator, config.distributed.model)
|
||||
for m in [generator_b, generator_a, discriminator_b, discriminator_a]:
|
||||
generation_init_weights(m)
|
||||
logger.info(discriminator_a)
|
||||
logger.info(generator_a)
|
||||
|
||||
optimizer_g = build_optimizer(itertools.chain(generator_b.parameters(), generator_a.parameters()),
|
||||
config.optimizers.generator)
|
||||
optimizer_d = build_optimizer(itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()),
|
||||
config.optimizers.discriminator)
|
||||
|
||||
milestones_values = [
|
||||
(0, config.optimizers.generator.lr),
|
||||
(100, config.optimizers.generator.lr),
|
||||
(200, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
lr_scheduler_g = PiecewiseLinear(optimizer_g, param_name="lr", milestones_values=milestones_values)
|
||||
|
||||
milestones_values = [
|
||||
(0, config.optimizers.discriminator.lr),
|
||||
(100, config.optimizers.discriminator.lr),
|
||||
(200, config.data.train.scheduler.target_lr)
|
||||
]
|
||||
lr_scheduler_d = PiecewiseLinear(optimizer_d, param_name="lr", milestones_values=milestones_values)
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
id_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
|
||||
image_buffer_a = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
|
||||
image_buffer_b = GANImageBuffer(config.data.train.buffer_size if config.data.train.buffer_size is not None else 50)
|
||||
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
real_a, real_b = batch["a"], batch["b"]
|
||||
|
||||
fake_b = generator_a(real_a) # G_A(A)
|
||||
rec_a = generator_b(fake_b) # G_B(G_A(A))
|
||||
fake_a = generator_b(real_b) # G_B(B)
|
||||
rec_b = generator_a(fake_a) # G_A(G_B(B))
|
||||
|
||||
optimizer_g.zero_grad()
|
||||
discriminator_a.requires_grad_(False)
|
||||
discriminator_b.requires_grad_(False)
|
||||
loss_g = dict(
|
||||
cycle_a=config.loss.cycle.weight * cycle_loss(rec_a, real_a),
|
||||
cycle_b=config.loss.cycle.weight * cycle_loss(rec_b, real_b),
|
||||
gan_a=config.loss.gan.weight * gan_loss(discriminator_a(fake_b), True),
|
||||
gan_b=config.loss.gan.weight * gan_loss(discriminator_b(fake_a), True)
|
||||
)
|
||||
if config.loss.id.weight > 0:
|
||||
loss_g["id_a"] = config.loss.id.weight * id_loss(generator_a(real_b), real_b), # G_A(B)
|
||||
loss_g["id_b"] = config.loss.id.weight * id_loss(generator_b(real_a), real_a), # G_B(A)
|
||||
sum(loss_g.values()).backward()
|
||||
optimizer_g.step()
|
||||
|
||||
discriminator_a.requires_grad_(True)
|
||||
discriminator_b.requires_grad_(True)
|
||||
optimizer_d.zero_grad()
|
||||
loss_d_a = dict(
|
||||
real=gan_loss(discriminator_a(real_b), True, is_discriminator=True),
|
||||
fake=gan_loss(discriminator_a(image_buffer_a.query(fake_b.detach())), False, is_discriminator=True),
|
||||
)
|
||||
loss_d_b = dict(
|
||||
real=gan_loss(discriminator_b(real_a), True, is_discriminator=True),
|
||||
fake=gan_loss(discriminator_b(image_buffer_b.query(fake_a.detach())), False, is_discriminator=True),
|
||||
)
|
||||
(sum(loss_d_a.values()) * 0.5).backward()
|
||||
(sum(loss_d_b.values()) * 0.5).backward()
|
||||
optimizer_d.step()
|
||||
|
||||
return {
|
||||
"loss": {
|
||||
"g": {ln: loss_g[ln].mean().item() for ln in loss_g},
|
||||
"d_a": {ln: loss_d_a[ln].mean().item() for ln in loss_d_a},
|
||||
"d_b": {ln: loss_d_b[ln].mean().item() for ln in loss_d_b},
|
||||
},
|
||||
"img": [
|
||||
real_a.detach(),
|
||||
fake_b.detach(),
|
||||
rec_a.detach(),
|
||||
real_b.detach(),
|
||||
fake_a.detach(),
|
||||
rec_b.detach()
|
||||
]
|
||||
}
|
||||
|
||||
trainer = Engine(_step)
|
||||
trainer.logger = logger
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_g)
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler_d)
|
||||
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_a"].values())).attach(trainer, "loss_d_a")
|
||||
RunningAverage(output_transform=lambda x: sum(x["loss"]["d_b"].values())).attach(trainer, "loss_d_b")
|
||||
|
||||
to_save = dict(
|
||||
generator_a=generator_a, generator_b=generator_b, discriminator_a=discriminator_a,
|
||||
discriminator_b=discriminator_b, optimizer_d=optimizer_d, optimizer_g=optimizer_g, trainer=trainer,
|
||||
lr_scheduler_d=lr_scheduler_d, lr_scheduler_g=lr_scheduler_g
|
||||
)
|
||||
|
||||
setup_common_handlers(trainer, config.output_dir, resume_from=config.resume_from, n_saved=5,
|
||||
filename_prefix=config.name, to_save=to_save,
|
||||
print_interval_event=Events.ITERATION_COMPLETED(every=10) | Events.COMPLETED,
|
||||
metrics_to_print=["loss_g", "loss_d_a", "loss_d_b"],
|
||||
save_interval_event=Events.ITERATION_COMPLETED(
|
||||
every=config.checkpoints.interval) | Events.COMPLETED)
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(once=config.max_iteration))
|
||||
def terminate(engine):
|
||||
engine.terminate()
|
||||
|
||||
if idist.get_rank() == 0:
|
||||
# Create a logger
|
||||
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
||||
tb_writer = tb_logger.writer
|
||||
|
||||
# Attach the logger to the trainer to log training loss at each iteration
|
||||
def global_step_transform(*args, **kwargs):
|
||||
return trainer.state.iteration
|
||||
|
||||
def output_transform(output):
|
||||
loss = dict()
|
||||
for tl in output["loss"]:
|
||||
if isinstance(output["loss"][tl], dict):
|
||||
for l in output["loss"][tl]:
|
||||
loss[f"{tl}_{l}"] = output["loss"][tl][l]
|
||||
else:
|
||||
loss[tl] = output["loss"][tl]
|
||||
return loss
|
||||
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OutputHandler(
|
||||
tag="loss",
|
||||
metric_names=["loss_g", "loss_d_a", "loss_d_b"],
|
||||
global_step_transform=global_step_transform,
|
||||
output_transform=output_transform
|
||||
),
|
||||
event_name=Events.ITERATION_COMPLETED(every=50)
|
||||
)
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OptimizerParamsHandler(optimizer_g, tag="optimizer_g"),
|
||||
event_name=Events.ITERATION_STARTED(every=50)
|
||||
)
|
||||
|
||||
@trainer.on(Events.ITERATION_COMPLETED(every=config.checkpoints.interval))
|
||||
def show_images(engine):
|
||||
tb_writer.add_image("train/img", make_2d_grid(engine.state.output["img"]), engine.state.iteration)
|
||||
|
||||
@trainer.on(Events.COMPLETED)
|
||||
@idist.one_rank_only()
|
||||
def _():
|
||||
# We need to close the logger with we are done
|
||||
tb_logger.close()
|
||||
|
||||
return trainer
|
||||
|
||||
|
||||
def get_tester(config, logger):
|
||||
generator_a = build_model(config.model.generator, config.distributed.model)
|
||||
generator_b = build_model(config.model.generator, config.distributed.model)
|
||||
|
||||
def _step(engine, batch):
|
||||
batch = convert_tensor(batch, idist.device())
|
||||
real_a, real_b = batch["a"], batch["b"]
|
||||
with torch.no_grad():
|
||||
fake_b = generator_a(real_a) # G_A(A)
|
||||
rec_a = generator_b(fake_b) # G_B(G_A(A))
|
||||
fake_a = generator_b(real_b) # G_B(B)
|
||||
rec_b = generator_a(fake_a) # G_A(G_B(B))
|
||||
return [
|
||||
real_a.detach(),
|
||||
fake_b.detach(),
|
||||
rec_a.detach(),
|
||||
real_b.detach(),
|
||||
fake_a.detach(),
|
||||
rec_b.detach()
|
||||
]
|
||||
|
||||
tester = Engine(_step)
|
||||
tester.logger = logger
|
||||
if idist.get_rank == 0:
|
||||
ProgressBar(ncols=0).attach(tester)
|
||||
to_load = dict(generator_a=generator_a, generator_b=generator_b)
|
||||
setup_common_handlers(tester, use_profiler=False, to_save=to_load, resume_from=config.resume_from)
|
||||
|
||||
@tester.on(Events.STARTED)
|
||||
@idist.one_rank_only()
|
||||
def mkdir(engine):
|
||||
img_output_dir = Path(config.output_dir) / "test_images"
|
||||
if not img_output_dir.exists():
|
||||
engine.logger.info(f"mkdir {img_output_dir}")
|
||||
img_output_dir.mkdir()
|
||||
|
||||
@tester.on(Events.ITERATION_COMPLETED)
|
||||
def save_images(engine):
|
||||
img_tensors = engine.state.output
|
||||
batch_size = img_tensors[0].size(0)
|
||||
for i in range(batch_size):
|
||||
torchvision.utils.save_image([img[i] for img in img_tensors],
|
||||
Path(config.output_dir) / f"test_images/{engine.state.iteration}_{i}.jpg",
|
||||
nrow=len(img_tensors))
|
||||
|
||||
return tester
|
||||
|
||||
|
||||
def run(task, config, logger):
|
||||
assert torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.benchmark = True
|
||||
logger.info(f"start task {task}")
|
||||
if task == "train":
|
||||
train_dataset = data.DATASET.build_with(config.data.train.dataset)
|
||||
logger.info(f"train with dataset:\n{train_dataset}")
|
||||
train_data_loader = idist.auto_dataloader(train_dataset, **config.data.train.dataloader)
|
||||
trainer = get_trainer(config, logger)
|
||||
try:
|
||||
trainer.run(train_data_loader, max_epochs=config.max_iteration // len(train_data_loader) + 1)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
elif task == "test":
|
||||
assert config.resume_from is not None
|
||||
test_dataset = data.DATASET.build_with(config.data.test.dataset)
|
||||
logger.info(f"test with dataset:\n{test_dataset}")
|
||||
test_data_loader = idist.auto_dataloader(test_dataset, **config.data.test.dataloader)
|
||||
tester = get_tester(config, logger)
|
||||
try:
|
||||
tester.run(test_data_loader, max_epochs=1)
|
||||
except Exception:
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
else:
|
||||
return NotImplemented(f"invalid task: {task}")
|
||||
@@ -1,9 +0,0 @@
|
||||
from data.dataset import EpisodicDataset, LMDBDataset
|
||||
|
||||
|
||||
def prototypical_trainer(config, logger):
|
||||
pass
|
||||
|
||||
|
||||
def prototypical_dataloader(config):
|
||||
pass
|
||||
153
engine/talking_anime.py
Normal file
153
engine/talking_anime.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from itertools import chain
|
||||
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from engine.base.i2i import EngineKernel, run_kernel
|
||||
from engine.util.build import build_model
|
||||
from loss.I2I.context_loss import ContextLoss
|
||||
from loss.I2I.edge_loss import EdgeLoss
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
from loss.gan import GANLoss
|
||||
from model.weight_init import generation_init_weights
|
||||
|
||||
|
||||
class TAEngineKernel(EngineKernel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
|
||||
perceptual_loss_cfg.pop("weight")
|
||||
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||
|
||||
style_loss_cfg = OmegaConf.to_container(config.loss.style)
|
||||
style_loss_cfg.pop("weight")
|
||||
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
|
||||
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
|
||||
context_loss_cfg = OmegaConf.to_container(config.loss.context)
|
||||
context_loss_cfg.pop("weight")
|
||||
self.context_loss = ContextLoss(**context_loss_cfg).to(idist.device())
|
||||
|
||||
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
|
||||
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
|
||||
|
||||
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
|
||||
idist.device())
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
generators = dict(
|
||||
anime=build_model(self.config.model.anime_generator),
|
||||
face=build_model(self.config.model.face_generator)
|
||||
)
|
||||
discriminators = dict(
|
||||
anime=build_model(self.config.model.discriminator),
|
||||
face=build_model(self.config.model.discriminator)
|
||||
)
|
||||
self.logger.debug(discriminators["face"])
|
||||
self.logger.debug(generators["face"])
|
||||
|
||||
for m in chain(generators.values(), discriminators.values()):
|
||||
generation_init_weights(m)
|
||||
|
||||
return generators, discriminators
|
||||
|
||||
def setup_after_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(True)
|
||||
|
||||
def setup_before_g(self):
|
||||
for discriminator in self.discriminators.values():
|
||||
discriminator.requires_grad_(False)
|
||||
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
with torch.set_grad_enabled(not inference):
|
||||
target_pose_anime = self.generators["anime"](
|
||||
torch.cat([batch["face_1"], torch.flip(batch["anime_img"], dims=[3])], dim=1))
|
||||
target_pose_face = self.generators["face"](target_pose_anime.mean(dim=1, keepdim=True), batch["face_0"])
|
||||
|
||||
return dict(fake_anime=target_pose_anime, fake_face=target_pose_face)
|
||||
|
||||
def cal_gan_and_fm_loss(self, discriminator, generated_img, match_img=None):
|
||||
pred_fake = discriminator(generated_img)
|
||||
loss_gan = 0
|
||||
for sub_pred_fake in pred_fake:
|
||||
# last output is actual prediction
|
||||
loss_gan += self.config.loss.gan.weight * self.gan_loss(sub_pred_fake[-1], True)
|
||||
|
||||
if match_img is None:
|
||||
# do not cal feature match loss
|
||||
return loss_gan, 0
|
||||
|
||||
pred_real = discriminator(match_img)
|
||||
loss_fm = 0
|
||||
num_scale_discriminator = len(pred_fake)
|
||||
for i in range(num_scale_discriminator):
|
||||
# last output is the final prediction, so we exclude it
|
||||
num_intermediate_outputs = len(pred_fake[i]) - 1
|
||||
for j in range(num_intermediate_outputs):
|
||||
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator
|
||||
loss_fm = self.config.loss.fm.weight * loss_fm
|
||||
return loss_gan, loss_fm
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
loss["face_style"] = self.config.loss.style.weight * self.style_loss(
|
||||
generated["fake_face"], batch["face_1"]
|
||||
)
|
||||
loss["face_recon"] = self.config.loss.recon.weight * self.recon_loss(
|
||||
generated["fake_face"], batch["face_1"]
|
||||
)
|
||||
loss["face_gan"], loss["face_fm"] = self.cal_gan_and_fm_loss(
|
||||
self.discriminators["face"], generated["fake_face"], batch["face_1"])
|
||||
loss["anime_gan"], loss["anime_fm"] = self.cal_gan_and_fm_loss(
|
||||
self.discriminators["anime"], generated["fake_anime"], batch["anime_img"])
|
||||
|
||||
loss["anime_edge"] = self.config.loss.edge.weight * self.edge_loss(
|
||||
generated["fake_anime"], batch["face_1"], gt_is_edge=False,
|
||||
)
|
||||
if self.config.loss.perceptual.weight > 0:
|
||||
loss["anime_perceptual"] = self.config.loss.perceptual.weight * self.perceptual_loss(
|
||||
generated["fake_anime"], batch["anime_img"]
|
||||
)
|
||||
if self.config.loss.context.weight > 0:
|
||||
loss["anime_context"] = self.config.loss.context.weight * self.context_loss(
|
||||
generated["fake_anime"], batch["anime_img"],
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
loss = dict()
|
||||
real = {"anime": "anime_img", "face": "face_1"}
|
||||
for phase in self.discriminators.keys():
|
||||
pred_real = self.discriminators[phase](batch[real[phase]])
|
||||
pred_fake = self.discriminators[phase](generated[f"fake_{phase}"].detach())
|
||||
loss[f"gan_{phase}"] = 0
|
||||
for i in range(len(pred_fake)):
|
||||
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
|
||||
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
|
||||
return loss
|
||||
|
||||
def intermediate_images(self, batch, generated) -> dict:
|
||||
"""
|
||||
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
:param batch:
|
||||
:param generated: dict of images
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
images = [batch["face_0"], batch["face_1"], batch["anime_img"], generated["fake_anime"].detach(),
|
||||
generated["fake_face"].detach()]
|
||||
return dict(
|
||||
b=[img for img in images]
|
||||
)
|
||||
|
||||
|
||||
def run(task, config, _):
|
||||
kernel = TAEngineKernel(config)
|
||||
run_kernel(task, config, kernel)
|
||||
0
engine/util/__init__.py
Normal file
0
engine/util/__init__.py
Normal file
33
engine/util/build.py
Normal file
33
engine/util/build.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model import MODEL
|
||||
|
||||
|
||||
def add_spectral_norm(module):
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
|
||||
return nn.utils.spectral_norm(module)
|
||||
else:
|
||||
return module
|
||||
|
||||
|
||||
def build_model(cfg):
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
|
||||
add_spectral_norm_flag = cfg.pop("_add_spectral_norm", False)
|
||||
model = MODEL.build_with(cfg)
|
||||
if bn_to_sync_bn:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
if add_spectral_norm_flag:
|
||||
model.apply(add_spectral_norm)
|
||||
return idist.auto_model(model)
|
||||
|
||||
|
||||
def build_optimizer(params, cfg):
|
||||
assert "_type" in cfg
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
|
||||
return idist.auto_optim(optimizer)
|
||||
66
engine/util/container.py
Normal file
66
engine/util/container.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
|
||||
|
||||
class LossContainer:
|
||||
def __init__(self, weight, loss):
|
||||
self.weight = weight
|
||||
self.loss = loss
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self.weight > 0:
|
||||
return self.weight * self.loss(*args, **kwargs)
|
||||
return 0.0
|
||||
|
||||
|
||||
class GANImageBuffer:
|
||||
"""This class implements an image buffer that stores previously
|
||||
generated images.
|
||||
This buffer allows us to update the discriminator using a history of
|
||||
generated images rather than the ones produced by the latest generator
|
||||
to reduce model oscillation.
|
||||
Args:
|
||||
buffer_size (int): The size of image buffer. If buffer_size = 0,
|
||||
no buffer will be created.
|
||||
buffer_ratio (float): The chance / possibility to use the images
|
||||
previously stored in the buffer.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size, buffer_ratio=0.5):
|
||||
self.buffer_size = buffer_size
|
||||
# create an empty buffer
|
||||
if self.buffer_size > 0:
|
||||
self.img_num = 0
|
||||
self.image_buffer = []
|
||||
self.buffer_ratio = buffer_ratio
|
||||
|
||||
def query(self, images):
|
||||
"""Query current image batch using a history of generated images.
|
||||
Args:
|
||||
images (Tensor): Current image batch without history information.
|
||||
"""
|
||||
if self.buffer_size == 0: # if the buffer size is 0, do nothing
|
||||
return images
|
||||
return_images = []
|
||||
for image in images:
|
||||
image = torch.unsqueeze(image.data, 0)
|
||||
# if the buffer is not full, keep inserting current images
|
||||
if self.img_num < self.buffer_size:
|
||||
self.img_num = self.img_num + 1
|
||||
self.image_buffer.append(image)
|
||||
return_images.append(image)
|
||||
else:
|
||||
use_buffer = torch.rand(1) < self.buffer_ratio
|
||||
# by self.buffer_ratio, the buffer will return a previously
|
||||
# stored image, and insert the current image into the buffer
|
||||
if use_buffer:
|
||||
random_id = torch.randint(0, self.buffer_size, (1,)).item()
|
||||
image_tmp = self.image_buffer[random_id].clone()
|
||||
self.image_buffer[random_id] = image
|
||||
return_images.append(image_tmp)
|
||||
# by (1 - self.buffer_ratio), the buffer will return the
|
||||
# current image
|
||||
else:
|
||||
return_images.append(image)
|
||||
# collect all the images and return
|
||||
return_images = torch.cat(return_images, 0)
|
||||
return return_images
|
||||
@@ -7,7 +7,7 @@ import ignite.distributed as idist
|
||||
from ignite.engine import Events, Engine
|
||||
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
|
||||
from ignite.contrib.handlers import BasicTimeProfiler, ProgressBar
|
||||
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler
|
||||
from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, OutputHandler, OptimizerParamsHandler
|
||||
|
||||
|
||||
def empty_cuda_cache(_):
|
||||
@@ -16,6 +16,16 @@ def empty_cuda_cache(_):
|
||||
gc.collect()
|
||||
|
||||
|
||||
def step_transform_maker(stype: str, pairs_per_iteration=None):
|
||||
assert stype in ["item", "iteration", "epoch"]
|
||||
if stype == "item":
|
||||
return lambda engine, _: engine.state.iteration * pairs_per_iteration
|
||||
if stype == "iteration":
|
||||
return lambda engine, _: engine.state.iteration
|
||||
if stype == "epoch":
|
||||
return lambda engine, _: engine.state.epoch
|
||||
|
||||
|
||||
def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_cache=False, use_profiler=True,
|
||||
to_save=None, end_event=None, set_epoch_for_dist_sampler=False):
|
||||
"""
|
||||
@@ -41,9 +51,10 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
trainer.logger.debug(f"set_epoch {engine.state.epoch - 1} for DistributedSampler")
|
||||
trainer.state.dataloader.sampler.set_epoch(engine.state.epoch - 1)
|
||||
|
||||
@trainer.on(Events.STARTED | Events.EPOCH_COMPLETED(once=1))
|
||||
trainer.logger.info(f"data loader length: {config.iterations_per_epoch} iterations per epoch")
|
||||
|
||||
@trainer.on(Events.EPOCH_COMPLETED(once=1))
|
||||
def print_info(engine):
|
||||
engine.logger.info(f"data loader length: {len(engine.state.dataloader)}")
|
||||
engine.logger.info(f"- GPU util: \n{torch.cuda.memory_summary(0)}")
|
||||
|
||||
if stop_on_nan:
|
||||
@@ -66,7 +77,7 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
|
||||
if to_save is not None:
|
||||
checkpoint_handler = Checkpoint(to_save, DiskSaver(dirname=config.output_dir, require_empty=False),
|
||||
n_saved=config.checkpoint.n_saved, filename_prefix=config.name)
|
||||
n_saved=config.handler.checkpoint.n_saved, filename_prefix=config.name)
|
||||
if config.resume_from is not None:
|
||||
@trainer.on(Events.STARTED)
|
||||
def resume(engine):
|
||||
@@ -74,11 +85,13 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
if not checkpoint_path.exists():
|
||||
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
|
||||
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
|
||||
trainer.logger.info(f"load state_dict for {ckp.keys()}")
|
||||
trainer.logger.info(f"load state_dict for {to_save.keys()}")
|
||||
Checkpoint.load_objects(to_load=to_save, checkpoint=ckp)
|
||||
engine.logger.info(f"resume from a checkpoint {checkpoint_path}")
|
||||
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=config.checkpoint.epoch_interval) | Events.COMPLETED,
|
||||
checkpoint_handler)
|
||||
trainer.add_event_handler(
|
||||
Events.EPOCH_COMPLETED(every=config.handler.checkpoint.epoch_interval) | Events.COMPLETED,
|
||||
checkpoint_handler
|
||||
)
|
||||
trainer.logger.debug(f"add checkpoint handler to save {to_save.keys()} periodically")
|
||||
if end_event is not None:
|
||||
trainer.logger.debug(f"engine will stop on {end_event}")
|
||||
@@ -88,17 +101,48 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
engine.terminate()
|
||||
|
||||
|
||||
def setup_tensorboard_handler(trainer: Engine, config, output_transform, iter_per_epoch):
|
||||
if config.interval.tensorboard is None:
|
||||
def setup_tensorboard_handler(trainer: Engine, config, optimizers, step_type="item"):
|
||||
if config.handler.tensorboard is None:
|
||||
return None
|
||||
if idist.get_rank() == 0:
|
||||
# Create a logger
|
||||
tb_logger = TensorboardLogger(log_dir=config.output_dir)
|
||||
basic_event = Events.ITERATION_COMPLETED(every=max(iter_per_epoch // config.interval.tensorboard.scalar, 1))
|
||||
tb_logger.attach(trainer, log_handler=OutputHandler(tag="metric", metric_names="all"),
|
||||
event_name=basic_event)
|
||||
tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", output_transform=output_transform),
|
||||
event_name=basic_event)
|
||||
tb_writer = tb_logger.writer
|
||||
pairs_per_iteration = config.data.train.dataloader.batch_size * idist.get_world_size()
|
||||
global_step_transform = step_transform_maker(step_type, pairs_per_iteration)
|
||||
|
||||
basic_event = Events.ITERATION_COMPLETED(
|
||||
every=max(config.iterations_per_epoch // config.handler.tensorboard.scalar, 1))
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OutputHandler(
|
||||
tag="metric", metric_names="all",
|
||||
global_step_transform=global_step_transform
|
||||
),
|
||||
event_name=basic_event
|
||||
)
|
||||
|
||||
@trainer.on(basic_event)
|
||||
def log_loss(engine):
|
||||
global_step = global_step_transform(engine, None)
|
||||
output_loss = engine.state.output["loss"]
|
||||
for total_loss in output_loss:
|
||||
if isinstance(output_loss[total_loss], dict):
|
||||
for ln in output_loss[total_loss]:
|
||||
tb_writer.add_scalar(f"train_{total_loss}/{ln}", output_loss[total_loss][ln], global_step)
|
||||
else:
|
||||
tb_writer.add_scalar(f"train/{total_loss}", output_loss[total_loss], global_step)
|
||||
|
||||
if isinstance(optimizers, dict):
|
||||
for name in optimizers:
|
||||
tb_logger.attach(
|
||||
trainer,
|
||||
log_handler=OptimizerParamsHandler(optimizers[name], tag=f"optimizer_{name}"),
|
||||
event_name=Events.ITERATION_STARTED
|
||||
)
|
||||
else:
|
||||
tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizers, tag=f"optimizer"),
|
||||
event_name=Events.ITERATION_STARTED)
|
||||
|
||||
@trainer.on(Events.COMPLETED)
|
||||
@idist.one_rank_only()
|
||||
48
engine/util/loss.py
Normal file
48
engine/util/loss.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
from loss.gan import GANLoss
|
||||
|
||||
|
||||
def gan_loss(config):
|
||||
gan_loss_cfg = OmegaConf.to_container(config)
|
||||
gan_loss_cfg.pop("weight")
|
||||
return GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
|
||||
|
||||
def perceptual_loss(config):
|
||||
perceptual_loss_cfg = OmegaConf.to_container(config)
|
||||
perceptual_loss_cfg.pop("weight")
|
||||
return PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
|
||||
|
||||
|
||||
def pixel_loss(level):
|
||||
return nn.L1Loss() if level == 1 else nn.MSELoss()
|
||||
|
||||
|
||||
def mse_loss(x, target_flag):
|
||||
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||
|
||||
|
||||
def bce_loss(x, target_flag):
|
||||
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
|
||||
|
||||
|
||||
def feature_match_loss(level, weight_policy):
|
||||
compare_loss = pixel_loss(level)
|
||||
assert weight_policy in ["same", "exponential_decline"]
|
||||
|
||||
def fm_loss(generated_features, target_features):
|
||||
num_scale = len(generated_features)
|
||||
loss = torch.zeros(1, device=idist.device())
|
||||
for s_i in range(num_scale):
|
||||
for i in range(len(generated_features[s_i]) - 1):
|
||||
weight = 1 if weight_policy == "same" else 2 ** i
|
||||
loss += weight * compare_loss(generated_features[s_i][i], target_features[s_i][i].detach()) / num_scale
|
||||
return loss
|
||||
|
||||
return fm_loss
|
||||
@@ -6,7 +6,6 @@ channels:
|
||||
dependencies:
|
||||
- python=3.8
|
||||
- numpy
|
||||
- ipython
|
||||
- tqdm
|
||||
- pyyaml
|
||||
- pytorch=1.6.*
|
||||
@@ -17,6 +16,6 @@ dependencies:
|
||||
- omegaconf
|
||||
- python-lmdb
|
||||
- fire
|
||||
# - opencv
|
||||
- opencv
|
||||
# - jupyterlab
|
||||
|
||||
|
||||
44
loss/I2I/context_loss.py
Normal file
44
loss/I2I/context_loss.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .perceptual_loss import PerceptualVGG
|
||||
|
||||
|
||||
class ContextLoss(nn.Module):
|
||||
def __init__(self, layer_weights, h=0.1, vgg_type='vgg19', norm_image_with_imagenet_param=True, norm_img=True,
|
||||
eps=1e-5):
|
||||
super(ContextLoss, self).__init__()
|
||||
self.eps = eps
|
||||
self.h = h
|
||||
self.layer_weights = layer_weights
|
||||
self.norm_img = norm_img
|
||||
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
|
||||
norm_image_with_imagenet_param=norm_image_with_imagenet_param)
|
||||
|
||||
def single_forward(self, source_feature, target_feature):
|
||||
mean_target_feature = target_feature.mean(dim=[2, 3], keepdim=True)
|
||||
source_feature = (source_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW
|
||||
target_feature = (target_feature - mean_target_feature).view(*source_feature.size()[:2], -1) # NxCxHW
|
||||
source_feature = F.normalize(source_feature, p=2, dim=1)
|
||||
target_feature = F.normalize(target_feature, p=2, dim=1)
|
||||
cosine_distance = (1 - torch.bmm(source_feature.transpose(1, 2), target_feature)) / 2 # NxHWxHW
|
||||
rel_distance = cosine_distance / (cosine_distance.min(2, keepdim=True)[0] + self.eps)
|
||||
w = torch.exp((1 - rel_distance) / self.h)
|
||||
cx = w.div(w.sum(dim=2, keepdim=True))
|
||||
cx = cx.max(dim=1, keepdim=True)[0].mean(dim=2)
|
||||
return -torch.log(cx).mean()
|
||||
|
||||
def forward(self, x, gt):
|
||||
if self.norm_img:
|
||||
x = (x + 1.) * 0.5
|
||||
gt = (gt + 1.) * 0.5
|
||||
|
||||
# extract vgg features
|
||||
x_features = self.vgg(x)
|
||||
gt_features = self.vgg(gt.detach())
|
||||
|
||||
loss = 0
|
||||
for k in x_features.keys():
|
||||
loss += self.single_forward(x_features[k], gt_features[k]) * self.layer_weights[k]
|
||||
return loss
|
||||
229
loss/I2I/minimal_geometry_distortion_constraint_loss.py
Normal file
229
loss/I2I/minimal_geometry_distortion_constraint_loss.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def gaussian_radial_basis_function(x, mu, sigma):
|
||||
# (kernel_size) -> (batch_size, kernel_size, c*h*w)
|
||||
mu = mu.view(1, mu.size(0), 1).expand(x.size(0), -1, x.size(1) * x.size(2) * x.size(3))
|
||||
# (batch_size, c, h, w) -> (batch_size, kernel_size, c*h*w)
|
||||
x = x.view(x.size(0), 1, -1).expand(-1, mu.size(1), -1)
|
||||
return torch.exp((x - mu).pow(2) / (2 * sigma ** 2))
|
||||
|
||||
|
||||
class ImporveMyLoss(torch.nn.Module):
|
||||
def __init__(self, device=idist.device()):
|
||||
super().__init__()
|
||||
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).to(device)
|
||||
self.x_mu_list = mu.repeat(9).view(-1, 81)
|
||||
self.y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)
|
||||
self.R = torch.eye(81).to(device)
|
||||
|
||||
def batch_ERSMI(self, I1, I2):
|
||||
batch_size = I1.shape[0]
|
||||
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
|
||||
if I2.shape[1] == 1 and I1.shape[1] != 1:
|
||||
I2 = I2.repeat(1, 3, 1, 1)
|
||||
|
||||
def kernel_F(y, mu_list, sigma):
|
||||
tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size, 1, 1) # [81, 784]
|
||||
tmp_y = y.view(batch_size, 1, -1).repeat(1, 81, 1)
|
||||
tmp_y = tmp_mu - tmp_y
|
||||
mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
|
||||
return mat_L
|
||||
|
||||
mat_K = kernel_F(I1, self.x_mu_list, 1)
|
||||
mat_L = kernel_F(I2, self.y_mu_list, 1)
|
||||
mat_k_l = mat_K * mat_L
|
||||
|
||||
H1 = (mat_K @ mat_K.transpose(1, 2)) * (mat_L @ mat_L.transpose(1, 2)) / (img_size ** 2)
|
||||
h_hat = mat_k_l @ mat_k_l.transpose(1, 2) / img_size
|
||||
small_h_hat = mat_K.sum(2).view(batch_size, -1, 1) * mat_L.sum(2).view(batch_size, -1, 1) / (img_size ** 2)
|
||||
h_hat = 0.5 * H1 + 0.5 * h_hat
|
||||
alpha = (h_hat + 0.05 * self.R).inverse() @ small_h_hat
|
||||
|
||||
ersmi = 2 * alpha.transpose(1, 2) @ small_h_hat - alpha.transpose(1, 2) @ h_hat @ alpha - 1
|
||||
|
||||
ersmi = -ersmi.squeeze().mean()
|
||||
return ersmi
|
||||
|
||||
def forward(self, fakeI, realI):
|
||||
return self.batch_ERSMI(fakeI, realI)
|
||||
|
||||
|
||||
class MyLoss(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyLoss, self).__init__()
|
||||
|
||||
def forward(self, fakeI, realI):
|
||||
fakeI = fakeI.cuda()
|
||||
realI = realI.cuda()
|
||||
|
||||
def batch_ERSMI(I1, I2):
|
||||
batch_size = I1.shape[0]
|
||||
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
|
||||
if I2.shape[1] == 1 and I1.shape[1] != 1:
|
||||
I2 = I2.repeat(1, 3, 1, 1)
|
||||
|
||||
def kernel_F(y, mu_list, sigma):
|
||||
tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size, 1, 1).cuda() # [81, 784]
|
||||
tmp_y = y.view(batch_size, 1, -1).repeat(1, 81, 1)
|
||||
tmp_y = tmp_mu - tmp_y
|
||||
mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
|
||||
return mat_L
|
||||
|
||||
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).cuda()
|
||||
|
||||
x_mu_list = mu.repeat(9).view(-1, 81)
|
||||
y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)
|
||||
|
||||
mat_K = kernel_F(I1, x_mu_list, 1)
|
||||
mat_L = kernel_F(I2, y_mu_list, 1)
|
||||
|
||||
H1 = ((mat_K.matmul(mat_K.transpose(1, 2))).mul(mat_L.matmul(mat_L.transpose(1, 2))) / (
|
||||
img_size ** 2)).cuda()
|
||||
H2 = ((mat_K.mul(mat_L)).matmul((mat_K.mul(mat_L)).transpose(1, 2)) / img_size).cuda()
|
||||
h2 = ((mat_K.sum(2).view(batch_size, -1, 1)).mul(mat_L.sum(2).view(batch_size, -1, 1)) / (
|
||||
img_size ** 2)).cuda()
|
||||
H2 = 0.5 * H1 + 0.5 * H2
|
||||
tmp = H2 + 0.05 * torch.eye(len(H2[0])).cuda()
|
||||
alpha = (tmp.inverse())
|
||||
|
||||
alpha = alpha.matmul(h2)
|
||||
ersmi = (2 * (alpha.transpose(1, 2)).matmul(h2) - ((alpha.transpose(1, 2)).matmul(H2)).matmul(
|
||||
alpha) - 1).squeeze()
|
||||
|
||||
ersmi = -ersmi.mean()
|
||||
return ersmi
|
||||
|
||||
batch_loss = batch_ERSMI(fakeI, realI)
|
||||
return batch_loss
|
||||
|
||||
|
||||
class MGCLoss(nn.Module):
|
||||
"""
|
||||
Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ
|
||||
"""
|
||||
|
||||
def __init__(self, mi_to_loss_way="opposite", beta=0.5, lambda_=0.05, device=idist.device()):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.lambda_ = lambda_
|
||||
assert mi_to_loss_way in ["opposite", "reciprocal"]
|
||||
self.mi_to_loss_way = mi_to_loss_way
|
||||
mu_y, mu_x = torch.meshgrid([torch.arange(-1, 1.25, 0.25), torch.arange(-1, 1.25, 0.25)])
|
||||
self.mu_x = mu_x.flatten().to(device)
|
||||
self.mu_y = mu_y.flatten().to(device)
|
||||
self.R = torch.eye(81).unsqueeze(0).to(device)
|
||||
|
||||
@staticmethod
|
||||
def batch_rSMI(img1, img2, mu_x, mu_y, beta, lambda_, R):
|
||||
assert img1.size() == img2.size()
|
||||
|
||||
num_pixel = img1.size(1) * img1.size(2) * img2.size(3)
|
||||
|
||||
mat_k = gaussian_radial_basis_function(img1, mu_x, sigma=1)
|
||||
mat_l = gaussian_radial_basis_function(img2, mu_y, sigma=1)
|
||||
|
||||
mat_k_mul_mat_l = mat_k * mat_l
|
||||
h_hat = (1 - beta) * (mat_k_mul_mat_l @ mat_k_mul_mat_l.transpose(1, 2)) / num_pixel
|
||||
h_hat += beta * ((mat_k @ mat_k.transpose(1, 2)) * (mat_l @ mat_l.transpose(1, 2))) / (num_pixel ** 2)
|
||||
small_h_hat = mat_k.sum(2, keepdim=True) * mat_l.sum(2, keepdim=True) / (num_pixel ** 2)
|
||||
|
||||
alpha = (h_hat + lambda_ * R).inverse() @ small_h_hat
|
||||
rSMI = 2 * alpha.transpose(1, 2) @ small_h_hat - alpha.transpose(1, 2) @ h_hat @ alpha - 1
|
||||
return rSMI.squeeze()
|
||||
|
||||
def forward(self, fake, real):
|
||||
rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_, self.R)
|
||||
if self.mi_to_loss_way == "reciprocal":
|
||||
return 1/rSMI.mean()
|
||||
return -rSMI.mean()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mg = MGCLoss(device=torch.device("cpu"))
|
||||
my = MyLoss().to("cuda")
|
||||
imy = ImporveMyLoss()
|
||||
|
||||
from data.transform import transform_pipeline
|
||||
|
||||
pipeline = transform_pipeline(
|
||||
['Load', 'ToTensor', {'Normalize': {'mean': [0.5, 0.5, 0.5], 'std': [0.5, 0.5, 0.5]}}])
|
||||
|
||||
img_a1 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_1.jpg")
|
||||
img_a2 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_2.jpg")
|
||||
img_a3 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id00022-twCPGo2rtCo-00294_3.jpg")
|
||||
img_b1 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_1.jpg")
|
||||
img_b2 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_2.jpg")
|
||||
img_b3 = pipeline("/data/i2i/VoxCeleb2Anime/trainA/id01222-2gHw81dNQiA-00005_3.jpg")
|
||||
|
||||
img_a1.requires_grad_(True)
|
||||
img_a2.requires_grad_(True)
|
||||
img_a3.requires_grad_(True)
|
||||
|
||||
# print("MyLoss")
|
||||
# l1 = my(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
|
||||
# l2 = my(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
|
||||
# l3 = my(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
|
||||
# l = (l1+l2+l3)/3
|
||||
# l.backward()
|
||||
# print(img_a1.grad[0][0][0:10])
|
||||
# print(img_a2.grad[0][0][0:10])
|
||||
# print(img_a3.grad[0][0][0:10])
|
||||
#
|
||||
# img_a1.grad = None
|
||||
# img_a2.grad = None
|
||||
# img_a3.grad = None
|
||||
#
|
||||
# print("---")
|
||||
# l = my(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
|
||||
# l.backward()
|
||||
# print(img_a1.grad[0][0][0:10])
|
||||
# print(img_a2.grad[0][0][0:10])
|
||||
# print(img_a3.grad[0][0][0:10])
|
||||
# img_a1.grad = None
|
||||
# img_a2.grad = None
|
||||
# img_a3.grad = None
|
||||
|
||||
print("MGCLoss")
|
||||
l1 = mg(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
|
||||
l2 = mg(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
|
||||
l3 = mg(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
|
||||
l = (l1 + l2 + l3) / 3
|
||||
l.backward()
|
||||
print(img_a1.grad[0][0][0:10])
|
||||
print(img_a2.grad[0][0][0:10])
|
||||
print(img_a3.grad[0][0][0:10])
|
||||
|
||||
img_a1.grad = None
|
||||
img_a2.grad = None
|
||||
img_a3.grad = None
|
||||
|
||||
print("---")
|
||||
l = mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
|
||||
l.backward()
|
||||
print(img_a1.grad[0][0][0:10])
|
||||
print(img_a2.grad[0][0][0:10])
|
||||
print(img_a3.grad[0][0][0:10])
|
||||
|
||||
# print("\nMGCLoss")
|
||||
# mg(img_a1.unsqueeze(0), img_b1.unsqueeze(0))
|
||||
# mg(img_a2.unsqueeze(0), img_b2.unsqueeze(0))
|
||||
# mg(img_a3.unsqueeze(0), img_b3.unsqueeze(0))
|
||||
#
|
||||
# print("---")
|
||||
# mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
|
||||
#
|
||||
# import pprofile
|
||||
#
|
||||
# profiler = pprofile.Profile()
|
||||
# with profiler:
|
||||
# iter_times = 1000
|
||||
# for _ in range(iter_times):
|
||||
# mg(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
|
||||
# for _ in range(iter_times):
|
||||
# my(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
|
||||
# for _ in range(iter_times):
|
||||
# imy(torch.stack([img_a1, img_a2, img_a3]), torch.stack([img_b1, img_b2, img_b3]))
|
||||
# profiler.print_stats()
|
||||
@@ -1,8 +1,52 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models.vgg as vgg
|
||||
|
||||
|
||||
# Sequential(
|
||||
# (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (1): ReLU(inplace=True)
|
||||
# (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (3): ReLU(inplace=True)
|
||||
# (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
|
||||
# (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (6): ReLU(inplace=True)
|
||||
# (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (8): ReLU(inplace=True)
|
||||
# (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
|
||||
# (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (11): ReLU(inplace=True)
|
||||
# (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (13): ReLU(inplace=True)
|
||||
# (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (15): ReLU(inplace=True)
|
||||
# (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (17): ReLU(inplace=True)
|
||||
# (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
|
||||
# (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (20): ReLU(inplace=True)
|
||||
# (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (22): ReLU(inplace=True)
|
||||
# (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (24): ReLU(inplace=True)
|
||||
# (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (26): ReLU(inplace=True)
|
||||
# (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
|
||||
# (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (29): ReLU(inplace=True)
|
||||
# (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (31): ReLU(inplace=True)
|
||||
# (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (33): ReLU(inplace=True)
|
||||
# (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
|
||||
# (35): ReLU(inplace=True)
|
||||
# (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
|
||||
# )
|
||||
class PerceptualVGG(nn.Module):
|
||||
"""VGG network used in calculating perceptual loss.
|
||||
In this implementation, we allow users to choose whether use normalization
|
||||
@@ -14,15 +58,15 @@ class PerceptualVGG(nn.Module):
|
||||
list contains the name each layer in `vgg.feature`. An example
|
||||
of this list is ['4', '10'].
|
||||
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
|
||||
use_input_norm (bool): If True, normalize the input image.
|
||||
norm_image_with_imagenet_param (bool): If True, normalize the input image.
|
||||
Importantly, the input feature must in the range [0, 1].
|
||||
Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_name_list, vgg_type='vgg19', use_input_norm=True):
|
||||
def __init__(self, layer_name_list, vgg_type='vgg19', norm_image_with_imagenet_param=True):
|
||||
super(PerceptualVGG, self).__init__()
|
||||
self.layer_name_list = layer_name_list
|
||||
self.use_input_norm = use_input_norm
|
||||
self.use_input_norm = norm_image_with_imagenet_param
|
||||
|
||||
# get vgg model and load pretrained vgg weight
|
||||
# remove _vgg from attributes to avoid `find_unused_parameters` bug
|
||||
@@ -74,7 +118,7 @@ class PerceptualLoss(nn.Module):
|
||||
in calculating losses.
|
||||
vgg_type (str): The type of vgg network used as feature extractor.
|
||||
Default: 'vgg19'.
|
||||
use_input_norm (bool): If True, normalize the input image in vgg.
|
||||
norm_image_with_imagenet_param (bool): If True, normalize the input image in vgg.
|
||||
Default: True.
|
||||
perceptual_loss (bool): If `perceptual_loss == True`, the perceptual
|
||||
loss will be calculated.
|
||||
@@ -87,22 +131,24 @@ class PerceptualLoss(nn.Module):
|
||||
Importantly, the input image must be in range [-1, 1].
|
||||
"""
|
||||
|
||||
def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, perceptual_loss=True,
|
||||
def __init__(self, layer_weights, vgg_type='vgg19', norm_image_with_imagenet_param=True, perceptual_loss=True,
|
||||
style_loss=False, norm_img=True, criterion='L1'):
|
||||
super(PerceptualLoss, self).__init__()
|
||||
self.norm_img = norm_img
|
||||
assert perceptual_loss ^ style_loss, "There must be one and only one true in style or perceptual"
|
||||
self.perceptual_loss = perceptual_loss
|
||||
self.style_loss = style_loss
|
||||
self.layer_weights = layer_weights
|
||||
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
|
||||
use_input_norm=use_input_norm)
|
||||
norm_image_with_imagenet_param=norm_image_with_imagenet_param)
|
||||
|
||||
if criterion == 'L1':
|
||||
self.criterion = torch.nn.L1Loss()
|
||||
elif criterion == "L2":
|
||||
self.criterion = torch.nn.MSELoss()
|
||||
else:
|
||||
raise NotImplementedError(f'{criterion} criterion has not been supported in this version.')
|
||||
self.percep_criterion, self.style_criterion = self.set_criterion(criterion)
|
||||
|
||||
def set_criterion(self, criterion: str):
|
||||
assert criterion in ["NL1", "NL2", "L1", "L2"]
|
||||
norm = F.instance_norm if criterion.startswith("N") else lambda x: x
|
||||
fn = F.l1_loss if criterion.endswith("L1") else F.mse_loss
|
||||
return lambda x, t: fn(norm(x), norm(t)), lambda x, t: fn(x, t)
|
||||
|
||||
def forward(self, x, gt):
|
||||
"""Forward function.
|
||||
@@ -124,22 +170,16 @@ class PerceptualLoss(nn.Module):
|
||||
if self.perceptual_loss:
|
||||
percep_loss = 0
|
||||
for k in x_features.keys():
|
||||
percep_loss += self.criterion(
|
||||
x_features[k], gt_features[k]) * self.layer_weights[k]
|
||||
else:
|
||||
percep_loss = None
|
||||
percep_loss += self.percep_criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
|
||||
return percep_loss
|
||||
|
||||
# calculate style loss
|
||||
if self.style_loss:
|
||||
style_loss = 0
|
||||
for k in x_features.keys():
|
||||
style_loss += self.criterion(
|
||||
self._gram_mat(x_features[k]),
|
||||
self._gram_mat(gt_features[k])) * self.layer_weights[k]
|
||||
else:
|
||||
style_loss = None
|
||||
|
||||
return percep_loss, style_loss
|
||||
style_loss += self.style_criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \
|
||||
self.layer_weights[k]
|
||||
return style_loss
|
||||
|
||||
def _gram_mat(self, x):
|
||||
"""Calculate Gram matrix.
|
||||
|
||||
20
loss/gan.py
20
loss/gan.py
@@ -1,4 +1,5 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
@@ -10,7 +11,7 @@ class GANLoss(nn.Module):
|
||||
self.fake_label_val = fake_label_val
|
||||
self.loss_type = loss_type
|
||||
|
||||
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||
def single_forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||
"""
|
||||
gan loss forward
|
||||
:param prediction: network prediction
|
||||
@@ -37,3 +38,20 @@ class GANLoss(nn.Module):
|
||||
return loss
|
||||
else:
|
||||
raise NotImplementedError(f'GAN type {self.loss_type} is not implemented.')
|
||||
|
||||
def forward(self, prediction, target_is_real: bool, is_discriminator=False):
|
||||
if isinstance(prediction, torch.Tensor):
|
||||
# origin
|
||||
return self.single_forward(prediction, target_is_real, is_discriminator)
|
||||
elif isinstance(prediction, list):
|
||||
# for multi scale discriminator, e.g. MultiScaleDiscriminator
|
||||
loss = 0
|
||||
for p in prediction:
|
||||
loss += self.single_forward(p[-1], target_is_real, is_discriminator)
|
||||
return loss
|
||||
elif isinstance(prediction, tuple):
|
||||
# for single discriminator set `need_intermediate_feature` true
|
||||
return self.single_forward(prediction[-1], target_is_real, is_discriminator)
|
||||
else:
|
||||
raise NotImplementedError(f"not support discriminator output: {prediction}")
|
||||
|
||||
|
||||
26
main.py
26
main.py
@@ -1,16 +1,15 @@
|
||||
from pathlib import Path
|
||||
from importlib import import_module
|
||||
|
||||
import torch
|
||||
|
||||
import ignite
|
||||
import ignite.distributed as idist
|
||||
from ignite.utils import manual_seed
|
||||
from util.misc import setup_logger
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import ignite
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
from ignite.utils import manual_seed
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from util.misc import setup_logger
|
||||
|
||||
|
||||
def log_basic_info(logger, config):
|
||||
logger.info(f"Train {config.name}")
|
||||
@@ -28,15 +27,14 @@ def log_basic_info(logger, config):
|
||||
def running(local_rank, config, task, backup_config=False, setup_output_dir=False, setup_random_seed=False):
|
||||
if setup_random_seed:
|
||||
manual_seed(config.misc.random_seed + idist.get_rank())
|
||||
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else config.output_dir
|
||||
output_dir = Path(config.result_dir) / config.name if config.output_dir is None else Path(config.output_dir)
|
||||
config.output_dir = str(output_dir)
|
||||
|
||||
if setup_output_dir and config.resume_from is None:
|
||||
if output_dir.exists():
|
||||
# assert not any(output_dir.iterdir()), "output_dir must be empty"
|
||||
contains = list(output_dir.iterdir())
|
||||
assert (len(contains) == 0) or (len(contains) == 1 and contains[0].name == "config.yml"), \
|
||||
f"output_dir must by empty or only contains config.yml, but now got {len(contains)} files"
|
||||
assert len(list(output_dir.glob("events*"))) == 0, f"{output_dir} containers tensorboard event"
|
||||
if (output_dir / "train.log").exists() and idist.get_rank() == 0:
|
||||
(output_dir / "train.log").unlink()
|
||||
else:
|
||||
if idist.get_rank() == 0:
|
||||
output_dir.mkdir(parents=True)
|
||||
@@ -65,6 +63,8 @@ def run(task, config: str, *omega_options, **kwargs):
|
||||
backup_config = kwargs.get("backup_config", False)
|
||||
setup_output_dir = kwargs.get("setup_output_dir", False)
|
||||
setup_random_seed = kwargs.get("setup_random_seed", False)
|
||||
assert torch.backends.cudnn.enabled
|
||||
torch.backends.cudnn.benchmark = True
|
||||
with idist.Parallel(backend=backend) as parallel:
|
||||
parallel.run(running, conf, task, backup_config=backup_config, setup_output_dir=setup_output_dir,
|
||||
setup_random_seed=setup_random_seed)
|
||||
|
||||
@@ -1,224 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .residual_generator import ResidualBlock
|
||||
from model.registry import MODEL
|
||||
from torchvision.models import vgg19
|
||||
from model.normalization import select_norm_layer
|
||||
|
||||
|
||||
class VGG19StyleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, style_dim=512, padding_mode='reflect', norm_type="NONE",
|
||||
vgg19_layers=(0, 5, 10, 19)):
|
||||
super().__init__()
|
||||
self.vgg19_layers = vgg19_layers
|
||||
self.vgg19 = vgg19(pretrained=True).features[:vgg19_layers[-1] + 1]
|
||||
self.vgg19.requires_grad_(False)
|
||||
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
|
||||
self.conv0 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
||||
bias=True),
|
||||
norm_layer(base_channels),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
self.conv = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
nn.Conv2d(base_channels * (2 ** i), base_channels * (2 ** i), kernel_size=4, stride=2, padding=1,
|
||||
padding_mode=padding_mode, bias=True),
|
||||
norm_layer(base_channels),
|
||||
nn.ReLU(True),
|
||||
) for i in range(1, 4)
|
||||
])
|
||||
self.pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv1x1 = nn.Conv2d(base_channels * (2 ** 4), style_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def fixed_style_features(self, x):
|
||||
features = []
|
||||
for i in range(len(self.vgg19)):
|
||||
x = self.vgg19[i](x)
|
||||
if i in self.vgg19_layers:
|
||||
features.append(x)
|
||||
return features
|
||||
|
||||
def forward(self, x):
|
||||
fsf = self.fixed_style_features(x)
|
||||
x = self.conv0(x)
|
||||
for i, l in enumerate(self.conv):
|
||||
x = l(torch.cat([x, fsf[i]], dim=1))
|
||||
x = self.pool(torch.cat([x, fsf[-1]], dim=1))
|
||||
x = self.conv1x1(x)
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
class ContentEncoder(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, num_blocks=8, padding_mode='reflect', norm_type="IN"):
|
||||
super().__init__()
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
|
||||
self.start_conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
|
||||
bias=True),
|
||||
norm_layer(num_features=base_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
# down sampling
|
||||
submodules = []
|
||||
num_down_sampling = 2
|
||||
for i in range(num_down_sampling):
|
||||
multiple = 2 ** i
|
||||
submodules += [
|
||||
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
|
||||
kernel_size=4, stride=2, padding=1, bias=True),
|
||||
norm_layer(num_features=base_channels * multiple * 2),
|
||||
nn.ReLU(inplace=True)
|
||||
]
|
||||
self.encoder = nn.Sequential(*submodules)
|
||||
res_block_channels = num_down_sampling ** 2 * base_channels
|
||||
self.resnet = nn.Sequential(
|
||||
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.start_conv(x)
|
||||
x = self.encoder(x)
|
||||
x = self.resnet(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, out_channels, base_channels=64, num_blocks=4, num_down_sampling=2, padding_mode='reflect',
|
||||
norm_type="LN"):
|
||||
super(Decoder, self).__init__()
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
use_bias = norm_type == "IN"
|
||||
|
||||
res_block_channels = (2 ** 2) * base_channels
|
||||
|
||||
self.resnet = nn.Sequential(
|
||||
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
|
||||
|
||||
# up sampling
|
||||
submodules = []
|
||||
for i in range(num_down_sampling):
|
||||
multiple = 2 ** (num_down_sampling - i)
|
||||
submodules += [
|
||||
nn.Upsample(scale_factor=2),
|
||||
nn.Conv2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=5, stride=1,
|
||||
padding=2, padding_mode=padding_mode, bias=use_bias),
|
||||
norm_layer(num_features=base_channels * multiple // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
self.decoder = nn.Sequential(*submodules)
|
||||
self.end_conv = nn.Sequential(
|
||||
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.resnet(x)
|
||||
x = self.decoder(x)
|
||||
x = self.end_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Fusion(nn.Module):
|
||||
def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"):
|
||||
super().__init__()
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
self.start_fc = nn.Sequential(
|
||||
nn.Linear(in_features, base_features),
|
||||
norm_layer(base_features),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
self.fcs = nn.Sequential(*[
|
||||
nn.Sequential(
|
||||
nn.Linear(base_features, base_features),
|
||||
norm_layer(base_features),
|
||||
nn.ReLU(True),
|
||||
) for _ in range(n_blocks - 2)
|
||||
])
|
||||
self.end_fc = nn.Sequential(
|
||||
nn.Linear(base_features, out_features),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.start_fc(x)
|
||||
x = self.fcs(x)
|
||||
return self.end_fc(x)
|
||||
|
||||
|
||||
@MODEL.register_module("TAHG-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, num_blocks=8,
|
||||
base_channels=64, padding_mode="reflect"):
|
||||
super(Generator, self).__init__()
|
||||
self.num_blocks = num_blocks
|
||||
self.style_encoders = nn.ModuleDict({
|
||||
"a": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
|
||||
padding_mode=padding_mode, norm_type="NONE"),
|
||||
"b": VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim,
|
||||
padding_mode=padding_mode, norm_type="NONE")
|
||||
})
|
||||
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=num_blocks,
|
||||
padding_mode=padding_mode, norm_type="IN")
|
||||
res_block_channels = 2 ** 2 * base_channels
|
||||
self.adain_res = nn.ModuleList([
|
||||
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_blocks)
|
||||
])
|
||||
self.decoders = nn.ModuleDict({
|
||||
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode),
|
||||
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=num_blocks, padding_mode=padding_mode)
|
||||
})
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(style_dim, style_dim),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
self.fusion = Fusion(style_dim, num_blocks * 2 * res_block_channels * 2, base_features=256, n_blocks=3,
|
||||
norm_type="NONE")
|
||||
|
||||
def forward(self, content_img, style_img, which_decoder: str = "a"):
|
||||
x = self.content_encoder(content_img)
|
||||
styles = self.fusion(self.fc(self.style_encoders[which_decoder](style_img)))
|
||||
styles = torch.chunk(styles, self.num_blocks * 2, dim=1)
|
||||
for i, ar in enumerate(self.adain_res):
|
||||
ar.norm1.set_style(styles[2 * i])
|
||||
ar.norm2.set_style(styles[2 * i + 1])
|
||||
x = ar(x)
|
||||
return self.decoders[which_decoder](x)
|
||||
|
||||
|
||||
@MODEL.register_module("TAHG-Discriminator")
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels=3, base_channels=64, num_down_sampling=2, num_blocks=3, norm_type="IN",
|
||||
padding_mode="reflect"):
|
||||
super(Discriminator, self).__init__()
|
||||
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
use_bias = norm_type == "IN"
|
||||
|
||||
sequence = [nn.Sequential(
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
|
||||
bias=use_bias),
|
||||
norm_layer(num_features=base_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)]
|
||||
# stacked intermediate layers,
|
||||
# gradually increasing the number of filters
|
||||
multiple_now = 1
|
||||
for n in range(1, num_down_sampling + 1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** n, 4)
|
||||
sequence += [
|
||||
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=3,
|
||||
padding=1, stride=2, bias=use_bias),
|
||||
norm_layer(base_channels * multiple_now),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
]
|
||||
for _ in range(num_blocks):
|
||||
sequence.append(ResidualBlock(base_channels * multiple_now, padding_mode, norm_type))
|
||||
self.model = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
@@ -1,237 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .residual_generator import ResidualBlock
|
||||
from model.registry import MODEL
|
||||
|
||||
|
||||
class RhoClipper(object):
|
||||
def __init__(self, clip_min, clip_max):
|
||||
self.clip_min = clip_min
|
||||
self.clip_max = clip_max
|
||||
assert clip_min < clip_max
|
||||
|
||||
def __call__(self, module):
|
||||
if hasattr(module, 'rho'):
|
||||
w = module.rho.data
|
||||
w = w.clamp(self.clip_min, self.clip_max)
|
||||
module.rho.data = w
|
||||
|
||||
|
||||
@MODEL.register_module("UGATIT-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False):
|
||||
assert (num_blocks >= 0)
|
||||
super(Generator, self).__init__()
|
||||
self.input_channels = in_channels
|
||||
self.output_channels = out_channels
|
||||
self.base_channels = base_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.img_size = img_size
|
||||
self.light = light
|
||||
|
||||
down_encoder = [nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3,
|
||||
padding_mode="reflect", bias=False),
|
||||
nn.InstanceNorm2d(base_channels),
|
||||
nn.ReLU(True)]
|
||||
|
||||
n_down_sampling = 2
|
||||
for i in range(n_down_sampling):
|
||||
mult = 2 ** i
|
||||
down_encoder += [nn.Conv2d(base_channels * mult, base_channels * mult * 2, kernel_size=3, stride=2,
|
||||
padding=1, bias=False, padding_mode="reflect"),
|
||||
nn.InstanceNorm2d(base_channels * mult * 2),
|
||||
nn.ReLU(True)]
|
||||
|
||||
# Down-Sampling Bottleneck
|
||||
mult = 2 ** n_down_sampling
|
||||
for i in range(num_blocks):
|
||||
# TODO: change ResnetBlock to ResidualBlock, check use_bias param
|
||||
down_encoder += [ResidualBlock(base_channels * mult, use_bias=False)]
|
||||
self.down_encoder = nn.Sequential(*down_encoder)
|
||||
|
||||
# Class Activation Map
|
||||
self.gap_fc = nn.Linear(base_channels * mult, 1, bias=False)
|
||||
self.gmp_fc = nn.Linear(base_channels * mult, 1, bias=False)
|
||||
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
|
||||
self.relu = nn.ReLU(True)
|
||||
|
||||
# Gamma, Beta block
|
||||
if self.light:
|
||||
fc = [nn.Linear(base_channels * mult, base_channels * mult, bias=False),
|
||||
nn.ReLU(True),
|
||||
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
|
||||
nn.ReLU(True)]
|
||||
else:
|
||||
fc = [
|
||||
nn.Linear(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, bias=False),
|
||||
nn.ReLU(True),
|
||||
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
|
||||
nn.ReLU(True)]
|
||||
self.fc = nn.Sequential(*fc)
|
||||
|
||||
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
||||
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
||||
|
||||
# Up-Sampling Bottleneck
|
||||
self.up_bottleneck = nn.ModuleList(
|
||||
[ResnetAdaILNBlock(base_channels * mult, use_bias=False) for _ in range(num_blocks)])
|
||||
|
||||
# Up-Sampling
|
||||
up_decoder = []
|
||||
for i in range(n_down_sampling):
|
||||
mult = 2 ** (n_down_sampling - i)
|
||||
up_decoder += [nn.Upsample(scale_factor=2, mode='nearest'),
|
||||
nn.Conv2d(base_channels * mult, base_channels * mult // 2, kernel_size=3, stride=1,
|
||||
padding=1, padding_mode="reflect", bias=False),
|
||||
ILN(base_channels * mult // 2),
|
||||
nn.ReLU(True)]
|
||||
|
||||
up_decoder += [nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3,
|
||||
padding_mode="reflect", bias=False),
|
||||
nn.Tanh()]
|
||||
self.up_decoder = nn.Sequential(*up_decoder)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.down_encoder(x)
|
||||
|
||||
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
|
||||
gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
|
||||
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
|
||||
gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
|
||||
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
|
||||
|
||||
x = torch.cat([gap, gmp], 1)
|
||||
x = self.relu(self.conv1x1(x))
|
||||
|
||||
heatmap = torch.sum(x, dim=1, keepdim=True)
|
||||
|
||||
if self.light:
|
||||
x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1)
|
||||
x_ = self.fc(x_.view(x_.shape[0], -1))
|
||||
else:
|
||||
x_ = self.fc(x.view(x.shape[0], -1))
|
||||
gamma, beta = self.gamma(x_), self.beta(x_)
|
||||
|
||||
for ub in self.up_bottleneck:
|
||||
x = ub(x, gamma, beta)
|
||||
|
||||
x = self.up_decoder(x)
|
||||
return x, cam_logit, heatmap
|
||||
|
||||
|
||||
class ResnetAdaILNBlock(nn.Module):
|
||||
def __init__(self, dim, use_bias):
|
||||
super(ResnetAdaILNBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
|
||||
self.norm1 = AdaILN(dim)
|
||||
self.relu1 = nn.ReLU(True)
|
||||
|
||||
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
|
||||
self.norm2 = AdaILN(dim)
|
||||
|
||||
def forward(self, x, gamma, beta):
|
||||
out = self.conv1(x)
|
||||
out = self.norm1(out, gamma, beta)
|
||||
out = self.relu1(out)
|
||||
out = self.conv2(out)
|
||||
out = self.norm2(out, gamma, beta)
|
||||
|
||||
return out + x
|
||||
|
||||
|
||||
def instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
|
||||
in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True)
|
||||
out_in = (x - in_mean) / torch.sqrt(in_var + eps)
|
||||
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
|
||||
out_ln = (x - ln_mean) / torch.sqrt(ln_var + eps)
|
||||
out = rho.expand(x.shape[0], -1, -1, -1) * out_in + (1 - rho.expand(x.shape[0], -1, -1, -1)) * out_ln
|
||||
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
|
||||
return out
|
||||
|
||||
|
||||
class AdaILN(nn.Module):
|
||||
def __init__(self, num_features, eps=1e-5, default_rho=0.9):
|
||||
super(AdaILN, self).__init__()
|
||||
self.eps = eps
|
||||
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
||||
self.rho.data.fill_(default_rho)
|
||||
|
||||
def forward(self, x, gamma, beta):
|
||||
return instance_layer_normalization(x, gamma, beta, self.rho, self.eps)
|
||||
|
||||
|
||||
class ILN(nn.Module):
|
||||
def __init__(self, num_features, eps=1e-5):
|
||||
super(ILN, self).__init__()
|
||||
self.eps = eps
|
||||
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
||||
self.gamma = nn.Parameter(torch.Tensor(1, num_features))
|
||||
self.beta = nn.Parameter(torch.Tensor(1, num_features))
|
||||
self.rho.data.fill_(0.0)
|
||||
self.gamma.data.fill_(1.0)
|
||||
self.beta.data.fill_(0.0)
|
||||
|
||||
def forward(self, x):
|
||||
return instance_layer_normalization(
|
||||
x, self.gamma.expand(x.shape[0], -1), self.beta.expand(x.shape[0], -1), self.rho, self.eps)
|
||||
|
||||
|
||||
@MODEL.register_module("UGATIT-Discriminator")
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, num_blocks=5):
|
||||
super(Discriminator, self).__init__()
|
||||
encoder = [self.build_conv_block(in_channels, base_channels)]
|
||||
|
||||
for i in range(1, num_blocks - 2):
|
||||
mult = 2 ** (i - 1)
|
||||
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2))
|
||||
|
||||
mult = 2 ** (num_blocks - 2 - 1)
|
||||
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2, stride=1))
|
||||
|
||||
self.encoder = nn.Sequential(*encoder)
|
||||
|
||||
# Class Activation Map
|
||||
mult = 2 ** (num_blocks - 2)
|
||||
self.gap_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
|
||||
self.gmp_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
|
||||
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
|
||||
self.leaky_relu = nn.LeakyReLU(0.2, True)
|
||||
|
||||
self.conv = nn.utils.spectral_norm(
|
||||
nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False, padding_mode="reflect"))
|
||||
|
||||
@staticmethod
|
||||
def build_conv_block(in_channels, out_channels, kernel_size=4, stride=2, padding=1, padding_mode="reflect"):
|
||||
return nn.Sequential(*[
|
||||
nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
|
||||
bias=True, padding=padding, padding_mode=padding_mode)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
])
|
||||
|
||||
def forward(self, x, return_heatmap=False):
|
||||
x = self.encoder(x)
|
||||
batch_size = x.size(0)
|
||||
|
||||
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) # B x C x 1 x 1, avg of per channel
|
||||
gap_logit = self.gap_fc(gap.view(batch_size, -1))
|
||||
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
|
||||
gmp_logit = self.gmp_fc(gmp.view(batch_size, -1))
|
||||
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
|
||||
|
||||
x = torch.cat([gap, gmp], 1)
|
||||
x = self.leaky_relu(self.conv1x1(x))
|
||||
|
||||
if return_heatmap:
|
||||
heatmap = torch.sum(x, dim=1, keepdim=True)
|
||||
return self.conv(x), cam_logit, heatmap
|
||||
else:
|
||||
return self.conv(x), cam_logit
|
||||
@@ -1,180 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from model.registry import MODEL
|
||||
from model.normalization import select_norm_layer
|
||||
|
||||
|
||||
class GANImageBuffer(object):
|
||||
"""This class implements an image buffer that stores previously
|
||||
generated images.
|
||||
This buffer allows us to update the discriminator using a history of
|
||||
generated images rather than the ones produced by the latest generator
|
||||
to reduce model oscillation.
|
||||
Args:
|
||||
buffer_size (int): The size of image buffer. If buffer_size = 0,
|
||||
no buffer will be created.
|
||||
buffer_ratio (float): The chance / possibility to use the images
|
||||
previously stored in the buffer.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size, buffer_ratio=0.5):
|
||||
self.buffer_size = buffer_size
|
||||
# create an empty buffer
|
||||
if self.buffer_size > 0:
|
||||
self.img_num = 0
|
||||
self.image_buffer = []
|
||||
self.buffer_ratio = buffer_ratio
|
||||
|
||||
def query(self, images):
|
||||
"""Query current image batch using a history of generated images.
|
||||
Args:
|
||||
images (Tensor): Current image batch without history information.
|
||||
"""
|
||||
if self.buffer_size == 0: # if the buffer size is 0, do nothing
|
||||
return images
|
||||
return_images = []
|
||||
for image in images:
|
||||
image = torch.unsqueeze(image.data, 0)
|
||||
# if the buffer is not full, keep inserting current images
|
||||
if self.img_num < self.buffer_size:
|
||||
self.img_num = self.img_num + 1
|
||||
self.image_buffer.append(image)
|
||||
return_images.append(image)
|
||||
else:
|
||||
use_buffer = torch.rand(1) < self.buffer_ratio
|
||||
# by self.buffer_ratio, the buffer will return a previously
|
||||
# stored image, and insert the current image into the buffer
|
||||
if use_buffer:
|
||||
random_id = torch.randint(0, self.buffer_size, (1,)).item()
|
||||
image_tmp = self.image_buffer[random_id].clone()
|
||||
self.image_buffer[random_id] = image
|
||||
return_images.append(image_tmp)
|
||||
# by (1 - self.buffer_ratio), the buffer will return the
|
||||
# current image
|
||||
else:
|
||||
return_images.append(image)
|
||||
# collect all the images and return
|
||||
return_images = torch.cat(return_images, 0)
|
||||
return return_images
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None):
|
||||
super(ResidualBlock, self).__init__()
|
||||
if use_bias is None:
|
||||
# Only for IN, use bias since it does not have affine parameters.
|
||||
use_bias = norm_type == "IN"
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
bias=use_bias)
|
||||
self.norm1 = norm_layer(num_channels)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
bias=use_bias)
|
||||
self.norm2 = norm_layer(num_channels)
|
||||
|
||||
def forward(self, x):
|
||||
res = x
|
||||
x = self.relu1(self.norm1(self.conv1(x)))
|
||||
x = self.norm2(self.conv2(x))
|
||||
return x + res
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class ResGenerator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect',
|
||||
norm_type="IN"):
|
||||
super(ResGenerator, self).__init__()
|
||||
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
use_bias = norm_type == "IN"
|
||||
|
||||
self.start_conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
|
||||
bias=use_bias),
|
||||
norm_layer(num_features=base_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
# down sampling
|
||||
submodules = []
|
||||
num_down_sampling = 2
|
||||
for i in range(num_down_sampling):
|
||||
multiple = 2 ** i
|
||||
submodules += [
|
||||
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
|
||||
kernel_size=3, stride=2, padding=1, bias=use_bias),
|
||||
norm_layer(num_features=base_channels * multiple * 2),
|
||||
nn.ReLU(inplace=True)
|
||||
]
|
||||
self.encoder = nn.Sequential(*submodules)
|
||||
|
||||
res_block_channels = num_down_sampling ** 2 * base_channels
|
||||
self.resnet_middle = nn.Sequential(
|
||||
*[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in
|
||||
range(num_blocks)])
|
||||
|
||||
# up sampling
|
||||
submodules = []
|
||||
for i in range(num_down_sampling):
|
||||
multiple = 2 ** (num_down_sampling - i)
|
||||
submodules += [
|
||||
nn.ConvTranspose2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=3, stride=2,
|
||||
padding=1, output_padding=1, bias=use_bias),
|
||||
norm_layer(num_features=base_channels * multiple // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
self.decoder = nn.Sequential(*submodules)
|
||||
|
||||
self.end_conv = nn.Sequential(
|
||||
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(self.start_conv(x))
|
||||
x = self.resnet_middle(x)
|
||||
return self.end_conv(self.decoder(x))
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class PatchDiscriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, num_conv=3, norm_type="IN"):
|
||||
super(PatchDiscriminator, self).__init__()
|
||||
assert num_conv >= 0, f'Number of conv blocks must be non-negative, but got {num_conv}.'
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
use_bias = norm_type == "IN"
|
||||
|
||||
kernel_size = 4
|
||||
padding = 1
|
||||
sequence = [
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size=kernel_size, stride=2, padding=padding),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
]
|
||||
|
||||
# stacked intermediate layers,
|
||||
# gradually increasing the number of filters
|
||||
multiple_now = 1
|
||||
for n in range(1, num_conv):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** n, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=kernel_size,
|
||||
padding=padding, stride=2, bias=use_bias),
|
||||
norm_layer(base_channels * multiple_now),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
]
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** num_conv, 8)
|
||||
sequence += [
|
||||
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size, stride=1,
|
||||
padding=padding, bias=use_bias),
|
||||
norm_layer(base_channels * multiple_now),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding)
|
||||
]
|
||||
self.model = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
@@ -1,5 +1,7 @@
|
||||
from model.registry import MODEL
|
||||
import model.GAN.residual_generator
|
||||
import model.GAN.TAHG
|
||||
import model.GAN.UGATIT
|
||||
import model.fewshot
|
||||
from model.registry import MODEL, NORMALIZATION
|
||||
import model.base.normalization
|
||||
import model.image_translation.UGATIT
|
||||
import model.image_translation.CycleGAN
|
||||
import model.image_translation.pix2pixHD
|
||||
import model.image_translation.GauGAN
|
||||
import model.image_translation.TSIT
|
||||
0
model/base/__init__.py
Normal file
0
model/base/__init__.py
Normal file
128
model/base/module.py
Normal file
128
model/base/module.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from model.registry import NORMALIZATION
|
||||
|
||||
_DO_NO_THING_FUNC = lambda x: x
|
||||
|
||||
|
||||
def _use_bias_checker(norm_type):
|
||||
return norm_type not in ["IN", "BN", "AdaIN", "FADE", "SPADE"]
|
||||
|
||||
|
||||
def _normalization(norm, num_features, additional_kwargs=None):
|
||||
if norm == "NONE":
|
||||
return _DO_NO_THING_FUNC
|
||||
|
||||
if additional_kwargs is None:
|
||||
additional_kwargs = {}
|
||||
kwargs = dict(_type=norm, num_features=num_features)
|
||||
kwargs.update(additional_kwargs)
|
||||
return NORMALIZATION.build_with(kwargs)
|
||||
|
||||
|
||||
def _activation(activation, inplace=True):
|
||||
if activation == "NONE":
|
||||
return _DO_NO_THING_FUNC
|
||||
elif activation == "ReLU":
|
||||
return nn.ReLU(inplace=inplace)
|
||||
elif activation == "LeakyReLU":
|
||||
return nn.LeakyReLU(negative_slope=0.2, inplace=inplace)
|
||||
elif activation == "Tanh":
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise NotImplementedError(f"{activation} not valid")
|
||||
|
||||
|
||||
class LinearBlock(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, bias=None, activation_type="ReLU", norm_type="NONE"):
|
||||
super().__init__()
|
||||
|
||||
self.norm_type = norm_type
|
||||
self.activation_type = activation_type
|
||||
|
||||
bias = _use_bias_checker(norm_type) if bias is None else bias
|
||||
self.linear = nn.Linear(in_features, out_features, bias)
|
||||
|
||||
self.normalization = _normalization(norm_type, out_features)
|
||||
self.activation = _activation(activation_type)
|
||||
|
||||
def forward(self, x):
|
||||
return self.activation(self.normalization(self.linear(x)))
|
||||
|
||||
|
||||
class Conv2dBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, bias=None,
|
||||
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None,
|
||||
pre_activation=False, use_transpose_conv=False, **conv_kwargs):
|
||||
super().__init__()
|
||||
self.norm_type = norm_type
|
||||
self.activation_type = activation_type
|
||||
self.pre_activation = pre_activation
|
||||
|
||||
if use_transpose_conv:
|
||||
# Only "zeros" padding mode is supported for ConvTranspose2d
|
||||
conv_kwargs["padding_mode"] = "zeros"
|
||||
conv = nn.ConvTranspose2d
|
||||
else:
|
||||
conv = nn.Conv2d
|
||||
|
||||
if pre_activation:
|
||||
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
|
||||
self.activation = _activation(activation_type, inplace=False)
|
||||
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
|
||||
else:
|
||||
# if caller not set bias, set bias automatically.
|
||||
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
|
||||
self.convolution = conv(in_channels, out_channels, **conv_kwargs)
|
||||
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
|
||||
self.activation = _activation(activation_type)
|
||||
|
||||
def forward(self, x):
|
||||
if self.pre_activation:
|
||||
return self.convolution(self.activation(self.normalization(x)))
|
||||
return self.activation(self.normalization(self.convolution(x)))
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_channels,
|
||||
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
|
||||
out_channels=None, out_activation_type=None, additional_norm_kwargs=None):
|
||||
"""
|
||||
Residual Conv Block
|
||||
:param in_channels:
|
||||
:param out_channels:
|
||||
:param padding_mode:
|
||||
:param activation_type:
|
||||
:param norm_type:
|
||||
:param out_activation_type:
|
||||
:param pre_activation: full pre-activation mode from https://arxiv.org/pdf/1603.05027v3.pdf, figure 4
|
||||
"""
|
||||
super().__init__()
|
||||
self.norm_type = norm_type
|
||||
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
if out_activation_type is None:
|
||||
# if not specify `out_activation_type`, using default `out_activation_type`
|
||||
# `out_activation_type` default mode:
|
||||
# "NONE" for not full pre-activation
|
||||
# `norm_type` for full pre-activation
|
||||
out_activation_type = "NONE" if not pre_activation else norm_type
|
||||
|
||||
self.learn_skip_connection = in_channels != out_channels
|
||||
|
||||
conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type,
|
||||
additional_norm_kwargs=additional_norm_kwargs, pre_activation=pre_activation,
|
||||
padding_mode=padding_mode)
|
||||
|
||||
self.conv1 = Conv2dBlock(in_channels, in_channels, **conv_param)
|
||||
self.conv2 = Conv2dBlock(in_channels, out_channels, **conv_param)
|
||||
|
||||
if self.learn_skip_connection:
|
||||
conv_param['kernel_size'] = 1
|
||||
conv_param['padding'] = 0
|
||||
self.res_conv = Conv2dBlock(in_channels, out_channels, **conv_param)
|
||||
|
||||
def forward(self, x):
|
||||
res = x if not self.learn_skip_connection else self.res_conv(x)
|
||||
return self.conv2(self.conv1(x)) + res
|
||||
143
model/base/normalization.py
Normal file
143
model/base/normalization.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model import NORMALIZATION
|
||||
from model.base.module import Conv2dBlock
|
||||
|
||||
_VALID_NORM_AND_ABBREVIATION = dict(
|
||||
IN="InstanceNorm2d",
|
||||
BN="BatchNorm2d",
|
||||
)
|
||||
|
||||
for abbr, name in _VALID_NORM_AND_ABBREVIATION.items():
|
||||
NORMALIZATION.register_module(module=getattr(nn, name), name=abbr)
|
||||
|
||||
|
||||
@NORMALIZATION.register_module("ADE")
|
||||
class AdaptiveDenormalization(nn.Module):
|
||||
def __init__(self, num_features, base_norm_type="BN", gamma_bias=0.0):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.base_norm_type = base_norm_type
|
||||
self.norm = self.base_norm(num_features)
|
||||
self.gamma = None
|
||||
self.gamma_bias = gamma_bias
|
||||
self.beta = None
|
||||
self.have_set_condition = False
|
||||
|
||||
def base_norm(self, num_features):
|
||||
if self.base_norm_type == "IN":
|
||||
return nn.InstanceNorm2d(num_features, affine=False)
|
||||
elif self.base_norm_type == "BN":
|
||||
return nn.BatchNorm2d(num_features, affine=False, track_running_stats=True)
|
||||
|
||||
def set_condition(self, gamma, beta):
|
||||
self.gamma, self.beta = gamma, beta
|
||||
self.have_set_condition = True
|
||||
|
||||
def forward(self, x):
|
||||
assert self.have_set_condition
|
||||
x = self.norm(x)
|
||||
x = (self.gamma + self.gamma_bias) * x + self.beta
|
||||
self.have_set_condition = False
|
||||
return x
|
||||
#
|
||||
# def __repr__(self):
|
||||
# return f"{self.__class__.__name__}(num_features={self.num_features}, " \
|
||||
# f"base_norm_type={self.base_norm_type})"
|
||||
|
||||
|
||||
@NORMALIZATION.register_module("AdaIN")
|
||||
class AdaptiveInstanceNorm2d(AdaptiveDenormalization):
|
||||
def __init__(self, num_features: int):
|
||||
super().__init__(num_features, "IN")
|
||||
self.num_features = num_features
|
||||
|
||||
def set_style(self, style):
|
||||
style = style.view(*style.size(), 1, 1)
|
||||
gamma, beta = style.chunk(2, 1)
|
||||
super().set_condition(gamma, beta)
|
||||
|
||||
|
||||
@NORMALIZATION.register_module("FADE")
|
||||
class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
|
||||
def __init__(self, num_features: int, condition_in_channels,
|
||||
base_norm_type="BN", padding_mode="zeros", gamma_bias=0.0):
|
||||
super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias)
|
||||
self.beta_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
|
||||
padding_mode=padding_mode)
|
||||
self.gamma_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
|
||||
padding_mode=padding_mode)
|
||||
|
||||
def set_feature(self, feature):
|
||||
gamma = self.gamma_conv(feature)
|
||||
beta = self.beta_conv(feature)
|
||||
super().set_condition(gamma, beta)
|
||||
|
||||
|
||||
@NORMALIZATION.register_module("SPADE")
|
||||
class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
|
||||
def __init__(self, num_features: int, condition_in_channels, base_channels=128, base_norm_type="BN",
|
||||
activation_type="ReLU", padding_mode="zeros", gamma_bias=0.0):
|
||||
super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias)
|
||||
self.base_conv_block = Conv2dBlock(condition_in_channels, base_channels, activation_type=activation_type,
|
||||
kernel_size=3, padding=1, padding_mode=padding_mode, norm_type="NONE")
|
||||
self.beta_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
self.gamma_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
|
||||
def set_condition_image(self, condition_image):
|
||||
feature = self.base_conv_block(condition_image)
|
||||
gamma = self.gamma_conv(feature)
|
||||
beta = self.beta_conv(feature)
|
||||
super().set_condition(gamma, beta)
|
||||
|
||||
|
||||
def _instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
|
||||
out = rho * F.instance_norm(x, eps=eps) + (1 - rho) * F.layer_norm(x, x.size()[1:], eps=eps)
|
||||
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
|
||||
return out
|
||||
|
||||
|
||||
@NORMALIZATION.register_module("ILN")
|
||||
class ILN(nn.Module):
|
||||
def __init__(self, num_features, eps=1e-5):
|
||||
super(ILN, self).__init__()
|
||||
self.eps = eps
|
||||
self.rho = nn.Parameter(torch.Tensor(num_features))
|
||||
self.gamma = nn.Parameter(torch.Tensor(num_features))
|
||||
self.beta = nn.Parameter(torch.Tensor(num_features))
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.zeros_(self.rho)
|
||||
nn.init.ones_(self.gamma)
|
||||
nn.init.zeros_(self.beta)
|
||||
|
||||
def forward(self, x):
|
||||
return _instance_layer_normalization(
|
||||
x, self.gamma.view(1, -1), self.beta.view(1, -1), self.rho.view(1, -1, 1, 1), self.eps)
|
||||
|
||||
|
||||
@NORMALIZATION.register_module("AdaILN")
|
||||
class AdaILN(nn.Module):
|
||||
def __init__(self, num_features, eps=1e-5, default_rho=0.9):
|
||||
super(AdaILN, self).__init__()
|
||||
self.eps = eps
|
||||
self.rho = nn.Parameter(torch.Tensor(num_features))
|
||||
self.rho.data.fill_(default_rho)
|
||||
|
||||
self.gamma = None
|
||||
self.beta = None
|
||||
self.have_set_condition = False
|
||||
|
||||
def set_condition(self, gamma, beta):
|
||||
self.gamma, self.beta = gamma, beta
|
||||
self.have_set_condition = True
|
||||
|
||||
def forward(self, x):
|
||||
assert self.have_set_condition
|
||||
out = _instance_layer_normalization(x, self.gamma, self.beta, self.rho.view(1, -1, 1, 1), self.eps)
|
||||
self.have_set_condition = False
|
||||
return out
|
||||
105
model/fewshot.py
105
model/fewshot.py
@@ -1,105 +0,0 @@
|
||||
import math
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from .registry import MODEL
|
||||
|
||||
|
||||
# --- gaussian initialize ---
|
||||
def init_layer(l):
|
||||
# Initialization using fan-in
|
||||
if isinstance(l, nn.Conv2d):
|
||||
n = l.kernel_size[0] * l.kernel_size[1] * l.out_channels
|
||||
l.weight.data.normal_(0, math.sqrt(2.0 / float(n)))
|
||||
elif isinstance(l, nn.BatchNorm2d):
|
||||
l.weight.data.fill_(1)
|
||||
l.bias.data.fill_(0)
|
||||
elif isinstance(l, nn.Linear):
|
||||
l.bias.data.fill_(0)
|
||||
|
||||
|
||||
class Flatten(nn.Module):
|
||||
def __init__(self):
|
||||
super(Flatten, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
class SimpleBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, half_res, leakyrelu=False):
|
||||
super(SimpleBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2 if half_res else 1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
)
|
||||
self.relu = nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True)
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1, 2 if half_res else 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels)
|
||||
)
|
||||
else:
|
||||
self.shortcut = nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
o = self.block(x)
|
||||
return self.relu(o + self.shortcut(x))
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, block, layers, dims, num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
|
||||
super().__init__()
|
||||
assert len(layers) == 4, 'Can have only four stages'
|
||||
self.inplanes = 64
|
||||
|
||||
self.start = nn.Sequential(
|
||||
nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False),
|
||||
nn.BatchNorm2d(self.inplanes),
|
||||
nn.ReLU(inplace=True) if not leakyrelu else nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
)
|
||||
|
||||
trunk = []
|
||||
in_channels = self.inplanes
|
||||
for i in range(4):
|
||||
for j in range(layers[i]):
|
||||
half_res = i >= 1 and j == 0
|
||||
trunk.append(block(in_channels, dims[i], half_res, leakyrelu))
|
||||
in_channels = dims[i]
|
||||
if flatten:
|
||||
trunk.append(nn.AvgPool2d(7))
|
||||
trunk.append(Flatten())
|
||||
if num_classes is not None:
|
||||
if classifier_type == "linear":
|
||||
trunk.append(nn.Linear(in_channels, num_classes))
|
||||
elif classifier_type == "distlinear":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"invalid classifier_type:{classifier_type}")
|
||||
self.trunk = nn.Sequential(*trunk)
|
||||
self.apply(init_layer)
|
||||
|
||||
def forward(self, x):
|
||||
return self.trunk(self.start(x))
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
def resnet10(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
|
||||
return ResNet(SimpleBlock, [1, 1, 1, 1], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
def resnet18(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
|
||||
return ResNet(SimpleBlock, [2, 2, 2, 2], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
def resnet34(num_classes=None, classifier_type="linear", flatten=True, leakyrelu=False):
|
||||
return ResNet(SimpleBlock, [3, 4, 6, 3], [64, 128, 256, 512], num_classes, classifier_type, flatten, leakyrelu)
|
||||
151
model/image_translation/CycleGAN.py
Normal file
151
model/image_translation/CycleGAN.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from model import MODEL
|
||||
from model.base.module import Conv2dBlock, ResidualBlock
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, base_channels, num_conv, num_res, max_down_sampling_multiple=2,
|
||||
padding_mode='reflect', activation_type="ReLU",
|
||||
down_conv_norm_type="IN", down_conv_kernel_size=3,
|
||||
res_norm_type="IN", pre_activation=False):
|
||||
super().__init__()
|
||||
|
||||
sequence = [Conv2dBlock(
|
||||
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=down_conv_norm_type
|
||||
)]
|
||||
multiple_now = 1
|
||||
for i in range(1, num_conv + 1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple)
|
||||
sequence.append(Conv2dBlock(
|
||||
multiple_prev * base_channels, multiple_now * base_channels,
|
||||
kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode="zeros",
|
||||
activation_type=activation_type, norm_type=down_conv_norm_type
|
||||
))
|
||||
self.out_channels = multiple_now * base_channels
|
||||
sequence += [
|
||||
ResidualBlock(
|
||||
self.out_channels,
|
||||
padding_mode=padding_mode,
|
||||
activation_type=activation_type,
|
||||
norm_type=res_norm_type,
|
||||
pre_activation=pre_activation
|
||||
) for _ in range(num_res)
|
||||
]
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, x):
|
||||
return self.sequence(x)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
|
||||
activation_type="ReLU", padding_mode='reflect',
|
||||
up_conv_kernel_size=5, up_conv_norm_type="LN",
|
||||
res_norm_type="AdaIN", pre_activation=False, use_transpose_conv=False):
|
||||
super().__init__()
|
||||
self.residual_blocks = nn.ModuleList([
|
||||
ResidualBlock(
|
||||
in_channels,
|
||||
padding_mode=padding_mode,
|
||||
activation_type=activation_type,
|
||||
norm_type=res_norm_type,
|
||||
pre_activation=pre_activation
|
||||
) for _ in range(num_residual_blocks)
|
||||
])
|
||||
|
||||
sequence = list()
|
||||
channels = in_channels
|
||||
padding = (up_conv_kernel_size - 1) // 2
|
||||
for i in range(num_up_sampling):
|
||||
if use_transpose_conv:
|
||||
sequence.append(Conv2dBlock(
|
||||
channels, channels // 2, kernel_size=up_conv_kernel_size, stride=2,
|
||||
padding=padding, output_padding=padding,
|
||||
padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=up_conv_norm_type,
|
||||
use_transpose_conv=True
|
||||
))
|
||||
else:
|
||||
sequence.append(nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1,
|
||||
padding=padding, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=up_conv_norm_type),
|
||||
))
|
||||
channels = channels // 2
|
||||
sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3,
|
||||
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"))
|
||||
|
||||
self.up_sequence = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, x):
|
||||
for i, blk in enumerate(self.residual_blocks):
|
||||
x = blk(x)
|
||||
return self.up_sequence(x)
|
||||
|
||||
|
||||
@MODEL.register_module("CycleGAN-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, activation_type="ReLU",
|
||||
padding_mode='reflect', norm_type="IN", pre_activation=False, use_transpose_conv=True):
|
||||
super().__init__()
|
||||
self.encoder = Encoder(in_channels, base_channels, num_conv=2, num_res=num_blocks,
|
||||
padding_mode=padding_mode, activation_type=activation_type,
|
||||
down_conv_norm_type=norm_type, res_norm_type=norm_type, pre_activation=pre_activation)
|
||||
self.decoder = Decoder(self.encoder.out_channels, out_channels, num_up_sampling=2, num_residual_blocks=0,
|
||||
padding_mode=padding_mode, activation_type=activation_type,
|
||||
up_conv_kernel_size=3, up_conv_norm_type=norm_type,
|
||||
pre_activation=pre_activation, use_transpose_conv=use_transpose_conv)
|
||||
|
||||
def forward(self, x):
|
||||
return self.decoder(self.encoder(x))
|
||||
|
||||
|
||||
@MODEL.register_module("PatchDiscriminator")
|
||||
class PatchDiscriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, num_conv=4, need_intermediate_feature=False,
|
||||
norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"):
|
||||
super().__init__()
|
||||
self.need_intermediate_feature = need_intermediate_feature
|
||||
kernel_size = 4
|
||||
padding = (kernel_size - 1) // 2
|
||||
sequence = [Conv2dBlock(
|
||||
in_channels, base_channels, kernel_size=kernel_size, stride=2, padding=padding, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=norm_type
|
||||
)]
|
||||
|
||||
multiple_now = 1
|
||||
for i in range(1, num_conv):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** 3)
|
||||
stride = 1 if i == num_conv - 1 else 2
|
||||
sequence.append(Conv2dBlock(
|
||||
multiple_prev * base_channels, multiple_now * base_channels,
|
||||
kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=norm_type
|
||||
))
|
||||
sequence.append(nn.Conv2d(
|
||||
base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding, padding_mode=padding_mode))
|
||||
if self.need_intermediate_feature:
|
||||
self.sequence = nn.ModuleList(sequence)
|
||||
else:
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, x):
|
||||
if self.need_intermediate_feature:
|
||||
intermediate_feature = []
|
||||
for layer in self.sequence:
|
||||
x = layer(x)
|
||||
intermediate_feature.append(x)
|
||||
return tuple(intermediate_feature)
|
||||
else:
|
||||
return self.sequence(x)
|
||||
|
||||
if __name__ == '__main__':
|
||||
g = Generator(**dict(in_channels=3, out_channels=3))
|
||||
print(g)
|
||||
pd = PatchDiscriminator(**dict(in_channels=3, base_channels=64, num_conv=4))
|
||||
print(pd)
|
||||
183
model/image_translation/GauGAN.py
Normal file
183
model/image_translation/GauGAN.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model.base.module import ResidualBlock, Conv2dBlock, LinearBlock
|
||||
from model import MODEL
|
||||
|
||||
class StyleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, style_dim, num_conv, end_size=(4, 4), base_channels=64,
|
||||
norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"):
|
||||
super().__init__()
|
||||
sequence = [Conv2dBlock(
|
||||
in_channels, base_channels, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=norm_type
|
||||
)]
|
||||
multiple_now = 0
|
||||
max_multiple = 3
|
||||
for i in range(1, num_conv + 1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** max_multiple)
|
||||
sequence.append(Conv2dBlock(
|
||||
multiple_prev * base_channels, multiple_now * base_channels,
|
||||
kernel_size=3, stride=2, padding=1, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=norm_type
|
||||
))
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
self.fc_avg = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim)
|
||||
self.fc_var = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.sequence(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return self.fc_avg(x), self.fc_var(x)
|
||||
|
||||
|
||||
class ImprovedSPADEGenerator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, output_size, have_style_input, style_dim, start_size=(4, 4),
|
||||
base_channels=64, padding_mode='reflect', activation_type="LeakyReLU", pre_activation=False):
|
||||
super().__init__()
|
||||
|
||||
assert output_size in (128, 256, 512, 1024)
|
||||
self.output_size = output_size
|
||||
|
||||
kernel_size = 3
|
||||
|
||||
if have_style_input:
|
||||
self.style_converter = nn.Sequential(
|
||||
LinearBlock(style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"),
|
||||
LinearBlock(2 * style_dim, 2 * style_dim, activation_type=activation_type, norm_type="NONE"),
|
||||
)
|
||||
|
||||
base_conv = partial(
|
||||
Conv2dBlock,
|
||||
pre_activation=pre_activation, activation_type=activation_type,
|
||||
norm_type="AdaIN" if have_style_input else "NONE",
|
||||
kernel_size=kernel_size, padding=(kernel_size - 1) // 2, padding_mode=padding_mode
|
||||
)
|
||||
|
||||
base_residual_block = partial(
|
||||
ResidualBlock,
|
||||
padding_mode=padding_mode,
|
||||
activation_type=activation_type,
|
||||
norm_type="SPADE",
|
||||
pre_activation=True,
|
||||
additional_norm_kwargs=dict(
|
||||
condition_in_channels=in_channels, base_channels=128, base_norm_type="BN",
|
||||
activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0
|
||||
)
|
||||
)
|
||||
|
||||
sequence = OrderedDict()
|
||||
channels = (2 ** 4) * base_channels
|
||||
sequence["block_head"] = nn.Sequential(OrderedDict([
|
||||
("conv_input", base_conv(in_channels=in_channels, out_channels=channels)),
|
||||
("conv_style", base_conv(in_channels=channels, out_channels=channels)),
|
||||
("res_a", base_residual_block(in_channels=channels, out_channels=channels)),
|
||||
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
|
||||
("up", nn.Upsample(scale_factor=2, mode='nearest'))
|
||||
]))
|
||||
|
||||
for i in range(4, 9 - min(int(math.log(self.output_size, 2)), 8), -1):
|
||||
channels = (2 ** (i - 1)) * base_channels
|
||||
sequence[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([
|
||||
("res_a", base_residual_block(in_channels=channels * 2, out_channels=channels)),
|
||||
("conv_style", base_conv(in_channels=channels, out_channels=channels)),
|
||||
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
|
||||
("up", nn.Upsample(scale_factor=2, mode='nearest'))
|
||||
]))
|
||||
self.sequence = nn.Sequential(sequence)
|
||||
# channels = 2*base_channels when output size is 256, 512, 1024
|
||||
# channels = 5*base_channels when output size is 128
|
||||
out_modules = OrderedDict()
|
||||
out_modules["out_1"] = nn.Sequential(
|
||||
Conv2dBlock(
|
||||
channels, out_channels, kernel_size=5, stride=1, padding=2,
|
||||
pre_activation=pre_activation,
|
||||
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"
|
||||
),
|
||||
nn.Tanh()
|
||||
)
|
||||
for i in range(int(math.log(self.output_size, 2)) - 8):
|
||||
channels = channels // 2
|
||||
out_modules[f"block_{2 * channels}"] = nn.Sequential(OrderedDict([
|
||||
("res_a", base_residual_block(in_channels=2 * channels, out_channels=channels)),
|
||||
("res_b", base_residual_block(in_channels=channels, out_channels=channels)),
|
||||
("up", nn.Upsample(scale_factor=2, mode='nearest'))
|
||||
]))
|
||||
out_modules[f"out_{i + 2}"] = nn.Sequential(
|
||||
Conv2dBlock(
|
||||
channels, out_channels, kernel_size=5, stride=1, padding=2,
|
||||
pre_activation=pre_activation,
|
||||
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"
|
||||
),
|
||||
nn.Tanh()
|
||||
)
|
||||
self.out_modules = nn.ModuleDict(out_modules)
|
||||
|
||||
def forward(self, seg, style=None):
|
||||
pass
|
||||
|
||||
@MODEL.register_module()
|
||||
class SPADEGenerator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
|
||||
padding_mode='reflect', activation_type="LeakyReLU"):
|
||||
super().__init__()
|
||||
self.sx, self.sy = start_size
|
||||
self.use_vae = use_vae
|
||||
self.num_z_dim = num_z_dim
|
||||
if use_vae:
|
||||
self.input_converter = nn.Linear(num_z_dim, 16 * base_channels * self.sx * self.sy)
|
||||
else:
|
||||
self.input_converter = nn.Conv2d(in_channels, 16 * base_channels, kernel_size=3, padding=1)
|
||||
|
||||
sequence = []
|
||||
|
||||
multiple_now = 16
|
||||
for i in range(num_blocks - 1, -1, -1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** 4)
|
||||
if i != num_blocks - 1:
|
||||
sequence.append(nn.Upsample(scale_factor=2))
|
||||
sequence.append(ResidualBlock(
|
||||
base_channels * multiple_prev,
|
||||
out_channels=base_channels * multiple_now,
|
||||
padding_mode=padding_mode,
|
||||
activation_type=activation_type,
|
||||
norm_type="SPADE",
|
||||
pre_activation=True,
|
||||
additional_norm_kwargs=dict(
|
||||
condition_in_channels=in_channels, base_channels=128, base_norm_type="BN",
|
||||
activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0
|
||||
)
|
||||
))
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
self.output_converter = Conv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
||||
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE")
|
||||
|
||||
def forward(self, seg, z=None):
|
||||
if self.use_vae:
|
||||
if z is None:
|
||||
z = torch.randn(seg.size(0), self.num_z_dim, device=seg.device)
|
||||
x = self.input_converter(z).view(seg.size(0), -1, self.sx, self.sy)
|
||||
else:
|
||||
x = self.input_converter(F.interpolate(seg, size=(self.sx, self.sy)))
|
||||
for blk in self.sequence:
|
||||
if isinstance(blk, ResidualBlock):
|
||||
downsampling_seg = F.interpolate(seg, size=x.size()[2:], mode='nearest')
|
||||
blk.conv1.normalization.set_condition_image(downsampling_seg)
|
||||
blk.conv2.normalization.set_condition_image(downsampling_seg)
|
||||
if blk.learn_skip_connection:
|
||||
blk.res_conv.normalization.set_condition_image(downsampling_seg)
|
||||
x = blk(x)
|
||||
return self.output_converter(x)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
g = SPADEGenerator(3, 3, 7, False, 256)
|
||||
print(g)
|
||||
print(g(torch.randn(2, 3, 256, 256)).size())
|
||||
89
model/image_translation/MUNIT.py
Normal file
89
model/image_translation/MUNIT.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from model import MODEL
|
||||
from model.base.module import LinearBlock
|
||||
from model.image_translation.CycleGAN import Encoder, Decoder
|
||||
|
||||
|
||||
class StyleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, out_dim, num_conv, base_channels=64,
|
||||
max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE",
|
||||
pre_activation=False):
|
||||
super().__init__()
|
||||
self.down_encoder = Encoder(
|
||||
in_channels, base_channels, num_conv, num_res=0, max_down_sampling_multiple=max_down_sampling_multiple,
|
||||
padding_mode=padding_mode, activation_type=activation_type,
|
||||
down_conv_norm_type=norm_type, down_conv_kernel_size=4, pre_activation=pre_activation,
|
||||
)
|
||||
sequence = list()
|
||||
sequence.append(nn.AdaptiveAvgPool2d(1))
|
||||
# conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code
|
||||
sequence.append(nn.Conv2d(self.down_encoder.out_channels, out_dim, kernel_size=1, stride=1, padding=0))
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, image):
|
||||
return self.sequence(image).view(image.size(0), -1)
|
||||
|
||||
|
||||
class MLPFusion(nn.Module):
|
||||
def __init__(self, in_features, out_features, base_features, n_blocks, activation_type="ReLU", norm_type="NONE"):
|
||||
super().__init__()
|
||||
|
||||
sequence = [LinearBlock(in_features, base_features, activation_type=activation_type, norm_type=norm_type)]
|
||||
sequence += [
|
||||
LinearBlock(base_features, base_features, activation_type=activation_type, norm_type=norm_type)
|
||||
for _ in range(n_blocks - 2)
|
||||
]
|
||||
sequence.append(LinearBlock(base_features, out_features, activation_type=activation_type, norm_type=norm_type))
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, x):
|
||||
return self.sequence(x)
|
||||
|
||||
|
||||
@MODEL.register_module("MUNIT-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, style_dim=8,
|
||||
num_mlp_base_feature=256, num_mlp_blocks=3,
|
||||
max_down_sampling_multiple=2, num_content_down_sampling=2, num_style_down_sampling=2,
|
||||
encoder_num_residual_blocks=4, decoder_num_residual_blocks=4,
|
||||
padding_mode='reflect', activation_type="ReLU", pre_activation=False):
|
||||
super().__init__()
|
||||
self.content_encoder = Encoder(
|
||||
in_channels, base_channels, num_content_down_sampling, encoder_num_residual_blocks,
|
||||
max_down_sampling_multiple=num_content_down_sampling,
|
||||
padding_mode=padding_mode, activation_type=activation_type,
|
||||
down_conv_norm_type="IN", down_conv_kernel_size=4,
|
||||
res_norm_type="IN", pre_activation=pre_activation
|
||||
)
|
||||
|
||||
self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels,
|
||||
max_down_sampling_multiple, padding_mode, activation_type,
|
||||
norm_type="NONE", pre_activation=pre_activation)
|
||||
|
||||
content_channels = base_channels * (2 ** max_down_sampling_multiple)
|
||||
|
||||
self.fusion = MLPFusion(style_dim, decoder_num_residual_blocks * 2 * content_channels * 2,
|
||||
num_mlp_base_feature, num_mlp_blocks, activation_type,
|
||||
norm_type="NONE")
|
||||
|
||||
self.decoder = Decoder(in_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks,
|
||||
activation_type=activation_type, padding_mode=padding_mode,
|
||||
up_conv_kernel_size=5, up_conv_norm_type="LN",
|
||||
res_norm_type="AdaIN", pre_activation=pre_activation)
|
||||
|
||||
def encode(self, x):
|
||||
return self.content_encoder(x), self.style_encoder(x)
|
||||
|
||||
def decode(self, content, style):
|
||||
as_param_style = torch.chunk(self.fusion(style), 2 * len(self.decoder.residual_blocks), dim=1)
|
||||
# set style for decoder
|
||||
for i, blk in enumerate(self.decoder.residual_blocks):
|
||||
blk.conv1.normalization.set_style(as_param_style[2 * i])
|
||||
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
|
||||
self.decoder(content)
|
||||
|
||||
def forward(self, x):
|
||||
content, style = self.encode(x)
|
||||
return self.decode(content, style)
|
||||
98
model/image_translation/TSIT.py
Normal file
98
model/image_translation/TSIT.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
|
||||
from model import MODEL
|
||||
from model.base.module import ResidualBlock, Conv2dBlock
|
||||
|
||||
|
||||
class Interpolation(nn.Module):
|
||||
def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
|
||||
super(Interpolation, self).__init__()
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
|
||||
recompute_scale_factor=False)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Interpolation(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
|
||||
|
||||
|
||||
@MODEL.register_module("TSIT-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=7,
|
||||
padding_mode='reflect', activation_type="LeakyReLU"):
|
||||
super().__init__()
|
||||
self.input_layer = Conv2dBlock(
|
||||
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type="IN",
|
||||
)
|
||||
multiple_now = 1
|
||||
stream_sequence = []
|
||||
for i in range(1, num_blocks + 1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** 4)
|
||||
stream_sequence.append(nn.Sequential(
|
||||
Interpolation(scale_factor=0.5, mode="nearest"),
|
||||
ResidualBlock(
|
||||
multiple_prev * base_channels, out_channels=multiple_now * base_channels,
|
||||
padding_mode=padding_mode, activation_type=activation_type, norm_type="IN")
|
||||
))
|
||||
self.down_sequence = nn.ModuleList(stream_sequence)
|
||||
|
||||
|
||||
sequence = []
|
||||
multiple_now = 16
|
||||
for i in range(num_blocks - 1, -1, -1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** 4)
|
||||
sequence.append(nn.Sequential(
|
||||
ResidualBlock(
|
||||
base_channels * multiple_prev,
|
||||
out_channels=base_channels * multiple_now,
|
||||
padding_mode=padding_mode,
|
||||
activation_type=activation_type,
|
||||
norm_type="FADE",
|
||||
pre_activation=True,
|
||||
additional_norm_kwargs=dict(
|
||||
condition_in_channels=base_channels * multiple_prev, base_norm_type="BN",
|
||||
padding_mode="zeros", gamma_bias=0.0
|
||||
)
|
||||
),
|
||||
Interpolation(scale_factor=2, mode="nearest")
|
||||
))
|
||||
self.up_sequence = nn.Sequential(*sequence)
|
||||
|
||||
self.output_layer = Conv2dBlock(
|
||||
base_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
||||
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"
|
||||
)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
c = self.input_layer(x)
|
||||
contents = []
|
||||
for blk in self.down_sequence:
|
||||
c = blk(c)
|
||||
contents.append(c)
|
||||
if z is None:
|
||||
# for image 256x256, z size: [batch_size, 1024, 2, 2]
|
||||
z = torch.randn(size=contents[-1].size(), device=contents[-1].device)
|
||||
|
||||
for blk in self.up_sequence:
|
||||
res = blk[0]
|
||||
c = contents.pop()
|
||||
res.conv1.normalization.set_feature(c)
|
||||
res.conv2.normalization.set_feature(c)
|
||||
if res.learn_skip_connection:
|
||||
res.res_conv.normalization.set_feature(c)
|
||||
return self.output_layer(self.up_sequence(z))
|
||||
|
||||
if __name__ == '__main__':
|
||||
g = Generator(3, 3).cuda()
|
||||
img = torch.randn(2, 3, 256, 256).cuda()
|
||||
print(g(img).size())
|
||||
|
||||
|
||||
125
model/image_translation/UGATIT.py
Normal file
125
model/image_translation/UGATIT.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from model import MODEL
|
||||
from model.base.module import Conv2dBlock, LinearBlock
|
||||
from model.image_translation.CycleGAN import Encoder, Decoder
|
||||
|
||||
|
||||
class CAMClassifier(nn.Module):
|
||||
def __init__(self, in_channels, activation_type="ReLU"):
|
||||
super(CAMClassifier, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.avg_fc = nn.Linear(in_channels, 1, bias=False)
|
||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||
self.max_fc = nn.Linear(in_channels, 1, bias=False)
|
||||
self.fusion_conv = Conv2dBlock(in_channels * 2, in_channels, kernel_size=1, stride=1, bias=True,
|
||||
activation_type=activation_type, norm_type="NONE")
|
||||
|
||||
def forward(self, x):
|
||||
avg_logit = self.avg_fc(self.avg_pool(x).view(x.size(0), -1))
|
||||
max_logit = self.max_fc(self.max_pool(x).view(x.size(0), -1))
|
||||
|
||||
return self.fusion_conv(torch.cat(
|
||||
[x * self.avg_fc.weight.unsqueeze(2).unsqueeze(3), x * self.max_fc.weight.unsqueeze(2).unsqueeze(3)],
|
||||
dim=1
|
||||
)), torch.cat([avg_logit, max_logit], 1)
|
||||
|
||||
|
||||
@MODEL.register_module("UGATIT-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False,
|
||||
activation_type="ReLU", norm_type="IN", padding_mode='reflect', pre_activation=False):
|
||||
super(Generator, self).__init__()
|
||||
|
||||
self.light = light
|
||||
|
||||
n_down_sampling = 2
|
||||
self.encoder = Encoder(in_channels, base_channels, n_down_sampling, num_blocks,
|
||||
padding_mode=padding_mode, activation_type=activation_type,
|
||||
down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type,
|
||||
pre_activation=pre_activation)
|
||||
mult = 2 ** n_down_sampling
|
||||
self.cam = CAMClassifier(base_channels * mult, activation_type)
|
||||
|
||||
# Gamma, Beta block
|
||||
if self.light:
|
||||
self.fc = nn.Sequential(
|
||||
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE"),
|
||||
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE")
|
||||
)
|
||||
else:
|
||||
self.fc = nn.Sequential(
|
||||
LinearBlock(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, False,
|
||||
"ReLU", "NONE"),
|
||||
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE")
|
||||
)
|
||||
|
||||
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
||||
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
||||
|
||||
self.decoder = Decoder(
|
||||
base_channels * mult, out_channels, n_down_sampling, num_blocks,
|
||||
activation_type=activation_type, padding_mode=padding_mode,
|
||||
up_conv_kernel_size=3, up_conv_norm_type="ILN",
|
||||
res_norm_type="AdaILN", pre_activation=pre_activation
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
|
||||
x, cam_logit = self.cam(x)
|
||||
|
||||
heatmap = torch.sum(x, dim=1, keepdim=True)
|
||||
|
||||
if self.light:
|
||||
x_ = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
|
||||
x_ = self.fc(x_.view(x_.shape[0], -1))
|
||||
else:
|
||||
x_ = self.fc(x.view(x.shape[0], -1))
|
||||
gamma, beta = self.gamma(x_), self.beta(x_)
|
||||
|
||||
for blk in self.decoder.residual_blocks:
|
||||
blk.conv1.normalization.set_condition(gamma, beta)
|
||||
blk.conv2.normalization.set_condition(gamma, beta)
|
||||
return self.decoder(x), cam_logit, heatmap
|
||||
|
||||
|
||||
@MODEL.register_module("UGATIT-Discriminator")
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, num_blocks=5,
|
||||
activation_type="LeakyReLU", norm_type="NONE", padding_mode='reflect'):
|
||||
super().__init__()
|
||||
|
||||
sequence = [Conv2dBlock(
|
||||
in_channels, base_channels, kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=norm_type
|
||||
)]
|
||||
|
||||
sequence += [Conv2dBlock(
|
||||
base_channels * (2 ** i), base_channels * (2 ** i) * 2,
|
||||
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=norm_type) for i in range(num_blocks - 3)]
|
||||
|
||||
sequence.append(
|
||||
Conv2dBlock(base_channels * (2 ** (num_blocks - 3)), base_channels * (2 ** (num_blocks - 2)),
|
||||
kernel_size=4, stride=1, padding=1, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=norm_type)
|
||||
)
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
|
||||
mult = 2 ** (num_blocks - 2)
|
||||
self.cam = CAMClassifier(base_channels * mult, activation_type)
|
||||
self.conv = nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False,
|
||||
padding_mode="reflect")
|
||||
|
||||
def forward(self, x, return_heatmap=False):
|
||||
x = self.sequence(x)
|
||||
|
||||
x, cam_logit = self.cam(x)
|
||||
|
||||
if return_heatmap:
|
||||
heatmap = torch.sum(x, dim=1, keepdim=True)
|
||||
return self.conv(x), cam_logit, heatmap
|
||||
else:
|
||||
return self.conv(x), cam_logit
|
||||
0
model/image_translation/__init__.py
Normal file
0
model/image_translation/__init__.py
Normal file
29
model/image_translation/pix2pixHD.py
Normal file
29
model/image_translation/pix2pixHD.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model import MODEL
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class MultiScaleDiscriminator(nn.Module):
|
||||
def __init__(self, num_scale, discriminator_cfg, down_sample_method="avg"):
|
||||
super().__init__()
|
||||
assert down_sample_method in ["avg", "bilinear"]
|
||||
self.down_sample_method = down_sample_method
|
||||
|
||||
self.discriminator_list = nn.ModuleList([
|
||||
MODEL.build_with(discriminator_cfg) for _ in range(num_scale)
|
||||
])
|
||||
|
||||
def down_sample(self, x):
|
||||
if self.down_sample_method == "avg":
|
||||
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
|
||||
if self.down_sample_method == "bilinear":
|
||||
return F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True)
|
||||
|
||||
def forward(self, x):
|
||||
results = []
|
||||
for discriminator in self.discriminator_list:
|
||||
results.append(discriminator(x))
|
||||
x = self.down_sample(x)
|
||||
return results
|
||||
@@ -1,75 +0,0 @@
|
||||
import torch.nn as nn
|
||||
import functools
|
||||
import torch
|
||||
|
||||
|
||||
def select_norm_layer(norm_type):
|
||||
if norm_type == "BN":
|
||||
return functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
||||
elif norm_type == "IN":
|
||||
return functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
||||
elif norm_type == "LN":
|
||||
return functools.partial(LayerNorm2d, affine=True)
|
||||
elif norm_type == "NONE":
|
||||
return lambda num_features: nn.Identity()
|
||||
elif norm_type == "AdaIN":
|
||||
return functools.partial(AdaptiveInstanceNorm2d, affine=False, track_running_stats=False)
|
||||
else:
|
||||
raise NotImplemented(f'normalization layer {norm_type} is not found')
|
||||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
def __init__(self, num_features, eps: float = 1e-5, affine: bool = True):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.affine = affine
|
||||
if self.affine:
|
||||
self.channel_gamma = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
||||
self.channel_beta = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
if self.affine:
|
||||
nn.init.uniform_(self.channel_gamma)
|
||||
nn.init.zeros_(self.channel_beta)
|
||||
|
||||
def forward(self, x):
|
||||
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
|
||||
x = (x - ln_mean) / torch.sqrt(ln_var + self.eps)
|
||||
if self.affine:
|
||||
return self.channel_gamma * x + self.channel_beta
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(num_features={self.num_features}, affine={self.affine})"
|
||||
|
||||
|
||||
class AdaptiveInstanceNorm2d(nn.Module):
|
||||
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1,
|
||||
affine: bool = False, track_running_stats: bool = False):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.affine = affine
|
||||
self.track_running_stats = track_running_stats
|
||||
self.norm = nn.InstanceNorm2d(num_features, eps, momentum, affine, track_running_stats)
|
||||
|
||||
self.gamma = None
|
||||
self.beta = None
|
||||
self.have_set_style = False
|
||||
|
||||
def set_style(self, style):
|
||||
style = style.view(*style.size(), 1, 1)
|
||||
self.gamma, self.beta = style.chunk(2, 1)
|
||||
self.have_set_style = True
|
||||
|
||||
def forward(self, x):
|
||||
assert self.have_set_style
|
||||
x = self.norm(x)
|
||||
x = self.gamma * x + self.beta
|
||||
self.have_set_style = False
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
|
||||
f"affine={self.affine}, track_running_stats={self.track_running_stats})"
|
||||
@@ -1,3 +1,4 @@
|
||||
from util.registry import Registry
|
||||
|
||||
MODEL = Registry("model")
|
||||
NORMALIZATION = Registry("normalization")
|
||||
@@ -65,7 +65,8 @@ def generation_init_weights(module, init_type='normal', init_gain=0.02):
|
||||
elif classname.find('BatchNorm2d') != -1:
|
||||
# BatchNorm Layer's weight is not a matrix;
|
||||
# only normal distribution applies.
|
||||
normal_init(m, 1.0, init_gain)
|
||||
if m.weight is not None:
|
||||
normal_init(m, 1.0, init_gain)
|
||||
|
||||
assert isinstance(module, nn.Module)
|
||||
module.apply(init_func)
|
||||
|
||||
8
run.sh
8
run.sh
@@ -5,16 +5,18 @@ TASK=$2
|
||||
GPUS=$3
|
||||
MORE_ARG=${*:4}
|
||||
|
||||
RANDOM_MASTER=$(shuf -i 2000-65000 -n 1)
|
||||
|
||||
_command="print(len('${GPUS}'.split(',')))"
|
||||
GPU_COUNT=$(python3 -c "${_command}")
|
||||
|
||||
echo "GPU_COUNT:${GPU_COUNT}"
|
||||
|
||||
echo CUDA_VISIBLE_DEVICES=$GPUS \
|
||||
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
|
||||
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
|
||||
main.py "$TASK" "$CONFIG" --backup_config --setup_output_dir --setup_random_seed "$MORE_ARG"
|
||||
|
||||
CUDA_VISIBLE_DEVICES=$GPUS \
|
||||
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
|
||||
PYTHONPATH=.:$PYTHONPATH OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node="$GPU_COUNT" \
|
||||
--master_port=${RANDOM_MASTER} \
|
||||
main.py "$TASK" "$CONFIG" $MORE_ARG --backup_config --setup_output_dir --setup_random_seed
|
||||
|
||||
|
||||
46
tool/dump_tensorboard.py
Normal file
46
tool/dump_tensorboard.py
Normal file
@@ -0,0 +1,46 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# edit from https://gist.github.com/hysts/81a0d30ac4f33dfa0c8859383aec42c2
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from tensorboard.backend.event_processing import event_accumulator
|
||||
|
||||
|
||||
def save(outdir: pathlib.Path, tag, event_acc):
|
||||
events = event_acc.Images(tag)
|
||||
|
||||
for index, event in enumerate(events):
|
||||
s = np.frombuffer(event.encoded_image_string, dtype=np.uint8)
|
||||
image = cv2.imdecode(s, cv2.IMREAD_COLOR)
|
||||
outpath = outdir / f"{tag.replace('/', '_')}@{index}.png"
|
||||
cv2.imwrite(outpath.as_posix(), image)
|
||||
|
||||
|
||||
# ffmpeg -framerate 1 -i ./tmp/test_b/%04d.jpg -vcodec mpeg4 test_b.mp4
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--path', type=str, required=True)
|
||||
parser.add_argument('--outdir', type=str, required=True)
|
||||
parser.add_argument("--tag", type=str, required=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
event_acc = event_accumulator.EventAccumulator(args.path, size_guidance={'images': 0})
|
||||
event_acc.Reload()
|
||||
|
||||
outdir = pathlib.Path(args.outdir)
|
||||
outdir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
if args.tag is None:
|
||||
for tag in event_acc.Tags()['images']:
|
||||
save(outdir, tag, event_acc)
|
||||
else:
|
||||
assert args.tag in event_acc.Tags()['images'], f"{args.tag} not in {event_acc.Tags()['images']}"
|
||||
save(outdir, args.tag, event_acc)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
32
tool/encoder_distance.py
Normal file
32
tool/encoder_distance.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
#
|
||||
# data = {}
|
||||
#
|
||||
# for i in range(1, 422 + 1):
|
||||
# _, names = torch.load(f"/tmp/pt/batch{i}.pt")
|
||||
# generated = torch.load(f"/tmp/pt/generated{i}.pt")
|
||||
# print(len(names))
|
||||
# for j, n in enumerate(names):
|
||||
# data[Path(names[j]).stem] = generated[j]
|
||||
#
|
||||
# torch.save(data, "/tmp/data.pt")
|
||||
|
||||
data = torch.load("/tmp/data.pt")
|
||||
|
||||
videos = sorted(list(set([k.split("@")[0] for k in data.keys()])))
|
||||
|
||||
for idx in range(len(videos)):
|
||||
print(videos[idx])
|
||||
videos_data = {}
|
||||
for k in data:
|
||||
if k.startswith(videos[idx]):
|
||||
videos_data[int(k.split("@")[-1])] = data[k]
|
||||
|
||||
to_save = []
|
||||
for i in range(2, len(videos_data) + 1):
|
||||
to_save.append(torch.mean(torch.abs(videos_data[i] - videos_data[1])).cpu())
|
||||
|
||||
torch.save(to_save, f"{videos[idx]}.pt")
|
||||
print(f"{videos[idx]}.pt")
|
||||
14
tool/inspect_model.py
Normal file
14
tool/inspect_model.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import sys
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from engine.util.build import build_model
|
||||
|
||||
config = OmegaConf.load(sys.argv[1])
|
||||
|
||||
|
||||
generator = build_model(config.model.generator)
|
||||
|
||||
ckp = torch.load(sys.argv[2], map_location="cpu")
|
||||
|
||||
generator.module.load_state_dict(ckp["generator_main"])
|
||||
13
tool/process/permutation_face.py
Normal file
13
tool/process/permutation_face.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from itertools import permutations
|
||||
|
||||
pids = defaultdict(list)
|
||||
for p in Path(sys.argv[1]).glob("*.jpg"):
|
||||
pids[p.stem[:7]].append(p.stem)
|
||||
|
||||
data = []
|
||||
for p in pids:
|
||||
data.extend(list(permutations(pids[p], 2)))
|
||||
|
||||
69
tool/verify_loss.py
Normal file
69
tool/verify_loss.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from ignite.utils import convert_tensor
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from data.dataset import SingleFolderDataset
|
||||
from loss.I2I.perceptual_loss import PerceptualLoss
|
||||
|
||||
import ignite.distributed as idist
|
||||
|
||||
CONFIG = """
|
||||
loss:
|
||||
perceptual:
|
||||
layer_weights:
|
||||
"1": 0.03125
|
||||
"6": 0.0625
|
||||
"11": 0.125
|
||||
"20": 0.25
|
||||
"29": 1
|
||||
criterion: 'NL2'
|
||||
style_loss: False
|
||||
perceptual_loss: True
|
||||
match_data:
|
||||
root: "/tmp/generated/"
|
||||
pipeline:
|
||||
- Load
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
not_match_data:
|
||||
root: "/data/i2i/selfie2anime/trainB/"
|
||||
pipeline:
|
||||
- Load
|
||||
- ToTensor
|
||||
- Normalize:
|
||||
mean: [ 0.5, 0.5, 0.5 ]
|
||||
std: [ 0.5, 0.5, 0.5 ]
|
||||
"""
|
||||
|
||||
config = OmegaConf.create(CONFIG)
|
||||
dataset = SingleFolderDataset(**config.match_data)
|
||||
data_loader = DataLoader(dataset, 1, False, num_workers=1)
|
||||
|
||||
perceptual_loss = PerceptualLoss(**config.loss.perceptual).to("cuda:0")
|
||||
|
||||
pls = []
|
||||
for batch in data_loader:
|
||||
with torch.no_grad():
|
||||
batch = convert_tensor(batch, "cuda:0")
|
||||
x, t = torch.chunk(batch, 2, -1)
|
||||
pl, _ = perceptual_loss(x, t)
|
||||
print(pl)
|
||||
pls.append(pl)
|
||||
|
||||
torch.save(torch.stack(pls).cpu(), "verify_loss.match.pt")
|
||||
|
||||
dataset = SingleFolderDataset(**config.not_match_data)
|
||||
data_loader = DataLoader(dataset, 4, False, num_workers=1)
|
||||
pls = []
|
||||
for batch in data_loader:
|
||||
with torch.no_grad():
|
||||
batch = convert_tensor(batch, "cuda:0")
|
||||
for i, j in [(0, 1), (1, 2), (2, 3), (3, 0)]:
|
||||
x, t = batch[i].unsqueeze(dim=0), batch[j].unsqueeze(dim=0)
|
||||
pl, _ = perceptual_loss(x, t)
|
||||
print(pl)
|
||||
pls.append(pl)
|
||||
torch.save(torch.stack(pls).cpu(), "verify_loss.not_match.pt")
|
||||
@@ -1,27 +0,0 @@
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import ignite.distributed as idist
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model import MODEL
|
||||
from util.distributed import auto_model
|
||||
|
||||
|
||||
def build_model(cfg, distributed_args=None):
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
model_distributed_config = cfg.pop("_distributed", dict())
|
||||
model = MODEL.build_with(cfg)
|
||||
|
||||
if model_distributed_config.get("bn_to_syncbn"):
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
|
||||
distributed_args = {} if distributed_args is None or idist.get_world_size() == 1 else distributed_args
|
||||
return auto_model(model, **distributed_args)
|
||||
|
||||
|
||||
def build_optimizer(params, cfg):
|
||||
assert "_type" in cfg
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
optimizer = getattr(optim, cfg.pop("_type"))(params=params, **cfg)
|
||||
return idist.auto_optim(optimizer)
|
||||
@@ -1,66 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ignite.distributed import utils as idist
|
||||
from ignite.distributed.comp_models import native as idist_native
|
||||
from ignite.utils import setup_logger
|
||||
|
||||
|
||||
def auto_model(model: nn.Module, **additional_kwargs) -> nn.Module:
|
||||
"""Helper method to adapt provided model for non-distributed and distributed configurations (supporting
|
||||
all available backends from :meth:`~ignite.distributed.utils.available_backends()`).
|
||||
|
||||
Internally, we perform to following:
|
||||
|
||||
- send model to current :meth:`~ignite.distributed.utils.device()` if model's parameters are not on the device.
|
||||
- wrap the model to `torch DistributedDataParallel`_ for native torch distributed if world size is larger than 1.
|
||||
- wrap the model to `torch DataParallel`_ if no distributed context found and more than one CUDA devices available.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ignite.distribted as idist
|
||||
|
||||
model = idist.auto_model(model)
|
||||
|
||||
In addition with NVidia/Apex, it can be used in the following way:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ignite.distribted as idist
|
||||
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
|
||||
model = idist.auto_model(model)
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model to adapt.
|
||||
|
||||
Returns:
|
||||
torch.nn.Module
|
||||
|
||||
.. _torch DistributedDataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel
|
||||
.. _torch DataParallel: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
|
||||
"""
|
||||
logger = setup_logger(__name__ + ".auto_model")
|
||||
|
||||
# Put model's parameters to device if its parameters are not on the device
|
||||
device = idist.device()
|
||||
if not all([p.device == device for p in model.parameters()]):
|
||||
model.to(device)
|
||||
|
||||
# distributed data parallel model
|
||||
if idist.get_world_size() > 1:
|
||||
if idist.backend() == idist_native.NCCL:
|
||||
lrank = idist.get_local_rank()
|
||||
logger.info("Apply torch DistributedDataParallel on model, device id: {}".format(lrank))
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[lrank, ], **additional_kwargs)
|
||||
elif idist.backend() == idist_native.GLOO:
|
||||
logger.info("Apply torch DistributedDataParallel on model")
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, **additional_kwargs)
|
||||
|
||||
# not distributed but multiple GPUs reachable so data parallel model
|
||||
elif torch.cuda.device_count() > 1 and "cuda" in idist.device().type:
|
||||
logger.info("Apply torch DataParallel on model")
|
||||
model = torch.nn.parallel.DataParallel(model, **additional_kwargs)
|
||||
|
||||
return model
|
||||
@@ -1,26 +1,34 @@
|
||||
import torchvision.utils
|
||||
from matplotlib.pyplot import get_cmap
|
||||
import torch
|
||||
import warnings
|
||||
from torch.nn.functional import interpolate
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def attention_colored_map(attentions, size=None, cmap_name="jet"):
|
||||
def attention_colored_map(attentions, size=None):
|
||||
assert attentions.dim() == 4 and attentions.size(1) == 1
|
||||
device = attentions.device
|
||||
|
||||
min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
||||
attentions -= min_attentions
|
||||
attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
||||
|
||||
if size is not None and attentions.size()[-2:] != size:
|
||||
attentions = attentions.detach().cpu().numpy()
|
||||
attentions = (attentions * 255).astype(np.uint8)
|
||||
need_resize = False
|
||||
if size is not None and attentions.shape[-2:] != size:
|
||||
assert len(size) == 2, "for interpolate, size must be (x, y), have two dim"
|
||||
attentions = interpolate(attentions, size, mode="bilinear", align_corners=False)
|
||||
cmap = get_cmap(cmap_name)
|
||||
ca = cmap(attentions.squeeze(1).cpu())[:, :, :, :3]
|
||||
return torch.from_numpy(ca).permute(0, 3, 1, 2).contiguous()
|
||||
need_resize = True
|
||||
|
||||
subs = []
|
||||
for sub in attentions:
|
||||
sub = cv2.resize(sub[0], size) if need_resize else sub[0] # numpy.array shape=size
|
||||
subs.append(cv2.applyColorMap(sub, cv2.COLORMAP_JET)) # append a (size[0], size[1], 3) numpy array
|
||||
subs = np.stack(subs) # (batch_size, size[0], size[1], 3)
|
||||
return torch.from_numpy(subs).permute(0, 3, 1, 2).contiguous().to(device).float() / 255
|
||||
|
||||
|
||||
def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
|
||||
def fuse_attention_map(images, attentions, alpha=0.5):
|
||||
"""
|
||||
|
||||
:param images: B x H x W
|
||||
@@ -35,7 +43,7 @@ def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
|
||||
if attentions.size(1) != 1:
|
||||
warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}")
|
||||
return images
|
||||
colored_attentions = attention_colored_map(attentions, images.size()[-2:], cmap_name).to(images.device)
|
||||
colored_attentions = attention_colored_map(attentions, images.size()[-2:])
|
||||
return images * alpha + colored_attentions * (1 - alpha)
|
||||
|
||||
|
||||
|
||||
23
util/misc.py
23
util/misc.py
@@ -1,6 +1,25 @@
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Optional
|
||||
import pkgutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
def import_submodules(package, recursive=True):
|
||||
""" Import all submodules of a module, recursively, including subpackages
|
||||
|
||||
:param package: package (name or actual module)
|
||||
:type package: str | module
|
||||
:rtype: dict[str, types.ModuleType]
|
||||
"""
|
||||
if isinstance(package, str):
|
||||
package = importlib.import_module(package)
|
||||
results = {}
|
||||
for loader, name, is_pkg in pkgutil.walk_packages(package.__path__):
|
||||
full_name = package.__name__ + '.' + name
|
||||
results[name] = importlib.import_module(full_name)
|
||||
if recursive and is_pkg:
|
||||
results.update(import_submodules(full_name))
|
||||
return results
|
||||
|
||||
|
||||
def setup_logger(
|
||||
@@ -8,7 +27,6 @@ def setup_logger(
|
||||
level: int = logging.INFO,
|
||||
logger_format: str = "%(asctime)s %(name)s %(levelname)s: %(message)s",
|
||||
filepath: Optional[str] = None,
|
||||
file_level: int = logging.DEBUG,
|
||||
distributed_rank: Optional[int] = None,
|
||||
) -> logging.Logger:
|
||||
"""Setups logger: name, level, format etc.
|
||||
@@ -18,7 +36,6 @@ def setup_logger(
|
||||
level (int): logging level, e.g. CRITICAL, ERROR, WARNING, INFO, DEBUG
|
||||
logger_format (str): logging format. By default, `%(asctime)s %(name)s %(levelname)s: %(message)s`
|
||||
filepath (str, optional): Optional logging file path. If not None, logs are written to the file.
|
||||
file_level (int): Optional logging level for logging file.
|
||||
distributed_rank (int, optional): Optional, rank in distributed configuration to avoid logger setup for workers.
|
||||
If None, distributed_rank is initialized to the rank of process.
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import inspect
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from omegaconf import OmegaConf
|
||||
import warnings
|
||||
from types import ModuleType
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
|
||||
|
||||
class _Registry:
|
||||
def __init__(self, name):
|
||||
@@ -51,6 +53,10 @@ class _Registry:
|
||||
else:
|
||||
raise TypeError(f'cfg must be a dict or a str, but got {type(cfg)}')
|
||||
|
||||
for invalid_key in [k for k in args.keys() if k.startswith("_")]:
|
||||
warnings.warn(f"got param start with `_`: {invalid_key}, will remove it")
|
||||
args.pop(invalid_key)
|
||||
|
||||
if not (isinstance(default_args, dict) or default_args is None):
|
||||
raise TypeError('default_args must be a dict or None, '
|
||||
f'but got {type(default_args)}')
|
||||
@@ -130,8 +136,11 @@ class Registry(_Registry):
|
||||
if module_name is None:
|
||||
module_name = module_class.__name__
|
||||
if not force and module_name in self._module_dict:
|
||||
raise KeyError(f'{module_name} is already registered '
|
||||
f'in {self.name}')
|
||||
if self._module_dict[module_name] == module_class:
|
||||
warnings.warn(f'{module_name} is already registered in {self.name}, but is the same class')
|
||||
return
|
||||
raise KeyError(f'{module_name}:{self._module_dict[module_name]} is already registered in {self.name}'
|
||||
f'so {module_class} can not be registered')
|
||||
self._module_dict[module_name] = module_class
|
||||
|
||||
def register_module(self, name=None, force=False, module=None):
|
||||
|
||||
Reference in New Issue
Block a user