add attention image fuse
This commit is contained in:
@@ -12,7 +12,6 @@ from ignite.contrib.handlers.tensorboard_logger import TensorboardLogger, Output
|
||||
def empty_cuda_cache(_):
|
||||
torch.cuda.empty_cache()
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
@@ -35,6 +34,14 @@ def setup_common_handlers(trainer: Engine, config, stop_on_nan=True, clear_cuda_
|
||||
:return:
|
||||
"""
|
||||
|
||||
# if train_sampler is not None:
|
||||
# if not isinstance(train_sampler, DistributedSampler):
|
||||
# raise TypeError("Train sampler should be torch DistributedSampler and have `set_epoch` method")
|
||||
#
|
||||
# @trainer.on(Events.EPOCH_STARTED)
|
||||
# def distrib_set_epoch(engine):
|
||||
# train_sampler.set_epoch(engine.state.epoch - 1)
|
||||
|
||||
@trainer.on(Events.STARTED)
|
||||
@idist.one_rank_only()
|
||||
def print_dataloader_size(engine):
|
||||
|
||||
@@ -1,4 +1,38 @@
|
||||
import torchvision.utils
|
||||
from matplotlib.pyplot import get_cmap
|
||||
import torch
|
||||
import warnings
|
||||
from torch.nn.functional import interpolate
|
||||
|
||||
|
||||
def fuse_attention_map(images, attentions, cmap_name="jet", alpha=0.5):
|
||||
"""
|
||||
|
||||
:param images: B x H x W
|
||||
:param attentions: B x Ha x Wa
|
||||
:param cmap_name:
|
||||
:param alpha:
|
||||
:return:
|
||||
"""
|
||||
if attentions.size(0) != images.size(0):
|
||||
warnings.warn(f"attentions: {attentions.size()} and images: {images.size} do not have same batch_size")
|
||||
return images
|
||||
if attentions.size(1) != 1:
|
||||
warnings.warn(f"attentions's channels should be 1 but got {attentions.size(1)}")
|
||||
return images
|
||||
|
||||
min_attentions = attentions.view(attentions.size(0), -1).min(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
||||
attentions -= min_attentions
|
||||
attentions /= attentions.view(attentions.size(0), -1).max(1, keepdim=True)[0].view(attentions.size(0), 1, 1, 1)
|
||||
|
||||
if images.size() != attentions.size():
|
||||
attentions = interpolate(attentions, images.size()[-2:])
|
||||
colored_attentions = torch.zeros_like(images)
|
||||
cmap = get_cmap(cmap_name)
|
||||
for i, at in enumerate(attentions):
|
||||
ca = cmap(at[0].cpu().numpy())[:, :, :3]
|
||||
colored_attentions[i] = torch.from_numpy(ca).permute(2, 0, 1).view(colored_attentions[i].size())
|
||||
return images * alpha + colored_attentions * (1 - alpha)
|
||||
|
||||
|
||||
def make_2d_grid(tensors, padding=0, normalize=True, range=None, scale_each=False, pad_value=0):
|
||||
|
||||
Reference in New Issue
Block a user