temp commit

This commit is contained in:
2020-10-10 10:43:00 +08:00
parent 776fe40199
commit 6ea13df465
10 changed files with 340 additions and 295 deletions

View File

@@ -0,0 +1,122 @@
from collections import defaultdict
from itertools import permutations, combinations
from pathlib import Path
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import functional as F
from data.registry import DATASET
from data.transform import transform_pipeline
from data.util import dlib_landmark
def normalize_tensor(tensor):
tensor = tensor.float()
tensor -= tensor.min()
tensor /= tensor.max()
return tensor
@DATASET.register_module()
class GenerationUnpairedDatasetWithEdge(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, landmarks_path, size=(256, 256),
with_path=False):
assert edge_type in ["hed", "canny", "landmark_hed", "landmark_canny"]
self.edge_type = edge_type
self.size = size
self.edges_path = Path(edges_path)
self.landmarks_path = Path(landmarks_path)
assert self.edges_path.exists()
assert self.landmarks_path.exists()
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=True)
self.random_pair = random_pair
self.with_path = with_path
def get_edge(self, origin_path):
op = Path(origin_path)
if self.edge_type.startswith("landmark_"):
edge_type = self.edge_type.lstrip("landmark_")
use_landmark = op.parent.name.endswith("A")
else:
edge_type = self.edge_type
use_landmark = False
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{edge_type}.png"
origin_edge = F.to_tensor(Image.open(edge_path).resize(self.size, Image.BILINEAR))
if not use_landmark:
return origin_edge
else:
landmark_path = self.landmarks_path / f"{op.parent.name}/{op.stem}.txt"
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(landmark_path, size=self.size)
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.size)))
part_labels = normalize_tensor(torch.from_numpy(part_labels))
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
# edges = origin_edge * (part_labels.sum(0) == 0) # remove edges within face
# edges = part_edge + edges
return torch.cat([origin_edge, part_edge, dist_tensor, part_labels])
def __getitem__(self, idx):
a_idx = idx % len(self.A)
b_idx = idx % len(self.B) if not self.random_pair else torch.randint(len(self.B), (1,)).item()
output = dict(a={}, b={})
output["a"]["img"], output["a"]["path"] = self.A[a_idx]
output["b"]["img"], output["b"]["path"] = self.B[b_idx]
for p in "ab":
output[p]["edge"] = self.get_edge(output[p]["path"])
return output
def __len__(self):
return max(len(self.A), len(self.B))
def __repr__(self):
return f"<GenerationUnpairedDatasetWithEdge:\n\tA: {self.A}\n\tB: {self.B}>\nPipeline:\n{self.A.pipeline}"
@DATASET.register_module()
class PoseFacesWithSingleAnime(Dataset):
def __init__(self, root_face, root_anime, landmark_path, num_face, face_pipeline, anime_pipeline, img_size,
with_order=True):
self.num_face = num_face
self.landmark_path = Path(landmark_path)
self.with_order = with_order
self.root_face = Path(root_face)
self.root_anime = Path(root_anime)
self.img_size = img_size
self.face_samples = self.iter_folders()
self.face_pipeline = transform_pipeline(face_pipeline)
self.B = SingleFolderDataset(root_anime, anime_pipeline, with_path=True)
def iter_folders(self):
pics_per_person = defaultdict(list)
for p in self.root_face.glob("*.jpg"):
pics_per_person[p.stem[:7]].append(p.stem)
data = []
for p in pics_per_person:
if len(pics_per_person[p]) >= self.num_face:
if self.with_order:
data.extend(list(combinations(pics_per_person[p], self.num_face)))
else:
data.extend(list(permutations(pics_per_person[p], self.num_face)))
return data
def read_pose(self, pose_txt):
key_points, part_labels, part_edge = dlib_landmark.read_keypoints(pose_txt, size=self.img_size)
dist_tensor = normalize_tensor(torch.from_numpy(dlib_landmark.dist_tensor(key_points, size=self.img_size)))
part_labels = normalize_tensor(torch.from_numpy(part_labels))
part_edge = torch.from_numpy(part_edge).unsqueeze(0).float()
return torch.cat([part_labels, part_edge, dist_tensor])
def __len__(self):
return len(self.face_samples)
def __getitem__(self, idx):
output = dict()
output["anime_img"], output["anime_path"] = self.B[torch.randint(len(self.B), (1,)).item()]
for i, f in enumerate(self.face_samples[idx]):
output[f"face_{i}"] = self.face_pipeline(self.root_face / f"{f}.jpg")
output[f"pose_{i}"] = self.read_pose(self.landmark_path / self.root_face.name / f"{f}.txt")
output[f"stem_{i}"] = f
return output