TAHG 0.0.2

This commit is contained in:
2020-08-30 14:44:40 +08:00
parent 715a2e64a1
commit 89b54105c7
8 changed files with 172 additions and 17 deletions

View File

@@ -2,10 +2,12 @@ import os
import pickle
from pathlib import Path
from collections import defaultdict
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from torchvision.transforms import functional as F
from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS, default_loader
import lmdb
@@ -176,17 +178,20 @@ class GenerationUnpairedDataset(Dataset):
@DATASET.register_module()
class GenerationUnpairedDatasetWithEdge(Dataset):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type):
def __init__(self, root_a, root_b, random_pair, pipeline, edge_type, edges_path, size=(256, 256)):
self.edge_type = edge_type
self.size = size
self.edges_path = Path(edges_path)
assert self.edges_path.exists()
self.A = SingleFolderDataset(root_a, pipeline, with_path=True)
self.B = SingleFolderDataset(root_b, pipeline, with_path=False)
self.random_pair = random_pair
def get_edge(self, origin_path):
op = Path(origin_path)
add = torch.load(op.parent / f"{op.stem}.add")
return {"edge": add["edge"].float().unsqueeze(dim=0),
"additional_info": torch.cat([add["seg"].float(), add["dist"].float()], dim=0)}
edge_path = self.edges_path / f"{op.parent.name}/{op.stem}.{self.edge_type}.png"
img = Image.open(edge_path).resize(self.size)
return {"edge": F.to_tensor(img)}
def __getitem__(self, idx):
a_idx = idx % len(self.A)

0
data/util/__init__.py Normal file
View File

View File

@@ -0,0 +1,67 @@
import numpy as np
import cv2
from skimage import feature
# https://ibug.doc.ic.ac.uk/resources/facial-point-annotations/
DLIB_LANDMARKS_PART_LIST = [
[list(range(0, 17)) + list(range(68, 83)) + [0]], # face
[range(17, 22)], # right eyebrow
[range(22, 27)], # left eyebrow
[[28, 31], range(31, 36), [35, 28]], # nose
[[36, 37, 38, 39], [39, 40, 41, 36]], # right eye
[[42, 43, 44, 45], [45, 46, 47, 42]], # left eye
[range(48, 55), [54, 55, 56, 57, 58, 59, 48]], # mouth
[range(60, 65), [64, 65, 66, 67, 60]] # tongue
]
def dist_tensor(key_points, size=(256, 256)):
dist_list = []
for edge_list in DLIB_LANDMARKS_PART_LIST:
for edge in edge_list:
pts = key_points[edge, :]
im_edge = np.zeros(size, np.uint8)
cv2.polylines(im_edge, [pts], isClosed=False, color=255, thickness=1)
im_dist = cv2.distanceTransform(255 - im_edge, cv2.DIST_L1, 3)
im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
dist_list.append(im_dist)
return np.stack(dist_list)
def read_keypoints(kp_path, origin_size=(256, 256), size=(256, 256), thickness=1):
key_points = np.loadtxt(kp_path, delimiter=",").astype(np.int32)
if origin_size != size:
# resize key_points using simplest way...
key_points = (key_points * (np.array(size) / np.array(origin_size))).astype(np.int32)
# add upper half face by symmetry
face_pts = key_points[:17, :]
face_baseline_y = (face_pts[0, 1] + face_pts[-1, 1]) // 2
upper_symmetry_face_pts = face_pts[1:-1, :].copy()
# keep x untouched
upper_symmetry_face_pts[:, 1] = face_baseline_y + (face_baseline_y - upper_symmetry_face_pts[:, 1]) * 2 // 3
key_points = np.vstack((key_points, upper_symmetry_face_pts[::-1, :]))
assert key_points.shape == (83, 2)
part_labels = np.zeros((len(DLIB_LANDMARKS_PART_LIST), *size), np.uint8)
part_edge = np.zeros(size, np.uint8)
for i, edge_list in enumerate(DLIB_LANDMARKS_PART_LIST):
indices = [item for sublist in edge_list for item in sublist]
pts = key_points[indices, :]
cv2.fillPoly(part_labels[i], pts=[pts], color=1)
if i in [1, 2]:
# some part of landmarks is a line
cv2.polylines(part_edge, [pts], isClosed=False, color=1, thickness=thickness)
else:
cv2.drawContours(part_edge, [pts], 0, color=1, thickness=thickness)
return key_points, part_labels, part_edge
def edge_map(img, part_labels, part_edge, remove_edge_within_face=True):
edges = feature.canny(np.array(img.convert("L")))
if remove_edge_within_face:
edges = edges * (part_labels.sum(0) == 0) # remove edges within face
edges = part_edge + edges
return edges