Compare commits
6 Commits
776fe40199
...
436bca88b4
| Author | SHA1 | Date | |
|---|---|---|---|
| 436bca88b4 | |||
| 6070f08835 | |||
| 06b2abd19a | |||
| 9c08b4cd09 | |||
| 04c6366c07 | |||
| 6ea13df465 |
2
.idea/deployment.xml
generated
2
.idea/deployment.xml
generated
@@ -1,6 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" autoUpload="Always" serverName="14d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<component name="PublishConfigData" autoUpload="Always" serverName="21d" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<serverData>
|
||||
<paths name="14d">
|
||||
<serverdata>
|
||||
|
||||
@@ -14,11 +14,15 @@ handler:
|
||||
n_saved: 2
|
||||
tensorboard:
|
||||
scalar: 100 # log scalar `scalar` times per epoch
|
||||
image: 2 # log image `image` times per epoch
|
||||
image: 4 # log image `image` times per epoch
|
||||
test:
|
||||
random: True
|
||||
images: 10
|
||||
|
||||
model:
|
||||
generator:
|
||||
_type: UGATIT-Generator
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
out_channels: 3
|
||||
base_channels: 64
|
||||
@@ -27,11 +31,13 @@ model:
|
||||
light: True
|
||||
local_discriminator:
|
||||
_type: UGATIT-Discriminator
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
num_blocks: 5
|
||||
global_discriminator:
|
||||
_type: UGATIT-Discriminator
|
||||
_add_spectral_norm: True
|
||||
in_channels: 3
|
||||
base_channels: 64
|
||||
num_blocks: 7
|
||||
@@ -50,6 +56,8 @@ loss:
|
||||
weight: 10.0
|
||||
cam:
|
||||
weight: 1000
|
||||
mgc:
|
||||
weight: 0
|
||||
|
||||
optimizers:
|
||||
generator:
|
||||
@@ -70,7 +78,7 @@ data:
|
||||
target_lr: 0
|
||||
buffer_size: 50
|
||||
dataloader:
|
||||
batch_size: 24
|
||||
batch_size: 4
|
||||
shuffle: True
|
||||
num_workers: 2
|
||||
pin_memory: True
|
||||
|
||||
287
data/dataset.py
287
data/dataset.py
@@ -1,287 +0,0 @@
|
||||
import os
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from itertools import permutations, combinations
|
||||
from pathlib import Path
|
||||
|
||||
import lmdb
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
|
||||
from torchvision.transforms import functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
from .registry import DATASET
|
||||
from .transform import transform_pipeline
|
||||
from .util import dlib_landmark
|
||||
|
||||
|
||||
def default_transform_way(transform, sample):
|
||||
return [transform(sample[0]), *sample[1:]]
|
||||
|
||||
|
||||
class LMDBDataset(Dataset):
|
||||
def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True,
|
||||
**lmdb_kwargs):
|
||||
self.path = lmdb_path
|
||||
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
|
||||
lock=False, **lmdb_kwargs)
|
||||
|
||||
with self.db.begin(write=False) as txn:
|
||||
self._len = pickle.loads(txn.get(b"$$len$$"))
|
||||
self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$"))
|
||||
if pipeline is None:
|
||||
self.not_done_pipeline = []
|
||||
else:
|
||||
self.not_done_pipeline = self._remain_pipeline(pipeline)
|
||||
self.transform = transform_pipeline(self.not_done_pipeline)
|
||||
self.transform_way = transform_way
|
||||
essential_attr = pickle.loads(txn.get(b"$$essential_attr$$"))
|
||||
for ea in essential_attr:
|
||||
setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8"))))
|
||||
|
||||
def _remain_pipeline(self, pipeline):
|
||||
for i, dp in enumerate(self.done_pipeline):
|
||||
if pipeline[i] != dp:
|
||||
raise ValueError(
|
||||
f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}")
|
||||
return pipeline[len(self.done_pipeline):]
|
||||
|
||||
def __repr__(self):
|
||||
return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}"
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def __getitem__(self, idx):
|
||||
with self.db.begin(write=False) as txn:
|
||||
sample = pickle.loads(txn.get("{}".format(idx).encode()))
|
||||
sample = self.transform_way(self.transform, sample)
|
||||
return sample
|
||||
|
||||
@staticmethod
|
||||
def lmdbify(dataset, done_pipeline, lmdb_path):
|
||||
env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path))
|
||||
with env.begin(write=True) as txn:
|
||||
for i in tqdm(range(len(dataset)), ncols=0):
|
||||
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
|
||||
txn.put(b"$$len$$", pickle.dumps(len(dataset)))
|
||||
txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline))
|
||||
essential_attr = getattr(dataset, "essential_attr", list())
|
||||
txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr))
|
||||
for ea in essential_attr:
|
||||
txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea)))
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class ImprovedImageFolder(ImageFolder):
|
||||
def __init__(self, root, pipeline):
|
||||
super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x)
|
||||
self.classes_list = defaultdict(list)
|
||||
self.essential_attr = ["classes_list"]
|
||||
for i in range(len(self)):
|
||||
self.classes_list[self.samples[i][-1]].append(i)
|
||||
assert len(self.classes_list) == len(self.classes)
|
||||
|
||||
|
||||
class EpisodicDataset(Dataset):
|
||||
def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes):
|
||||
self.origin = origin_dataset
|
||||
self.num_class = num_class
|
||||
assert self.num_class < len(self.origin.classes_list)
|
||||
self.num_query = num_query # K
|
||||
self.num_support = num_support # K
|
||||
self.num_episodes = num_episodes
|
||||
|
||||
def _fetch_list_data(self, id_list):
|
||||
return [self.origin[i][0] for i in id_list]
|
||||
|
||||
def __len__(self):
|
||||
return self.num_episodes
|
||||
|
||||
def __getitem__(self, _):
|
||||
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
|
||||
support_set = []
|
||||
query_set = []
|
||||
target_set = []
|
||||
for tag, c in enumerate(random_classes):
|
||||
image_list = self.origin.classes_list[c]
|
||||
|
||||
if len(image_list) >= self.num_query + self.num_support:
|
||||
# have enough images belong to this class
|
||||
idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist()
|
||||
else:
|
||||
idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist()
|
||||
|
||||
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
|
||||
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
|
||||
support_set.extend(support)
|
||||
query_set.extend(query)
|
||||
target_set.extend([tag] * self.num_query)
|
||||
return {
|
||||
"support": torch.stack(support_set),
|
||||
"query": torch.stack(query_set),
|
||||
"target": torch.tensor(target_set)
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
return f"<EpisodicDataset NE{self.num_episodes} NC{self.num_class} NS{self.num_support} NQ{self.num_query}>"
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class SingleFolderDataset(Dataset):
|
||||
def __init__(self, root, pipeline, with_path=False):
|
||||
assert os.path.isdir(root)
|
||||
self.root = root
|
||||
samples = []
|
||||
for r, _, fns in sorted(os.walk(self.root, followlinks=True)):
|
||||
for fn in sorted(fns):
|
||||
path = os.path.join(r, fn)
|
||||
if has_file_allowed_extension(path, IMG_EXTENSIONS):
|
||||
samples.append(path)
|
||||
self.samples = samples
|
||||
self.pipeline = transform_pipeline(pipeline)
|
||||
self.with_path = with_path
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if not self.with_path:
|
||||
return self.pipeline(self.samples[idx])
|
||||
else:
|
||||
return self.pipeline(self.samples[idx]), self.samples[idx]
|
||||
|
||||
def __repr__(self):
|
||||
return f"<SingleFolderDataset root={self.root} len={len(self)}>"
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDataset(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False):
|
||||
self.A = SingleFolderDataset(root_a, pipeline, with_path)
|
||||
self.B = SingleFolderDataset(root_b, pipeline, with_path)
|
||||
self.random_pair = random_pair
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||
return dict(a=self.A[a_idx], b=self.B[b_idx])
|
||||
|
||||
def __len__(self):
|
||||
return max(len(self.A), len(self.B))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
||||
|
||||
|
||||
def normalize_tensor(tensor):
|
||||
tensor = tensor.float()
|
||||
tensor -= tensor.min()
|
||||
tensor /= tensor.max()
|
||||
return tensor
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256),
|
||||
with_path=False):
|
||||
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
|
||||
self.edge_type = edge_type
|
||||
self.size = size
|
||||
self.edges_path = Path(edges_path)
|
||||
self.landmarks_path = Path(landmarks_path)
|
||||
assert self.edges_path.exists()
|
||||
assert self.landmarks_path.exists()
|
||||
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
||||
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
|
||||
self.random_pair = random_pair
|
||||
self.with_path = with_path
|
||||
|
||||
def get_edge(self, origin_path):
|
||||
op = Path(origin_path)
|
||||
if self.edge_type.startswith("landmark_"):
|
||||
edge_type = self.edge_type.lstrip("landmark_")
|
||||
use_landmark = op.parent.name.endswith("A")
|
||||
else:
|
||||
edge_type = self.edge_type
|
||||
use_landmark = False
|
||||
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
|
||||
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR))
|
||||
if not use_landmark:
|
||||
return origin_edge
|
||||
else:
|
||||
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt"
|
||||
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
|
||||
|
||||
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size)))
|
||||
part_labels = normalize_tensor(torch.from_numpy(part_labels))
|
||||
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
|
||||
# edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
|
||||
# edges = part_edge + edges
|
||||
return torch.cat([origin_edge, part_edge, dist_tensor, part_labels])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||
output = dict(a={}, b={})
|
||||
output["a"]["img"], output["a"]["path"] = self.A[a_idx]
|
||||
output["b"]["img"], output["b"]["path"] = self.B[b_idx]
|
||||
for p in "ab":
|
||||
output[p]["edge"] = self.get_edge(output[p]["path"])
|
||||
return output
|
||||
|
||||
def __len__(self):
|
||||
return max(len(self.A), len(self.B))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class PoseFacesWithSingleAnime(Dataset):
|
||||
def __init__(self, root_face, root_anime, landmark_path, num_face, face_pipeline, anime_pipeline, img_size,
|
||||
with_order=True):
|
||||
self.num_face = num_face
|
||||
self.landmark_path = Path(landmark_path)
|
||||
self.with_order = with_order
|
||||
self.root_face = Path(root_face)
|
||||
self.root_anime = Path(root_anime)
|
||||
self.img_size = img_size
|
||||
self.face_samples = self.iter_folders()
|
||||
self.face_pipeline = transform_pipeline(face_pipeline)
|
||||
self.B = SingleFolderDataset(root_anime, anime_pipeline, with_path=True)
|
||||
|
||||
def iter_folders(self):
|
||||
pics_per_person = defaultdict(list)
|
||||
for p in self.root_face.glob("*.jpg"):
|
||||
pics_per_person[p.stem[:7]].append(p.stem)
|
||||
data = []
|
||||
for p in pics_per_person:
|
||||
if len(pics_per_person[p]) >= self.num_face:
|
||||
if self.with_order:
|
||||
data.extend(list(combinations(pics_per_person[p], self.num_face)))
|
||||
else:
|
||||
data.extend(list(permutations(pics_per_person[p], self.num_face)))
|
||||
return data
|
||||
|
||||
def read_pose(self, pose_txt):
|
||||
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(pose_txt, size=self.img_size)
|
||||
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.img_size)))
|
||||
part_labels = normalize_tensor(torch.from_numpy(part_labels))
|
||||
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
|
||||
return torch.cat([part_labels, part_edge, dist_tensor])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.face_samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
output = dict()
|
||||
output["anime_img"], output["anime_path"] = self.B[torch.randint(len(self.B), (1,)).item()]
|
||||
for i, f in enumerate(self.face_samples[idx]):
|
||||
output[f"face_{i}"] = self.face_pipeline(self.root_face / f"{f}.jpg")
|
||||
output[f"pose_{i}"] = self.read_pose(self.landmark_path / self.root_face.name / f"{f}.txt")
|
||||
output[f"stem_{i}"] = f
|
||||
return output
|
||||
3
data/dataset/__init__.py
Normal file
3
data/dataset/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from util.misc import import_submodules
|
||||
|
||||
__all__ = import_submodules(__name__).keys()
|
||||
63
data/dataset/few-shot.py
Normal file
63
data/dataset/few-shot.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import ImageFolder
|
||||
|
||||
from data.registry import DATASET
|
||||
from data.transform import transform_pipeline
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class ImprovedImageFolder(ImageFolder):
|
||||
def __init__(self, root, pipeline):
|
||||
super().__init__(root, transform_pipeline(pipeline), loader=lambda x: x)
|
||||
self.classes_list = defaultdict(list)
|
||||
self.essential_attr = ["classes_list"]
|
||||
for i in range(len(self)):
|
||||
self.classes_list[self.samples[i][-1]].append(i)
|
||||
assert len(self.classes_list) == len(self.classes)
|
||||
|
||||
|
||||
class EpisodicDataset(Dataset):
|
||||
def __init__(self, origin_dataset, num_class, num_query, num_support, num_episodes):
|
||||
self.origin = origin_dataset
|
||||
self.num_class = num_class
|
||||
assert self.num_class < len(self.origin.classes_list)
|
||||
self.num_query = num_query # K
|
||||
self.num_support = num_support # K
|
||||
self.num_episodes = num_episodes
|
||||
|
||||
def _fetch_list_data(self, id_list):
|
||||
return [self.origin[i][0] for i in id_list]
|
||||
|
||||
def __len__(self):
|
||||
return self.num_episodes
|
||||
|
||||
def __getitem__(self, _):
|
||||
random_classes = torch.randperm(len(self.origin.classes_list))[:self.num_class].tolist()
|
||||
support_set = []
|
||||
query_set = []
|
||||
target_set = []
|
||||
for tag, c in enumerate(random_classes):
|
||||
image_list = self.origin.classes_list[c]
|
||||
|
||||
if len(image_list) >= self.num_query + self.num_support:
|
||||
# have enough images belong to this class
|
||||
idx_list = torch.randperm(len(image_list))[:self.num_query + self.num_support].tolist()
|
||||
else:
|
||||
idx_list = torch.randint(high=len(image_list), size=(self.num_query + self.num_support,)).tolist()
|
||||
|
||||
support = self._fetch_list_data(map(image_list.__getitem__, idx_list[:self.num_support]))
|
||||
query = self._fetch_list_data(map(image_list.__getitem__, idx_list[self.num_support:]))
|
||||
support_set.extend(support)
|
||||
query_set.extend(query)
|
||||
target_set.extend([tag] * self.num_query)
|
||||
return {
|
||||
"support": torch.stack(support_set),
|
||||
"query": torch.stack(query_set),
|
||||
"target": torch.tensor(target_set)
|
||||
}
|
||||
|
||||
def __repr__(self):
|
||||
return f"<EpisodicDataset NE{self.num_episodes} NC{self.num_class} NS{self.num_support} NQ{self.num_query}>"
|
||||
62
data/dataset/image_translation.py
Normal file
62
data/dataset/image_translation.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
|
||||
|
||||
from data.registry import DATASET
|
||||
from data.transform import transform_pipeline
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class SingleFolderDataset(Dataset):
|
||||
def __init__(self, root, pipeline, with_path=False):
|
||||
assert os.path.isdir(root)
|
||||
self.root = root
|
||||
samples = []
|
||||
for r, _, fns in sorted(os.walk(self.root, followlinks=True)):
|
||||
for fn in sorted(fns):
|
||||
path = os.path.join(r, fn)
|
||||
if has_file_allowed_extension(path, IMG_EXTENSIONS):
|
||||
samples.append(path)
|
||||
self.samples = samples
|
||||
self.pipeline = transform_pipeline(pipeline)
|
||||
self.with_path = with_path
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
output = dict(img=self.pipeline(self.samples[idx]))
|
||||
if self.with_path:
|
||||
output["path"] = self.samples[idx]
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f"<SingleFolderDataset root={self.root} len={len(self)} with_path={self.with_path}>"
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDataset(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, with_path=False):
|
||||
self.A = SingleFolderDataset(root_a, pipeline, with_path)
|
||||
self.B = SingleFolderDataset(root_b, pipeline, with_path)
|
||||
self.with_path = with_path
|
||||
self.random_pair = random_pair
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||
output_a = self.A[a_idx]
|
||||
output_b = self.B[b_idx]
|
||||
output = dict(a=output_a["img"], b=output_b["img"])
|
||||
if self.with_path:
|
||||
output["a_path"] = output_a["path"]
|
||||
output["b_path"] = output_b["path"]
|
||||
return output
|
||||
|
||||
def __len__(self):
|
||||
return max(len(self.A), len(self.B))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GenerationUnpairedDataset:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
||||
65
data/dataset/lmdb.py
Normal file
65
data/dataset/lmdb.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import lmdb
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from data.transform import transform_pipeline
|
||||
|
||||
|
||||
def default_transform_way(transform, sample):
|
||||
return [transform(sample[0]), *sample[1:]]
|
||||
|
||||
|
||||
class LMDBDataset(Dataset):
|
||||
def __init__(self, lmdb_path, pipeline=None, transform_way=default_transform_way, map_size=2 ** 40, readonly=True,
|
||||
**lmdb_kwargs):
|
||||
self.path = lmdb_path
|
||||
self.db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), map_size=map_size, readonly=readonly,
|
||||
lock=False, **lmdb_kwargs)
|
||||
|
||||
with self.db.begin(write=False) as txn:
|
||||
self._len = pickle.loads(txn.get(b"$$len$$"))
|
||||
self.done_pipeline = pickle.loads(txn.get(b"$$done_pipeline$$"))
|
||||
if pipeline is None:
|
||||
self.not_done_pipeline = []
|
||||
else:
|
||||
self.not_done_pipeline = self._remain_pipeline(pipeline)
|
||||
self.transform = transform_pipeline(self.not_done_pipeline)
|
||||
self.transform_way = transform_way
|
||||
essential_attr = pickle.loads(txn.get(b"$$essential_attr$$"))
|
||||
for ea in essential_attr:
|
||||
setattr(self, ea, pickle.loads(txn.get(f"${ea}$".encode(encoding="utf-8"))))
|
||||
|
||||
def _remain_pipeline(self, pipeline):
|
||||
for i, dp in enumerate(self.done_pipeline):
|
||||
if pipeline[i] != dp:
|
||||
raise ValueError(
|
||||
f"pipeline {self.done_pipeline} saved in this lmdb database is not match with pipeline:{pipeline}")
|
||||
return pipeline[len(self.done_pipeline):]
|
||||
|
||||
def __repr__(self):
|
||||
return f"LMDBDataset: {self.path}\nlength: {len(self)}\n{self.transform}"
|
||||
|
||||
def __len__(self):
|
||||
return self._len
|
||||
|
||||
def __getitem__(self, idx):
|
||||
with self.db.begin(write=False) as txn:
|
||||
sample = pickle.loads(txn.get("{}".format(idx).encode()))
|
||||
sample = self.transform_way(self.transform, sample)
|
||||
return sample
|
||||
|
||||
@staticmethod
|
||||
def lmdbify(dataset, done_pipeline, lmdb_path):
|
||||
env = lmdb.open(lmdb_path, map_size=2 ** 40, subdir=os.path.isdir(lmdb_path))
|
||||
with env.begin(write=True) as txn:
|
||||
for i in tqdm(range(len(dataset)), ncols=0):
|
||||
txn.put("{}".format(i).encode(), pickle.dumps(dataset[i]))
|
||||
txn.put(b"$$len$$", pickle.dumps(len(dataset)))
|
||||
txn.put(b"$$done_pipeline$$", pickle.dumps(done_pipeline))
|
||||
essential_attr = getattr(dataset, "essential_attr", list())
|
||||
txn.put(b"$$essential_attr$$", pickle.dumps(essential_attr))
|
||||
for ea in essential_attr:
|
||||
txn.put(f"${ea}$".encode(encoding="utf-8"), pickle.dumps(getattr(dataset, ea)))
|
||||
122
data/dataset/pose_transfer.py
Normal file
122
data/dataset/pose_transfer.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from collections import defaultdict
|
||||
from itertools import permutations, combinations
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
from data.registry import DATASET
|
||||
from data.transform import transform_pipeline
|
||||
from data.util import dlib_landmark
|
||||
|
||||
|
||||
def normalize_tensor(tensor):
|
||||
tensor = tensor.float()
|
||||
tensor -= tensor.min()
|
||||
tensor /= tensor.max()
|
||||
return tensor
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class GenerationUnpairedDatasetWithEdge(Dataset):
|
||||
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256),
|
||||
with_path=False):
|
||||
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
|
||||
self.edge_type = edge_type
|
||||
self.size = size
|
||||
self.edges_path = Path(edges_path)
|
||||
self.landmarks_path = Path(landmarks_path)
|
||||
assert self.edges_path.exists()
|
||||
assert self.landmarks_path.exists()
|
||||
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
|
||||
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
|
||||
self.random_pair = random_pair
|
||||
self.with_path = with_path
|
||||
|
||||
def get_edge(self, origin_path):
|
||||
op = Path(origin_path)
|
||||
if self.edge_type.startswith("landmark_"):
|
||||
edge_type = self.edge_type.lstrip("landmark_")
|
||||
use_landmark = op.parent.name.endswith("A")
|
||||
else:
|
||||
edge_type = self.edge_type
|
||||
use_landmark = False
|
||||
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
|
||||
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR))
|
||||
if not use_landmark:
|
||||
return origin_edge
|
||||
else:
|
||||
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt"
|
||||
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
|
||||
|
||||
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size)))
|
||||
part_labels = normalize_tensor(torch.from_numpy(part_labels))
|
||||
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
|
||||
# edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
|
||||
# edges = part_edge + edges
|
||||
return torch.cat([origin_edge, part_edge, dist_tensor, part_labels])
|
||||
|
||||
def __getitem__(self, idx):
|
||||
a_idx = idx % len(self.A)
|
||||
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
|
||||
output = dict(a={}, b={})
|
||||
output["a"]["img"], output["a"]["path"] = self.A[a_idx]
|
||||
output["b"]["img"], output["b"]["path"] = self.B[b_idx]
|
||||
for p in "ab":
|
||||
output[p]["edge"] = self.get_edge(output[p]["path"])
|
||||
return output
|
||||
|
||||
def __len__(self):
|
||||
return max(len(self.A), len(self.B))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
|
||||
|
||||
|
||||
@DATASET.register_module()
|
||||
class PoseFacesWithSingleAnime(Dataset):
|
||||
def __init__(self, root_face, root_anime, landmark_path, num_face, face_pipeline, anime_pipeline, img_size,
|
||||
with_order=True):
|
||||
self.num_face = num_face
|
||||
self.landmark_path = Path(landmark_path)
|
||||
self.with_order = with_order
|
||||
self.root_face = Path(root_face)
|
||||
self.root_anime = Path(root_anime)
|
||||
self.img_size = img_size
|
||||
self.face_samples = self.iter_folders()
|
||||
self.face_pipeline = transform_pipeline(face_pipeline)
|
||||
self.B = SingleFolderDataset(root_anime, anime_pipeline, with_path=True)
|
||||
|
||||
def iter_folders(self):
|
||||
pics_per_person = defaultdict(list)
|
||||
for p in self.root_face.glob("*.jpg"):
|
||||
pics_per_person[p.stem[:7]].append(p.stem)
|
||||
data = []
|
||||
for p in pics_per_person:
|
||||
if len(pics_per_person[p]) >= self.num_face:
|
||||
if self.with_order:
|
||||
data.extend(list(combinations(pics_per_person[p], self.num_face)))
|
||||
else:
|
||||
data.extend(list(permutations(pics_per_person[p], self.num_face)))
|
||||
return data
|
||||
|
||||
def read_pose(self, pose_txt):
|
||||
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(pose_txt, size=self.img_size)
|
||||
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.img_size)))
|
||||
part_labels = normalize_tensor(torch.from_numpy(part_labels))
|
||||
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
|
||||
return torch.cat([part_labels, part_edge, dist_tensor])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.face_samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
output = dict()
|
||||
output["anime_img"], output["anime_path"] = self.B[torch.randint(len(self.B), (1,)).item()]
|
||||
for i, f in enumerate(self.face_samples[idx]):
|
||||
output[f"face_{i}"] = self.face_pipeline(self.root_face / f"{f}.jpg")
|
||||
output[f"pose_{i}"] = self.read_pose(self.landmark_path / self.root_face.name / f"{f}.txt")
|
||||
output[f"stem_{i}"] = f
|
||||
return output
|
||||
@@ -28,7 +28,7 @@ class Load:
|
||||
|
||||
|
||||
def transform_pipeline(pipeline_description):
|
||||
if len(pipeline_description) == 0:
|
||||
if pipeline_description is None or len(pipeline_description) == 0:
|
||||
return lambda x: x
|
||||
transform_list = [TRANSFORM.build_with(pd) for pd in pipeline_description]
|
||||
return transforms.Compose(transform_list)
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import ignite.distributed as idist
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import ignite.distributed as idist
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from loss.gan import GANLoss
|
||||
from model.GAN.UGATIT import RhoClipper
|
||||
from model.GAN.base import GANImageBuffer
|
||||
from util.image import attention_colored_map
|
||||
from engine.base.i2i import EngineKernel, run_kernel, TestEngineKernel
|
||||
from engine.util.build import build_model
|
||||
from engine.util.container import LossContainer
|
||||
from loss.I2I.minimal_geometry_distortion_constraint_loss import MyLoss
|
||||
from loss.gan import GANLoss
|
||||
from model.image_translation.UGATIT import RhoClipper
|
||||
from util.image import attention_colored_map
|
||||
|
||||
|
||||
def mse_loss(x, target_flag):
|
||||
@@ -28,11 +28,11 @@ class UGATITEngineKernel(EngineKernel):
|
||||
gan_loss_cfg = OmegaConf.to_container(config.loss.gan)
|
||||
gan_loss_cfg.pop("weight")
|
||||
self.gan_loss = GANLoss(**gan_loss_cfg).to(idist.device())
|
||||
self.cycle_loss = nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss()
|
||||
self.id_loss = nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss()
|
||||
self.cycle_loss = LossContainer(config.loss.cycle.weight,
|
||||
nn.L1Loss() if config.loss.cycle.level == 1 else nn.MSELoss())
|
||||
self.mgc_loss = LossContainer(config.loss.mgc.weight, MyLoss())
|
||||
self.id_loss = LossContainer(config.loss.id.weight, nn.L1Loss() if config.loss.id.level == 1 else nn.MSELoss())
|
||||
self.rho_clipper = RhoClipper(0, 1)
|
||||
self.image_buffers = {k: GANImageBuffer(config.data.train.buffer_size or 50) for k in
|
||||
self.discriminators.keys()}
|
||||
self.train_generator_first = False
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
@@ -79,9 +79,9 @@ class UGATITEngineKernel(EngineKernel):
|
||||
loss = dict()
|
||||
for phase in "ab":
|
||||
cycle_image = generated["images"]["a2b2a" if phase == "a" else "b2a2b"]
|
||||
loss[f"cycle_{phase}"] = self.config.loss.cycle.weight * self.cycle_loss(cycle_image, batch[phase])
|
||||
loss[f"id_{phase}"] = self.config.loss.id.weight * self.id_loss(batch[phase],
|
||||
generated["images"][f"{phase}2{phase}"])
|
||||
loss[f"cycle_{phase}"] = self.cycle_loss(cycle_image, batch[phase])
|
||||
loss[f"id_{phase}"] = self.id_loss(batch[phase], generated["images"][f"{phase}2{phase}"])
|
||||
loss[f"mgc_{phase}"] = self.mgc_loss(batch[phase], generated["images"]["a2b" if phase == "a" else "b2a"])
|
||||
for dk in "lg":
|
||||
generated_image = generated["images"]["a2b" if phase == "b" else "b2a"]
|
||||
pred_fake, cam_pred = self.discriminators[dk + phase](generated_image)
|
||||
|
||||
@@ -64,7 +64,7 @@ class EngineKernel(object):
|
||||
self.engine = engine
|
||||
|
||||
def build_models(self) -> (dict, dict):
|
||||
raise NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def to_save(self):
|
||||
to_save = {}
|
||||
@@ -73,19 +73,19 @@ class EngineKernel(object):
|
||||
return to_save
|
||||
|
||||
def setup_after_g(self):
|
||||
raise NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def setup_before_g(self):
|
||||
raise NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, batch, inference=False) -> dict:
|
||||
raise NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def criterion_generators(self, batch, generated) -> dict:
|
||||
raise NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def criterion_discriminators(self, batch, generated) -> dict:
|
||||
raise NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def intermediate_images(self, batch, generated) -> dict:
|
||||
"""
|
||||
@@ -94,12 +94,19 @@ class EngineKernel(object):
|
||||
:param generated: dict of images
|
||||
:return: dict like: {"a": [img1, img2, ...], "b": [img3, img4, ...]}
|
||||
"""
|
||||
raise NotImplemented
|
||||
raise NotImplementedError
|
||||
|
||||
def change_engine(self, config, engine: Engine):
|
||||
pass
|
||||
|
||||
|
||||
def _remove_no_grad_loss(loss_dict):
|
||||
for k in loss_dict:
|
||||
if not isinstance(loss_dict[k], torch.Tensor):
|
||||
loss_dict.pop(k)
|
||||
return loss_dict
|
||||
|
||||
|
||||
def get_trainer(config, kernel: EngineKernel):
|
||||
logger = logging.getLogger(config.name)
|
||||
generators, discriminators = kernel.generators, kernel.discriminators
|
||||
@@ -147,10 +154,10 @@ def get_trainer(config, kernel: EngineKernel):
|
||||
|
||||
if engine.state.iteration % iteration_per_image == 0:
|
||||
return {
|
||||
"loss": dict(g=loss_g, d=loss_d),
|
||||
"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d)),
|
||||
"img": kernel.intermediate_images(batch, generated)
|
||||
}
|
||||
return {"loss": dict(g=loss_g, d=loss_d)}
|
||||
return {"loss": dict(g=_remove_no_grad_loss(loss_g), d=_remove_no_grad_loss(loss_d))}
|
||||
|
||||
trainer = Engine(_step)
|
||||
trainer.logger = logger
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
import torch
|
||||
import ignite.distributed as idist
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from model import MODEL
|
||||
import torch.optim as optim
|
||||
from util.misc import add_spectral_norm
|
||||
|
||||
|
||||
def build_model(cfg):
|
||||
cfg = OmegaConf.to_container(cfg)
|
||||
bn_to_sync_bn = cfg.pop("_bn_to_sync_bn", False)
|
||||
add_spectral_norm_flag = cfg.pop("_add_spectral_norm", False)
|
||||
model = MODEL.build_with(cfg)
|
||||
if bn_to_sync_bn:
|
||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||
if add_spectral_norm_flag:
|
||||
model.apply(add_spectral_norm)
|
||||
return idist.auto_model(model)
|
||||
|
||||
|
||||
|
||||
9
engine/util/container.py
Normal file
9
engine/util/container.py
Normal file
@@ -0,0 +1,9 @@
|
||||
class LossContainer:
|
||||
def __init__(self, weight, loss):
|
||||
self.weight = weight
|
||||
self.loss = loss
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self.weight > 0:
|
||||
return self.weight * self.loss(*args, **kwargs)
|
||||
return 0.0
|
||||
111
loss/I2I/minimal_geometry_distortion_constraint_loss.py
Normal file
111
loss/I2I/minimal_geometry_distortion_constraint_loss.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def gaussian_radial_basis_function(x, mu, sigma):
|
||||
# (kernel_size) -> (batch_size, kernel_size, c*h*w)
|
||||
mu = mu.view(1, mu.size(0), 1).expand(x.size(0), -1, x.size(1) * x.size(2) * x.size(3))
|
||||
mu = mu.to(x.device)
|
||||
# (batch_size, c, h, w) -> (batch_size, kernel_size, c*h*w)
|
||||
x = x.view(x.size(0), 1, -1).expand(-1, mu.size(1), -1)
|
||||
return torch.exp((x - mu).pow(2) / (2 * sigma ** 2))
|
||||
|
||||
|
||||
class MyLoss(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyLoss, self).__init__()
|
||||
|
||||
def forward(self, fakeI, realI):
|
||||
def batch_ERSMI(I1, I2):
|
||||
batch_size = I1.shape[0]
|
||||
img_size = I1.shape[1] * I1.shape[2] * I1.shape[3]
|
||||
if I2.shape[1] == 1 and I1.shape[1] != 1:
|
||||
I2 = I2.repeat(1, 3, 1, 1)
|
||||
|
||||
def kernel_F(y, mu_list, sigma):
|
||||
tmp_mu = mu_list.view(-1, 1).repeat(1, img_size).repeat(batch_size, 1, 1).cuda() # [81, 784]
|
||||
tmp_y = y.view(batch_size, 1, -1).repeat(1, 81, 1)
|
||||
tmp_y = tmp_mu - tmp_y
|
||||
mat_L = torch.exp(tmp_y.pow(2) / (2 * sigma ** 2))
|
||||
return mat_L
|
||||
|
||||
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0]).cuda()
|
||||
|
||||
x_mu_list = mu.repeat(9).view(-1, 81)
|
||||
y_mu_list = mu.unsqueeze(0).t().repeat(1, 9).view(-1, 81)
|
||||
|
||||
mat_K = kernel_F(I1, x_mu_list, 1)
|
||||
mat_L = kernel_F(I2, y_mu_list, 1)
|
||||
|
||||
H1 = ((mat_K.matmul(mat_K.transpose(1, 2))).mul(mat_L.matmul(mat_L.transpose(1, 2))) / (
|
||||
img_size ** 2)).cuda()
|
||||
H2 = ((mat_K.mul(mat_L)).matmul((mat_K.mul(mat_L)).transpose(1, 2)) / img_size).cuda()
|
||||
h2 = ((mat_K.sum(2).view(batch_size, -1, 1)).mul(mat_L.sum(2).view(batch_size, -1, 1)) / (
|
||||
img_size ** 2)).cuda()
|
||||
H2 = 0.5 * H1 + 0.5 * H2
|
||||
tmp = H2 + 0.05 * torch.eye(len(H2[0])).cuda()
|
||||
alpha = (tmp.inverse())
|
||||
|
||||
alpha = alpha.matmul(h2)
|
||||
ersmi = (2 * (alpha.transpose(1, 2)).matmul(h2) - ((alpha.transpose(1, 2)).matmul(H2)).matmul(
|
||||
alpha) - 1).squeeze()
|
||||
ersmi = -ersmi.mean()
|
||||
return ersmi
|
||||
|
||||
batch_loss = batch_ERSMI(fakeI, realI)
|
||||
return batch_loss
|
||||
|
||||
|
||||
class MGCLoss(nn.Module):
|
||||
"""
|
||||
Minimal Geometry-Distortion Constraint Loss from https://openreview.net/forum?id=R5M7Mxl1xZ
|
||||
"""
|
||||
|
||||
def __init__(self, beta=0.5, lambda_=0.05):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.lambda_ = lambda_
|
||||
mu = torch.Tensor([-1.0, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1.0])
|
||||
self.mu_x = mu.repeat(9)
|
||||
self.mu_y = mu.unsqueeze(0).t().repeat(1, 9).view(-1)
|
||||
|
||||
@staticmethod
|
||||
def batch_rSMI(img1, img2, mu_x, mu_y, beta, lambda_):
|
||||
assert img1.size() == img2.size()
|
||||
|
||||
num_pixel = img1.size(1) * img1.size(2) * img2.size(3)
|
||||
|
||||
mat_k = gaussian_radial_basis_function(img1, mu_x, sigma=1)
|
||||
mat_l = gaussian_radial_basis_function(img2, mu_y, sigma=1)
|
||||
|
||||
mat_k_mul_mat_l = mat_k * mat_l
|
||||
h_hat = (1 - beta) * (mat_k_mul_mat_l.matmul(mat_k_mul_mat_l.transpose(1, 2))) / num_pixel
|
||||
h_hat += beta * (mat_k.matmul(mat_k.transpose(1, 2)) * mat_l.matmul(mat_l.transpose(1, 2))) / (num_pixel ** 2)
|
||||
small_h_hat = mat_k.sum(2, keepdim=True) * mat_l.sum(2, keepdim=True) / (num_pixel ** 2)
|
||||
|
||||
R = torch.eye(h_hat.size(1)).to(img1.device)
|
||||
alpha = (h_hat + lambda_ * R).inverse().matmul(small_h_hat)
|
||||
|
||||
rSMI = (2 * alpha.transpose(1, 2).matmul(small_h_hat)) - alpha.transpose(1, 2).matmul(h_hat).matmul(alpha) - 1
|
||||
return rSMI
|
||||
|
||||
def forward(self, fake, real):
|
||||
rSMI = self.batch_rSMI(fake, real, self.mu_x, self.mu_y, self.beta, self.lambda_)
|
||||
return -rSMI.squeeze().mean()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mg = MGCLoss().to("cuda")
|
||||
|
||||
|
||||
def norm(x):
|
||||
x -= x.min()
|
||||
x /= x.max()
|
||||
return (x - 0.5) * 2
|
||||
|
||||
|
||||
x1 = norm(torch.randn(5, 3, 256, 256))
|
||||
x2 = norm(x1 * 2 + 1)
|
||||
x3 = norm(torch.randn(5, 3, 256, 256))
|
||||
x4 = norm(torch.exp(x3))
|
||||
print(mg(x1, x1), mg(x1, x2), mg(x1, x3), mg(x1, x4))
|
||||
@@ -1,62 +0,0 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from model.normalization import select_norm_layer
|
||||
from model.registry import MODEL
|
||||
from .base import ResidualBlock
|
||||
|
||||
|
||||
@MODEL.register_module("CyCle-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=9, padding_mode='reflect',
|
||||
norm_type="IN"):
|
||||
super(Generator, self).__init__()
|
||||
assert num_blocks >= 0, f'Number of residual blocks must be non-negative, but got {num_blocks}.'
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
use_bias = norm_type == "IN"
|
||||
|
||||
self.start_conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding_mode=padding_mode, padding=3,
|
||||
bias=use_bias),
|
||||
norm_layer(num_features=base_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
# down sampling
|
||||
submodules = []
|
||||
num_down_sampling = 2
|
||||
for i in range(num_down_sampling):
|
||||
multiple = 2 ** i
|
||||
submodules += [
|
||||
nn.Conv2d(in_channels=base_channels * multiple, out_channels=base_channels * multiple * 2,
|
||||
kernel_size=3, stride=2, padding=1, bias=use_bias),
|
||||
norm_layer(num_features=base_channels * multiple * 2),
|
||||
nn.ReLU(inplace=True)
|
||||
]
|
||||
self.encoder = nn.Sequential(*submodules)
|
||||
|
||||
res_block_channels = num_down_sampling ** 2 * base_channels
|
||||
self.resnet_middle = nn.Sequential(
|
||||
*[ResidualBlock(res_block_channels, padding_mode, norm_type) for _ in
|
||||
range(num_blocks)])
|
||||
|
||||
# up sampling
|
||||
submodules = []
|
||||
for i in range(num_down_sampling):
|
||||
multiple = 2 ** (num_down_sampling - i)
|
||||
submodules += [
|
||||
nn.ConvTranspose2d(base_channels * multiple, base_channels * multiple // 2, kernel_size=3, stride=2,
|
||||
padding=1, output_padding=1, bias=use_bias),
|
||||
norm_layer(num_features=base_channels * multiple // 2),
|
||||
nn.ReLU(inplace=True),
|
||||
]
|
||||
self.decoder = nn.Sequential(*submodules)
|
||||
|
||||
self.end_conv = nn.Sequential(
|
||||
nn.Conv2d(base_channels, out_channels, kernel_size=7, padding=3, padding_mode=padding_mode),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(self.start_conv(x))
|
||||
x = self.resnet_middle(x)
|
||||
return self.end_conv(self.decoder(x))
|
||||
@@ -1,150 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from model import MODEL
|
||||
from model.GAN.base import Conv2dBlock, ResBlock
|
||||
from model.normalization import select_norm_layer
|
||||
|
||||
|
||||
class StyleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, out_dim, num_conv, base_channels=64, use_spectral_norm=False,
|
||||
max_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE"):
|
||||
super(StyleEncoder, self).__init__()
|
||||
|
||||
sequence = [Conv2dBlock(
|
||||
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
||||
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
|
||||
)]
|
||||
|
||||
multiple_now = 1
|
||||
for i in range(1, num_conv + 1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** max_multiple)
|
||||
sequence.append(Conv2dBlock(
|
||||
multiple_prev * base_channels, multiple_now * base_channels,
|
||||
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
|
||||
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
|
||||
))
|
||||
sequence.append(nn.AdaptiveAvgPool2d(1))
|
||||
# conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code
|
||||
sequence.append(nn.Conv2d(multiple_now * base_channels, out_dim, kernel_size=1, stride=1, padding=0))
|
||||
self.model = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x).view(x.size(0), -1)
|
||||
|
||||
|
||||
class ContentEncoder(nn.Module):
|
||||
def __init__(self, in_channels, num_down_sampling, num_res_blocks, base_channels=64, use_spectral_norm=False,
|
||||
padding_mode='reflect', activation_type="ReLU", norm_type="IN"):
|
||||
super().__init__()
|
||||
sequence = [Conv2dBlock(
|
||||
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
||||
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
|
||||
)]
|
||||
|
||||
for i in range(num_down_sampling):
|
||||
sequence.append(Conv2dBlock(
|
||||
base_channels * (2 ** i), base_channels * (2 ** (i + 1)),
|
||||
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
|
||||
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
|
||||
))
|
||||
|
||||
sequence += [ResBlock(base_channels * (2 ** num_down_sampling), use_spectral_norm, padding_mode, norm_type,
|
||||
activation_type) for _ in range(num_res_blocks)]
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, x):
|
||||
return self.sequence(x)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_up_sampling, num_res_blocks,
|
||||
use_spectral_norm=False, res_norm_type="AdaIN", norm_type="LN", activation_type="ReLU",
|
||||
padding_mode='reflect'):
|
||||
super(Decoder, self).__init__()
|
||||
self.res_norm_type = res_norm_type
|
||||
self.res_blocks = nn.ModuleList([
|
||||
ResBlock(in_channels, use_spectral_norm, padding_mode, res_norm_type, activation_type=activation_type)
|
||||
for _ in range(num_res_blocks)
|
||||
])
|
||||
sequence = list()
|
||||
channels = in_channels
|
||||
for i in range(num_up_sampling):
|
||||
sequence.append(nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
Conv2dBlock(channels, channels // 2,
|
||||
kernel_size=5, stride=1, padding=2, padding_mode=padding_mode,
|
||||
use_spectral_norm=use_spectral_norm, activation_type=activation_type, norm_type=norm_type
|
||||
),
|
||||
))
|
||||
channels = channels // 2
|
||||
sequence.append(
|
||||
Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect",
|
||||
use_spectral_norm=use_spectral_norm, activation_type="Tanh", norm_type="NONE"))
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, x):
|
||||
for blk in self.res_blocks:
|
||||
x = blk(x)
|
||||
return self.sequence(x)
|
||||
|
||||
|
||||
class Fusion(nn.Module):
|
||||
def __init__(self, in_features, out_features, base_features, n_blocks, norm_type="NONE"):
|
||||
super().__init__()
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
self.start_fc = nn.Sequential(
|
||||
nn.Linear(in_features, base_features),
|
||||
norm_layer(base_features),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
self.fcs = nn.Sequential(*[
|
||||
nn.Sequential(
|
||||
nn.Linear(base_features, base_features),
|
||||
norm_layer(base_features),
|
||||
nn.ReLU(True),
|
||||
) for _ in range(n_blocks - 2)
|
||||
])
|
||||
self.end_fc = nn.Sequential(
|
||||
nn.Linear(base_features, out_features),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.start_fc(x)
|
||||
x = self.fcs(x)
|
||||
return self.end_fc(x)
|
||||
|
||||
|
||||
@MODEL.register_module("MUNIT-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels, num_sampling, num_style_dim, num_style_conv,
|
||||
num_content_res_blocks, num_decoder_res_blocks, num_fusion_dim, num_fusion_blocks,
|
||||
use_spectral_norm=False, activation_type="ReLU", padding_mode='reflect'):
|
||||
super().__init__()
|
||||
self.num_decoder_res_blocks = num_decoder_res_blocks
|
||||
self.content_encoder = ContentEncoder(in_channels, num_sampling, num_content_res_blocks, base_channels,
|
||||
use_spectral_norm, padding_mode, activation_type, norm_type="IN")
|
||||
self.style_encoder = StyleEncoder(in_channels, num_style_dim, num_style_conv, base_channels, use_spectral_norm,
|
||||
padding_mode, activation_type, norm_type="NONE")
|
||||
content_channels = base_channels * (2 ** 2)
|
||||
self.decoder = Decoder(content_channels, out_channels, num_sampling,
|
||||
num_decoder_res_blocks, use_spectral_norm, "AdaIN", norm_type="LN",
|
||||
activation_type=activation_type, padding_mode=padding_mode)
|
||||
self.fusion = Fusion(num_style_dim, num_decoder_res_blocks * 2 * content_channels * 2,
|
||||
base_features=num_fusion_dim, n_blocks=num_fusion_blocks, norm_type="NONE")
|
||||
|
||||
def encode(self, x):
|
||||
return self.content_encoder(x), self.style_encoder(x)
|
||||
|
||||
def decode(self, content, style):
|
||||
as_param_style = torch.chunk(self.fusion(style), self.num_decoder_res_blocks * 2, dim=1)
|
||||
# set style for decoder
|
||||
for i, blk in enumerate(self.decoder.res_blocks):
|
||||
blk.conv1.normalization.set_style(as_param_style[2 * i])
|
||||
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
|
||||
return self.decoder(content)
|
||||
|
||||
def forward(self, x):
|
||||
content, style = self.encode(x)
|
||||
return self.decode(content, style)
|
||||
@@ -1,171 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.models import vgg19
|
||||
|
||||
from model.normalization import select_norm_layer
|
||||
from model.registry import MODEL
|
||||
from .MUNIT import ContentEncoder, Fusion, Decoder, StyleEncoder
|
||||
from .base import ResBlock
|
||||
|
||||
|
||||
class VGG19StyleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, style_dim=512, padding_mode='reflect', norm_type="NONE",
|
||||
vgg19_layers=(0, 5, 10, 19), fix_vgg19=True):
|
||||
super().__init__()
|
||||
self.vgg19_layers = vgg19_layers
|
||||
self.vgg19 = vgg19(pretrained=True).features[:vgg19_layers[-1] + 1]
|
||||
self.vgg19.requires_grad_(not fix_vgg19)
|
||||
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
|
||||
self.conv0 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
||||
bias=True),
|
||||
norm_layer(base_channels),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
self.conv = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
nn.Conv2d(base_channels * (2 ** i), base_channels * (2 ** i), kernel_size=4, stride=2, padding=1,
|
||||
padding_mode=padding_mode, bias=True),
|
||||
norm_layer(base_channels),
|
||||
nn.ReLU(True),
|
||||
) for i in range(1, 4)
|
||||
])
|
||||
self.pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv1x1 = nn.Conv2d(base_channels * (2 ** 4), style_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def fixed_style_features(self, x):
|
||||
features = []
|
||||
for i in range(len(self.vgg19)):
|
||||
x = self.vgg19[i](x)
|
||||
if i in self.vgg19_layers:
|
||||
features.append(x)
|
||||
return features
|
||||
|
||||
def forward(self, x):
|
||||
fsf = self.fixed_style_features(x)
|
||||
x = self.conv0(x)
|
||||
for i, l in enumerate(self.conv):
|
||||
x = l(torch.cat([x, fsf[i]], dim=1))
|
||||
x = self.pool(torch.cat([x, fsf[-1]], dim=1))
|
||||
x = self.conv1x1(x)
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
@MODEL.register_module("TAFG-ResGenerator")
|
||||
class ResGenerator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=3, use_spectral_norm=False, num_res_blocks=8, base_channels=64):
|
||||
super().__init__()
|
||||
self.content_encoder = ContentEncoder(in_channels, 2, num_res_blocks=num_res_blocks,
|
||||
use_spectral_norm=use_spectral_norm)
|
||||
resnet_channels = 2 ** 2 * base_channels
|
||||
self.decoder = Decoder(resnet_channels, out_channels, 2,
|
||||
0, use_spectral_norm, "IN", norm_type="LN", padding_mode="reflect")
|
||||
|
||||
def forward(self, x):
|
||||
return self.decoder(self.content_encoder(x))
|
||||
|
||||
|
||||
@MODEL.register_module("TAFG-SingleGenerator")
|
||||
class SingleGenerator(nn.Module):
|
||||
def __init__(self, style_in_channels, content_in_channels, out_channels=3, use_spectral_norm=False,
|
||||
style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8,
|
||||
num_res_blocks=8, base_channels=64, padding_mode="reflect"):
|
||||
super().__init__()
|
||||
self.num_adain_blocks = num_adain_blocks
|
||||
if style_encoder_type == "StyleEncoder":
|
||||
self.style_encoder = StyleEncoder(
|
||||
style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
|
||||
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"
|
||||
)
|
||||
elif style_encoder_type == "VGG19StyleEncoder":
|
||||
self.style_encoder = VGG19StyleEncoder(
|
||||
style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode, norm_type="NONE"
|
||||
)
|
||||
else:
|
||||
raise NotImplemented(f"do not support {style_encoder_type}")
|
||||
|
||||
resnet_channels = 2 ** 2 * base_channels
|
||||
self.style_converter = Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256,
|
||||
n_blocks=3, norm_type="NONE")
|
||||
self.content_encoder = ContentEncoder(content_in_channels, 2, num_res_blocks=num_res_blocks,
|
||||
use_spectral_norm=use_spectral_norm)
|
||||
|
||||
self.decoder = Decoder(resnet_channels, out_channels, 2,
|
||||
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode)
|
||||
|
||||
def forward(self, content_img, style_img):
|
||||
content = self.content_encoder(content_img)
|
||||
style = self.style_encoder(style_img)
|
||||
as_param_style = torch.chunk(self.style_converter(style), self.num_adain_blocks * 2, dim=1)
|
||||
# set style for decoder
|
||||
for i, blk in enumerate(self.decoder.res_blocks):
|
||||
blk.conv1.normalization.set_style(as_param_style[2 * i])
|
||||
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
|
||||
return self.decoder(content)
|
||||
|
||||
|
||||
@MODEL.register_module("TAFG-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, style_in_channels, content_in_channels=3, out_channels=3, use_spectral_norm=False,
|
||||
style_encoder_type="StyleEncoder", num_style_conv=4, style_dim=512, num_adain_blocks=8,
|
||||
num_res_blocks=8, base_channels=64, padding_mode="reflect"):
|
||||
super(Generator, self).__init__()
|
||||
self.num_adain_blocks = num_adain_blocks
|
||||
if style_encoder_type == "StyleEncoder":
|
||||
self.style_encoders = nn.ModuleDict(dict(
|
||||
a=StyleEncoder(style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
|
||||
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"),
|
||||
b=StyleEncoder(style_in_channels, style_dim, num_style_conv, base_channels, use_spectral_norm,
|
||||
max_multiple=4, padding_mode=padding_mode, norm_type="NONE"),
|
||||
))
|
||||
elif style_encoder_type == "VGG19StyleEncoder":
|
||||
self.style_encoders = nn.ModuleDict(dict(
|
||||
a=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
|
||||
norm_type="NONE"),
|
||||
b=VGG19StyleEncoder(style_in_channels, base_channels, style_dim=style_dim, padding_mode=padding_mode,
|
||||
norm_type="NONE", fix_vgg19=False)
|
||||
))
|
||||
else:
|
||||
raise NotImplemented(f"do not support {style_encoder_type}")
|
||||
resnet_channels = 2 ** 2 * base_channels
|
||||
self.style_converters = nn.ModuleDict(dict(
|
||||
a=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
|
||||
norm_type="NONE"),
|
||||
b=Fusion(style_dim, num_adain_blocks * 2 * resnet_channels * 2, base_features=256, n_blocks=3,
|
||||
norm_type="NONE"),
|
||||
))
|
||||
self.content_encoders = nn.ModuleDict({
|
||||
"a": ContentEncoder(content_in_channels, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm),
|
||||
"b": ContentEncoder(1, 2, num_res_blocks=0, use_spectral_norm=use_spectral_norm)
|
||||
})
|
||||
|
||||
self.content_resnet = nn.Sequential(*[
|
||||
ResBlock(resnet_channels, use_spectral_norm, padding_mode, "IN")
|
||||
for _ in range(num_res_blocks)
|
||||
])
|
||||
self.decoders = nn.ModuleDict(dict(
|
||||
a=Decoder(resnet_channels, out_channels, 2,
|
||||
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode),
|
||||
b=Decoder(resnet_channels, out_channels, 2,
|
||||
num_adain_blocks, use_spectral_norm, "AdaIN", norm_type="LN", padding_mode=padding_mode),
|
||||
))
|
||||
|
||||
def encode(self, content_img, style_img, which_content, which_style):
|
||||
content = self.content_resnet(self.content_encoders[which_content](content_img))
|
||||
style = self.style_encoders[which_style](style_img)
|
||||
return content, style
|
||||
|
||||
def decode(self, content, style, which):
|
||||
decoder = self.decoders[which]
|
||||
as_param_style = torch.chunk(self.style_converters[which](style), self.num_adain_blocks * 2, dim=1)
|
||||
# set style for decoder
|
||||
for i, blk in enumerate(decoder.res_blocks):
|
||||
blk.conv1.normalization.set_style(as_param_style[2 * i])
|
||||
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
|
||||
return decoder(content)
|
||||
|
||||
def forward(self, content_img, style_img, which_content, which_style):
|
||||
content, style = self.encode(content_img, style_img, which_content, which_style)
|
||||
return self.decode(content, style, which_style)
|
||||
@@ -1,88 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model import MODEL
|
||||
from model.base.module import Conv2dBlock, ResidualBlock, ReverseResidualBlock
|
||||
|
||||
|
||||
class Interpolation(nn.Module):
|
||||
def __init__(self, scale_factor=None, mode='nearest', align_corners=None):
|
||||
super(Interpolation, self).__init__()
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
return F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners,
|
||||
recompute_scale_factor=False)
|
||||
|
||||
def __repr__(self):
|
||||
return f"DownSampling(scale_factor={self.scale_factor}, mode={self.mode}, align_corners={self.align_corners})"
|
||||
|
||||
|
||||
@MODEL.register_module("TSIT-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, content_in_channels=3, out_channels=3, base_channels=64, num_blocks=7,
|
||||
padding_mode="reflect", activation_type="ReLU"):
|
||||
super().__init__()
|
||||
self.num_blocks = num_blocks
|
||||
self.base_channels = base_channels
|
||||
|
||||
self.content_stream = self.build_stream(padding_mode, activation_type)
|
||||
self.start_conv = Conv2dBlock(content_in_channels, base_channels, activation_type=activation_type,
|
||||
norm_type="IN", kernel_size=7, padding_mode=padding_mode, padding=3)
|
||||
|
||||
sequence = []
|
||||
multiple_now = min(2 ** self.num_blocks, 2 ** 4)
|
||||
for i in range(1, self.num_blocks + 1):
|
||||
m = self.num_blocks - i
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** m, 2 ** 4)
|
||||
sequence.append(nn.Sequential(
|
||||
ReverseResidualBlock(
|
||||
multiple_prev * base_channels, multiple_now * base_channels,
|
||||
padding_mode=padding_mode, norm_type="FADE",
|
||||
additional_norm_kwargs=dict(
|
||||
condition_in_channels=multiple_prev * base_channels,
|
||||
base_norm_type="BN",
|
||||
padding_mode=padding_mode
|
||||
)
|
||||
),
|
||||
Interpolation(2, mode="nearest")
|
||||
))
|
||||
self.generator = nn.Sequential(*sequence)
|
||||
self.end_conv = Conv2dBlock(base_channels, out_channels, activation_type="Tanh",
|
||||
kernel_size=7, padding_mode=padding_mode, padding=3)
|
||||
|
||||
def build_stream(self, padding_mode, activation_type):
|
||||
multiple_now = 1
|
||||
stream_sequence = []
|
||||
for i in range(1, self.num_blocks + 1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** 4)
|
||||
stream_sequence.append(nn.Sequential(
|
||||
Interpolation(scale_factor=0.5, mode="nearest"),
|
||||
ResidualBlock(
|
||||
multiple_prev * self.base_channels, multiple_now * self.base_channels,
|
||||
padding_mode=padding_mode, activation_type=activation_type, norm_type="IN")
|
||||
))
|
||||
return nn.ModuleList(stream_sequence)
|
||||
|
||||
def forward(self, content, z=None):
|
||||
c = self.start_conv(content)
|
||||
content_features = []
|
||||
for i in range(self.num_blocks):
|
||||
c = self.content_stream[i](c)
|
||||
content_features.append(c)
|
||||
if z is None:
|
||||
z = torch.randn(size=content_features[-1].size(), device=content_features[-1].device)
|
||||
|
||||
for i in range(self.num_blocks):
|
||||
m = - i - 1
|
||||
res_block = self.generator[i][0]
|
||||
res_block.conv1.normalization.set_feature(content_features[m])
|
||||
res_block.conv2.normalization.set_feature(content_features[m])
|
||||
if res_block.learn_skip_connection:
|
||||
res_block.res_conv.normalization.set_feature(content_features[m])
|
||||
return self.end_conv(self.generator(z))
|
||||
@@ -1,236 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .base import ResidualBlock
|
||||
from model.registry import MODEL
|
||||
|
||||
|
||||
class RhoClipper(object):
|
||||
def __init__(self, clip_min, clip_max):
|
||||
self.clip_min = clip_min
|
||||
self.clip_max = clip_max
|
||||
assert clip_min < clip_max
|
||||
|
||||
def __call__(self, module):
|
||||
if hasattr(module, 'rho'):
|
||||
w = module.rho.data
|
||||
w = w.clamp(self.clip_min, self.clip_max)
|
||||
module.rho.data = w
|
||||
|
||||
|
||||
@MODEL.register_module("UGATIT-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False):
|
||||
assert (num_blocks >= 0)
|
||||
super(Generator, self).__init__()
|
||||
self.input_channels = in_channels
|
||||
self.output_channels = out_channels
|
||||
self.base_channels = base_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.img_size = img_size
|
||||
self.light = light
|
||||
|
||||
down_encoder = [nn.Conv2d(in_channels, base_channels, kernel_size=7, stride=1, padding=3,
|
||||
padding_mode="reflect", bias=False),
|
||||
nn.InstanceNorm2d(base_channels),
|
||||
nn.ReLU(True)]
|
||||
|
||||
n_down_sampling = 2
|
||||
for i in range(n_down_sampling):
|
||||
mult = 2 ** i
|
||||
down_encoder += [nn.Conv2d(base_channels * mult, base_channels * mult * 2, kernel_size=3, stride=2,
|
||||
padding=1, bias=False, padding_mode="reflect"),
|
||||
nn.InstanceNorm2d(base_channels * mult * 2),
|
||||
nn.ReLU(True)]
|
||||
|
||||
# Down-Sampling Bottleneck
|
||||
mult = 2 ** n_down_sampling
|
||||
for i in range(num_blocks):
|
||||
down_encoder += [ResidualBlock(base_channels * mult, use_bias=False)]
|
||||
self.down_encoder = nn.Sequential(*down_encoder)
|
||||
|
||||
# Class Activation Map
|
||||
self.gap_fc = nn.Linear(base_channels * mult, 1, bias=False)
|
||||
self.gmp_fc = nn.Linear(base_channels * mult, 1, bias=False)
|
||||
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
|
||||
self.relu = nn.ReLU(True)
|
||||
|
||||
# Gamma, Beta block
|
||||
if self.light:
|
||||
fc = [nn.Linear(base_channels * mult, base_channels * mult, bias=False),
|
||||
nn.ReLU(True),
|
||||
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
|
||||
nn.ReLU(True)]
|
||||
else:
|
||||
fc = [
|
||||
nn.Linear(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, bias=False),
|
||||
nn.ReLU(True),
|
||||
nn.Linear(base_channels * mult, base_channels * mult, bias=False),
|
||||
nn.ReLU(True)]
|
||||
self.fc = nn.Sequential(*fc)
|
||||
|
||||
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
||||
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
||||
|
||||
# Up-Sampling Bottleneck
|
||||
self.up_bottleneck = nn.ModuleList(
|
||||
[ResnetAdaILNBlock(base_channels * mult, use_bias=False) for _ in range(num_blocks)])
|
||||
|
||||
# Up-Sampling
|
||||
up_decoder = []
|
||||
for i in range(n_down_sampling):
|
||||
mult = 2 ** (n_down_sampling - i)
|
||||
up_decoder += [nn.Upsample(scale_factor=2, mode='nearest'),
|
||||
nn.Conv2d(base_channels * mult, base_channels * mult // 2, kernel_size=3, stride=1,
|
||||
padding=1, padding_mode="reflect", bias=False),
|
||||
ILN(base_channels * mult // 2),
|
||||
nn.ReLU(True)]
|
||||
|
||||
up_decoder += [nn.Conv2d(base_channels, out_channels, kernel_size=7, stride=1, padding=3,
|
||||
padding_mode="reflect", bias=False),
|
||||
nn.Tanh()]
|
||||
self.up_decoder = nn.Sequential(*up_decoder)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.down_encoder(x)
|
||||
|
||||
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
|
||||
gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
|
||||
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
|
||||
gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
|
||||
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
|
||||
|
||||
x = torch.cat([gap, gmp], 1)
|
||||
x = self.relu(self.conv1x1(x))
|
||||
|
||||
heatmap = torch.sum(x, dim=1, keepdim=True)
|
||||
|
||||
if self.light:
|
||||
x_ = torch.nn.functional.adaptive_avg_pool2d(x, 1)
|
||||
x_ = self.fc(x_.view(x_.shape[0], -1))
|
||||
else:
|
||||
x_ = self.fc(x.view(x.shape[0], -1))
|
||||
gamma, beta = self.gamma(x_), self.beta(x_)
|
||||
|
||||
for ub in self.up_bottleneck:
|
||||
x = ub(x, gamma, beta)
|
||||
|
||||
x = self.up_decoder(x)
|
||||
return x, cam_logit, heatmap
|
||||
|
||||
|
||||
class ResnetAdaILNBlock(nn.Module):
|
||||
def __init__(self, dim, use_bias):
|
||||
super(ResnetAdaILNBlock, self).__init__()
|
||||
self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
|
||||
self.norm1 = AdaILN(dim)
|
||||
self.relu1 = nn.ReLU(True)
|
||||
|
||||
self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=use_bias, padding_mode="reflect")
|
||||
self.norm2 = AdaILN(dim)
|
||||
|
||||
def forward(self, x, gamma, beta):
|
||||
out = self.conv1(x)
|
||||
out = self.norm1(out, gamma, beta)
|
||||
out = self.relu1(out)
|
||||
out = self.conv2(out)
|
||||
out = self.norm2(out, gamma, beta)
|
||||
|
||||
return out + x
|
||||
|
||||
|
||||
def instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
|
||||
in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True)
|
||||
out_in = (x - in_mean) / torch.sqrt(in_var + eps)
|
||||
ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True)
|
||||
out_ln = (x - ln_mean) / torch.sqrt(ln_var + eps)
|
||||
out = rho.expand(x.shape[0], -1, -1, -1) * out_in + (1 - rho.expand(x.shape[0], -1, -1, -1)) * out_ln
|
||||
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
|
||||
return out
|
||||
|
||||
|
||||
class AdaILN(nn.Module):
|
||||
def __init__(self, num_features, eps=1e-5, default_rho=0.9):
|
||||
super(AdaILN, self).__init__()
|
||||
self.eps = eps
|
||||
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
||||
self.rho.data.fill_(default_rho)
|
||||
|
||||
def forward(self, x, gamma, beta):
|
||||
return instance_layer_normalization(x, gamma, beta, self.rho, self.eps)
|
||||
|
||||
|
||||
class ILN(nn.Module):
|
||||
def __init__(self, num_features, eps=1e-5):
|
||||
super(ILN, self).__init__()
|
||||
self.eps = eps
|
||||
self.rho = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
|
||||
self.gamma = nn.Parameter(torch.Tensor(1, num_features))
|
||||
self.beta = nn.Parameter(torch.Tensor(1, num_features))
|
||||
self.rho.data.fill_(0.0)
|
||||
self.gamma.data.fill_(1.0)
|
||||
self.beta.data.fill_(0.0)
|
||||
|
||||
def forward(self, x):
|
||||
return instance_layer_normalization(
|
||||
x, self.gamma.expand(x.shape[0], -1), self.beta.expand(x.shape[0], -1), self.rho, self.eps)
|
||||
|
||||
|
||||
@MODEL.register_module("UGATIT-Discriminator")
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, num_blocks=5):
|
||||
super(Discriminator, self).__init__()
|
||||
encoder = [self.build_conv_block(in_channels, base_channels)]
|
||||
|
||||
for i in range(1, num_blocks - 2):
|
||||
mult = 2 ** (i - 1)
|
||||
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2))
|
||||
|
||||
mult = 2 ** (num_blocks - 2 - 1)
|
||||
encoder.append(self.build_conv_block(base_channels * mult, base_channels * mult * 2, stride=1))
|
||||
|
||||
self.encoder = nn.Sequential(*encoder)
|
||||
|
||||
# Class Activation Map
|
||||
mult = 2 ** (num_blocks - 2)
|
||||
self.gap_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
|
||||
self.gmp_fc = nn.utils.spectral_norm(nn.Linear(base_channels * mult, 1, bias=False))
|
||||
self.conv1x1 = nn.Conv2d(base_channels * mult * 2, base_channels * mult, kernel_size=1, stride=1, bias=True)
|
||||
self.leaky_relu = nn.LeakyReLU(0.2, True)
|
||||
|
||||
self.conv = nn.utils.spectral_norm(
|
||||
nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False, padding_mode="reflect"))
|
||||
|
||||
@staticmethod
|
||||
def build_conv_block(in_channels, out_channels, kernel_size=4, stride=2, padding=1, padding_mode="reflect"):
|
||||
return nn.Sequential(*[
|
||||
nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
|
||||
bias=True, padding=padding, padding_mode=padding_mode)),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
])
|
||||
|
||||
def forward(self, x, return_heatmap=False):
|
||||
x = self.encoder(x)
|
||||
batch_size = x.size(0)
|
||||
|
||||
gap = torch.nn.functional.adaptive_avg_pool2d(x, 1) # B x C x 1 x 1, avg of per channel
|
||||
gap_logit = self.gap_fc(gap.view(batch_size, -1))
|
||||
gap = x * self.gap_fc.weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
|
||||
gmp_logit = self.gmp_fc(gmp.view(batch_size, -1))
|
||||
gmp = x * self.gmp_fc.weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
cam_logit = torch.cat([gap_logit, gmp_logit], 1)
|
||||
|
||||
x = torch.cat([gap, gmp], 1)
|
||||
x = self.leaky_relu(self.conv1x1(x))
|
||||
|
||||
if return_heatmap:
|
||||
heatmap = torch.sum(x, dim=1, keepdim=True)
|
||||
return self.conv(x), cam_logit, heatmap
|
||||
else:
|
||||
return self.conv(x), cam_logit
|
||||
@@ -1,203 +0,0 @@
|
||||
from functools import partial
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from model import MODEL
|
||||
from model.normalization import select_norm_layer
|
||||
|
||||
|
||||
class GANImageBuffer(object):
|
||||
"""This class implements an image buffer that stores previously
|
||||
generated images.
|
||||
This buffer allows us to update the discriminator using a history of
|
||||
generated images rather than the ones produced by the latest generator
|
||||
to reduce model oscillation.
|
||||
Args:
|
||||
buffer_size (int): The size of image buffer. If buffer_size = 0,
|
||||
no buffer will be created.
|
||||
buffer_ratio (float): The chance / possibility to use the images
|
||||
previously stored in the buffer.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size, buffer_ratio=0.5):
|
||||
self.buffer_size = buffer_size
|
||||
# create an empty buffer
|
||||
if self.buffer_size > 0:
|
||||
self.img_num = 0
|
||||
self.image_buffer = []
|
||||
self.buffer_ratio = buffer_ratio
|
||||
|
||||
def query(self, images):
|
||||
"""Query current image batch using a history of generated images.
|
||||
Args:
|
||||
images (Tensor): Current image batch without history information.
|
||||
"""
|
||||
if self.buffer_size == 0: # if the buffer size is 0, do nothing
|
||||
return images
|
||||
return_images = []
|
||||
for image in images:
|
||||
image = torch.unsqueeze(image.data, 0)
|
||||
# if the buffer is not full, keep inserting current images
|
||||
if self.img_num < self.buffer_size:
|
||||
self.img_num = self.img_num + 1
|
||||
self.image_buffer.append(image)
|
||||
return_images.append(image)
|
||||
else:
|
||||
use_buffer = torch.rand(1) < self.buffer_ratio
|
||||
# by self.buffer_ratio, the buffer will return a previously
|
||||
# stored image, and insert the current image into the buffer
|
||||
if use_buffer:
|
||||
random_id = torch.randint(0, self.buffer_size, (1,)).item()
|
||||
image_tmp = self.image_buffer[random_id].clone()
|
||||
self.image_buffer[random_id] = image
|
||||
return_images.append(image_tmp)
|
||||
# by (1 - self.buffer_ratio), the buffer will return the
|
||||
# current image
|
||||
else:
|
||||
return_images.append(image)
|
||||
# collect all the images and return
|
||||
return_images = torch.cat(return_images, 0)
|
||||
return return_images
|
||||
|
||||
|
||||
# based SPADE or pix2pixHD Discriminator
|
||||
@MODEL.register_module("PatchDiscriminator")
|
||||
class PatchDiscriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels, num_conv=4, use_spectral=False, norm_type="IN",
|
||||
need_intermediate_feature=False):
|
||||
super().__init__()
|
||||
self.need_intermediate_feature = need_intermediate_feature
|
||||
|
||||
kernel_size = 4
|
||||
padding = math.ceil((kernel_size - 1.0) / 2)
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
use_bias = norm_type == "IN"
|
||||
padding_mode = "zeros"
|
||||
|
||||
sequence = [nn.Sequential(
|
||||
nn.Conv2d(in_channels, base_channels, kernel_size, stride=2, padding=padding),
|
||||
nn.LeakyReLU(0.2, False)
|
||||
)]
|
||||
multiple_now = 1
|
||||
for i in range(1, num_conv):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** 3)
|
||||
stride = 1 if i == num_conv - 1 else 2
|
||||
sequence.append(nn.Sequential(
|
||||
self.build_conv2d(use_spectral, base_channels * multiple_prev, base_channels * multiple_now,
|
||||
kernel_size, stride, padding, bias=use_bias, padding_mode=padding_mode),
|
||||
norm_layer(base_channels * multiple_now),
|
||||
nn.LeakyReLU(0.2, inplace=False),
|
||||
))
|
||||
multiple_now = min(2 ** num_conv, 8)
|
||||
sequence.append(nn.Conv2d(base_channels * multiple_now, 1, kernel_size, stride=1, padding=padding,
|
||||
padding_mode=padding_mode))
|
||||
self.conv_blocks = nn.ModuleList(sequence)
|
||||
|
||||
@staticmethod
|
||||
def build_conv2d(use_spectral, in_channels: int, out_channels: int, kernel_size, stride, padding,
|
||||
bias=True, padding_mode: str = 'zeros'):
|
||||
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias, padding_mode=padding_mode)
|
||||
if not use_spectral:
|
||||
return conv
|
||||
return nn.utils.spectral_norm(conv)
|
||||
|
||||
def forward(self, x):
|
||||
if self.need_intermediate_feature:
|
||||
intermediate_feature = []
|
||||
for layer in self.conv_blocks:
|
||||
x = layer(x)
|
||||
intermediate_feature.append(x)
|
||||
return tuple(intermediate_feature)
|
||||
else:
|
||||
for layer in self.conv_blocks:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, num_channels, padding_mode='reflect', norm_type="IN", use_bias=None):
|
||||
super(ResidualBlock, self).__init__()
|
||||
if use_bias is None:
|
||||
# Only for IN, use bias since it does not have affine parameters.
|
||||
use_bias = norm_type == "IN"
|
||||
norm_layer = select_norm_layer(norm_type)
|
||||
self.conv1 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
bias=use_bias)
|
||||
self.norm1 = norm_layer(num_channels)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
bias=use_bias)
|
||||
self.norm2 = norm_layer(num_channels)
|
||||
|
||||
def forward(self, x):
|
||||
res = x
|
||||
x = self.relu1(self.norm1(self.conv1(x)))
|
||||
x = self.norm2(self.conv2(x))
|
||||
return x + res
|
||||
|
||||
|
||||
_DO_NO_THING_FUNC = lambda x: x
|
||||
|
||||
|
||||
def select_activation(t):
|
||||
if t == "ReLU":
|
||||
return partial(nn.ReLU, inplace=True)
|
||||
elif t == "LeakyReLU":
|
||||
return partial(nn.LeakyReLU, negative_slope=0.2, inplace=True)
|
||||
elif t == "Tanh":
|
||||
return partial(nn.Tanh)
|
||||
elif t == "NONE":
|
||||
return _DO_NO_THING_FUNC
|
||||
else:
|
||||
raise NotImplemented
|
||||
|
||||
|
||||
def _use_bias_checker(norm_type):
|
||||
return norm_type not in ["IN", "BN", "AdaIN"]
|
||||
|
||||
|
||||
class Conv2dBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, use_spectral_norm=False, activation_type="ReLU",
|
||||
bias=None, norm_type="NONE", **conv_kwargs):
|
||||
super().__init__()
|
||||
self.norm_type = norm_type
|
||||
self.activation_type = activation_type
|
||||
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
|
||||
conv = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
||||
self.convolution = nn.utils.spectral_norm(conv) if use_spectral_norm else conv
|
||||
if norm_type != "NONE":
|
||||
self.normalization = select_norm_layer(norm_type)(out_channels)
|
||||
if activation_type != "NONE":
|
||||
self.activation = select_activation(activation_type)()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.convolution(x)
|
||||
if self.norm_type != "NONE":
|
||||
x = self.normalization(x)
|
||||
if self.activation_type != "NONE":
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, num_channels, use_spectral_norm=False, padding_mode='reflect',
|
||||
norm_type="IN", activation_type="ReLU", use_bias=None):
|
||||
super().__init__()
|
||||
self.norm_type = norm_type
|
||||
if use_bias is None:
|
||||
# bias will be canceled after channel wise normalization
|
||||
use_bias = _use_bias_checker(norm_type)
|
||||
|
||||
self.conv1 = Conv2dBlock(num_channels, num_channels, use_spectral_norm,
|
||||
kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias,
|
||||
norm_type=norm_type, activation_type=activation_type)
|
||||
self.conv2 = Conv2dBlock(num_channels, num_channels, use_spectral_norm,
|
||||
kernel_size=3, padding=1, padding_mode=padding_mode, bias=use_bias,
|
||||
norm_type=norm_type, activation_type="NONE")
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.conv1(x)) + x
|
||||
@@ -1,25 +0,0 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model import MODEL
|
||||
|
||||
|
||||
@MODEL.register_module()
|
||||
class MultiScaleDiscriminator(nn.Module):
|
||||
def __init__(self, num_scale, discriminator_cfg):
|
||||
super().__init__()
|
||||
|
||||
self.discriminator_list = nn.ModuleList([
|
||||
MODEL.build_with(discriminator_cfg) for _ in range(num_scale)
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def down_sample(x):
|
||||
return F.avg_pool2d(x, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
|
||||
|
||||
def forward(self, x):
|
||||
results = []
|
||||
for discriminator in self.discriminator_list:
|
||||
results.append(discriminator(x))
|
||||
x = self.down_sample(x)
|
||||
return results
|
||||
@@ -1,10 +1,3 @@
|
||||
from model.registry import MODEL, NORMALIZATION
|
||||
import model.GAN.CycleGAN
|
||||
import model.GAN.MUNIT
|
||||
import model.GAN.TAFG
|
||||
import model.GAN.TSIT
|
||||
import model.GAN.UGATIT
|
||||
import model.GAN.base
|
||||
import model.GAN.wrapper
|
||||
import model.base.normalization
|
||||
|
||||
import model.image_translation
|
||||
|
||||
@@ -20,22 +20,40 @@ def _normalization(norm, num_features, additional_kwargs=None):
|
||||
return NORMALIZATION.build_with(kwargs)
|
||||
|
||||
|
||||
def _activation(activation):
|
||||
def _activation(activation, inplace=True):
|
||||
if activation == "NONE":
|
||||
return _DO_NO_THING_FUNC
|
||||
elif activation == "ReLU":
|
||||
return nn.ReLU(inplace=True)
|
||||
return nn.ReLU(inplace=inplace)
|
||||
elif activation == "LeakyReLU":
|
||||
return nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
return nn.LeakyReLU(negative_slope=0.2, inplace=inplace)
|
||||
elif activation == "Tanh":
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise NotImplemented(activation)
|
||||
raise NotImplementedError(f"{activation} not valid")
|
||||
|
||||
|
||||
class LinearBlock(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, bias=None, activation_type="ReLU", norm_type="NONE"):
|
||||
super().__init__()
|
||||
|
||||
self.norm_type = norm_type
|
||||
self.activation_type = activation_type
|
||||
|
||||
bias = _use_bias_checker(norm_type) if bias is None else bias
|
||||
self.linear = nn.Linear(in_features, out_features, bias)
|
||||
|
||||
self.normalization = _normalization(norm_type, out_features)
|
||||
self.activation = _activation(activation_type)
|
||||
|
||||
def forward(self, x):
|
||||
return self.activation(self.normalization(self.linear(x)))
|
||||
|
||||
|
||||
class Conv2dBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, bias=None,
|
||||
activation_type="ReLU", norm_type="NONE", **conv_kwargs):
|
||||
activation_type="ReLU", norm_type="NONE",
|
||||
additional_norm_kwargs=None, **conv_kwargs):
|
||||
super().__init__()
|
||||
self.norm_type = norm_type
|
||||
self.activation_type = activation_type
|
||||
@@ -44,65 +62,63 @@ class Conv2dBlock(nn.Module):
|
||||
conv_kwargs["bias"] = _use_bias_checker(norm_type) if bias is None else bias
|
||||
|
||||
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
||||
self.normalization = _normalization(norm_type, out_channels)
|
||||
self.normalization = _normalization(norm_type, out_channels, additional_norm_kwargs)
|
||||
self.activation = _activation(activation_type)
|
||||
|
||||
def forward(self, x):
|
||||
return self.activation(self.normalization(self.convolution(x)))
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, num_channels, out_channels=None, padding_mode='reflect',
|
||||
activation_type="ReLU", out_activation_type=None, norm_type="IN"):
|
||||
super().__init__()
|
||||
self.norm_type = norm_type
|
||||
|
||||
if out_channels is None:
|
||||
out_channels = num_channels
|
||||
if out_activation_type is None:
|
||||
out_activation_type = "NONE"
|
||||
|
||||
self.learn_skip_connection = num_channels != out_channels
|
||||
|
||||
self.conv1 = Conv2dBlock(num_channels, num_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
norm_type=norm_type, activation_type=activation_type)
|
||||
self.conv2 = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
norm_type=norm_type, activation_type=out_activation_type)
|
||||
|
||||
if self.learn_skip_connection:
|
||||
self.res_conv = Conv2dBlock(num_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode,
|
||||
norm_type=norm_type, activation_type=out_activation_type)
|
||||
|
||||
def forward(self, x):
|
||||
res = x if not self.learn_skip_connection else self.res_conv(x)
|
||||
return self.conv2(self.conv1(x)) + res
|
||||
|
||||
|
||||
class ReverseConv2dBlock(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int,
|
||||
activation_type="ReLU", norm_type="NONE", additional_norm_kwargs=None, **conv_kwargs):
|
||||
super().__init__()
|
||||
self.normalization = _normalization(norm_type, in_channels, additional_norm_kwargs)
|
||||
self.activation = _activation(activation_type)
|
||||
self.activation = _activation(activation_type, inplace=False)
|
||||
self.convolution = nn.Conv2d(in_channels, out_channels, **conv_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.convolution(self.activation(self.normalization(x)))
|
||||
|
||||
|
||||
class ReverseResidualBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, padding_mode="reflect",
|
||||
norm_type="IN", additional_norm_kwargs=None, activation_type="ReLU"):
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_channels,
|
||||
padding_mode='reflect', activation_type="ReLU", norm_type="IN", pre_activation=False,
|
||||
out_channels=None, out_activation_type=None, additional_norm_kwargs=None):
|
||||
"""
|
||||
Residual Conv Block
|
||||
:param in_channels:
|
||||
:param out_channels:
|
||||
:param padding_mode:
|
||||
:param activation_type:
|
||||
:param norm_type:
|
||||
:param out_activation_type:
|
||||
:param pre_activation: full pre-activation mode from https://arxiv.org/pdf/1603.05027v3.pdf, figure 4
|
||||
"""
|
||||
super().__init__()
|
||||
self.norm_type = norm_type
|
||||
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
if out_activation_type is None:
|
||||
# if not specify `out_activation_type`, using default `out_activation_type`
|
||||
# `out_activation_type` default mode:
|
||||
# "NONE" for not full pre-activation
|
||||
# `norm_type` for full pre-activation
|
||||
out_activation_type = "NONE" if not pre_activation else norm_type
|
||||
|
||||
self.learn_skip_connection = in_channels != out_channels
|
||||
self.conv1 = ReverseConv2dBlock(in_channels, in_channels, activation_type, norm_type, additional_norm_kwargs,
|
||||
kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
self.conv2 = ReverseConv2dBlock(in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
|
||||
kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
|
||||
conv_block = ReverseConv2dBlock if pre_activation else Conv2dBlock
|
||||
conv_param = dict(kernel_size=3, padding=1, norm_type=norm_type, activation_type=activation_type,
|
||||
additional_norm_kwargs=additional_norm_kwargs,
|
||||
padding_mode=padding_mode)
|
||||
|
||||
self.conv1 = conv_block(in_channels, in_channels, **conv_param)
|
||||
self.conv2 = conv_block(in_channels, out_channels, **conv_param)
|
||||
|
||||
if self.learn_skip_connection:
|
||||
self.res_conv = ReverseConv2dBlock(
|
||||
in_channels, out_channels, activation_type, norm_type, additional_norm_kwargs,
|
||||
kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
self.res_conv = conv_block(in_channels, out_channels, **conv_param)
|
||||
|
||||
def forward(self, x):
|
||||
res = x if not self.learn_skip_connection else self.res_conv(x)
|
||||
|
||||
@@ -16,18 +16,19 @@ for abbr, name in _VALID_NORM_AND_ABBREVIATION.items():
|
||||
|
||||
@NORMALIZATION.register_module("ADE")
|
||||
class AdaptiveDenormalization(nn.Module):
|
||||
def __init__(self, num_features, base_norm_type="BN"):
|
||||
def __init__(self, num_features, base_norm_type="BN", gamma_bias=0.0):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.base_norm_type = base_norm_type
|
||||
self.norm = self.base_norm(num_features)
|
||||
self.gamma = None
|
||||
self.gamma_bias = gamma_bias
|
||||
self.beta = None
|
||||
self.have_set_condition = False
|
||||
|
||||
def base_norm(self, num_features):
|
||||
if self.base_norm_type == "IN":
|
||||
return nn.InstanceNorm2d(num_features)
|
||||
return nn.InstanceNorm2d(num_features, affine=False)
|
||||
elif self.base_norm_type == "BN":
|
||||
return nn.BatchNorm2d(num_features, affine=False, track_running_stats=True)
|
||||
|
||||
@@ -38,13 +39,13 @@ class AdaptiveDenormalization(nn.Module):
|
||||
def forward(self, x):
|
||||
assert self.have_set_condition
|
||||
x = self.norm(x)
|
||||
x = self.gamma * x + self.beta
|
||||
x = (self.gamma + self.gamma_bias) * x + self.beta
|
||||
self.have_set_condition = False
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(num_features={self.num_features}, " \
|
||||
f"base_norm_type={self.base_norm_type})"
|
||||
#
|
||||
# def __repr__(self):
|
||||
# return f"{self.__class__.__name__}(num_features={self.num_features}, " \
|
||||
# f"base_norm_type={self.base_norm_type})"
|
||||
|
||||
|
||||
@NORMALIZATION.register_module("AdaIN")
|
||||
@@ -61,8 +62,9 @@ class AdaptiveInstanceNorm2d(AdaptiveDenormalization):
|
||||
|
||||
@NORMALIZATION.register_module("FADE")
|
||||
class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
|
||||
def __init__(self, num_features: int, condition_in_channels, base_norm_type="BN", padding_mode="zeros"):
|
||||
super().__init__(num_features, base_norm_type)
|
||||
def __init__(self, num_features: int, condition_in_channels,
|
||||
base_norm_type="BN", padding_mode="zeros", gamma_bias=0.0):
|
||||
super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias)
|
||||
self.beta_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
|
||||
padding_mode=padding_mode)
|
||||
self.gamma_conv = nn.Conv2d(condition_in_channels, self.num_features, kernel_size=3, padding=1,
|
||||
@@ -77,9 +79,9 @@ class FeatureAdaptiveDenormalization(AdaptiveDenormalization):
|
||||
@NORMALIZATION.register_module("SPADE")
|
||||
class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
|
||||
def __init__(self, num_features: int, condition_in_channels, base_channels=128, base_norm_type="BN",
|
||||
activation_type="ReLU", padding_mode="zeros"):
|
||||
super().__init__(num_features, base_norm_type)
|
||||
self.base_conv_block = Conv2dBlock(condition_in_channels, num_features, activation_type=activation_type,
|
||||
activation_type="ReLU", padding_mode="zeros", gamma_bias=0.0):
|
||||
super().__init__(num_features, base_norm_type, gamma_bias=gamma_bias)
|
||||
self.base_conv_block = Conv2dBlock(condition_in_channels, base_channels, activation_type=activation_type,
|
||||
kernel_size=3, padding=1, padding_mode=padding_mode, norm_type="NONE")
|
||||
self.beta_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
self.gamma_conv = nn.Conv2d(base_channels, num_features, kernel_size=3, padding=1, padding_mode=padding_mode)
|
||||
@@ -93,7 +95,7 @@ class SpatiallyAdaptiveDenormalization(AdaptiveDenormalization):
|
||||
|
||||
def _instance_layer_normalization(x, gamma, beta, rho, eps=1e-5):
|
||||
out = rho * F.instance_norm(x, eps=eps) + (1 - rho) * F.layer_norm(x, x.size()[1:], eps=eps)
|
||||
out = out * gamma + beta
|
||||
out = out * gamma.unsqueeze(2).unsqueeze(3) + beta.unsqueeze(2).unsqueeze(3)
|
||||
return out
|
||||
|
||||
|
||||
@@ -115,7 +117,7 @@ class ILN(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return _instance_layer_normalization(
|
||||
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
|
||||
x, self.gamma.view(1, -1), self.beta.view(1, -1), self.rho.view(1, -1, 1, 1), self.eps)
|
||||
|
||||
|
||||
@NORMALIZATION.register_module("AdaILN")
|
||||
@@ -136,7 +138,6 @@ class AdaILN(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
assert self.have_set_condition
|
||||
out = _instance_layer_normalization(
|
||||
x, self.gamma.expand_as(x), self.beta.expand_as(x), self.rho.expand_as(x), self.eps)
|
||||
out = _instance_layer_normalization(x, self.gamma, self.beta, self.rho.view(1, -1, 1, 1), self.eps)
|
||||
self.have_set_condition = False
|
||||
return out
|
||||
|
||||
76
model/image_translation/CycleGAN.py
Normal file
76
model/image_translation/CycleGAN.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from model.base.module import Conv2dBlock, ResidualBlock
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, base_channels, num_conv, num_res, max_down_sampling_multiple=2,
|
||||
padding_mode='reflect', activation_type="ReLU",
|
||||
down_conv_norm_type="IN", down_conv_kernel_size=3,
|
||||
res_norm_type="IN", pre_activation=False):
|
||||
super().__init__()
|
||||
|
||||
sequence = [Conv2dBlock(
|
||||
in_channels, base_channels, kernel_size=7, stride=1, padding=3, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=down_conv_norm_type
|
||||
)]
|
||||
multiple_now = 1
|
||||
for i in range(1, num_conv + 1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** max_down_sampling_multiple)
|
||||
sequence.append(Conv2dBlock(
|
||||
multiple_prev * base_channels, multiple_now * base_channels,
|
||||
kernel_size=down_conv_kernel_size, stride=2, padding=1, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=down_conv_norm_type
|
||||
))
|
||||
self.out_channels = multiple_now * base_channels
|
||||
sequence += [
|
||||
ResidualBlock(
|
||||
self.out_channels,
|
||||
padding_mode=padding_mode,
|
||||
activation_type=activation_type,
|
||||
norm_type=res_norm_type,
|
||||
pre_activation=pre_activation
|
||||
) for _ in range(num_res)
|
||||
]
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, x):
|
||||
return self.sequence(x)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_up_sampling, num_residual_blocks,
|
||||
activation_type="ReLU", padding_mode='reflect',
|
||||
up_conv_kernel_size=5, up_conv_norm_type="LN",
|
||||
res_norm_type="AdaIN", pre_activation=False):
|
||||
super().__init__()
|
||||
self.residual_blocks = nn.ModuleList([
|
||||
ResidualBlock(
|
||||
in_channels,
|
||||
padding_mode=padding_mode,
|
||||
activation_type=activation_type,
|
||||
norm_type=res_norm_type,
|
||||
pre_activation=pre_activation
|
||||
) for _ in range(num_residual_blocks)
|
||||
])
|
||||
|
||||
sequence = list()
|
||||
channels = in_channels
|
||||
for i in range(num_up_sampling):
|
||||
sequence.append(nn.Sequential(
|
||||
nn.Upsample(scale_factor=2),
|
||||
Conv2dBlock(channels, channels // 2, kernel_size=up_conv_kernel_size, stride=1,
|
||||
padding=int(up_conv_kernel_size / 2), padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=up_conv_norm_type),
|
||||
))
|
||||
channels = channels // 2
|
||||
sequence.append(Conv2dBlock(channels, out_channels, kernel_size=7, stride=1, padding=3,
|
||||
padding_mode=padding_mode, activation_type="Tanh", norm_type="NONE"))
|
||||
|
||||
self.up_sequence = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, x):
|
||||
for i, blk in enumerate(self.residual_blocks):
|
||||
x = blk(x)
|
||||
return self.up_sequence(x)
|
||||
95
model/image_translation/GauGAN.py
Normal file
95
model/image_translation/GauGAN.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model.base.module import ResidualBlock, ReverseConv2dBlock, Conv2dBlock
|
||||
|
||||
|
||||
class StyleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, style_dim, num_conv, end_size=(4, 4), base_channels=64,
|
||||
norm_type="IN", padding_mode='reflect', activation_type="LeakyReLU"):
|
||||
super().__init__()
|
||||
sequence = [Conv2dBlock(
|
||||
in_channels, base_channels, kernel_size=3, stride=1, padding=1, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=norm_type
|
||||
)]
|
||||
multiple_now = 0
|
||||
max_multiple = 3
|
||||
for i in range(1, num_conv + 1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** max_multiple)
|
||||
sequence.append(Conv2dBlock(
|
||||
multiple_prev * base_channels, multiple_now * base_channels,
|
||||
kernel_size=3, stride=2, padding=1, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=norm_type
|
||||
))
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
self.fc_avg = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim)
|
||||
self.fc_var = nn.Linear(base_channels * (2 ** max_multiple) * end_size[0] * end_size[1], style_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.sequence(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return self.fc_avg(x), self.fc_var(x)
|
||||
|
||||
|
||||
class SPADEGenerator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_blocks, use_vae, num_z_dim, start_size=(4, 4), base_channels=64,
|
||||
padding_mode='reflect', activation_type="LeakyReLU"):
|
||||
super().__init__()
|
||||
self.sx, self.sy = start_size
|
||||
self.use_vae = use_vae
|
||||
self.num_z_dim = num_z_dim
|
||||
if use_vae:
|
||||
self.input_converter = nn.Linear(num_z_dim, 16 * base_channels * self.sx * self.sy)
|
||||
else:
|
||||
self.input_converter = nn.Conv2d(in_channels, 16 * base_channels, kernel_size=3, padding=1)
|
||||
|
||||
sequence = []
|
||||
|
||||
multiple_now = 16
|
||||
for i in range(num_blocks - 1, -1, -1):
|
||||
multiple_prev = multiple_now
|
||||
multiple_now = min(2 ** i, 2 ** 4)
|
||||
if i != num_blocks - 1:
|
||||
sequence.append(nn.Upsample(scale_factor=2))
|
||||
sequence.append(ResidualBlock(
|
||||
base_channels * multiple_prev,
|
||||
out_channels=base_channels * multiple_now,
|
||||
padding_mode=padding_mode,
|
||||
activation_type=activation_type,
|
||||
norm_type="SPADE",
|
||||
pre_activation=True,
|
||||
additional_norm_kwargs=dict(
|
||||
condition_in_channels=in_channels, base_channels=128, base_norm_type="BN",
|
||||
activation_type="ReLU", padding_mode="zeros", gamma_bias=1.0
|
||||
)
|
||||
))
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
self.output_converter = nn.Sequential(
|
||||
ReverseConv2dBlock(base_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
||||
padding_mode=padding_mode, activation_type=activation_type, norm_type="NONE"),
|
||||
nn.Tanh()
|
||||
)
|
||||
|
||||
def forward(self, seg, z=None):
|
||||
if self.use_vae:
|
||||
if z is None:
|
||||
z = torch.randn(seg.size(0), self.num_z_dim, device=seg.device)
|
||||
x = self.input_converter(z).view(seg.size(0), -1, self.sx, self.sy)
|
||||
else:
|
||||
x = self.input_converter(F.interpolate(seg, size=(self.sx, self.sy)))
|
||||
for blk in self.sequence:
|
||||
if isinstance(blk, ResidualBlock):
|
||||
downsampling_seg = F.interpolate(seg, size=x.size()[2:], mode='nearest')
|
||||
blk.conv1.normalization.set_condition_image(downsampling_seg)
|
||||
blk.conv2.normalization.set_condition_image(downsampling_seg)
|
||||
if blk.learn_skip_connection:
|
||||
blk.res_conv.normalization.set_condition_image(downsampling_seg)
|
||||
x = blk(x)
|
||||
return self.output_converter(x)
|
||||
|
||||
if __name__ == '__main__':
|
||||
g = SPADEGenerator(3, 3, 7, False, 256)
|
||||
print(g)
|
||||
print(g(torch.randn(2, 3, 256, 256)).size())
|
||||
89
model/image_translation/MUNIT.py
Normal file
89
model/image_translation/MUNIT.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from model import MODEL
|
||||
from model.base.module import LinearBlock
|
||||
from model.image_translation.CycleGAN import Encoder, Decoder
|
||||
|
||||
|
||||
class StyleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, out_dim, num_conv, base_channels=64,
|
||||
max_down_sampling_multiple=2, padding_mode='reflect', activation_type="ReLU", norm_type="NONE",
|
||||
pre_activation=False):
|
||||
super().__init__()
|
||||
self.down_encoder = Encoder(
|
||||
in_channels, base_channels, num_conv, num_res=0, max_down_sampling_multiple=max_down_sampling_multiple,
|
||||
padding_mode=padding_mode, activation_type=activation_type,
|
||||
down_conv_norm_type=norm_type, down_conv_kernel_size=4, pre_activation=pre_activation,
|
||||
)
|
||||
sequence = list()
|
||||
sequence.append(nn.AdaptiveAvgPool2d(1))
|
||||
# conv1x1 works as fc when tensor's size is (batch_size, channels, 1, 1), keep same with origin code
|
||||
sequence.append(nn.Conv2d(self.down_encoder.out_channels, out_dim, kernel_size=1, stride=1, padding=0))
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, image):
|
||||
return self.sequence(image).view(image.size(0), -1)
|
||||
|
||||
|
||||
class MLPFusion(nn.Module):
|
||||
def __init__(self, in_features, out_features, base_features, n_blocks, activation_type="ReLU", norm_type="NONE"):
|
||||
super().__init__()
|
||||
|
||||
sequence = [LinearBlock(in_features, base_features, activation_type=activation_type, norm_type=norm_type)]
|
||||
sequence += [
|
||||
LinearBlock(base_features, base_features, activation_type=activation_type, norm_type=norm_type)
|
||||
for _ in range(n_blocks - 2)
|
||||
]
|
||||
sequence.append(LinearBlock(base_features, out_features, activation_type=activation_type, norm_type=norm_type))
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
|
||||
def forward(self, x):
|
||||
return self.sequence(x)
|
||||
|
||||
|
||||
@MODEL.register_module("MUNIT-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, style_dim=8,
|
||||
num_mlp_base_feature=256, num_mlp_blocks=3,
|
||||
max_down_sampling_multiple=2, num_content_down_sampling=2, num_style_down_sampling=2,
|
||||
encoder_num_residual_blocks=4, decoder_num_residual_blocks=4,
|
||||
padding_mode='reflect', activation_type="ReLU", pre_activation=False):
|
||||
super().__init__()
|
||||
self.content_encoder = Encoder(
|
||||
in_channels, base_channels, num_content_down_sampling, encoder_num_residual_blocks,
|
||||
max_down_sampling_multiple=num_content_down_sampling,
|
||||
padding_mode=padding_mode, activation_type=activation_type,
|
||||
down_conv_norm_type="IN", down_conv_kernel_size=4,
|
||||
res_norm_type="IN", pre_activation=pre_activation
|
||||
)
|
||||
|
||||
self.style_encoder = StyleEncoder(in_channels, style_dim, num_style_down_sampling, base_channels,
|
||||
max_down_sampling_multiple, padding_mode, activation_type,
|
||||
norm_type="NONE", pre_activation=pre_activation)
|
||||
|
||||
content_channels = base_channels * (2 ** max_down_sampling_multiple)
|
||||
|
||||
self.fusion = MLPFusion(style_dim, decoder_num_residual_blocks * 2 * content_channels * 2,
|
||||
num_mlp_base_feature, num_mlp_blocks, activation_type,
|
||||
norm_type="NONE")
|
||||
|
||||
self.decoder = Decoder(in_channels, out_channels, max_down_sampling_multiple, decoder_num_residual_blocks,
|
||||
activation_type=activation_type, padding_mode=padding_mode,
|
||||
up_conv_kernel_size=5, up_conv_norm_type="LN",
|
||||
res_norm_type="AdaIN", pre_activation=pre_activation)
|
||||
|
||||
def encode(self, x):
|
||||
return self.content_encoder(x), self.style_encoder(x)
|
||||
|
||||
def decode(self, content, style):
|
||||
as_param_style = torch.chunk(self.fusion(style), 2 * len(self.decoder.residual_blocks), dim=1)
|
||||
# set style for decoder
|
||||
for i, blk in enumerate(self.decoder.residual_blocks):
|
||||
blk.conv1.normalization.set_style(as_param_style[2 * i])
|
||||
blk.conv2.normalization.set_style(as_param_style[2 * i + 1])
|
||||
self.decoder(content)
|
||||
|
||||
def forward(self, x):
|
||||
content, style = self.encode(x)
|
||||
return self.decode(content, style)
|
||||
138
model/image_translation/UGATIT.py
Normal file
138
model/image_translation/UGATIT.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from model import MODEL
|
||||
from model.base.module import Conv2dBlock, LinearBlock
|
||||
from model.image_translation.CycleGAN import Encoder, Decoder
|
||||
|
||||
|
||||
class RhoClipper(object):
|
||||
def __init__(self, clip_min, clip_max):
|
||||
self.clip_min = clip_min
|
||||
self.clip_max = clip_max
|
||||
assert clip_min < clip_max
|
||||
|
||||
def __call__(self, module):
|
||||
if hasattr(module, 'rho'):
|
||||
w = module.rho.data
|
||||
w = w.clamp(self.clip_min, self.clip_max)
|
||||
module.rho.data = w
|
||||
|
||||
|
||||
class CAMClassifier(nn.Module):
|
||||
def __init__(self, in_channels, activation_type="ReLU"):
|
||||
super(CAMClassifier, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.avg_fc = nn.Linear(in_channels, 1, bias=False)
|
||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||
self.max_fc = nn.Linear(in_channels, 1, bias=False)
|
||||
self.fusion_conv = Conv2dBlock(in_channels * 2, in_channels, kernel_size=1, stride=1, bias=True,
|
||||
activation_type=activation_type, norm_type="NONE")
|
||||
|
||||
def forward(self, x):
|
||||
avg_logit = self.avg_fc(self.avg_pool(x).view(x.size(0), -1))
|
||||
max_logit = self.max_fc(self.max_pool(x).view(x.size(0), -1))
|
||||
|
||||
return self.fusion_conv(torch.cat(
|
||||
[x * self.avg_fc.weight.unsqueeze(2).unsqueeze(3), x * self.max_fc.weight.unsqueeze(2).unsqueeze(3)],
|
||||
dim=1
|
||||
)), torch.cat([avg_logit, max_logit], 1)
|
||||
|
||||
|
||||
@MODEL.register_module("UGATIT-Generator")
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, base_channels=64, num_blocks=6, img_size=256, light=False,
|
||||
activation_type="ReLU", norm_type="IN", padding_mode='reflect', pre_activation=False):
|
||||
super(Generator, self).__init__()
|
||||
|
||||
self.light = light
|
||||
|
||||
n_down_sampling = 2
|
||||
self.encoder = Encoder(in_channels, base_channels, n_down_sampling, num_blocks,
|
||||
padding_mode=padding_mode, activation_type=activation_type,
|
||||
down_conv_norm_type=norm_type, down_conv_kernel_size=3, res_norm_type=norm_type,
|
||||
pre_activation=pre_activation)
|
||||
mult = 2 ** n_down_sampling
|
||||
self.cam = CAMClassifier(base_channels * mult, activation_type)
|
||||
|
||||
# Gamma, Beta block
|
||||
if self.light:
|
||||
self.fc = nn.Sequential(
|
||||
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE"),
|
||||
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE")
|
||||
)
|
||||
else:
|
||||
self.fc = nn.Sequential(
|
||||
LinearBlock(img_size // mult * img_size // mult * base_channels * mult, base_channels * mult, False,
|
||||
"ReLU", "NONE"),
|
||||
LinearBlock(base_channels * mult, base_channels * mult, False, "ReLU", "NONE")
|
||||
)
|
||||
|
||||
self.gamma = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
||||
self.beta = nn.Linear(base_channels * mult, base_channels * mult, bias=False)
|
||||
|
||||
self.decoder = Decoder(
|
||||
base_channels * mult, out_channels, n_down_sampling, num_blocks,
|
||||
activation_type=activation_type, padding_mode=padding_mode,
|
||||
up_conv_kernel_size=3, up_conv_norm_type="ILN",
|
||||
res_norm_type="AdaILN", pre_activation=pre_activation
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
|
||||
x, cam_logit = self.cam(x)
|
||||
|
||||
heatmap = torch.sum(x, dim=1, keepdim=True)
|
||||
|
||||
if self.light:
|
||||
x_ = torch.nn.functional.adaptive_avg_pool2d(x, (1, 1))
|
||||
x_ = self.fc(x_.view(x_.shape[0], -1))
|
||||
else:
|
||||
x_ = self.fc(x.view(x.shape[0], -1))
|
||||
gamma, beta = self.gamma(x_), self.beta(x_)
|
||||
|
||||
for blk in self.decoder.residual_blocks:
|
||||
blk.conv1.normalization.set_condition(gamma, beta)
|
||||
blk.conv2.normalization.set_condition(gamma, beta)
|
||||
return self.decoder(x), cam_logit, heatmap
|
||||
|
||||
|
||||
@MODEL.register_module("UGATIT-Discriminator")
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(self, in_channels, base_channels=64, num_blocks=5,
|
||||
activation_type="LeakyReLU", norm_type="NONE", padding_mode='reflect'):
|
||||
super().__init__()
|
||||
|
||||
sequence = [Conv2dBlock(
|
||||
in_channels, base_channels, kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=norm_type
|
||||
)]
|
||||
|
||||
sequence += [Conv2dBlock(
|
||||
base_channels * (2 ** i), base_channels * (2 ** i) * 2,
|
||||
kernel_size=4, stride=2, padding=1, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=norm_type) for i in range(num_blocks - 3)]
|
||||
|
||||
sequence.append(
|
||||
Conv2dBlock(base_channels * (2 ** (num_blocks - 3)), base_channels * (2 ** (num_blocks - 2)),
|
||||
kernel_size=4, stride=1, padding=1, padding_mode=padding_mode,
|
||||
activation_type=activation_type, norm_type=norm_type)
|
||||
)
|
||||
self.sequence = nn.Sequential(*sequence)
|
||||
|
||||
mult = 2 ** (num_blocks - 2)
|
||||
self.cam = CAMClassifier(base_channels * mult, activation_type)
|
||||
self.conv = nn.Conv2d(base_channels * mult, 1, kernel_size=4, stride=1, padding=1, bias=False,
|
||||
padding_mode="reflect")
|
||||
|
||||
def forward(self, x, return_heatmap=False):
|
||||
x = self.sequence(x)
|
||||
|
||||
x, cam_logit = self.cam(x)
|
||||
|
||||
if return_heatmap:
|
||||
heatmap = torch.sum(x, dim=1, keepdim=True)
|
||||
return self.conv(x), cam_logit, heatmap
|
||||
else:
|
||||
return self.conv(x), cam_logit
|
||||
0
model/image_translation/__init__.py
Normal file
0
model/image_translation/__init__.py
Normal file
22
util/misc.py
22
util/misc.py
@@ -1,4 +1,6 @@
|
||||
import importlib
|
||||
import logging
|
||||
import pkgutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -6,12 +8,30 @@ import torch.nn as nn
|
||||
|
||||
|
||||
def add_spectral_norm(module):
|
||||
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)) and not hasattr(module, 'weight_u'):
|
||||
return nn.utils.spectral_norm(module)
|
||||
else:
|
||||
return module
|
||||
|
||||
|
||||
def import_submodules(package, recursive=True):
|
||||
""" Import all submodules of a module, recursively, including subpackages
|
||||
|
||||
:param package: package (name or actual module)
|
||||
:type package: str | module
|
||||
:rtype: dict[str, types.ModuleType]
|
||||
"""
|
||||
if isinstance(package, str):
|
||||
package = importlib.import_module(package)
|
||||
results = {}
|
||||
for loader, name, is_pkg in pkgutil.walk_packages(package.__path__):
|
||||
full_name = package.__name__ + '.' + name
|
||||
results[name] = importlib.import_module(full_name)
|
||||
if recursive and is_pkg:
|
||||
results.update(import_submodules(full_name))
|
||||
return results
|
||||
|
||||
|
||||
def setup_logger(
|
||||
name: Optional[str] = None,
|
||||
level: int = logging.INFO,
|
||||
|
||||
Reference in New Issue
Block a user