Compare commits

...

2 Commits

Author SHA1 Message Date
61e04de8a5 TAFG 2020-09-17 09:34:53 +08:00
2ff4a91057 add MUNIT 2020-09-14 22:30:05 +08:00
13 changed files with 662 additions and 278 deletions

2
.idea/deployment.xml generated
View File

@@ -5,7 +5,7 @@
<paths name="14d"> <paths name="14d">
<serverdata> <serverdata>
<mappings> <mappings>
<mapping local="$PROJECT_DIR$" web="/" /> <mapping deploy="raycv" local="$PROJECT_DIR$" web="/" />
</mappings> </mappings>
</serverdata> </serverdata>
</paths> </paths>

View File

@@ -0,0 +1,132 @@
name: MUNIT-edges2shoes
engine: MUNIT
result_dir: ./result
max_pairs: 1000000
handler:
clear_cuda_cache: True
set_epoch_for_dist_sampler: True
checkpoint:
epoch_interval: 1 # checkpoint once per `epoch_interval` epoch
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch
misc:
random_seed: 324
model:
generator:
_type: MUNIT-Generator
in_channels: 3
out_channels: 3
base_channels: 64
num_sampling: 2
num_style_dim: 8
num_style_conv: 4
num_content_res_blocks: 4
num_decoder_res_blocks: 4
num_fusion_dim: 256
num_fusion_blocks: 3
discriminator:
_type: MultiScaleDiscriminator
num_scale: 2
discriminator_cfg:
_type: PatchDiscriminator
in_channels: 3
base_channels: 64
use_spectral: True
need_intermediate_feature: True
loss:
gan:
loss_type: lsgan
real_label_val: 1.0
fake_label_val: 0.0
weight: 1.0
perceptual:
layer_weights:
"1": 0.03125
"6": 0.0625
"11": 0.125
"20": 0.25
"29": 1
criterion: 'L1'
style_loss: False
perceptual_loss: True
weight: 0
recon:
level: 1
style:
weight: 1
content:
weight: 1
image:
weight: 10
cycle:
weight: 0
optimizers:
generator:
_type: Adam
lr: 0.0001
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
discriminator:
_type: Adam
lr: 4e-4
betas: [ 0.5, 0.999 ]
weight_decay: 0.0001
data:
train:
scheduler:
start_proportion: 0.5
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 1
shuffle: True
num_workers: 1
pin_memory: True
drop_last: True
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/edges2shoes/trainA"
root_b: "/data/i2i/edges2shoes/trainB"
random_pair: True
pipeline:
- Load
- Resize:
size: [ 286, 286 ]
- RandomCrop:
size: [ 256, 256 ]
- RandomHorizontalFlip
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]
test:
which: dataset
dataloader:
batch_size: 8
shuffle: False
num_workers: 1
pin_memory: False
drop_last: False
dataset:
_type: GenerationUnpairedDataset
root_a: "/data/i2i/edges2shoes/testA"
root_b: "/data/i2i/edges2shoes/testB"
random_pair: False
pipeline:
- Load
- Resize:
size: [ 256, 256 ]
- ToTensor
- Normalize:
mean: [ 0.5, 0.5, 0.5 ]
std: [ 0.5, 0.5, 0.5 ]

View File

@@ -1,4 +1,4 @@
name: TAFG name: TAFG-vox2
engine: TAFG engine: TAFG
result_dir: ./result result_dir: ./result
max_pairs: 1500000 max_pairs: 1500000
@@ -11,11 +11,11 @@ handler:
n_saved: 2 n_saved: 2
tensorboard: tensorboard:
scalar: 100 # log scalar `scalar` times per epoch scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch image: 4 # log image `image` times per epoch
misc: misc:
random_seed: 324 random_seed: 123
model: model:
generator: generator:
@@ -24,7 +24,9 @@ model:
style_in_channels: 3 style_in_channels: 3
content_in_channels: 24 content_in_channels: 24
num_adain_blocks: 8 num_adain_blocks: 8
num_res_blocks: 0 num_res_blocks: 8
use_spectral_norm: True
style_use_fc: False
discriminator: discriminator:
_type: MultiScaleDiscriminator _type: MultiScaleDiscriminator
num_scale: 2 num_scale: 2
@@ -51,26 +53,22 @@ loss:
criterion: 'L1' criterion: 'L1'
style_loss: False style_loss: False
perceptual_loss: True perceptual_loss: True
weight: 10 weight: 0
style:
layer_weights:
"3": 1
criterion: 'L1'
style_loss: True
perceptual_loss: False
weight: 10
fm:
level: 1
weight: 10
recon: recon:
level: 1 level: 1
weight: 10 weight: 10
style_recon: style_recon:
level: 1 level: 1
weight: 0 weight: 5
content_recon:
level: 1
weight: 10
edge: edge:
weight: 10 weight: 10
hed_pretrained_model_path: ./network-bsds500.pytorch hed_pretrained_model_path: ./network-bsds500.pytorch
cycle:
level: 1
weight: 10
optimizers: optimizers:
generator: generator:
@@ -91,9 +89,9 @@ data:
target_lr: 0 target_lr: 0
buffer_size: 50 buffer_size: 50
dataloader: dataloader:
batch_size: 8 batch_size: 1
shuffle: True shuffle: True
num_workers: 2 num_workers: 1
pin_memory: True pin_memory: True
drop_last: True drop_last: True
dataset: dataset:
@@ -116,7 +114,7 @@ data:
test: test:
which: video_dataset which: video_dataset
dataloader: dataloader:
batch_size: 8 batch_size: 1
shuffle: False shuffle: False
num_workers: 1 num_workers: 1
pin_memory: False pin_memory: False
@@ -145,7 +143,7 @@ data:
pipeline: pipeline:
- Load - Load
- Resize: - Resize:
size: [ 256, 256 ] size: [ 128, 128 ]
- ToTensor - ToTensor
- Normalize: - Normalize:
mean: [ 0.5, 0.5, 0.5 ] mean: [ 0.5, 0.5, 0.5 ]

View File

@@ -203,7 +203,7 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
op = Path(origin_path) op = Path(origin_path)
if self.edge_type.startswith("landmark_"): if self.edge_type.startswith("landmark_"):
edge_type = self.edge_type.lstrip("landmark_") edge_type = self.edge_type.lstrip("landmark_")
use_landmark = True use_landmark = op.parent.name.endswith("A")
else: else:
edge_type = self.edge_type edge_type = self.edge_type
use_landmark = False use_landmark = False
@@ -225,14 +225,11 @@ class GenerationUnpairedDatasetWithEdge(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
a_idx = idx % len(self.A) a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item() b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
if self.with_path: output = dict(a={}, b={})
output = {"a": self.A[a_idx], "b": self.B[b_idx]} output["a"]["img"], output["a"]["path"] = self.A[a_idx]
output["edge_a"] = output["a"][1] output["b"]["img"], output["b"]["path"] = self.B[b_idx]
return output for p in "ab":
output = dict() output[p]["edge"] = self.get_edge(output[p]["path"])
output["a"], path_a = self.A[a_idx]
output["b"], path_b = self.B[b_idx]
output["edge_a"] = self.get_edge(path_a)
return output return output
def __len__(self): def __len__(self):

154
engine/MUNIT.py Normal file
View File

@@ -0,0 +1,154 @@
import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
from loss.I2I.perceptual_loss import PerceptualLoss
from loss.gan import GANLoss
def mse_loss(x, target_flag):
return F.mse_loss(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
def bce_loss(x, target_flag):
return F.binary_cross_entropy_with_logits(x, torch.ones_like(x) if target_flag else torch.zeros_like(x))
class MUNITEngineKernel(EngineKernel):
def __init__(self, config):
super().__init__(config)
perceptual_loss_cfg = OmegaConf.to_container(config.loss.perceptual)
perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.train_generator_first = False
def build_models(self) -> (dict, dict):
generators = dict(
a=build_model(self.config.model.generator),
b=build_model(self.config.model.generator)
)
discriminators = dict(
a=build_model(self.config.model.discriminator),
b=build_model(self.config.model.discriminator)
)
self.logger.debug(discriminators["a"])
self.logger.debug(generators["a"])
return generators, discriminators
def setup_after_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(True)
def setup_before_g(self):
for discriminator in self.discriminators.values():
discriminator.requires_grad_(False)
def forward(self, batch, inference=False) -> dict:
styles = dict()
contents = dict()
images = dict()
with torch.set_grad_enabled(not inference):
for phase in "ab":
contents[phase], styles[phase] = self.generators[phase].encode(batch[phase])
images["{0}2{0}".format(phase)] = self.generators[phase].decode(contents[phase], styles[phase])
styles[f"random_{phase}"] = torch.randn_like(styles[phase]).to(idist.device())
for phase in ("a2b", "b2a"):
# images["a2b"] = Gb.decode(content_a, random_style_b)
images[phase] = self.generators[phase[-1]].decode(contents[phase[0]], styles[f"random_{phase[-1]}"])
# contents["a2b"], styles["a2b"] = Gb.encode(images["a2b"])
contents[phase], styles[phase] = self.generators[phase[-1]].encode(images[phase])
if self.config.loss.recon.cycle.weight > 0:
images[f"{phase}2{phase[0]}"] = self.generators[phase[0]].decode(contents[phase], styles[phase[0]])
return dict(styles=styles, contents=contents, images=images)
def criterion_generators(self, batch, generated) -> dict:
loss = dict()
for phase in "ab":
loss[f"recon_image_{phase}"] = self.config.loss.recon.image.weight * self.recon_loss(
batch[phase], generated["images"]["{0}2{0}".format(phase)])
loss[f"recon_content_{phase}"] = self.config.loss.recon.content.weight * self.recon_loss(
generated["contents"][phase], generated["contents"]["a2b" if phase == "a" else "b2a"])
loss[f"recon_style_{phase}"] = self.config.loss.recon.style.weight * self.recon_loss(
generated["styles"][f"random_{phase}"], generated["styles"]["b2a" if phase == "a" else "a2b"])
pred_fake = self.discriminators[phase](generated["images"]["b2a" if phase == "a" else "a2b"])
loss[f"gan_{phase}"] = 0
for sub_pred_fake in pred_fake:
# last output is actual prediction
loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True)
if self.config.loss.recon.cycle.weight > 0:
loss[f"recon_cycle_{phase}"] = self.config.loss.recon.cycle.weight * self.recon_loss(
batch[phase], generated["images"]["a2b2a" if phase == "a" else "b2a2b"])
if self.config.loss.perceptual.weight > 0:
loss[f"perceptual_{phase}"] = self.config.loss.perceptual.weight * self.perceptual_loss(
batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
return loss
def criterion_discriminators(self, batch, generated) -> dict:
loss = dict()
for phase in ("a2b", "b2a"):
pred_real = self.discriminators[phase[-1]](batch[phase[-1]])
pred_fake = self.discriminators[phase[-1]](generated["images"][phase].detach())
loss[f"gan_{phase[-1]}"] = 0
for i in range(len(pred_fake)):
loss[f"gan_{phase[-1]}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
+ self.gan_loss(pred_real[i][-1], True, is_discriminator=True)) / 2
return loss
def intermediate_images(self, batch, generated) -> dict:
"""
returned dict must be like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
:param batch:
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
generated = {img: generated["images"][img].detach() for img in generated["images"]}
images = dict()
for phase in "ab":
images[phase] = [batch[phase].detach(), generated["{0}2{0}".format(phase)],
generated["a2b" if phase == "a" else "b2a"]]
if self.config.loss.recon.cycle.weight > 0:
images[phase].append(generated["a2b2a" if phase == "a" else "b2a2b"])
return images
class MUNITTestEngineKernel(TestEngineKernel):
def __init__(self, config):
super().__init__(config)
def build_generators(self) -> dict:
generators = dict(
a=build_model(self.config.model.generator),
b=build_model(self.config.model.generator)
)
return generators
def to_load(self):
return {f"generator_{k}": self.generators[k] for k in self.generators}
def inference(self, batch):
with torch.no_grad():
fake, _, _ = self.generators["a2b"](batch[0])
return fake.detach()
def run(task, config, _):
if task == "train":
kernel = MUNITEngineKernel(config)
run_kernel(task, config, kernel)
elif task == "test":
kernel = MUNITTestEngineKernel(config)
run_kernel(task, config, kernel)
else:
raise NotImplemented

View File

@@ -3,8 +3,7 @@ from itertools import chain
import ignite.distributed as idist import ignite.distributed as idist
import torch import torch
import torch.nn as nn import torch.nn as nn
from ignite.engine import Events from omegaconf import OmegaConf
from omegaconf import read_write, OmegaConf
from engine.base.i2i import EngineKernel, run_kernel from engine.base.i2i import EngineKernel, run_kernel
from engine.util.build import build_model from engine.util.build import build_model
@@ -21,17 +20,14 @@ class TAFGEngineKernel(EngineKernel):
perceptual_loss_cfg.pop("weight") perceptual_loss_cfg.pop("weight")
self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device()) self.perceptual_loss = PerceptualLoss(**perceptual_loss_cfg).to(idist.device())
style_loss_cfg = OmegaConf.to_container(config.loss.style)
style_loss_cfg.pop("weight")
self.style_loss = PerceptualLoss(**style_loss_cfg).to(idist.device())
gan_loss_cfg = OmegaConf.to_container(config.loss.gan) gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight") gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device()) self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.fm_loss = nn.L1Loss() if config.loss.fm.level == 1 else nn.MSELoss()
self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss() self.recon_loss = nn.L1Loss() if config.loss.recon.level == 1 else nn.MSELoss()
self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss() self.style_recon_loss = nn.L1Loss() if config.loss.style_recon.level == 1 else nn.MSELoss()
self.content_recon_loss = nn.L1Loss() if config.loss.content_recon.level == 1 else nn.MSELoss()
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to( self.edge_loss = EdgeLoss("HED", hed_pretrained_model_path=config.loss.edge.hed_pretrained_model_path).to(
idist.device()) idist.device())
@@ -67,47 +63,67 @@ class TAFGEngineKernel(EngineKernel):
def forward(self, batch, inference=False) -> dict: def forward(self, batch, inference=False) -> dict:
generator = self.generators["main"] generator = self.generators["main"]
batch = self._process_batch(batch, inference) batch = self._process_batch(batch, inference)
styles = dict()
contents = dict()
images = dict()
with torch.set_grad_enabled(not inference): with torch.set_grad_enabled(not inference):
fake = dict( for ph in "ab":
a=generator(content_img=batch["edge_a"], style_img=batch["a"], which_decoder="a"), contents[ph], styles[ph] = generator.encode(batch[ph]["edge"], batch[ph]["img"], ph, ph)
b=generator(content_img=batch["edge_a"], style_img=batch["b"], which_decoder="b"), for ph in ("a2b", "b2a"):
) images[f"fake_{ph[-1]}"] = generator.decode(contents[ph[0]], styles[ph[-1]], ph[-1])
return fake contents["recon_a"], styles["recon_b"] = generator.encode(
self.edge_loss.edge_extractor(images["fake_b"]), images["fake_b"], "b", "b")
images["a2a"] = generator.decode(contents["a"], styles["a"], "a")
images["b2b"] = generator.decode(contents["b"], styles["recon_b"], "b")
images["cycle_a"] = generator.decode(contents["recon_a"], styles["a"], "a")
return dict(styles=styles, contents=contents, images=images)
def criterion_generators(self, batch, generated) -> dict: def criterion_generators(self, batch, generated) -> dict:
batch = self._process_batch(batch) batch = self._process_batch(batch)
loss = dict() loss = dict()
loss_perceptual, _ = self.perceptual_loss(generated["b"], batch["a"])
_, loss_style = self.style_loss(generated["a"], batch["a"]) for ph in "ab":
loss["style"] = self.config.loss.style.weight * loss_style loss[f"recon_image_{ph}"] = self.config.loss.recon.weight * self.recon_loss(
loss["perceptual"] = self.config.loss.perceptual.weight * loss_perceptual generated["images"][f"{ph}2{ph}"], batch[ph]["img"])
for phase in "ab":
pred_fake = self.discriminators[phase](generated[phase]) pred_fake = self.discriminators[ph](generated["images"][f"fake_{ph}"])
loss[f"gan_{phase}"] = 0 loss[f"gan_{ph}"] = 0
for sub_pred_fake in pred_fake: for sub_pred_fake in pred_fake:
# last output is actual prediction # last output is actual prediction
loss[f"gan_{phase}"] += self.gan_loss(sub_pred_fake[-1], True) loss[f"gan_{ph}"] += self.gan_loss(sub_pred_fake[-1], True) * self.config.loss.gan.weight
loss[f"recon_content_a"] = self.config.loss.content_recon.weight * self.content_recon_loss(
generated["contents"]["a"], generated["contents"]["recon_a"]
)
loss[f"recon_style_b"] = self.config.loss.style_recon.weight * self.style_recon_loss(
generated["styles"]["b"], generated["styles"]["recon_b"]
)
if self.config.loss.fm.weight > 0 and phase == "b": for ph in ("a2b", "b2a"):
pred_real = self.discriminators[phase](batch[phase]) if self.config.loss.perceptual.weight > 0:
loss_fm = 0 loss[f"perceptual_{ph}"] = self.config.loss.perceptual.weight * self.perceptual_loss(
num_scale_discriminator = len(pred_fake) batch[ph[0]]["img"], generated["images"][f"fake_{ph[-1]}"]
for i in range(num_scale_discriminator): )
# last output is the final prediction, so we exclude it if self.config.loss.edge.weight > 0:
num_intermediate_outputs = len(pred_fake[i]) - 1 loss[f"edge_a"] = self.config.loss.edge.weight * self.edge_loss(
for j in range(num_intermediate_outputs): generated["images"]["fake_b"], batch["a"]["edge"][:, 0:1, :, :]
loss_fm += self.fm_loss(pred_fake[i][j], pred_real[i][j].detach()) / num_scale_discriminator )
loss[f"fm_{phase}"] = self.config.loss.fm.weight * loss_fm loss[f"edge_b"] = self.config.loss.edge.weight * self.edge_loss(
loss["recon"] = self.config.loss.recon.weight * self.recon_loss(generated["a"], batch["a"]) generated["images"]["fake_a"], batch["b"]["edge"]
loss["edge"] = self.config.loss.edge.weight * self.edge_loss(generated["b"], batch["edge_a"][:, 0:1, :, :]) )
if self.config.loss.cycle.weight > 0:
loss[f"cycle_a"] = self.config.loss.cycle.weight * self.cycle_loss(
batch["a"]["img"], generated["images"]["cycle_a"]
)
return loss return loss
def criterion_discriminators(self, batch, generated) -> dict: def criterion_discriminators(self, batch, generated) -> dict:
loss = dict() loss = dict()
# batch = self._process_batch(batch) # batch = self._process_batch(batch)
for phase in self.discriminators.keys(): for phase in self.discriminators.keys():
pred_real = self.discriminators[phase](batch[phase]) pred_real = self.discriminators[phase](batch[phase]["img"])
pred_fake = self.discriminators[phase](generated[phase].detach()) pred_fake = self.discriminators[phase](generated["images"][f"fake_{phase}"].detach())
loss[f"gan_{phase}"] = 0 loss[f"gan_{phase}"] = 0
for i in range(len(pred_fake)): for i in range(len(pred_fake)):
loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True) loss[f"gan_{phase}"] += (self.gan_loss(pred_fake[i][-1], False, is_discriminator=True)
@@ -122,17 +138,25 @@ class TAFGEngineKernel(EngineKernel):
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]} :return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
""" """
batch = self._process_batch(batch) batch = self._process_batch(batch)
edge = batch["edge_a"][:, 0:1, :, :]
return dict( return dict(
a=[edge.expand(-1, 3, -1, -1).detach(), batch["a"].detach(), batch["b"].detach(), generated["a"].detach(), a=[batch["a"]["edge"][:, 0:1, :, :].expand(-1, 3, -1, -1).detach(),
generated["b"].detach()] batch["a"]["img"].detach(),
generated["images"]["a2a"].detach(),
generated["images"]["fake_b"].detach(),
generated["images"]["cycle_a"].detach(),
],
b=[batch["b"]["edge"].expand(-1, 3, -1, -1).detach(),
batch["b"]["img"].detach(),
generated["images"]["b2b"].detach(),
generated["images"]["fake_a"].detach()]
) )
def change_engine(self, config, trainer): def change_engine(self, config, trainer):
@trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3))) pass
def change_config(engine): # @trainer.on(Events.ITERATION_STARTED(once=int(config.max_iteration / 3)))
with read_write(config): # def change_config(engine):
config.loss.perceptual.weight = 5 # with read_write(config):
# config.loss.perceptual.weight = 5
def run(task, config, _): def run(task, config, _):

View File

@@ -132,9 +132,12 @@ def get_trainer(config, kernel: EngineKernel):
generated = kernel.forward(batch) generated = kernel.forward(batch)
if kernel.train_generator_first: if kernel.train_generator_first:
# simultaneous, train G with simultaneous D
loss_g = train_generators(batch, generated) loss_g = train_generators(batch, generated)
loss_d = train_discriminators(batch, generated) loss_d = train_discriminators(batch, generated)
else: else:
# update discriminators first, not simultaneous.
# train G with updated discriminators
loss_d = train_discriminators(batch, generated) loss_d = train_discriminators(batch, generated)
loss_g = train_generators(batch, generated) loss_g = train_generators(batch, generated)
@@ -152,8 +155,8 @@ def get_trainer(config, kernel: EngineKernel):
kernel.change_engine(config, trainer) kernel.change_engine(config, trainer)
RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values())).attach(trainer, "loss_g") RunningAverage(output_transform=lambda x: sum(x["loss"]["g"].values()), epoch_bound=False).attach(trainer, "loss_g")
RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values())).attach(trainer, "loss_d") RunningAverage(output_transform=lambda x: sum(x["loss"]["d"].values()), epoch_bound=False).attach(trainer, "loss_d")
to_save = dict(trainer=trainer) to_save = dict(trainer=trainer)
to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers}) to_save.update({f"lr_scheduler_{k}": lr_schedulers[k] for k in lr_schedulers})
to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers}) to_save.update({f"optimizer_{k}": optimizers[k] for k in optimizers})
@@ -188,7 +191,13 @@ def get_trainer(config, kernel: EngineKernel):
for i in range(random_start, random_start + 10): for i in range(random_start, random_start + 10):
batch = convert_tensor(engine.state.test_dataset[i], idist.device()) batch = convert_tensor(engine.state.test_dataset[i], idist.device())
for k in batch: for k in batch:
batch[k] = batch[k].view(1, *batch[k].size()) if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].unsqueeze(0)
elif isinstance(batch[k], dict):
for kk in batch[k]:
if isinstance(batch[k][kk], torch.Tensor):
batch[k][kk] = batch[k][kk].unsqueeze(0)
generated = kernel.forward(batch) generated = kernel.forward(batch)
images = kernel.intermediate_images(batch, generated) images = kernel.intermediate_images(batch, generated)

View File

@@ -6,7 +6,6 @@ channels:
dependencies: dependencies:
- python=3.8 - python=3.8
- numpy - numpy
- ipython
- tqdm - tqdm
- pyyaml - pyyaml
- pytorch=1.6.* - pytorch=1.6.*

View File

@@ -92,6 +92,7 @@ class PerceptualLoss(nn.Module):
style_loss=False, norm_img=True, criterion='L1'): style_loss=False, norm_img=True, criterion='L1'):
super(PerceptualLoss, self).__init__() super(PerceptualLoss, self).__init__()
self.norm_img = norm_img self.norm_img = norm_img
assert perceptual_loss ^ style_loss, "There must be one and only one true in style or perceptual"
self.perceptual_loss = perceptual_loss self.perceptual_loss = perceptual_loss
self.style_loss = style_loss self.style_loss = style_loss
self.layer_weights = layer_weights self.layer_weights = layer_weights
@@ -127,8 +128,7 @@ class PerceptualLoss(nn.Module):
percep_loss = 0 percep_loss = 0
for k in x_features.keys(): for k in x_features.keys():
percep_loss += self.percep_criterion(x_features[k], gt_features[k]) * self.layer_weights[k] percep_loss += self.percep_criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
else: return percep_loss
percep_loss = None
# calculate style loss # calculate style loss
if self.style_loss: if self.style_loss:
@@ -136,10 +136,7 @@ class PerceptualLoss(nn.Module):
for k in x_features.keys(): for k in x_features.keys():
style_loss += self.style_criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \ style_loss += self.style_criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k])) * \
self.layer_weights[k] self.layer_weights[k]
else: return style_loss
style_loss = None
return percep_loss, style_loss
def _gram_mat(self, x): def _gram_mat(self, x):
"""Calculate Gram matrix. """Calculate Gram matrix.

154
model/GAN/MUNIT.py Normal file
View File

@@ -0,0 +1,154 @@
import torch
import torch.nn as nn
from model import MODEL
from model.GAN.base import Conv2dBlock, ResBlock
from model.normalization import select_norm_layer
class StyleEncoder(nn.Module):
def __init__(self, in_channels, out_dim, num_conv, base_channels=64, use_spectral_norm=False,
padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
super(StyleEncoder, self).__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
)]
multiple_now = 1
for i in range(1, num_conv + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 2)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
))
sequence.append(nn.AdaptiveAvgPool2d(1))
# conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code
sequence.append(nn.Conv2d(multiple_now * base_channels, out_dim, kernel_size=1, stride=1, padding=0))
self.model = nn.Sequential(*sequence)
def forward(self, x):
return self.model(x).view(x.size(0), -1)
class ContentEncoder(nn.Module):
def __init__(self, in_channels, num_down_sampling, num_res_blocks, base_channels=64, use_spectral_norm=False,
padding_mode='reflect', activation_type="ReLU", norm_type="IN"):
super().__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
)]
for i in range(num_down_sampling):
sequence.append(Conv2dBlock(
base_channels * (2 ** i), base_channels * (2 ** (i + 1)),
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
))
for _ in range(num_res_blocks):
sequence.append(
ResBlock(base_channels * (2 ** num_down_sampling), use_spectral_norm, padding_mode, norm_type,
activation_type)
)
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
return self.sequence(x)
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, num_up_sampling, num_res_blocks,
use_spectral_norm=False, res_norm_type="AdaIN", norm_type="LN", activation_type="ReLU",
padding_mode='reflect'):
super(Decoder, self).__init__()
self.res_norm_type = res_norm_type
self.res_blocks = nn.ModuleList([
ResBlock(in_channels, use_spectral_norm, padding_mode, res_norm_type, activation_type=activation_type)
for _ in range(num_res_blocks)
])
sequence = list()
channels = in_channels
for i in range(num_up_sampling):
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2,
kernel_size=5, stride=1, padding=2, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
),
))
channels = channels // 2
sequence.append(
Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect",
use_spectral_norm=use_spectral_norm, activation_type="Tanh", norm_type="NONE"))
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
for blk in self.res_blocks:
x = blk(x)
return self.sequence(x)
class Fusion(nn.Module):
def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"):
super().__init__()
norm_layer = select_norm_layer(norm_type)
self.start_fc = nn.Sequential(
nn.Linear(in_features, base_features),
norm_layer(base_features),
nn.ReLU(True),
)
self.fcs = nn.Sequential(*[
nn.Sequential(
nn.Linear(base_features, base_features),
norm_layer(base_features),
nn.ReLU(True),
) for _ in range(n_blocks - 2)
])
self.end_fc = nn.Sequential(
nn.Linear(base_features, out_features),
)
def forward(self, x):
x = self.start_fc(x)
x = self.fcs(x)
return self.end_fc(x)
@MODEL.register_module("MUNIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels, num_sampling, num_style_dim, num_style_conv,
num_content_res_blocks, num_decoder_res_blocks, num_fusion_dim, num_fusion_blocks,
use_spectral_norm=False, activation_type="ReLU", padding_mode='reflect'):
super().__init__()
self.num_decoder_res_blocks = num_decoder_res_blocks
self.content_encoder = ContentEncoder(in_channels, num_sampling, num_content_res_blocks, base_channels,
use_spectral_norm, padding_mode, activation_type, norm_type="IN")
self.style_encoder = StyleEncoder(in_channels, num_style_dim, num_style_conv, base_channels, use_spectral_norm,
padding_mode, activation_type, norm_type="NONE")
content_channels = base_channels * (2 ** 2)
self.decoder = Decoder(content_channels, out_channels, num_sampling,
num_decoder_res_blocks, use_spectral_norm, "AdaIN", norm_type="LN",
activation_type=activation_type, padding_mode=padding_mode)
self.fusion = Fusion(num_style_dim, num_decoder_res_blocks * 2 * content_channels * 2,
base_features=num_fusion_dim, n_blocks=num_fusion_blocks, norm_type="NONE")
def encode(self, x):
return self.content_encoder(x), self.style_encoder(x)
def decode(self, content, style):
as_param_style = torch.chunk(self.fusion(style), self.num_decoder_res_blocks * 2, dim=1)
# set style for decoder
for i, blk in enumerate(self.decoder.res_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
return self.decoder(content)
def forward(self, x):
content, style = self.encode(x)
return self.decode(content, style)

View File

@@ -4,16 +4,17 @@ from torchvision.models import vgg19
from model.normalization import select_norm_layer from model.normalization import select_norm_layer
from model.registry import MODEL from model.registry import MODEL
from .base import ResidualBlock from .MUNIT import ContentEncoder, Fusion, Decoder
from .base import ResBlock
class VGG19StyleEncoder(nn.Module): class VGG19StyleEncoder(nn.Module):
def __init__(self, in_channels, base_channels=64, style_dim=512, padding_mode='reflect', norm_type="NONE", def __init__(self, in_channels, base_channels=64, style_dim=512, padding_mode='reflect', norm_type="NONE",
vgg19_layers=(0, 5, 10, 19)): vgg19_layers=(0, 5, 10, 19), fix_vgg19=True):
super().__init__() super().__init__()
self.vgg19_layers = vgg19_layers self.vgg19_layers = vgg19_layers
self.vgg19 = vgg19(pretrained=True).features[:vgg19_layers[-1] + 1] self.vgg19 = vgg19(pretrained=True).features[:vgg19_layers[-1] + 1]
self.vgg19.requires_grad_(False) self.vgg19.requires_grad_(not fix_vgg19)
norm_layer = select_norm_layer(norm_type) norm_layer = select_norm_layer(norm_type)
@@ -52,203 +53,57 @@ class VGG19StyleEncoder(nn.Module):
return x.view(x.size(0), -1) return x.view(x.size(0), -1)
class ContentEncoder(nn.Module):
def __init__(self, in_channels, base_channels=64, num_blocks=8, padding_mode='reflect', norm_type="IN"):
super().__init__()
norm_layer = select_norm_layer(norm_type)
self.start_conv = nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
bias=True),
norm_layer(num_features=base_channels),
nn.ReLU(inplace=True)
)
# down sampling
submodules = []
num_down_sampling = 2
for i in range(num_down_sampling):
multiple = 2 ** i
submodules += [
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
kernel_size=4, stride=2, padding=1, bias=True),
norm_layer(num_features=base_channels * multiple * 2),
nn.ReLU(inplace=True)
]
self.encoder = nn.Sequential(*submodules)
res_block_channels = num_down_sampling ** 2 * base_channels
self.resnet = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
def forward(self, x):
x = self.start_conv(x)
x = self.encoder(x)
x = self.resnet(x)
return x
class Decoder(nn.Module):
def __init__(self, out_channels, base_channels=64, num_blocks=4, num_down_sampling=2, padding_mode='reflect',
norm_type="LN"):
super(Decoder, self).__init__()
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
res_block_channels = (2 ** 2) * base_channels
self.resnet = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type, use_bias=True) for _ in range(num_blocks)])
# up sampling
submodules = []
for i in range(num_down_sampling):
multiple = 2 ** (num_down_sampling - i)
submodules += [
nn.Upsample(scale_factor=2),
nn.Conv2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=5, stride=1,
padding=2, padding_mode=padding_mode, bias=use_bias),
norm_layer(num_features=base_channels * multiple // 2),
nn.ReLU(inplace=True),
]
self.decoder = nn.Sequential(*submodules)
self.end_conv = nn.Sequential(
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
nn.Tanh()
)
def forward(self, x):
x = self.resnet(x)
x = self.decoder(x)
x = self.end_conv(x)
return x
class Fusion(nn.Module):
def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"):
super().__init__()
norm_layer = select_norm_layer(norm_type)
self.start_fc = nn.Sequential(
nn.Linear(in_features, base_features),
norm_layer(base_features),
nn.ReLU(True),
)
self.fcs = nn.Sequential(*[
nn.Sequential(
nn.Linear(base_features, base_features),
norm_layer(base_features),
nn.ReLU(True),
) for _ in range(n_blocks - 2)
])
self.end_fc = nn.Sequential(
nn.Linear(base_features, out_features),
)
def forward(self, x):
x = self.start_fc(x)
x = self.fcs(x)
return self.end_fc(x)
class StyleGenerator(nn.Module):
def __init__(self, style_in_channels, style_dim=512, num_blocks=8, base_channels=64, padding_mode="reflect"):
super().__init__()
self.num_blocks = num_blocks
self.style_encoder = VGG19StyleEncoder(
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE")
self.fc = nn.Sequential(
nn.Linear(style_dim, style_dim),
nn.ReLU(True),
)
res_block_channels = 2 ** 2 * base_channels
self.fusion = Fusion(style_dim, num_blocks * 2 * res_block_channels * 2, base_features=256, n_blocks=3,
norm_type="NONE")
def forward(self, x):
styles = self.fusion(self.fc(self.style_encoder(x)))
return styles
@MODEL.register_module("TAFG-Generator") @MODEL.register_module("TAFG-Generator")
class Generator(nn.Module): class Generator(nn.Module):
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, style_dim=512, def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,
num_adain_blocks=8, num_res_blocks=4, style_dim=512, style_use_fc=True,
num_adain_blocks=8, num_res_blocks=8,
base_channels=64, padding_mode="reflect"): base_channels=64, padding_mode="reflect"):
super(Generator, self).__init__() super(Generator, self).__init__()
self.num_adain_blocks=num_adain_blocks self.num_adain_blocks = num_adain_blocks
self.style_encoders = nn.ModuleDict({ self.style_encoders = nn.ModuleDict(dict(
"a": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks, a=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
base_channels=base_channels, padding_mode=padding_mode), norm_type="NONE"),
"b": StyleGenerator(style_in_channels, style_dim=style_dim, num_blocks=num_adain_blocks, b=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
base_channels=base_channels, padding_mode=padding_mode), norm_type="NONE", fix_vgg19=False)
}) ))
self.content_encoder = ContentEncoder(content_in_channels, base_channels, num_blocks=8, resnet_channels = 2 ** 2 * base_channels
padding_mode=padding_mode, norm_type="IN") self.style_converters = nn.ModuleDict(dict(
res_block_channels = 2 ** 2 * base_channels a=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
norm_type="NONE"),
self.resnet = nn.ModuleDict({ b=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
"a": nn.Sequential(*[ norm_type="NONE"),
ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks) ))
]), self.content_encoders = nn.ModuleDict({
"b": nn.Sequential(*[ "a": ContentEncoder(content_in_channels, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm),
ResidualBlock(res_block_channels, padding_mode, "IN", use_bias=True) for _ in range(num_res_blocks) "b": ContentEncoder(1, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm)
])
})
self.adain_resnet = nn.ModuleDict({
"a": nn.ModuleList([
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks)
]),
"b": nn.ModuleList([
ResidualBlock(res_block_channels, padding_mode, "AdaIN", use_bias=True) for _ in range(num_adain_blocks)
])
}) })
self.decoders = nn.ModuleDict({ self.content_resnet = nn.Sequential(*[
"a": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode), ResBlock(resnet_channels, use_spectral_norm, padding_mode, "IN")
"b": Decoder(out_channels, base_channels, norm_type="LN", num_blocks=0, padding_mode=padding_mode) for _ in range(num_res_blocks)
}) ])
self.decoders = nn.ModuleDict(dict(
a=Decoder(resnet_channels, out_channels, 2,
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode),
b=Decoder(resnet_channels, out_channels, 2,
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode),
))
def forward(self, content_img, style_img, which_decoder: str = "a"): def encode(self, content_img, style_img, which_content, which_style):
x = self.content_encoder(content_img) content = self.content_resnet(self.content_encoders[which_content](content_img))
x = self.resnet[which_decoder](x) style = self.style_encoders[which_style](style_img)
styles = self.style_encoders[which_decoder](style_img) return content, style
styles = torch.chunk(styles, self.num_adain_blocks * 2, dim=1)
for i, ar in enumerate(self.adain_resnet[which_decoder]):
ar.norm1.set_style(styles[2 * i])
ar.norm2.set_style(styles[2 * i + 1])
x = ar(x)
return self.decoders[which_decoder](x)
def decode(self, content, style, which):
decoder = self.decoders[which]
as_param_style = torch.chunk(self.style_converters[which](style), self.num_adain_blocks * 2, dim=1)
# set style for decoder
for i, blk in enumerate(decoder.res_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
return decoder(content)
@MODEL.register_module("TAFG-Discriminator") def forward(self, content_img, style_img, which_content, which_style):
class Discriminator(nn.Module): content, style = self.encode(content_img, style_img, which_content, which_style)
def __init__(self, in_channels=3, base_channels=64, num_down_sampling=2, num_blocks=3, norm_type="IN", return self.decode(content, style, which_style)
padding_mode="reflect"):
super(Discriminator, self).__init__()
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
sequence = [nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
bias=use_bias),
norm_layer(num_features=base_channels),
nn.ReLU(inplace=True)
)]
# stacked intermediate layers,
# gradually increasing the number of filters
multiple_now = 1
for n in range(1, num_down_sampling + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** n, 4)
sequence += [
nn.Conv2d(base_channels * multiple_prev, base_channels * multiple_now, kernel_size=3,
padding=1, stride=2, bias=use_bias),
norm_layer(base_channels * multiple_now),
nn.LeakyReLU(0.2, inplace=True)
]
for _ in range(num_blocks):
sequence.append(ResidualBlock(base_channels * multiple_now, padding_mode, norm_type))
self.model = nn.Sequential(*sequence)
def forward(self, x):
return self.model(x)

View File

@@ -1,10 +1,11 @@
import math from functools import partial
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from model.normalization import select_norm_layer
from model import MODEL from model import MODEL
from model.normalization import select_norm_layer
class GANImageBuffer(object): class GANImageBuffer(object):
@@ -137,3 +138,66 @@ class ResidualBlock(nn.Module):
x = self.relu1(self.norm1(self.conv1(x))) x = self.relu1(self.norm1(self.conv1(x)))
x = self.norm2(self.conv2(x)) x = self.norm2(self.conv2(x))
return x + res return x + res
_DO_NO_THING_FUNC = lambda x: x
def select_activation(t):
if t == "ReLU":
return partial(nn.ReLU, inplace=True)
elif t == "LeakyReLU":
return partial(nn.LeakyReLU, negative_slope=0.2, inplace=True)
elif t == "Tanh":
return partial(nn.Tanh)
elif t == "NONE":
return _DO_NO_THING_FUNC
else:
raise NotImplemented
def _use_bias_checker(norm_type):
return norm_type not in ["IN", "BN", "AdaIN"]
class Conv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, use_spectral_norm=False, activation_type="ReLU",
bias=None, norm_type="NONE", **conv_kwargs):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
conv = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
self.convolution = nn.utils.spectral_norm(conv) if use_spectral_norm else conv
if norm_type != "NONE":
self.normalization = select_norm_layer(norm_type)(out_channels)
if activation_type != "NONE":
self.activation = select_activation(activation_type)()
def forward(self, x):
x = self.convolution(x)
if self.norm_type != "NONE":
x = self.normalization(x)
if self.activation_type != "NONE":
x = self.activation(x)
return x
class ResBlock(nn.Module):
def __init__(self, num_channels, use_spectral_norm=False, padding_mode='reflect',
norm_type="IN", activation_type="ReLU", use_bias=None):
super().__init__()
self.norm_type = norm_type
if use_bias is None:
# bias will be canceled after channel wise normalization
use_bias = _use_bias_checker(norm_type)
self.conv1 = Conv2dBlock(num_channels, num_channels, use_spectral_norm,
kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias,
norm_type=norm_type, activation_type=activation_type)
self.conv2 = Conv2dBlock(num_channels, num_channels, use_spectral_norm,
kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias,
norm_type=norm_type, activation_type="NONE")
def forward(self, x):
return self.conv2(self.conv1(x)) + x

View File

@@ -4,4 +4,5 @@ import model.GAN.TAFG
import model.GAN.UGATIT import model.GAN.UGATIT
import model.GAN.wrapper import model.GAN.wrapper
import model.GAN.base import model.GAN.base
import model.GAN.TSIT import model.GAN.TSIT
import model.GAN.MUNIT