change
This commit is contained in:
@@ -1,26 +1,34 @@
|
||||
import torchvision.utils
|
||||
from matplotlib.pyplot import get_cmap
|
||||
import torch
|
||||
import warnings
|
||||
from torch.nn.functional import interpolate
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def attention_colored_map(attentions, size=None, cmap_name="jet"):
|
||||
def attention_colored_map(attentions, size=None):
|
||||
assert attentions.dim() == 4 and attentions.size(1) == 1
|
||||
device = attentions.device
|
||||
|
||||
min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
||||
attentions -= min_attentions
|
||||
attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
||||
|
||||
if size is not None and attentions.size()[-2:] != size:
|
||||
attentions = attentions.detach().cpu().numpy()
|
||||
attentions = (attentions * 255).astype(np.uint8)
|
||||
need_resize = False
|
||||
if size is not None and attentions.shape[-2:] != size:
|
||||
assert len(size) == 2, "for interpolate, size must be (x, y), have two dim"
|
||||
attentions = interpolate(attentions, size, mode="bilinear", align_corners=False)
|
||||
cmap = get_cmap(cmap_name)
|
||||
ca = cmap(attentions.squeeze(1).cpu())[:, :, :, :3]
|
||||
return torch.from_numpy(ca).permute(0, 3, 1, 2).contiguous()
|
||||
need_resize = True
|
||||
|
||||
subs = []
|
||||
for sub in attentions:
|
||||
sub = cv2.resize(sub[0], size) if need_resize else sub[0] # numpy.array shape=size
|
||||
subs.append(cv2.applyColorMap(sub, cv2.COLORMAP_JET)) # append a (size[0], size[1], 3) numpy array
|
||||
subs = np.stack(subs) # (batch_size, size[0], size[1], 3)
|
||||
return torch.from_numpy(subs).permute(0, 3, 1, 2).contiguous().to(device).float() / 255
|
||||
|
||||
|
||||
def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
|
||||
def fuse_attention_map(images, attentions, alpha=0.5):
|
||||
"""
|
||||
|
||||
:param images: B x H x W
|
||||
@@ -35,7 +43,7 @@ def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
|
||||
if attentions.size(1) != 1:
|
||||
warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}")
|
||||
return images
|
||||
colored_attentions = attention_colored_map(attentions, images.size()[-2:], cmap_name).to(images.device)
|
||||
colored_attentions = attention_colored_map(attentions, images.size()[-2:])
|
||||
return images * alpha + colored_attentions * (1 - alpha)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user