TANG 0.0.1
This commit is contained in:
0
loss/I2I/__init__.py
Normal file
0
loss/I2I/__init__.py
Normal file
129
loss/I2I/edge_loss.py
Normal file
129
loss/I2I/edge_loss.py
Normal file
@@ -0,0 +1,129 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class HED(nn.Module):
|
||||
def __init__(self, pretrained_model_path, norm_img=True):
|
||||
"""
|
||||
HED module to get edge
|
||||
:param pretrained_model_path: path to pretrained HED.
|
||||
:param norm_img(bool): If True, the image will be normed to [0, 1]. Note that
|
||||
this is different from the `use_input_norm` which norm the input in
|
||||
in forward function of vgg according to the statistics of dataset.
|
||||
Importantly, the input image must be in range [-1, 1].
|
||||
"""
|
||||
super().__init__()
|
||||
self.norm_img = norm_img
|
||||
|
||||
self.vgg_nets = nn.ModuleList([torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.ReLU(inplace=False)
|
||||
), torch.nn.Sequential(
|
||||
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.ReLU(inplace=False)
|
||||
), torch.nn.Sequential(
|
||||
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.ReLU(inplace=False)
|
||||
), torch.nn.Sequential(
|
||||
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.ReLU(inplace=False)
|
||||
), torch.nn.Sequential(
|
||||
torch.nn.MaxPool2d(kernel_size=2, stride=2),
|
||||
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.ReLU(inplace=False),
|
||||
torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
||||
torch.nn.ReLU(inplace=False)
|
||||
)])
|
||||
|
||||
self.score_nets = nn.ModuleList([
|
||||
torch.nn.Conv2d(in_channels=i, out_channels=1, kernel_size=1, stride=1, padding=0)
|
||||
for i in [64, 128, 256, 512, 512]
|
||||
])
|
||||
|
||||
self.combine_net = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
|
||||
torch.nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.load_weights(pretrained_model_path)
|
||||
self.register_buffer('mean', torch.Tensor([104.00698793, 116.66876762, 122.67891434]).view(1, 3, 1, 1))
|
||||
for v in self.parameters():
|
||||
v.requies_grad = False
|
||||
|
||||
def load_weights(self, pretrained_model_path):
|
||||
checkpoint_path = Path(pretrained_model_path)
|
||||
if not checkpoint_path.exists():
|
||||
raise FileNotFoundError(f"Checkpoint '{checkpoint_path}' is not found")
|
||||
ckp = torch.load(checkpoint_path.as_posix(), map_location="cpu")
|
||||
m = {"One": "0", "Two": "1", "Thr": "2", "Fou": "3", "Fiv": "4"}
|
||||
|
||||
def replace_key(key):
|
||||
if key.startswith("moduleVgg"):
|
||||
return f"vgg_nets.{m[key[9:12]]}{key[12:]}"
|
||||
elif key.startswith("moduleScore"):
|
||||
return f"score_nets.{m[key[11:14]]}{key[14:]}"
|
||||
elif key.startswith("moduleCombine"):
|
||||
return f"combine_net{key[13:]}"
|
||||
else:
|
||||
raise ValueError("wrong checkpoint for HED")
|
||||
|
||||
module_dict = {replace_key(k): v for k, v in ckp.items()}
|
||||
self.load_state_dict(module_dict, strict=True)
|
||||
|
||||
def forward(self, x):
|
||||
if self.norm_img:
|
||||
x = (x + 1.) * 0.5
|
||||
x = x * 255.0 - self.mean
|
||||
img_size = (x.size(2), x.size(3))
|
||||
|
||||
to_combine = []
|
||||
for i in range(5):
|
||||
x = self.vgg_nets[i](x)
|
||||
score_x = self.score_nets[i](x)
|
||||
to_combine.append(F.interpolate(input=score_x, size=img_size, mode='bilinear', align_corners=False))
|
||||
out = self.combine_net(torch.cat(to_combine, 1))
|
||||
return out.clamp(0.0, 1.0)
|
||||
|
||||
|
||||
class EdgeLoss(nn.Module):
|
||||
def __init__(self, edge_extractor_type="HED", norm_img=True, criterion='L1', **kwargs):
|
||||
super(EdgeLoss, self).__init__()
|
||||
if edge_extractor_type == "HED":
|
||||
pretrained_model_path = kwargs.get("hed_pretrained_model_path")
|
||||
self.edge_extractor = HED(pretrained_model_path, norm_img)
|
||||
else:
|
||||
raise NotImplemented(f"do not support edge_extractor_type {edge_extractor_type}")
|
||||
|
||||
if criterion == 'L1':
|
||||
self.criterion = nn.L1Loss()
|
||||
elif criterion == "L2":
|
||||
self.criterion = nn.MSELoss()
|
||||
else:
|
||||
raise NotImplementedError(f'{criterion} criterion has not been supported in this version.')
|
||||
|
||||
def forward(self, x, gt, gt_is_edge=True):
|
||||
edge = self.edge_extractor(x)
|
||||
if not gt_is_edge:
|
||||
gt = self.edge_extractor(gt.detach())
|
||||
loss = self.criterion(edge, gt)
|
||||
return loss
|
||||
155
loss/I2I/perceptual_loss.py
Normal file
155
loss/I2I/perceptual_loss.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.models.vgg as vgg
|
||||
|
||||
|
||||
class PerceptualVGG(nn.Module):
|
||||
"""VGG network used in calculating perceptual loss.
|
||||
In this implementation, we allow users to choose whether use normalization
|
||||
in the input feature and the type of vgg network. Note that the pretrained
|
||||
path must fit the vgg type.
|
||||
Args:
|
||||
layer_name_list (list[str]): According to the index in this list,
|
||||
forward function will return the corresponding features. This
|
||||
list contains the name each layer in `vgg.feature`. An example
|
||||
of this list is ['4', '10'].
|
||||
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
|
||||
use_input_norm (bool): If True, normalize the input image.
|
||||
Importantly, the input feature must in the range [0, 1].
|
||||
Default: True.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_name_list, vgg_type='vgg19', use_input_norm=True):
|
||||
super(PerceptualVGG, self).__init__()
|
||||
self.layer_name_list = layer_name_list
|
||||
self.use_input_norm = use_input_norm
|
||||
|
||||
# get vgg model and load pretrained vgg weight
|
||||
# remove _vgg from attributes to avoid `find_unused_parameters` bug
|
||||
_vgg = getattr(vgg, vgg_type)(pretrained=True)
|
||||
num_layers = max(map(int, layer_name_list)) + 1
|
||||
assert len(_vgg.features) >= num_layers
|
||||
# only borrow layers that will be used from _vgg to avoid unused params
|
||||
self.vgg_layers = _vgg.features[:num_layers]
|
||||
|
||||
if self.use_input_norm:
|
||||
# the mean is for image with range [0, 1]
|
||||
self.register_buffer(
|
||||
'mean',
|
||||
torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
||||
# the std is for image with range [-1, 1]
|
||||
self.register_buffer(
|
||||
'std',
|
||||
torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
||||
|
||||
for v in self.vgg_layers.parameters():
|
||||
v.requies_grad = False
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function.
|
||||
Args:
|
||||
x (Tensor): Input tensor with shape (n, c, h, w).
|
||||
Returns:
|
||||
Tensor: Forward results.
|
||||
"""
|
||||
|
||||
if self.use_input_norm:
|
||||
x = (x - self.mean) / self.std
|
||||
output = {}
|
||||
|
||||
for i, l in enumerate(self.vgg_layers):
|
||||
x = l(x)
|
||||
if str(i) in self.layer_name_list:
|
||||
output[str(i)] = x.clone()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class PerceptualLoss(nn.Module):
|
||||
"""Perceptual loss with commonly used style loss.
|
||||
Args:
|
||||
layer_weights (dict): The weight for each layer of vgg feature.
|
||||
Here is an example: {'4': 1., '9': 1., '18': 1.}, which means the
|
||||
5th, 10th and 18th feature layer will be extracted with weight 1.0
|
||||
in calculating losses.
|
||||
vgg_type (str): The type of vgg network used as feature extractor.
|
||||
Default: 'vgg19'.
|
||||
use_input_norm (bool): If True, normalize the input image in vgg.
|
||||
Default: True.
|
||||
perceptual_loss (bool): If `perceptual_loss == True`, the perceptual
|
||||
loss will be calculated.
|
||||
Default: True.
|
||||
style_loss (bool): If `style_loss == False`, the style loss will be calculated.
|
||||
Default: False.
|
||||
norm_img (bool): If True, the image will be normed to [0, 1]. Note that
|
||||
this is different from the `use_input_norm` which norm the input in
|
||||
in forward function of vgg according to the statistics of dataset.
|
||||
Importantly, the input image must be in range [-1, 1].
|
||||
"""
|
||||
|
||||
def __init__(self, layer_weights, vgg_type='vgg19', use_input_norm=True, perceptual_loss=True,
|
||||
style_loss=False, norm_img=True, criterion='L1'):
|
||||
super(PerceptualLoss, self).__init__()
|
||||
self.norm_img = norm_img
|
||||
self.perceptual_loss = perceptual_loss
|
||||
self.style_loss = style_loss
|
||||
self.layer_weights = layer_weights
|
||||
self.vgg = PerceptualVGG(layer_name_list=list(layer_weights.keys()), vgg_type=vgg_type,
|
||||
use_input_norm=use_input_norm)
|
||||
|
||||
if criterion == 'L1':
|
||||
self.criterion = torch.nn.L1Loss()
|
||||
elif criterion == "L2":
|
||||
self.criterion = torch.nn.MSELoss()
|
||||
else:
|
||||
raise NotImplementedError(f'{criterion} criterion has not been supported in this version.')
|
||||
|
||||
def forward(self, x, gt):
|
||||
"""Forward function.
|
||||
Args:
|
||||
x (Tensor): Input tensor with shape (n, c, h, w).
|
||||
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
|
||||
Returns:
|
||||
Tensor: Forward results.
|
||||
"""
|
||||
|
||||
if self.norm_img:
|
||||
x = (x + 1.) * 0.5
|
||||
gt = (gt + 1.) * 0.5
|
||||
# extract vgg features
|
||||
x_features = self.vgg(x)
|
||||
gt_features = self.vgg(gt.detach())
|
||||
|
||||
# calculate preceptual loss
|
||||
if self.perceptual_loss:
|
||||
percep_loss = 0
|
||||
for k in x_features.keys():
|
||||
percep_loss += self.criterion(
|
||||
x_features[k], gt_features[k]) * self.layer_weights[k]
|
||||
else:
|
||||
percep_loss = None
|
||||
|
||||
# calculate style loss
|
||||
if self.style_loss:
|
||||
style_loss = 0
|
||||
for k in x_features.keys():
|
||||
style_loss += self.criterion(
|
||||
self._gram_mat(x_features[k]),
|
||||
self._gram_mat(gt_features[k])) * self.layer_weights[k]
|
||||
else:
|
||||
style_loss = None
|
||||
|
||||
return percep_loss, style_loss
|
||||
|
||||
def _gram_mat(self, x):
|
||||
"""Calculate Gram matrix.
|
||||
Args:
|
||||
x (torch.Tensor): Tensor with shape of (n, c, h, w).
|
||||
Returns:
|
||||
torch.Tensor: Gram matrix.
|
||||
"""
|
||||
(n, c, h, w) = x.size()
|
||||
features = x.view(n, c, w * h)
|
||||
features_t = features.transpose(1, 2)
|
||||
gram = features.bmm(features_t) / (c * h * w)
|
||||
return gram
|
||||
Reference in New Issue
Block a user