Author :Horizon Max
✨ 编程技巧篇:各种操作小结
机器视觉篇:会变魔术 OpenCV
深度学习篇:简单入门 PyTorch
神经网络篇:经典网络模型
算法篇:再忙也别忘了 LeetCode
随着神经网路模型的不断发展,深度模型通过使用更抽象
(增加网络层数)和 更紧密
(端到端训练)实现了更好的性能 ;
但随之带来的是对于神经网络的 可解释性 :为什么会出现这样的结果?网络的关注点在哪?
基于此提出的 Grad-CAM
利用热力图的方式实现网络预测过程的可视化,并帮助我们更好的理解神经网络 ;
Grad-CAM 是 CAM 的推广,不需要更改网络结构或重新训练就能实现更多 CNN 模型的可视化 ;
论文地址:Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization
GitHub:PyTorch-Grad-CAM
针对 类别为c、宽度u、高度为v 的类别定位图 Grad-CAM L G r a d − C A M c ^c_{Grad-CAM} Grad−CAMc ∈ Ru×v
Z
:宽度 i 和 高度 j 的乘积 ;c
:选取的类别 c ;k
:第 k 个通道 ;A
:需要进行可视化的特征层,一般选取最后一个卷积层的输出 ;ReLU
:使最后的输出结果 >0 ,抑制不感兴趣的权重部分 ;详细可以参考下图:
虽然 Grad-CAM 具有分类区分和局部化相关图像区域的能力 ;
但仍缺乏类似于 Guided Backpropagation 像素空间梯度可视化的那种突出细粒度细节的能力 ;
基于此,作者通过元素级乘法融合了Guided Backpropagation 和 Grad-CAM 可视化 ;
首先使用双线性插值将 L G r a d − C A M c ^c_{Grad-CAM} Grad−CAMc 上采样到输入图像分辨率 ;
# Here is the code :
import os
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import models
from torchvision import transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
def main():
model = models.resnet50(pretrained=True)
target_layers = [model.layer4[-1]]
data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Prepare image
img_path = "image.png"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path).convert('RGB')
img = np.array(img, dtype=np.uint8)
img_tensor = data_transform(img)
input_tensor = torch.unsqueeze(img_tensor, dim=0)
# Grad CAM
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
# targets = [ClassifierOutputTarget(281)] # cat
targets = [ClassifierOutputTarget(254)] # dog
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :]
visualization = show_cam_on_image(img.astype(dtype=np.float32)/255.,
grayscale_cam, use_rgb=True)
plt.imshow(visualization)
plt.show()
if __name__ == '__main__':
main()
结果展示:
targets = [ClassifierOutputTarget(281)] # cat
targets = [ClassifierOutputTarget(254)] # dog