TANG 0.0.1

This commit is contained in:
2020-08-30 09:34:23 +08:00
parent 7a85499edf
commit 715a2e64a1
10 changed files with 690 additions and 2 deletions

0
loss/I2I/__init__.py Normal file
View File

129
loss/I2I/edge_loss.py Normal file
View File

@@ -0,0 +1,129 @@
from pathlib import Path
import torch
import torch.nn as nn
from torch.nn import functional as F
class HED(nn.Module):
def __init__(self, pretrained_model_path, norm_img=True):
"""
HED module to get edge
:param pretrained_model_path: path to pretrained HED.
:param norm_img(bool): If True, the image will be normed to [0, 1]. Note that
this is different from the `use_input_norm` which norm the input in
in forward function of vgg according to the statistics of dataset.
Importantly, the input image must be in range [-1, 1].
"""
super().__init__()
self.norm_img = norm_img
self.vgg_nets = nn.ModuleList([torch.nn.Sequential(
torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
), torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
), torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
), torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
), torch.nn.Sequential(
torch.nn.MaxPool2d(kernel_size=2, stride=2),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(inplace=False)
)])
self.score_nets = nn.ModuleList([
torch.nn.Conv2d(in_channels=i, out_channels=1, kernel_size=1, stride=1, padding=0)
for i in [64, 128, 256, 512, 512]
])
self.combine_net = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
torch.nn.Sigmoid()
)
self.load_weights(pretrained_model_path)
self.register_buffer('mean', torch.Tensor([104.00698793, 116.66876762, 122.67891434]).view(1, 3, 1, 1))
for v in self.parameters():
v.requies_grad = False
def load_weights(self, pretrained_model_path):
checkpoint_path = Path(pretrained_model_path)
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
m = {"One": "0", "Two": "1", "Thr": "2", "Fou": "3", "Fiv": "4"}
def replace_key(key):
if key.startswith("moduleVgg"):
return f"vgg_nets.{m[key[9:12]]}{key[12:]}"
elif key.startswith("moduleScore"):
return f"score_nets.{m[key[11:14]]}{key[14:]}"
elif key.startswith("moduleCombine"):
return f"combine_net{key[13:]}"
else:
raise ValueError("wrong checkpoint for HED")
module_dict = {replace_key(k): v for k, v in ckp.items()}
self.load_state_dict(module_dict, strict=True)
def forward(self, x):
if self.norm_img:
x = (x + 1.) * 0.5
x = x * 255.0 - self.mean
img_size = (x.size(2), x.size(3))
to_combine = []
for i in range(5):
x = self.vgg_nets[i](x)
score_x = self.score_nets[i](x)
to_combine.append(F.interpolate(input=score_x, size=img_size, mode='bilinear', align_corners=False))
out = self.combine_net(torch.cat(to_combine, 1))
return out.clamp(0.0, 1.0)
class EdgeLoss(nn.Module):
def __init__(self, edge_extractor_type="HED", norm_img=True, criterion='L1', **kwargs):
super(EdgeLoss, self).__init__()
if edge_extractor_type == "HED":
pretrained_model_path = kwargs.get("hed_pretrained_model_path")
self.edge_extractor = HED(pretrained_model_path, norm_img)
else:
raise NotImplemented(f"do not support edge_extractor_type {edge_extractor_type}")
if criterion == 'L1':
self.criterion = nn.L1Loss()
elif criterion == "L2":
self.criterion = nn.MSELoss()
else:
raise NotImplementedError(f'{criterion} criterion has not been supported in this version.')
def forward(self, x, gt, gt_is_edge=True):
edge = self.edge_extractor(x)
if not gt_is_edge:
gt = self.edge_extractor(gt.detach())
loss = self.criterion(edge, gt)
return loss

155
loss/I2I/perceptual_loss.py Normal file
View File

@@ -0,0 +1,155 @@
import torch
import torch.nn as nn
import torchvision.models.vgg as vgg
class PerceptualVGG(nn.Module):
"""VGG network used in calculating perceptual loss.
In this implementation, we allow users to choose whether use normalization
in the input feature and the type of vgg network. Note that the pretrained
path must fit the vgg type.
Args:
layer_name_list (list[str]): According to the index in this list,
forward function will return the corresponding features. This
list contains the name each layer in `vgg.feature`. An example
of this list is ['4', '10'].
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image.
Importantly, the input feature must in the range [0, 1].
Default: True.
"""
def __init__(self, layer_name_list, vgg_type='vgg19', use_input_norm=True):
super(PerceptualVGG, self).__init__()
self.layer_name_list = layer_name_list
self.use_input_norm = use_input_norm
# get vgg model and load pretrained vgg weight
# remove _vgg from attributes to avoid `find_unused_parameters` bug
_vgg = getattr(vgg, vgg_type)(pretrained=True)
num_layers = max(map(int, layer_name_list)) + 1
assert len(_vgg.features) >= num_layers
# only borrow layers that will be used from _vgg to avoid unused params
self.vgg_layers = _vgg.features[:num_layers]
if self.use_input_norm:
# the mean is for image with range [0, 1]
self.register_buffer(
'mean',
torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
# the std is for image with range [-1, 1]
self.register_buffer(
'std',
torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
for v in self.vgg_layers.parameters():
v.requies_grad = False
def forward(self, x):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.use_input_norm:
x = (x - self.mean) / self.std
output = {}
for i, l in enumerate(self.vgg_layers):
x = l(x)
if str(i) in self.layer_name_list:
output[str(i)] = x.clone()
return output
class PerceptualLoss(nn.Module):
"""Perceptual loss with commonly used style loss.
Args:
layer_weights (dict): The weight for each layer of vgg feature.
Here is an example: {'4': 1., '9': 1., '18': 1.}, which means the
5th, 10th and 18th feature layer will be extracted with weight 1.0
in calculating losses.
vgg_type (str): The type of vgg network used as feature extractor.
Default: 'vgg19'.
use_input_norm (bool): If True, normalize the input image in vgg.
Default: True.
perceptual_loss (bool): If `perceptual_loss == True`, the perceptual
loss will be calculated.
Default: True.
style_loss (bool): If `style_loss == False`, the style loss will be calculated.
Default: False.
norm_img (bool): If True, the image will be normed to [0, 1]. Note that
this is different from the `use_input_norm` which norm the input in
in forward function of vgg according to the statistics of dataset.
Importantly, the input image must be in range [-1, 1].
"""
def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, perceptual_loss=True,
style_loss=False, norm_img=True, criterion='L1'):
super(PerceptualLoss, self).__init__()
self.norm_img = norm_img
self.perceptual_loss = perceptual_loss
self.style_loss = style_loss
self.layer_weights = layer_weights
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
use_input_norm=use_input_norm)
if criterion == 'L1':
self.criterion = torch.nn.L1Loss()
elif criterion == "L2":
self.criterion = torch.nn.MSELoss()
else:
raise NotImplementedError(f'{criterion} criterion has not been supported in this version.')
def forward(self, x, gt):
"""Forward function.
Args:
x (Tensor): Input tensor with shape (n, c, h, w).
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
Returns:
Tensor: Forward results.
"""
if self.norm_img:
x = (x + 1.) * 0.5
gt = (gt + 1.) * 0.5
# extract vgg features
x_features = self.vgg(x)
gt_features = self.vgg(gt.detach())
# calculate preceptual loss
if self.perceptual_loss:
percep_loss = 0
for k in x_features.keys():
percep_loss += self.criterion(
x_features[k], gt_features[k]) * self.layer_weights[k]
else:
percep_loss = None
# calculate style loss
if self.style_loss:
style_loss = 0
for k in x_features.keys():
style_loss += self.criterion(
self._gram_mat(x_features[k]),
self._gram_mat(gt_features[k])) * self.layer_weights[k]
else:
style_loss = None
return percep_loss, style_loss
def _gram_mat(self, x):
"""Calculate Gram matrix.
Args:
x (torch.Tensor): Tensor with shape of (n, c, h, w).
Returns:
torch.Tensor: Gram matrix.
"""
(n, c, h, w) = x.size()
features = x.view(n, c, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (c * h * w)
return gram