Compare commits

..

6 Commits

Author SHA1 Message Date
436bca88b4 add loss container 2020-10-11 23:09:04 +08:00
6070f08835 add GauGAN 2020-10-11 23:05:38 +08:00
06b2abd19a add flag to switch to norm-activ-conv 2020-10-11 19:02:42 +08:00
9c08b4cd09 move encoder, decoder to CycleGAN 2020-10-11 11:09:16 +08:00
04c6366c07 rewrite 2020-10-11 10:02:33 +08:00
6ea13df465 temp commit 2020-10-10 10:43:00 +08:00
31 changed files with 980 additions and 1321 deletions

2
.idea/deployment.xml generated
View File

@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="14d" remoteFilesAllowedToDisappearOnAutoupload="false">
<component name="PublishConfigData" autoUpload="Always" serverName="21d" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="14d">
<serverdata>

View File

@@ -14,11 +14,15 @@ handler:
n_saved: 2
tensorboard:
scalar: 100 # log scalar `scalar` times per epoch
image: 2 # log image `image` times per epoch
image: 4 # log image `image` times per epoch
test:
random: True
images: 10
model:
generator:
_type: UGATIT-Generator
_add_spectral_norm: True
in_channels: 3
out_channels: 3
base_channels: 64
@@ -27,11 +31,13 @@ model:
light: True
local_discriminator:
_type: UGATIT-Discriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_blocks: 5
global_discriminator:
_type: UGATIT-Discriminator
_add_spectral_norm: True
in_channels: 3
base_channels: 64
num_blocks: 7
@@ -50,6 +56,8 @@ loss:
weight: 10.0
cam:
weight: 1000
mgc:
weight: 0
optimizers:
generator:
@@ -70,7 +78,7 @@ data:
target_lr: 0
buffer_size: 50
dataloader:
batch_size: 24
batch_size: 4
shuffle: True
num_workers: 2
pin_memory: True

View File

@@ -1,287 +0,0 @@
import os
import pickle
from collections import defaultdict
from itertools import permutations, combinations
from pathlib import Path
import lmdb
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
from torchvision.transforms import functional as F
from tqdm import tqdm
from .registry import DATASET
from .transform import transform_pipeline
from .util import dlib_landmark
def default_transform_way(transform, sample):
return [transform(sample[0]), *sample[1:]]
class LMDBDataset(Dataset):
def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True,
**lmdb_kwargs):
self.path = lmdb_path
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
lock=False, **lmdb_kwargs)
with self.db.begin(write=False) as txn:
self._len = pickle.loads(txn.get(b"$$len$$"))
self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$"))
if pipeline is None:
self.not_done_pipeline = []
else:
self.not_done_pipeline = self._remain_pipeline(pipeline)
self.transform = transform_pipeline(self.not_done_pipeline)
self.transform_way = transform_way
essential_attr = pickle.loads(txn.get(b"$$essential_attr$$"))
for ea in essential_attr:
setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8"))))
def _remain_pipeline(self, pipeline):
for i, dp in enumerate(self.done_pipeline):
if pipeline[i] != dp:
raise ValueError(
f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}")
return pipeline[len(self.done_pipeline):]
def __repr__(self):
return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}"
def __len__(self):
return self._len
def __getitem__(self, idx):
with self.db.begin(write=False) as txn:
sample = pickle.loads(txn.get("{}".format(idx).encode()))
sample = self.transform_way(self.transform, sample)
return sample
@staticmethod
def lmdbify(dataset, done_pipeline, lmdb_path):
env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=0):
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
txn.put(b"$$len$$", pickle.dumps(len(dataset)))
txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline))
essential_attr = getattr(dataset, "essential_attr", list())
txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr))
for ea in essential_attr:
txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea)))
@DATASET.register_module()
class ImprovedImageFolder(ImageFolder):
def __init__(self, root, pipeline):
super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x)
self.classes_list = defaultdict(list)
self.essential_attr = ["classes_list"]
for i in range(len(self)):
self.classes_list[self.samples[i][-1]].append(i)
assert len(self.classes_list) == len(self.classes)
class EpisodicDataset(Dataset):
def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes):
self.origin = origin_dataset
self.num_class = num_class
assert self.num_class < len(self.origin.classes_list)
self.num_query = num_query # K
self.num_support = num_support # K
self.num_episodes = num_episodes
def _fetch_list_data(self, id_list):
return [self.origin[i][0] for i in id_list]
def __len__(self):
return self.num_episodes
def __getitem__(self, _):
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
support_set = []
query_set = []
target_set = []
for tag, c in enumerate(random_classes):
image_list = self.origin.classes_list[c]
if len(image_list) >= self.num_query + self.num_support:
# have enough images belong to this class
idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist()
else:
idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist()
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
support_set.extend(support)
query_set.extend(query)
target_set.extend([tag] * self.num_query)
return {
"support": torch.stack(support_set),
"query": torch.stack(query_set),
"target": torch.tensor(target_set)
}
def __repr__(self):
return f"<EpisodicDataset NE{self.num_episodes} NC{self.num_class} NS{self.num_support} NQ{self.num_query}>"
@DATASET.register_module()
class SingleFolderDataset(Dataset):
def __init__(self, root, pipeline, with_path=False):
assert os.path.isdir(root)
self.root = root
samples = []
for r, _, fns in sorted(os.walk(self.root, followlinks=True)):
for fn in sorted(fns):
path = os.path.join(r, fn)
if has_file_allowed_extension(path, IMG_EXTENSIONS):
samples.append(path)
self.samples = samples
self.pipeline = transform_pipeline(pipeline)
self.with_path = with_path
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
if not self.with_path:
return self.pipeline(self.samples[idx])
else:
return self.pipeline(self.samples[idx]), self.samples[idx]
def __repr__(self):
return f"<SingleFolderDataset root={self.root} len={len(self)}>"
@DATASET.register_module()
class GenerationUnpairedDataset(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False):
self.A = SingleFolderDataset(root_a, pipeline, with_path)
self.B = SingleFolderDataset(root_b, pipeline, with_path)
self.random_pair = random_pair
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
return dict(a=self.A[a_idx], b=self.B[b_idx])
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
def normalize_tensor(tensor):
tensor = tensor.float()
tensor -= tensor.min()
tensor /= tensor.max()
return tensor
@DATASET.register_module()
class GenerationUnpairedDatasetWithEdge(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256),
with_path=False):
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
self.edge_type = edge_type
self.size = size
self.edges_path = Path(edges_path)
self.landmarks_path = Path(landmarks_path)
assert self.edges_path.exists()
assert self.landmarks_path.exists()
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
self.random_pair = random_pair
self.with_path = with_path
def get_edge(self, origin_path):
op = Path(origin_path)
if self.edge_type.startswith("landmark_"):
edge_type = self.edge_type.lstrip("landmark_")
use_landmark = op.parent.name.endswith("A")
else:
edge_type = self.edge_type
use_landmark = False
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR))
if not use_landmark:
return origin_edge
else:
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt"
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size)))
part_labels = normalize_tensor(torch.from_numpy(part_labels))
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
# edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
# edges = part_edge + edges
return torch.cat([origin_edge, part_edge, dist_tensor, part_labels])
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
output = dict(a={}, b={})
output["a"]["img"], output["a"]["path"] = self.A[a_idx]
output["b"]["img"], output["b"]["path"] = self.B[b_idx]
for p in "ab":
output[p]["edge"] = self.get_edge(output[p]["path"])
return output
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
@DATASET.register_module()
class PoseFacesWithSingleAnime(Dataset):
def __init__(self, root_face, root_anime, landmark_path, num_face, face_pipeline, anime_pipeline, img_size,
with_order=True):
self.num_face = num_face
self.landmark_path = Path(landmark_path)
self.with_order = with_order
self.root_face = Path(root_face)
self.root_anime = Path(root_anime)
self.img_size = img_size
self.face_samples = self.iter_folders()
self.face_pipeline = transform_pipeline(face_pipeline)
self.B = SingleFolderDataset(root_anime, anime_pipeline, with_path=True)
def iter_folders(self):
pics_per_person = defaultdict(list)
for p in self.root_face.glob("*.jpg"):
pics_per_person[p.stem[:7]].append(p.stem)
data = []
for p in pics_per_person:
if len(pics_per_person[p]) >= self.num_face:
if self.with_order:
data.extend(list(combinations(pics_per_person[p], self.num_face)))
else:
data.extend(list(permutations(pics_per_person[p], self.num_face)))
return data
def read_pose(self, pose_txt):
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(pose_txt, size=self.img_size)
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.img_size)))
part_labels = normalize_tensor(torch.from_numpy(part_labels))
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
return torch.cat([part_labels, part_edge, dist_tensor])
def __len__(self):
return len(self.face_samples)
def __getitem__(self, idx):
output = dict()
output["anime_img"], output["anime_path"] = self.B[torch.randint(len(self.B), (1,)).item()]
for i, f in enumerate(self.face_samples[idx]):
output[f"face_{i}"] = self.face_pipeline(self.root_face / f"{f}.jpg")
output[f"pose_{i}"] = self.read_pose(self.landmark_path / self.root_face.name / f"{f}.txt")
output[f"stem_{i}"] = f
return output

3
data/dataset/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from util.misc import import_submodules
__all__ = import_submodules(__name__).keys()

63
data/dataset/few-shot.py Normal file
View File

@@ -0,0 +1,63 @@
from collections import defaultdict
import torch
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from data.registry import DATASET
from data.transform import transform_pipeline
@DATASET.register_module()
class ImprovedImageFolder(ImageFolder):
def __init__(self, root, pipeline):
super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x)
self.classes_list = defaultdict(list)
self.essential_attr = ["classes_list"]
for i in range(len(self)):
self.classes_list[self.samples[i][-1]].append(i)
assert len(self.classes_list) == len(self.classes)
class EpisodicDataset(Dataset):
def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes):
self.origin = origin_dataset
self.num_class = num_class
assert self.num_class < len(self.origin.classes_list)
self.num_query = num_query # K
self.num_support = num_support # K
self.num_episodes = num_episodes
def _fetch_list_data(self, id_list):
return [self.origin[i][0] for i in id_list]
def __len__(self):
return self.num_episodes
def __getitem__(self, _):
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
support_set = []
query_set = []
target_set = []
for tag, c in enumerate(random_classes):
image_list = self.origin.classes_list[c]
if len(image_list) >= self.num_query + self.num_support:
# have enough images belong to this class
idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist()
else:
idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist()
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
support_set.extend(support)
query_set.extend(query)
target_set.extend([tag] * self.num_query)
return {
"support": torch.stack(support_set),
"query": torch.stack(query_set),
"target": torch.tensor(target_set)
}
def __repr__(self):
return f"<EpisodicDataset NE{self.num_episodes} NC{self.num_class} NS{self.num_support} NQ{self.num_query}>"

View File

@@ -0,0 +1,62 @@
import os
import torch
from torch.utils.data import Dataset
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
from data.registry import DATASET
from data.transform import transform_pipeline
@DATASET.register_module()
class SingleFolderDataset(Dataset):
def __init__(self, root, pipeline, with_path=False):
assert os.path.isdir(root)
self.root = root
samples = []
for r, _, fns in sorted(os.walk(self.root, followlinks=True)):
for fn in sorted(fns):
path = os.path.join(r, fn)
if has_file_allowed_extension(path, IMG_EXTENSIONS):
samples.append(path)
self.samples = samples
self.pipeline = transform_pipeline(pipeline)
self.with_path = with_path
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
output = dict(img=self.pipeline(self.samples[idx]))
if self.with_path:
output["path"] = self.samples[idx]
return output
def __repr__(self):
return f"<SingleFolderDataset root={self.root} len={len(self)} with_path={self.with_path}>"
@DATASET.register_module()
class GenerationUnpairedDataset(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False):
self.A = SingleFolderDataset(root_a, pipeline, with_path)
self.B = SingleFolderDataset(root_b, pipeline, with_path)
self.with_path = with_path
self.random_pair = random_pair
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
output_a = self.A[a_idx]
output_b = self.B[b_idx]
output = dict(a=output_a["img"], b=output_b["img"])
if self.with_path:
output["a_path"] = output_a["path"]
output["b_path"] = output_b["path"]
return output
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"

65
data/dataset/lmdb.py Normal file
View File

@@ -0,0 +1,65 @@
import os
import pickle
import lmdb
from torch.utils.data import Dataset
from tqdm import tqdm
from data.transform import transform_pipeline
def default_transform_way(transform, sample):
return [transform(sample[0]), *sample[1:]]
class LMDBDataset(Dataset):
def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True,
**lmdb_kwargs):
self.path = lmdb_path
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
lock=False, **lmdb_kwargs)
with self.db.begin(write=False) as txn:
self._len = pickle.loads(txn.get(b"$$len$$"))
self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$"))
if pipeline is None:
self.not_done_pipeline = []
else:
self.not_done_pipeline = self._remain_pipeline(pipeline)
self.transform = transform_pipeline(self.not_done_pipeline)
self.transform_way = transform_way
essential_attr = pickle.loads(txn.get(b"$$essential_attr$$"))
for ea in essential_attr:
setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8"))))
def _remain_pipeline(self, pipeline):
for i, dp in enumerate(self.done_pipeline):
if pipeline[i] != dp:
raise ValueError(
f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}")
return pipeline[len(self.done_pipeline):]
def __repr__(self):
return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}"
def __len__(self):
return self._len
def __getitem__(self, idx):
with self.db.begin(write=False) as txn:
sample = pickle.loads(txn.get("{}".format(idx).encode()))
sample = self.transform_way(self.transform, sample)
return sample
@staticmethod
def lmdbify(dataset, done_pipeline, lmdb_path):
env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path))
with env.begin(write=True) as txn:
for i in tqdm(range(len(dataset)), ncols=0):
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
txn.put(b"$$len$$", pickle.dumps(len(dataset)))
txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline))
essential_attr = getattr(dataset, "essential_attr", list())
txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr))
for ea in essential_attr:
txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea)))

View File

@@ -0,0 +1,122 @@
from collections import defaultdict
from itertools import permutations, combinations
from pathlib import Path
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import functional as F
from data.registry import DATASET
from data.transform import transform_pipeline
from data.util import dlib_landmark
def normalize_tensor(tensor):
tensor = tensor.float()
tensor -= tensor.min()
tensor /= tensor.max()
return tensor
@DATASET.register_module()
class GenerationUnpairedDatasetWithEdge(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256),
with_path=False):
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
self.edge_type = edge_type
self.size = size
self.edges_path = Path(edges_path)
self.landmarks_path = Path(landmarks_path)
assert self.edges_path.exists()
assert self.landmarks_path.exists()
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
self.random_pair = random_pair
self.with_path = with_path
def get_edge(self, origin_path):
op = Path(origin_path)
if self.edge_type.startswith("landmark_"):
edge_type = self.edge_type.lstrip("landmark_")
use_landmark = op.parent.name.endswith("A")
else:
edge_type = self.edge_type
use_landmark = False
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR))
if not use_landmark:
return origin_edge
else:
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt"
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size)))
part_labels = normalize_tensor(torch.from_numpy(part_labels))
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
# edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
# edges = part_edge + edges
return torch.cat([origin_edge, part_edge, dist_tensor, part_labels])
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
output = dict(a={}, b={})
output["a"]["img"], output["a"]["path"] = self.A[a_idx]
output["b"]["img"], output["b"]["path"] = self.B[b_idx]
for p in "ab":
output[p]["edge"] = self.get_edge(output[p]["path"])
return output
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
@DATASET.register_module()
class PoseFacesWithSingleAnime(Dataset):
def __init__(self, root_face, root_anime, landmark_path, num_face, face_pipeline, anime_pipeline, img_size,
with_order=True):
self.num_face = num_face
self.landmark_path = Path(landmark_path)
self.with_order = with_order
self.root_face = Path(root_face)
self.root_anime = Path(root_anime)
self.img_size = img_size
self.face_samples = self.iter_folders()
self.face_pipeline = transform_pipeline(face_pipeline)
self.B = SingleFolderDataset(root_anime, anime_pipeline, with_path=True)
def iter_folders(self):
pics_per_person = defaultdict(list)
for p in self.root_face.glob("*.jpg"):
pics_per_person[p.stem[:7]].append(p.stem)
data = []
for p in pics_per_person:
if len(pics_per_person[p]) >= self.num_face:
if self.with_order:
data.extend(list(combinations(pics_per_person[p], self.num_face)))
else:
data.extend(list(permutations(pics_per_person[p], self.num_face)))
return data
def read_pose(self, pose_txt):
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(pose_txt, size=self.img_size)
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.img_size)))
part_labels = normalize_tensor(torch.from_numpy(part_labels))
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
return torch.cat([part_labels, part_edge, dist_tensor])
def __len__(self):
return len(self.face_samples)
def __getitem__(self, idx):
output = dict()
output["anime_img"], output["anime_path"] = self.B[torch.randint(len(self.B), (1,)).item()]
for i, f in enumerate(self.face_samples[idx]):
output[f"face_{i}"] = self.face_pipeline(self.root_face / f"{f}.jpg")
output[f"pose_{i}"] = self.read_pose(self.landmark_path / self.root_face.name / f"{f}.txt")
output[f"stem_{i}"] = f
return output

View File

@@ -28,7 +28,7 @@ class Load:
def transform_pipeline(pipeline_description):
if len(pipeline_description) == 0:
if pipeline_description is None or len(pipeline_description) == 0:
return lambda x: x
transform_list = [TRANSFORM.build_with(pd) for pd in pipeline_description]
return transforms.Compose(transform_list)

View File

@@ -1,16 +1,16 @@
from omegaconf import OmegaConf
import ignite.distributed as idist
import torch
import torch.nn as nn
import torch.nn.functional as F
import ignite.distributed as idist
from omegaconf import OmegaConf
from loss.gan import GANLoss
from model.GAN.UGATIT import RhoClipper
from model.GAN.base import GANImageBuffer
from util.image import attention_colored_map
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
from engine.util.build import build_model
from engine.util.container import LossContainer
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
from loss.gan import GANLoss
from model.image_translation.UGATIT import RhoClipper
from util.image import attention_colored_map
def mse_loss(x, target_flag):
@@ -28,11 +28,11 @@ class UGATITEngineKernel(EngineKernel):
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
gan_loss_cfg.pop("weight")
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
self.cycle_loss = LossContainer(config.loss.cycle.weight,
nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss())
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
self.id_loss = LossContainer(config.loss.id.weight, nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss())
self.rho_clipper = RhoClipper(0, 1)
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
self.discriminators.keys()}
self.train_generator_first = False
def build_models(self) -> (dict, dict):
@@ -79,9 +79,9 @@ class UGATITEngineKernel(EngineKernel):
loss = dict()
for phase in "ab":
cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"]
loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase])
loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase],
generated["images"][f"{phase}2{phase}"])
loss[f"cycle_{phase}"] = self.cycle_loss(cycle_image, batch[phase])
loss[f"id_{phase}"] = self.id_loss(batch[phase], generated["images"][f"{phase}2{phase}"])
loss[f"mgc_{phase}"] = self.mgc_loss(batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
for dk in "lg":
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)

View File

@@ -64,7 +64,7 @@ class EngineKernel(object):
self.engine = engine
def build_models(self) -> (dict, dict):
raise NotImplemented
raise NotImplementedError
def to_save(self):
to_save = {}
@@ -73,19 +73,19 @@ class EngineKernel(object):
return to_save
def setup_after_g(self):
raise NotImplemented
raise NotImplementedError
def setup_before_g(self):
raise NotImplemented
raise NotImplementedError
def forward(self, batch, inference=False) -> dict:
raise NotImplemented
raise NotImplementedError
def criterion_generators(self, batch, generated) -> dict:
raise NotImplemented
raise NotImplementedError
def criterion_discriminators(self, batch, generated) -> dict:
raise NotImplemented
raise NotImplementedError
def intermediate_images(self, batch, generated) -> dict:
"""
@@ -94,12 +94,19 @@ class EngineKernel(object):
:param generated: dict of images
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
"""
raise NotImplemented
raise NotImplementedError
def change_engine(self, config, engine: Engine):
pass
def _remove_no_grad_loss(loss_dict):
for k in loss_dict:
if not isinstance(loss_dict[k], torch.Tensor):
loss_dict.pop(k)
return loss_dict
def get_trainer(config, kernel: EngineKernel):
logger = logging.getLogger(config.name)
generators, discriminators = kernel.generators, kernel.discriminators
@@ -147,10 +154,10 @@ def get_trainer(config, kernel: EngineKernel):
if engine.state.iteration % iteration_per_image == 0:
return {
"loss": dict(g=loss_g, d=loss_d),
"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d)),
"img": kernel.intermediate_images(batch, generated)
}
return {"loss": dict(g=loss_g, d=loss_d)}
return {"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d))}
trainer = Engine(_step)
trainer.logger = logger

View File

@@ -1,18 +1,21 @@
import torch
import ignite.distributed as idist
import torch
import torch.optim as optim
from omegaconf import OmegaConf
from model import MODEL
import torch.optim as optim
from util.misc import add_spectral_norm
def build_model(cfg):
cfg = OmegaConf.to_container(cfg)
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
add_spectral_norm_flag = cfg.pop("_add_spectral_norm", False)
model = MODEL.build_with(cfg)
if bn_to_sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if add_spectral_norm_flag:
model.apply(add_spectral_norm)
return idist.auto_model(model)

9
engine/util/container.py Normal file
View File

@@ -0,0 +1,9 @@
class LossContainer:
def __init__(self, weight, loss):
self.weight = weight
self.loss = loss
def __call__(self, *args, **kwargs):
if self.weight > 0:
return self.weight * self.loss(*args, **kwargs)
return 0.0

View File

@@ -0,0 +1,111 @@
import torch
import torch.nn as nn
def gaussian_radial_basis_function(x, mu, sigma):
# (kernel_size) -> (batch_size, kernel_size, c*h*w)
mu = mu.view(1, mu.size(0), 1).expand(x.size(0), -1, x.size(1) * x.size(2) * x.size(3))
mu = mu.to(x.device)
# (batch_size, c, h, w) -> (batch_size, kernel_size, c*h*w)
x = x.view(x.size(0), 1, -1).expand(-1, mu.size(1), -1)
return torch.exp((x - mu).pow(2) / (2 * sigma ** 2))
class MyLoss(torch.nn.Module):
def __init__(self):
super(MyLoss, self).__init__()
def forward(self, fakeI, realI):
def batch_ERSMI(I1, I2):
batch_size = I1.shape[0]
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
if I2.shape[1] == 1 and I1.shape[1] != 1:
I2 = I2.repeat(1, 3, 1, 1)
def kernel_F(y, mu_list, sigma):
tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size, 1, 1).cuda() # [81, 784]
tmp_y = y.view(batch_size, 1, -1).repeat(1, 81, 1)
tmp_y = tmp_mu - tmp_y
mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
return mat_L
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).cuda()
x_mu_list = mu.repeat(9).view(-1, 81)
y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)
mat_K = kernel_F(I1, x_mu_list, 1)
mat_L = kernel_F(I2, y_mu_list, 1)
H1 = ((mat_K.matmul(mat_K.transpose(1, 2))).mul(mat_L.matmul(mat_L.transpose(1, 2))) / (
img_size ** 2)).cuda()
H2 = ((mat_K.mul(mat_L)).matmul((mat_K.mul(mat_L)).transpose(1, 2)) / img_size).cuda()
h2 = ((mat_K.sum(2).view(batch_size, -1, 1)).mul(mat_L.sum(2).view(batch_size, -1, 1)) / (
img_size ** 2)).cuda()
H2 = 0.5 * H1 + 0.5 * H2
tmp = H2 + 0.05 * torch.eye(len(H2[0])).cuda()
alpha = (tmp.inverse())
alpha = alpha.matmul(h2)
ersmi = (2 * (alpha.transpose(1, 2)).matmul(h2) - ((alpha.transpose(1, 2)).matmul(H2)).matmul(
alpha) - 1).squeeze()
ersmi = -ersmi.mean()
return ersmi
batch_loss = batch_ERSMI(fakeI, realI)
return batch_loss
class MGCLoss(nn.Module):
"""
Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ
"""
def __init__(self, beta=0.5, lambda_=0.05):
super().__init__()
self.beta = beta
self.lambda_ = lambda_
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0])
self.mu_x = mu.repeat(9)
self.mu_y = mu.unsqueeze(0).t().repeat(1, 9).view(-1)
@staticmethod
def batch_rSMI(img1, img2, mu_x, mu_y, beta, lambda_):
assert img1.size() == img2.size()
num_pixel = img1.size(1) * img1.size(2) * img2.size(3)
mat_k = gaussian_radial_basis_function(img1, mu_x, sigma=1)
mat_l = gaussian_radial_basis_function(img2, mu_y, sigma=1)
mat_k_mul_mat_l = mat_k * mat_l
h_hat = (1 - beta) * (mat_k_mul_mat_l.matmul(mat_k_mul_mat_l.transpose(1, 2))) / num_pixel
h_hat += beta * (mat_k.matmul(mat_k.transpose(1, 2)) * mat_l.matmul(mat_l.transpose(1, 2))) / (num_pixel ** 2)
small_h_hat = mat_k.sum(2, keepdim=True) * mat_l.sum(2, keepdim=True) / (num_pixel ** 2)
R = torch.eye(h_hat.size(1)).to(img1.device)
alpha = (h_hat + lambda_ * R).inverse().matmul(small_h_hat)
rSMI = (2 * alpha.transpose(1, 2).matmul(small_h_hat)) - alpha.transpose(1, 2).matmul(h_hat).matmul(alpha) - 1
return rSMI
def forward(self, fake, real):
rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_)
return -rSMI.squeeze().mean()
if __name__ == '__main__':
mg = MGCLoss().to("cuda")
def norm(x):
x -= x.min()
x /= x.max()
return (x - 0.5) * 2
x1 = norm(torch.randn(5, 3, 256, 256))
x2 = norm(x1 * 2 + 1)
x3 = norm(torch.randn(5, 3, 256, 256))
x4 = norm(torch.exp(x3))
print(mg(x1, x1), mg(x1, x2), mg(x1, x3), mg(x1, x4))

View File

@@ -1,62 +0,0 @@
import torch.nn as nn
from model.normalization import select_norm_layer
from model.registry import MODEL
from .base import ResidualBlock
@MODEL.register_module("CyCle-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect',
norm_type="IN"):
super(Generator, self).__init__()
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
self.start_conv = nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
bias=use_bias),
norm_layer(num_features=base_channels),
nn.ReLU(inplace=True)
)
# down sampling
submodules = []
num_down_sampling = 2
for i in range(num_down_sampling):
multiple = 2 ** i
submodules += [
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(num_features=base_channels * multiple * 2),
nn.ReLU(inplace=True)
]
self.encoder = nn.Sequential(*submodules)
res_block_channels = num_down_sampling ** 2 * base_channels
self.resnet_middle = nn.Sequential(
*[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in
range(num_blocks)])
# up sampling
submodules = []
for i in range(num_down_sampling):
multiple = 2 ** (num_down_sampling - i)
submodules += [
nn.ConvTranspose2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=3, stride=2,
padding=1, output_padding=1, bias=use_bias),
norm_layer(num_features=base_channels * multiple // 2),
nn.ReLU(inplace=True),
]
self.decoder = nn.Sequential(*submodules)
self.end_conv = nn.Sequential(
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
nn.Tanh()
)
def forward(self, x):
x = self.encoder(self.start_conv(x))
x = self.resnet_middle(x)
return self.end_conv(self.decoder(x))

View File

@@ -1,150 +0,0 @@
import torch
import torch.nn as nn
from model import MODEL
from model.GAN.base import Conv2dBlock, ResBlock
from model.normalization import select_norm_layer
class StyleEncoder(nn.Module):
def __init__(self, in_channels, out_dim, num_conv, base_channels=64, use_spectral_norm=False,
max_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
super(StyleEncoder, self).__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
)]
multiple_now = 1
for i in range(1, num_conv + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** max_multiple)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
))
sequence.append(nn.AdaptiveAvgPool2d(1))
# conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code
sequence.append(nn.Conv2d(multiple_now * base_channels, out_dim, kernel_size=1, stride=1, padding=0))
self.model = nn.Sequential(*sequence)
def forward(self, x):
return self.model(x).view(x.size(0), -1)
class ContentEncoder(nn.Module):
def __init__(self, in_channels, num_down_sampling, num_res_blocks, base_channels=64, use_spectral_norm=False,
padding_mode='reflect', activation_type="ReLU", norm_type="IN"):
super().__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
)]
for i in range(num_down_sampling):
sequence.append(Conv2dBlock(
base_channels * (2 ** i), base_channels * (2 ** (i + 1)),
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
))
sequence += [ResBlock(base_channels * (2 ** num_down_sampling), use_spectral_norm, padding_mode, norm_type,
activation_type) for _ in range(num_res_blocks)]
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
return self.sequence(x)
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, num_up_sampling, num_res_blocks,
use_spectral_norm=False, res_norm_type="AdaIN", norm_type="LN", activation_type="ReLU",
padding_mode='reflect'):
super(Decoder, self).__init__()
self.res_norm_type = res_norm_type
self.res_blocks = nn.ModuleList([
ResBlock(in_channels, use_spectral_norm, padding_mode, res_norm_type, activation_type=activation_type)
for _ in range(num_res_blocks)
])
sequence = list()
channels = in_channels
for i in range(num_up_sampling):
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2,
kernel_size=5, stride=1, padding=2, padding_mode=padding_mode,
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
),
))
channels = channels // 2
sequence.append(
Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect",
use_spectral_norm=use_spectral_norm, activation_type="Tanh", norm_type="NONE"))
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
for blk in self.res_blocks:
x = blk(x)
return self.sequence(x)
class Fusion(nn.Module):
def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"):
super().__init__()
norm_layer = select_norm_layer(norm_type)
self.start_fc = nn.Sequential(
nn.Linear(in_features, base_features),
norm_layer(base_features),
nn.ReLU(True),
)
self.fcs = nn.Sequential(*[
nn.Sequential(
nn.Linear(base_features, base_features),
norm_layer(base_features),
nn.ReLU(True),
) for _ in range(n_blocks - 2)
])
self.end_fc = nn.Sequential(
nn.Linear(base_features, out_features),
)
def forward(self, x):
x = self.start_fc(x)
x = self.fcs(x)
return self.end_fc(x)
@MODEL.register_module("MUNIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels, num_sampling, num_style_dim, num_style_conv,
num_content_res_blocks, num_decoder_res_blocks, num_fusion_dim, num_fusion_blocks,
use_spectral_norm=False, activation_type="ReLU", padding_mode='reflect'):
super().__init__()
self.num_decoder_res_blocks = num_decoder_res_blocks
self.content_encoder = ContentEncoder(in_channels, num_sampling, num_content_res_blocks, base_channels,
use_spectral_norm, padding_mode, activation_type, norm_type="IN")
self.style_encoder = StyleEncoder(in_channels, num_style_dim, num_style_conv, base_channels, use_spectral_norm,
padding_mode, activation_type, norm_type="NONE")
content_channels = base_channels * (2 ** 2)
self.decoder = Decoder(content_channels, out_channels, num_sampling,
num_decoder_res_blocks, use_spectral_norm, "AdaIN", norm_type="LN",
activation_type=activation_type, padding_mode=padding_mode)
self.fusion = Fusion(num_style_dim, num_decoder_res_blocks * 2 * content_channels * 2,
base_features=num_fusion_dim, n_blocks=num_fusion_blocks, norm_type="NONE")
def encode(self, x):
return self.content_encoder(x), self.style_encoder(x)
def decode(self, content, style):
as_param_style = torch.chunk(self.fusion(style), self.num_decoder_res_blocks * 2, dim=1)
# set style for decoder
for i, blk in enumerate(self.decoder.res_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
return self.decoder(content)
def forward(self, x):
content, style = self.encode(x)
return self.decode(content, style)

View File

@@ -1,171 +0,0 @@
import torch
import torch.nn as nn
from torchvision.models import vgg19
from model.normalization import select_norm_layer
from model.registry import MODEL
from .MUNIT import ContentEncoder, Fusion, Decoder, StyleEncoder
from .base import ResBlock
class VGG19StyleEncoder(nn.Module):
def __init__(self, in_channels, base_channels=64, style_dim=512, padding_mode='reflect', norm_type="NONE",
vgg19_layers=(0, 5, 10, 19), fix_vgg19=True):
super().__init__()
self.vgg19_layers = vgg19_layers
self.vgg19 = vgg19(pretrained=True).features[:vgg19_layers[-1] + 1]
self.vgg19.requires_grad_(not fix_vgg19)
norm_layer = select_norm_layer(norm_type)
self.conv0 = nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
bias=True),
norm_layer(base_channels),
nn.ReLU(True),
)
self.conv = nn.ModuleList([
nn.Sequential(
nn.Conv2d(base_channels * (2 ** i), base_channels * (2 ** i), kernel_size=4, stride=2, padding=1,
padding_mode=padding_mode, bias=True),
norm_layer(base_channels),
nn.ReLU(True),
) for i in range(1, 4)
])
self.pool = nn.AdaptiveAvgPool2d(1)
self.conv1x1 = nn.Conv2d(base_channels * (2 ** 4), style_dim, kernel_size=1, stride=1, padding=0)
def fixed_style_features(self, x):
features = []
for i in range(len(self.vgg19)):
x = self.vgg19[i](x)
if i in self.vgg19_layers:
features.append(x)
return features
def forward(self, x):
fsf = self.fixed_style_features(x)
x = self.conv0(x)
for i, l in enumerate(self.conv):
x = l(torch.cat([x, fsf[i]], dim=1))
x = self.pool(torch.cat([x, fsf[-1]], dim=1))
x = self.conv1x1(x)
return x.view(x.size(0), -1)
@MODEL.register_module("TAFG-ResGenerator")
class ResGenerator(nn.Module):
def __init__(self, in_channels, out_channels=3, use_spectral_norm=False, num_res_blocks=8, base_channels=64):
super().__init__()
self.content_encoder = ContentEncoder(in_channels, 2, num_res_blocks=num_res_blocks,
use_spectral_norm=use_spectral_norm)
resnet_channels = 2 ** 2 * base_channels
self.decoder = Decoder(resnet_channels, out_channels, 2,
0, use_spectral_norm, "IN", norm_type="LN", padding_mode="reflect")
def forward(self, x):
return self.decoder(self.content_encoder(x))
@MODEL.register_module("TAFG-SingleGenerator")
class SingleGenerator(nn.Module):
def __init__(self, style_in_channels, content_in_channels, out_channels=3, use_spectral_norm=False,
style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8,
num_res_blocks=8, base_channels=64, padding_mode="reflect"):
super().__init__()
self.num_adain_blocks = num_adain_blocks
if style_encoder_type == "StyleEncoder":
self.style_encoder = StyleEncoder(
style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"
)
elif style_encoder_type == "VGG19StyleEncoder":
self.style_encoder = VGG19StyleEncoder(
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE"
)
else:
raise NotImplemented(f"do not support {style_encoder_type}")
resnet_channels = 2 ** 2 * base_channels
self.style_converter = Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256,
n_blocks=3, norm_type="NONE")
self.content_encoder = ContentEncoder(content_in_channels, 2, num_res_blocks=num_res_blocks,
use_spectral_norm=use_spectral_norm)
self.decoder = Decoder(resnet_channels, out_channels, 2,
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode)
def forward(self, content_img, style_img):
content = self.content_encoder(content_img)
style = self.style_encoder(style_img)
as_param_style = torch.chunk(self.style_converter(style), self.num_adain_blocks * 2, dim=1)
# set style for decoder
for i, blk in enumerate(self.decoder.res_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
return self.decoder(content)
@MODEL.register_module("TAFG-Generator")
class Generator(nn.Module):
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,
style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8,
num_res_blocks=8, base_channels=64, padding_mode="reflect"):
super(Generator, self).__init__()
self.num_adain_blocks = num_adain_blocks
if style_encoder_type == "StyleEncoder":
self.style_encoders = nn.ModuleDict(dict(
a=StyleEncoder(style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"),
b=StyleEncoder(style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"),
))
elif style_encoder_type == "VGG19StyleEncoder":
self.style_encoders = nn.ModuleDict(dict(
a=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
norm_type="NONE"),
b=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
norm_type="NONE", fix_vgg19=False)
))
else:
raise NotImplemented(f"do not support {style_encoder_type}")
resnet_channels = 2 ** 2 * base_channels
self.style_converters = nn.ModuleDict(dict(
a=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
norm_type="NONE"),
b=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
norm_type="NONE"),
))
self.content_encoders = nn.ModuleDict({
"a": ContentEncoder(content_in_channels, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm),
"b": ContentEncoder(1, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm)
})
self.content_resnet = nn.Sequential(*[
ResBlock(resnet_channels, use_spectral_norm, padding_mode, "IN")
for _ in range(num_res_blocks)
])
self.decoders = nn.ModuleDict(dict(
a=Decoder(resnet_channels, out_channels, 2,
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode),
b=Decoder(resnet_channels, out_channels, 2,
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode),
))
def encode(self, content_img, style_img, which_content, which_style):
content = self.content_resnet(self.content_encoders[which_content](content_img))
style = self.style_encoders[which_style](style_img)
return content, style
def decode(self, content, style, which):
decoder = self.decoders[which]
as_param_style = torch.chunk(self.style_converters[which](style), self.num_adain_blocks * 2, dim=1)
# set style for decoder
for i, blk in enumerate(decoder.res_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
return decoder(content)
def forward(self, content_img, style_img, which_content, which_style):
content, style = self.encode(content_img, style_img, which_content, which_style)
return self.decode(content, style, which_style)

View File

@@ -1,88 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import MODEL
from model.base.module import Conv2dBlock, ResidualBlock, ReverseResidualBlock
class Interpolation(nn.Module):
def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
super(Interpolation, self).__init__()
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
def forward(self, x):
return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
recompute_scale_factor=False)
def __repr__(self):
return f"DownSampling(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
@MODEL.register_module("TSIT-Generator")
class Generator(nn.Module):
def __init__(self, content_in_channels=3, out_channels=3, base_channels=64, num_blocks=7,
padding_mode="reflect", activation_type="ReLU"):
super().__init__()
self.num_blocks = num_blocks
self.base_channels = base_channels
self.content_stream = self.build_stream(padding_mode, activation_type)
self.start_conv = Conv2dBlock(content_in_channels, base_channels, activation_type=activation_type,
norm_type="IN", kernel_size=7, padding_mode=padding_mode, padding=3)
sequence = []
multiple_now = min(2 ** self.num_blocks, 2 ** 4)
for i in range(1, self.num_blocks + 1):
m = self.num_blocks - i
multiple_prev = multiple_now
multiple_now = min(2 ** m, 2 ** 4)
sequence.append(nn.Sequential(
ReverseResidualBlock(
multiple_prev * base_channels, multiple_now * base_channels,
padding_mode=padding_mode, norm_type="FADE",
additional_norm_kwargs=dict(
condition_in_channels=multiple_prev * base_channels,
base_norm_type="BN",
padding_mode=padding_mode
)
),
Interpolation(2, mode="nearest")
))
self.generator = nn.Sequential(*sequence)
self.end_conv = Conv2dBlock(base_channels, out_channels, activation_type="Tanh",
kernel_size=7, padding_mode=padding_mode, padding=3)
def build_stream(self, padding_mode, activation_type):
multiple_now = 1
stream_sequence = []
for i in range(1, self.num_blocks + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
stream_sequence.append(nn.Sequential(
Interpolation(scale_factor=0.5, mode="nearest"),
ResidualBlock(
multiple_prev * self.base_channels, multiple_now * self.base_channels,
padding_mode=padding_mode, activation_type=activation_type, norm_type="IN")
))
return nn.ModuleList(stream_sequence)
def forward(self, content, z=None):
c = self.start_conv(content)
content_features = []
for i in range(self.num_blocks):
c = self.content_stream[i](c)
content_features.append(c)
if z is None:
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
for i in range(self.num_blocks):
m = - i - 1
res_block = self.generator[i][0]
res_block.conv1.normalization.set_feature(content_features[m])
res_block.conv2.normalization.set_feature(content_features[m])
if res_block.learn_skip_connection:
res_block.res_conv.normalization.set_feature(content_features[m])
return self.end_conv(self.generator(z))

View File

@@ -1,236 +0,0 @@
import torch
import torch.nn as nn
from .base import ResidualBlock
from model.registry import MODEL
class RhoClipper(object):
def __init__(self, clip_min, clip_max):
self.clip_min = clip_min
self.clip_max = clip_max
assert clip_min < clip_max
def __call__(self, module):
if hasattr(module, 'rho'):
w = module.rho.data
w = w.clamp(self.clip_min, self.clip_max)
module.rho.data = w
@MODEL.register_module("UGATIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False):
assert (num_blocks >= 0)
super(Generator, self).__init__()
self.input_channels = in_channels
self.output_channels = out_channels
self.base_channels = base_channels
self.num_blocks = num_blocks
self.img_size = img_size
self.light = light
down_encoder = [nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3,
padding_mode="reflect", bias=False),
nn.InstanceNorm2d(base_channels),
nn.ReLU(True)]
n_down_sampling = 2
for i in range(n_down_sampling):
mult = 2 ** i
down_encoder += [nn.Conv2d(base_channels * mult, base_channels * mult * 2, kernel_size=3, stride=2,
padding=1, bias=False, padding_mode="reflect"),
nn.InstanceNorm2d(base_channels * mult * 2),
nn.ReLU(True)]
# Down-Sampling Bottleneck
mult = 2 ** n_down_sampling
for i in range(num_blocks):
down_encoder += [ResidualBlock(base_channels * mult, use_bias=False)]
self.down_encoder = nn.Sequential(*down_encoder)
# Class Activation Map
self.gap_fc = nn.Linear(base_channels * mult, 1, bias=False)
self.gmp_fc = nn.Linear(base_channels * mult, 1, bias=False)
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
self.relu = nn.ReLU(True)
# Gamma, Beta block
if self.light:
fc = [nn.Linear(base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True),
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True)]
else:
fc = [
nn.Linear(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True),
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
nn.ReLU(True)]
self.fc = nn.Sequential(*fc)
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
# Up-Sampling Bottleneck
self.up_bottleneck = nn.ModuleList(
[ResnetAdaILNBlock(base_channels * mult, use_bias=False) for _ in range(num_blocks)])
# Up-Sampling
up_decoder = []
for i in range(n_down_sampling):
mult = 2 ** (n_down_sampling - i)
up_decoder += [nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(base_channels * mult, base_channels * mult // 2, kernel_size=3, stride=1,
padding=1, padding_mode="reflect", bias=False),
ILN(base_channels * mult // 2),
nn.ReLU(True)]
up_decoder += [nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3,
padding_mode="reflect", bias=False),
nn.Tanh()]
self.up_decoder = nn.Sequential(*up_decoder)
def forward(self, x):
x = self.down_encoder(x)
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
x = torch.cat([gap, gmp], 1)
x = self.relu(self.conv1x1(x))
heatmap = torch.sum(x, dim=1, keepdim=True)
if self.light:
x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1)
x_ = self.fc(x_.view(x_.shape[0], -1))
else:
x_ = self.fc(x.view(x.shape[0], -1))
gamma, beta = self.gamma(x_), self.beta(x_)
for ub in self.up_bottleneck:
x = ub(x, gamma, beta)
x = self.up_decoder(x)
return x, cam_logit, heatmap
class ResnetAdaILNBlock(nn.Module):
def __init__(self, dim, use_bias):
super(ResnetAdaILNBlock, self).__init__()
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
self.norm1 = AdaILN(dim)
self.relu1 = nn.ReLU(True)
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
self.norm2 = AdaILN(dim)
def forward(self, x, gamma, beta):
out = self.conv1(x)
out = self.norm1(out, gamma, beta)
out = self.relu1(out)
out = self.conv2(out)
out = self.norm2(out, gamma, beta)
return out + x
def instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True)
out_in = (x - in_mean) / torch.sqrt(in_var + eps)
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
out_ln = (x - ln_mean) / torch.sqrt(ln_var + eps)
out = rho.expand(x.shape[0], -1, -1, -1) * out_in + (1 - rho.expand(x.shape[0], -1, -1, -1)) * out_ln
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
return out
class AdaILN(nn.Module):
def __init__(self, num_features, eps=1e-5, default_rho=0.9):
super(AdaILN, self).__init__()
self.eps = eps
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.rho.data.fill_(default_rho)
def forward(self, x, gamma, beta):
return instance_layer_normalization(x, gamma, beta, self.rho, self.eps)
class ILN(nn.Module):
def __init__(self, num_features, eps=1e-5):
super(ILN, self).__init__()
self.eps = eps
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
self.gamma = nn.Parameter(torch.Tensor(1, num_features))
self.beta = nn.Parameter(torch.Tensor(1, num_features))
self.rho.data.fill_(0.0)
self.gamma.data.fill_(1.0)
self.beta.data.fill_(0.0)
def forward(self, x):
return instance_layer_normalization(
x, self.gamma.expand(x.shape[0], -1), self.beta.expand(x.shape[0], -1), self.rho, self.eps)
@MODEL.register_module("UGATIT-Discriminator")
class Discriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_blocks=5):
super(Discriminator, self).__init__()
encoder = [self.build_conv_block(in_channels, base_channels)]
for i in range(1, num_blocks - 2):
mult = 2 ** (i - 1)
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2))
mult = 2 ** (num_blocks - 2 - 1)
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2, stride=1))
self.encoder = nn.Sequential(*encoder)
# Class Activation Map
mult = 2 ** (num_blocks - 2)
self.gap_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
self.gmp_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
self.leaky_relu = nn.LeakyReLU(0.2, True)
self.conv = nn.utils.spectral_norm(
nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False, padding_mode="reflect"))
@staticmethod
def build_conv_block(in_channels, out_channels, kernel_size=4, stride=2, padding=1, padding_mode="reflect"):
return nn.Sequential(*[
nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
bias=True, padding=padding, padding_mode=padding_mode)),
nn.LeakyReLU(0.2, True),
])
def forward(self, x, return_heatmap=False):
x = self.encoder(x)
batch_size = x.size(0)
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) # B x C x 1 x 1, avg of per channel
gap_logit = self.gap_fc(gap.view(batch_size, -1))
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
gmp_logit = self.gmp_fc(gmp.view(batch_size, -1))
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
x = torch.cat([gap, gmp], 1)
x = self.leaky_relu(self.conv1x1(x))
if return_heatmap:
heatmap = torch.sum(x, dim=1, keepdim=True)
return self.conv(x), cam_logit, heatmap
else:
return self.conv(x), cam_logit

View File

@@ -1,203 +0,0 @@
from functools import partial
import math
import torch
import torch.nn as nn
from model import MODEL
from model.normalization import select_norm_layer
class GANImageBuffer(object):
"""This class implements an image buffer that stores previously
generated images.
This buffer allows us to update the discriminator using a history of
generated images rather than the ones produced by the latest generator
to reduce model oscillation.
Args:
buffer_size (int): The size of image buffer. If buffer_size = 0,
no buffer will be created.
buffer_ratio (float): The chance / possibility to use the images
previously stored in the buffer.
"""
def __init__(self, buffer_size, buffer_ratio=0.5):
self.buffer_size = buffer_size
# create an empty buffer
if self.buffer_size > 0:
self.img_num = 0
self.image_buffer = []
self.buffer_ratio = buffer_ratio
def query(self, images):
"""Query current image batch using a history of generated images.
Args:
images (Tensor): Current image batch without history information.
"""
if self.buffer_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
# if the buffer is not full, keep inserting current images
if self.img_num < self.buffer_size:
self.img_num = self.img_num + 1
self.image_buffer.append(image)
return_images.append(image)
else:
use_buffer = torch.rand(1) < self.buffer_ratio
# by self.buffer_ratio, the buffer will return a previously
# stored image, and insert the current image into the buffer
if use_buffer:
random_id = torch.randint(0, self.buffer_size, (1,)).item()
image_tmp = self.image_buffer[random_id].clone()
self.image_buffer[random_id] = image
return_images.append(image_tmp)
# by (1 - self.buffer_ratio), the buffer will return the
# current image
else:
return_images.append(image)
# collect all the images and return
return_images = torch.cat(return_images, 0)
return return_images
# based SPADE or pix2pixHD Discriminator
@MODEL.register_module("PatchDiscriminator")
class PatchDiscriminator(nn.Module):
def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN",
need_intermediate_feature=False):
super().__init__()
self.need_intermediate_feature = need_intermediate_feature
kernel_size = 4
padding = math.ceil((kernel_size - 1.0) / 2)
norm_layer = select_norm_layer(norm_type)
use_bias = norm_type == "IN"
padding_mode = "zeros"
sequence = [nn.Sequential(
nn.Conv2d(in_channels, base_channels, kernel_size, stride=2, padding=padding),
nn.LeakyReLU(0.2, False)
)]
multiple_now = 1
for i in range(1, num_conv):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 3)
stride = 1 if i == num_conv - 1 else 2
sequence.append(nn.Sequential(
self.build_conv2d(use_spectral, base_channels * multiple_prev, base_channels * multiple_now,
kernel_size, stride, padding, bias=use_bias, padding_mode=padding_mode),
norm_layer(base_channels * multiple_now),
nn.LeakyReLU(0.2, inplace=False),
))
multiple_now = min(2 ** num_conv, 8)
sequence.append(nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding,
padding_mode=padding_mode))
self.conv_blocks = nn.ModuleList(sequence)
@staticmethod
def build_conv2d(use_spectral, in_channels: int, out_channels: int, kernel_size, stride, padding,
bias=True, padding_mode: str = 'zeros'):
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias, padding_mode=padding_mode)
if not use_spectral:
return conv
return nn.utils.spectral_norm(conv)
def forward(self, x):
if self.need_intermediate_feature:
intermediate_feature = []
for layer in self.conv_blocks:
x = layer(x)
intermediate_feature.append(x)
return tuple(intermediate_feature)
else:
for layer in self.conv_blocks:
x = layer(x)
return x
@MODEL.register_module()
class ResidualBlock(nn.Module):
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None):
super(ResidualBlock, self).__init__()
if use_bias is None:
# Only for IN, use bias since it does not have affine parameters.
use_bias = norm_type == "IN"
norm_layer = select_norm_layer(norm_type)
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
bias=use_bias)
self.norm1 = norm_layer(num_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
bias=use_bias)
self.norm2 = norm_layer(num_channels)
def forward(self, x):
res = x
x = self.relu1(self.norm1(self.conv1(x)))
x = self.norm2(self.conv2(x))
return x + res
_DO_NO_THING_FUNC = lambda x: x
def select_activation(t):
if t == "ReLU":
return partial(nn.ReLU, inplace=True)
elif t == "LeakyReLU":
return partial(nn.LeakyReLU, negative_slope=0.2, inplace=True)
elif t == "Tanh":
return partial(nn.Tanh)
elif t == "NONE":
return _DO_NO_THING_FUNC
else:
raise NotImplemented
def _use_bias_checker(norm_type):
return norm_type not in ["IN", "BN", "AdaIN"]
class Conv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, use_spectral_norm=False, activation_type="ReLU",
bias=None, norm_type="NONE", **conv_kwargs):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
conv = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
self.convolution = nn.utils.spectral_norm(conv) if use_spectral_norm else conv
if norm_type != "NONE":
self.normalization = select_norm_layer(norm_type)(out_channels)
if activation_type != "NONE":
self.activation = select_activation(activation_type)()
def forward(self, x):
x = self.convolution(x)
if self.norm_type != "NONE":
x = self.normalization(x)
if self.activation_type != "NONE":
x = self.activation(x)
return x
class ResBlock(nn.Module):
def __init__(self, num_channels, use_spectral_norm=False, padding_mode='reflect',
norm_type="IN", activation_type="ReLU", use_bias=None):
super().__init__()
self.norm_type = norm_type
if use_bias is None:
# bias will be canceled after channel wise normalization
use_bias = _use_bias_checker(norm_type)
self.conv1 = Conv2dBlock(num_channels, num_channels, use_spectral_norm,
kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias,
norm_type=norm_type, activation_type=activation_type)
self.conv2 = Conv2dBlock(num_channels, num_channels, use_spectral_norm,
kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias,
norm_type=norm_type, activation_type="NONE")
def forward(self, x):
return self.conv2(self.conv1(x)) + x

View File

@@ -1,25 +0,0 @@
import torch.nn as nn
import torch.nn.functional as F
from model import MODEL
@MODEL.register_module()
class MultiScaleDiscriminator(nn.Module):
def __init__(self, num_scale, discriminator_cfg):
super().__init__()
self.discriminator_list = nn.ModuleList([
MODEL.build_with(discriminator_cfg) for _ in range(num_scale)
])
@staticmethod
def down_sample(x):
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
def forward(self, x):
results = []
for discriminator in self.discriminator_list:
results.append(discriminator(x))
x = self.down_sample(x)
return results

View File

@@ -1,10 +1,3 @@
from model.registry import MODEL, NORMALIZATION
import model.GAN.CycleGAN
import model.GAN.MUNIT
import model.GAN.TAFG
import model.GAN.TSIT
import model.GAN.UGATIT
import model.GAN.base
import model.GAN.wrapper
import model.base.normalization
import model.image_translation

View File

@@ -20,22 +20,40 @@ def _normalization(norm, num_features, additional_kwargs=None):
return NORMALIZATION.build_with(kwargs)
def _activation(activation):
def _activation(activation, inplace=True):
if activation == "NONE":
return _DO_NO_THING_FUNC
elif activation == "ReLU":
return nn.ReLU(inplace=True)
return nn.ReLU(inplace=inplace)
elif activation == "LeakyReLU":
return nn.LeakyReLU(negative_slope=0.2, inplace=True)
return nn.LeakyReLU(negative_slope=0.2, inplace=inplace)
elif activation == "Tanh":
return nn.Tanh()
else:
raise NotImplemented(activation)
raise NotImplementedError(f"{activation} not valid")
class LinearBlock(nn.Module):
def __init__(self, in_features: int, out_features: int, bias=None, activation_type="ReLU", norm_type="NONE"):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
bias = _use_bias_checker(norm_type) if bias is None else bias
self.linear = nn.Linear(in_features, out_features, bias)
self.normalization = _normalization(norm_type, out_features)
self.activation = _activation(activation_type)
def forward(self, x):
return self.activation(self.normalization(self.linear(x)))
class Conv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, bias=None,
activation_type="ReLU", norm_type="NONE", **conv_kwargs):
activation_type="ReLU", norm_type="NONE",
additional_norm_kwargs=None, **conv_kwargs):
super().__init__()
self.norm_type = norm_type
self.activation_type = activation_type
@@ -44,65 +62,63 @@ class Conv2dBlock(nn.Module):
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
self.normalization = _normalization(norm_type, out_channels)
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
self.activation = _activation(activation_type)
def forward(self, x):
return self.activation(self.normalization(self.convolution(x)))
class ResidualBlock(nn.Module):
def __init__(self, num_channels, out_channels=None, padding_mode='reflect',
activation_type="ReLU", out_activation_type=None, norm_type="IN"):
super().__init__()
self.norm_type = norm_type
if out_channels is None:
out_channels = num_channels
if out_activation_type is None:
out_activation_type = "NONE"
self.learn_skip_connection = num_channels != out_channels
self.conv1 = Conv2dBlock(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=activation_type)
self.conv2 = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=out_activation_type)
if self.learn_skip_connection:
self.res_conv = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
norm_type=norm_type, activation_type=out_activation_type)
def forward(self, x):
res = x if not self.learn_skip_connection else self.res_conv(x)
return self.conv2(self.conv1(x)) + res
class ReverseConv2dBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int,
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
super().__init__()
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
self.activation = _activation(activation_type)
self.activation = _activation(activation_type, inplace=False)
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
def forward(self, x):
return self.convolution(self.activation(self.normalization(x)))
class ReverseResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, padding_mode="reflect",
norm_type="IN", additional_norm_kwargs=None, activation_type="ReLU"):
class ResidualBlock(nn.Module):
def __init__(self, in_channels,
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
out_channels=None, out_activation_type=None, additional_norm_kwargs=None):
"""
Residual Conv Block
:param in_channels:
:param out_channels:
:param padding_mode:
:param activation_type:
:param norm_type:
:param out_activation_type:
:param pre_activation: full pre-activation mode from https://arxiv.org/pdf/1603.05027v3.pdf, figure 4
"""
super().__init__()
self.norm_type = norm_type
if out_channels is None:
out_channels = in_channels
if out_activation_type is None:
# if not specify `out_activation_type`, using default `out_activation_type`
# `out_activation_type` default mode:
# "NONE" for not full pre-activation
# `norm_type` for full pre-activation
out_activation_type = "NONE" if not pre_activation else norm_type
self.learn_skip_connection = in_channels != out_channels
self.conv1 = ReverseConv2dBlock(in_channels, in_channels, activation_type, norm_type, additional_norm_kwargs,
kernel_size=3, padding=1, padding_mode=padding_mode)
self.conv2 = ReverseConv2dBlock(in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
kernel_size=3, padding=1, padding_mode=padding_mode)
conv_block = ReverseConv2dBlock if pre_activation else Conv2dBlock
conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type,
additional_norm_kwargs=additional_norm_kwargs,
padding_mode=padding_mode)
self.conv1 = conv_block(in_channels, in_channels, **conv_param)
self.conv2 = conv_block(in_channels, out_channels, **conv_param)
if self.learn_skip_connection:
self.res_conv = ReverseConv2dBlock(
in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
kernel_size=3, padding=1, padding_mode=padding_mode)
self.res_conv = conv_block(in_channels, out_channels, **conv_param)
def forward(self, x):
res = x if not self.learn_skip_connection else self.res_conv(x)

View File

@@ -16,18 +16,19 @@ for abbr, name in _VALID_NORM_AND_ABBREVIATION.items():
@NORMALIZATION.register_module("ADE")
class AdaptiveDenormalization(nn.Module):
def __init__(self, num_features, base_norm_type="BN"):
def __init__(self, num_features, base_norm_type="BN", gamma_bias=0.0):
super().__init__()
self.num_features = num_features
self.base_norm_type = base_norm_type
self.norm = self.base_norm(num_features)
self.gamma = None
self.gamma_bias = gamma_bias
self.beta = None
self.have_set_condition = False
def base_norm(self, num_features):
if self.base_norm_type == "IN":
return nn.InstanceNorm2d(num_features)
return nn.InstanceNorm2d(num_features, affine=False)
elif self.base_norm_type == "BN":
return nn.BatchNorm2d(num_features, affine=False, track_running_stats=True)
@@ -38,13 +39,13 @@ class AdaptiveDenormalization(nn.Module):
def forward(self, x):
assert self.have_set_condition
x = self.norm(x)
x = self.gamma * x + self.beta
x = (self.gamma + self.gamma_bias) * x + self.beta
self.have_set_condition = False
return x
def __repr__(self):
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
f"base_norm_type={self.base_norm_type})"
#
# def __repr__(self):
# return f"{self.__class__.__name__}(num_features={self.num_features}, " \
# f"base_norm_type={self.base_norm_type})"
@NORMALIZATION.register_module("AdaIN")
@@ -61,8 +62,9 @@ class AdaptiveInstanceNorm2d(AdaptiveDenormalization):
@NORMALIZATION.register_module("FADE")
class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
def __init__(self, num_features: int, condition_in_channels, base_norm_type="BN", padding_mode="zeros"):
super().__init__(num_features, base_norm_type)
def __init__(self, num_features: int, condition_in_channels,
base_norm_type="BN", padding_mode="zeros", gamma_bias=0.0):
super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias)
self.beta_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
padding_mode=padding_mode)
self.gamma_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
@@ -77,9 +79,9 @@ class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
@NORMALIZATION.register_module("SPADE")
class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
def __init__(self, num_features: int, condition_in_channels, base_channels=128, base_norm_type="BN",
activation_type="ReLU", padding_mode="zeros"):
super().__init__(num_features, base_norm_type)
self.base_conv_block = Conv2dBlock(condition_in_channels, num_features, activation_type=activation_type,
activation_type="ReLU", padding_mode="zeros", gamma_bias=0.0):
super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias)
self.base_conv_block = Conv2dBlock(condition_in_channels, base_channels, activation_type=activation_type,
kernel_size=3, padding=1, padding_mode=padding_mode, norm_type="NONE")
self.beta_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
self.gamma_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
@@ -93,7 +95,7 @@ class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
def _instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
out = rho * F.instance_norm(x, eps=eps) + (1 - rho) * F.layer_norm(x, x.size()[1:], eps=eps)
out = out * gamma + beta
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
return out
@@ -115,7 +117,7 @@ class ILN(nn.Module):
def forward(self, x):
return _instance_layer_normalization(
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
x, self.gamma.view(1, -1), self.beta.view(1, -1), self.rho.view(1, -1, 1, 1), self.eps)
@NORMALIZATION.register_module("AdaILN")
@@ -136,7 +138,6 @@ class AdaILN(nn.Module):
def forward(self, x):
assert self.have_set_condition
out = _instance_layer_normalization(
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
out = _instance_layer_normalization(x, self.gamma, self.beta, self.rho.view(1, -1, 1, 1), self.eps)
self.have_set_condition = False
return out

View File

@@ -0,0 +1,76 @@
import torch.nn as nn
from model.base.module import Conv2dBlock, ResidualBlock
class Encoder(nn.Module):
def __init__(self, in_channels, base_channels, num_conv, num_res, max_down_sampling_multiple=2,
padding_mode='reflect', activation_type="ReLU",
down_conv_norm_type="IN", down_conv_kernel_size=3,
res_norm_type="IN", pre_activation=False):
super().__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
activation_type=activation_type, norm_type=down_conv_norm_type
)]
multiple_now = 1
for i in range(1, num_conv + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=down_conv_norm_type
))
self.out_channels = multiple_now * base_channels
sequence += [
ResidualBlock(
self.out_channels,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type=res_norm_type,
pre_activation=pre_activation
) for _ in range(num_res)
]
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
return self.sequence(x)
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
activation_type="ReLU", padding_mode='reflect',
up_conv_kernel_size=5, up_conv_norm_type="LN",
res_norm_type="AdaIN", pre_activation=False):
super().__init__()
self.residual_blocks = nn.ModuleList([
ResidualBlock(
in_channels,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type=res_norm_type,
pre_activation=pre_activation
) for _ in range(num_residual_blocks)
])
sequence = list()
channels = in_channels
for i in range(num_up_sampling):
sequence.append(nn.Sequential(
nn.Upsample(scale_factor=2),
Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1,
padding=int(up_conv_kernel_size / 2), padding_mode=padding_mode,
activation_type=activation_type, norm_type=up_conv_norm_type),
))
channels = channels // 2
sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3,
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"))
self.up_sequence = nn.Sequential(*sequence)
def forward(self, x):
for i, blk in enumerate(self.residual_blocks):
x = blk(x)
return self.up_sequence(x)

View File

@@ -0,0 +1,95 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.base.module import ResidualBlock, ReverseConv2dBlock, Conv2dBlock
class StyleEncoder(nn.Module):
def __init__(self, in_channels, style_dim, num_conv, end_size=(4, 4), base_channels=64,
norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
multiple_now = 0
max_multiple = 3
for i in range(1, num_conv + 1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** max_multiple)
sequence.append(Conv2dBlock(
multiple_prev * base_channels, multiple_now * base_channels,
kernel_size=3, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
))
self.sequence = nn.Sequential(*sequence)
self.fc_avg = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim)
self.fc_var = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim)
def forward(self, x):
x = self.sequence(x)
x = x.view(x.size(0), -1)
return self.fc_avg(x), self.fc_var(x)
class SPADEGenerator(nn.Module):
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
padding_mode='reflect', activation_type="LeakyReLU"):
super().__init__()
self.sx, self.sy = start_size
self.use_vae = use_vae
self.num_z_dim = num_z_dim
if use_vae:
self.input_converter = nn.Linear(num_z_dim, 16 * base_channels * self.sx * self.sy)
else:
self.input_converter = nn.Conv2d(in_channels, 16 * base_channels, kernel_size=3, padding=1)
sequence = []
multiple_now = 16
for i in range(num_blocks - 1, -1, -1):
multiple_prev = multiple_now
multiple_now = min(2 ** i, 2 ** 4)
if i != num_blocks - 1:
sequence.append(nn.Upsample(scale_factor=2))
sequence.append(ResidualBlock(
base_channels * multiple_prev,
out_channels=base_channels * multiple_now,
padding_mode=padding_mode,
activation_type=activation_type,
norm_type="SPADE",
pre_activation=True,
additional_norm_kwargs=dict(
condition_in_channels=in_channels, base_channels=128, base_norm_type="BN",
activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0
)
))
self.sequence = nn.Sequential(*sequence)
self.output_converter = nn.Sequential(
ReverseConv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"),
nn.Tanh()
)
def forward(self, seg, z=None):
if self.use_vae:
if z is None:
z = torch.randn(seg.size(0), self.num_z_dim, device=seg.device)
x = self.input_converter(z).view(seg.size(0), -1, self.sx, self.sy)
else:
x = self.input_converter(F.interpolate(seg, size=(self.sx, self.sy)))
for blk in self.sequence:
if isinstance(blk, ResidualBlock):
downsampling_seg = F.interpolate(seg, size=x.size()[2:], mode='nearest')
blk.conv1.normalization.set_condition_image(downsampling_seg)
blk.conv2.normalization.set_condition_image(downsampling_seg)
if blk.learn_skip_connection:
blk.res_conv.normalization.set_condition_image(downsampling_seg)
x = blk(x)
return self.output_converter(x)
if __name__ == '__main__':
g = SPADEGenerator(3, 3, 7, False, 256)
print(g)
print(g(torch.randn(2, 3, 256, 256)).size())

View File

@@ -0,0 +1,89 @@
import torch
import torch.nn as nn
from model import MODEL
from model.base.module import LinearBlock
from model.image_translation.CycleGAN import Encoder, Decoder
class StyleEncoder(nn.Module):
def __init__(self, in_channels, out_dim, num_conv, base_channels=64,
max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE",
pre_activation=False):
super().__init__()
self.down_encoder = Encoder(
in_channels, base_channels, num_conv, num_res=0, max_down_sampling_multiple=max_down_sampling_multiple,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type=norm_type, down_conv_kernel_size=4, pre_activation=pre_activation,
)
sequence = list()
sequence.append(nn.AdaptiveAvgPool2d(1))
# conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code
sequence.append(nn.Conv2d(self.down_encoder.out_channels, out_dim, kernel_size=1, stride=1, padding=0))
self.sequence = nn.Sequential(*sequence)
def forward(self, image):
return self.sequence(image).view(image.size(0), -1)
class MLPFusion(nn.Module):
def __init__(self, in_features, out_features, base_features, n_blocks, activation_type="ReLU", norm_type="NONE"):
super().__init__()
sequence = [LinearBlock(in_features, base_features, activation_type=activation_type, norm_type=norm_type)]
sequence += [
LinearBlock(base_features, base_features, activation_type=activation_type, norm_type=norm_type)
for _ in range(n_blocks - 2)
]
sequence.append(LinearBlock(base_features, out_features, activation_type=activation_type, norm_type=norm_type))
self.sequence = nn.Sequential(*sequence)
def forward(self, x):
return self.sequence(x)
@MODEL.register_module("MUNIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, style_dim=8,
num_mlp_base_feature=256, num_mlp_blocks=3,
max_down_sampling_multiple=2, num_content_down_sampling=2, num_style_down_sampling=2,
encoder_num_residual_blocks=4, decoder_num_residual_blocks=4,
padding_mode='reflect', activation_type="ReLU", pre_activation=False):
super().__init__()
self.content_encoder = Encoder(
in_channels, base_channels, num_content_down_sampling, encoder_num_residual_blocks,
max_down_sampling_multiple=num_content_down_sampling,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type="IN", down_conv_kernel_size=4,
res_norm_type="IN", pre_activation=pre_activation
)
self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels,
max_down_sampling_multiple, padding_mode, activation_type,
norm_type="NONE", pre_activation=pre_activation)
content_channels = base_channels * (2 ** max_down_sampling_multiple)
self.fusion = MLPFusion(style_dim, decoder_num_residual_blocks * 2 * content_channels * 2,
num_mlp_base_feature, num_mlp_blocks, activation_type,
norm_type="NONE")
self.decoder = Decoder(in_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks,
activation_type=activation_type, padding_mode=padding_mode,
up_conv_kernel_size=5, up_conv_norm_type="LN",
res_norm_type="AdaIN", pre_activation=pre_activation)
def encode(self, x):
return self.content_encoder(x), self.style_encoder(x)
def decode(self, content, style):
as_param_style = torch.chunk(self.fusion(style), 2 * len(self.decoder.residual_blocks), dim=1)
# set style for decoder
for i, blk in enumerate(self.decoder.residual_blocks):
blk.conv1.normalization.set_style(as_param_style[2 * i])
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
self.decoder(content)
def forward(self, x):
content, style = self.encode(x)
return self.decode(content, style)

View File

@@ -0,0 +1,138 @@
import torch
import torch.nn as nn
from model import MODEL
from model.base.module import Conv2dBlock, LinearBlock
from model.image_translation.CycleGAN import Encoder, Decoder
class RhoClipper(object):
def __init__(self, clip_min, clip_max):
self.clip_min = clip_min
self.clip_max = clip_max
assert clip_min < clip_max
def __call__(self, module):
if hasattr(module, 'rho'):
w = module.rho.data
w = w.clamp(self.clip_min, self.clip_max)
module.rho.data = w
class CAMClassifier(nn.Module):
def __init__(self, in_channels, activation_type="ReLU"):
super(CAMClassifier, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.avg_fc = nn.Linear(in_channels, 1, bias=False)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.max_fc = nn.Linear(in_channels, 1, bias=False)
self.fusion_conv = Conv2dBlock(in_channels * 2, in_channels, kernel_size=1, stride=1, bias=True,
activation_type=activation_type, norm_type="NONE")
def forward(self, x):
avg_logit = self.avg_fc(self.avg_pool(x).view(x.size(0), -1))
max_logit = self.max_fc(self.max_pool(x).view(x.size(0), -1))
return self.fusion_conv(torch.cat(
[x * self.avg_fc.weight.unsqueeze(2).unsqueeze(3), x * self.max_fc.weight.unsqueeze(2).unsqueeze(3)],
dim=1
)), torch.cat([avg_logit, max_logit], 1)
@MODEL.register_module("UGATIT-Generator")
class Generator(nn.Module):
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False,
activation_type="ReLU", norm_type="IN", padding_mode='reflect', pre_activation=False):
super(Generator, self).__init__()
self.light = light
n_down_sampling = 2
self.encoder = Encoder(in_channels, base_channels, n_down_sampling, num_blocks,
padding_mode=padding_mode, activation_type=activation_type,
down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type,
pre_activation=pre_activation)
mult = 2 ** n_down_sampling
self.cam = CAMClassifier(base_channels * mult, activation_type)
# Gamma, Beta block
if self.light:
self.fc = nn.Sequential(
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE"),
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE")
)
else:
self.fc = nn.Sequential(
LinearBlock(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, False,
"ReLU", "NONE"),
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE")
)
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
self.decoder = Decoder(
base_channels * mult, out_channels, n_down_sampling, num_blocks,
activation_type=activation_type, padding_mode=padding_mode,
up_conv_kernel_size=3, up_conv_norm_type="ILN",
res_norm_type="AdaILN", pre_activation=pre_activation
)
def forward(self, x):
x = self.encoder(x)
x, cam_logit = self.cam(x)
heatmap = torch.sum(x, dim=1, keepdim=True)
if self.light:
x_ = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
x_ = self.fc(x_.view(x_.shape[0], -1))
else:
x_ = self.fc(x.view(x.shape[0], -1))
gamma, beta = self.gamma(x_), self.beta(x_)
for blk in self.decoder.residual_blocks:
blk.conv1.normalization.set_condition(gamma, beta)
blk.conv2.normalization.set_condition(gamma, beta)
return self.decoder(x), cam_logit, heatmap
@MODEL.register_module("UGATIT-Discriminator")
class Discriminator(nn.Module):
def __init__(self, in_channels, base_channels=64, num_blocks=5,
activation_type="LeakyReLU", norm_type="NONE", padding_mode='reflect'):
super().__init__()
sequence = [Conv2dBlock(
in_channels, base_channels, kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type
)]
sequence += [Conv2dBlock(
base_channels * (2 ** i), base_channels * (2 ** i) * 2,
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type) for i in range(num_blocks - 3)]
sequence.append(
Conv2dBlock(base_channels * (2 ** (num_blocks - 3)), base_channels * (2 ** (num_blocks - 2)),
kernel_size=4, stride=1, padding=1, padding_mode=padding_mode,
activation_type=activation_type, norm_type=norm_type)
)
self.sequence = nn.Sequential(*sequence)
mult = 2 ** (num_blocks - 2)
self.cam = CAMClassifier(base_channels * mult, activation_type)
self.conv = nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False,
padding_mode="reflect")
def forward(self, x, return_heatmap=False):
x = self.sequence(x)
x, cam_logit = self.cam(x)
if return_heatmap:
heatmap = torch.sum(x, dim=1, keepdim=True)
return self.conv(x), cam_logit, heatmap
else:
return self.conv(x), cam_logit

View File

View File

@@ -1,4 +1,6 @@
import importlib
import logging
import pkgutil
from pathlib import Path
from typing import Optional
@@ -6,12 +8,30 @@ import torch.nn as nn
def add_spectral_norm(module):
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
return nn.utils.spectral_norm(module)
else:
return module
def import_submodules(package, recursive=True):
""" Import all submodules of a module, recursively, including subpackages
:param package: package (name or actual module)
:type package: str | module
:rtype: dict[str, types.ModuleType]
"""
if isinstance(package, str):
package = importlib.import_module(package)
results = {}
for loader, name, is_pkg in pkgutil.walk_packages(package.__path__):
full_name = package.__name__ + '.' + name
results[name] = importlib.import_module(full_name)
if recursive and is_pkg:
results.update(import_submodules(full_name))
return results
def setup_logger(
name: Optional[str] = None,
level: int = logging.INFO,