IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> CNN可视化技术 -- CAM & Grad-CAM详解及pytorch简洁实现 -> 正文阅读

[人工智能]CNN可视化技术 -- CAM & Grad-CAM详解及pytorch简洁实现

前言

CNN中的特征可视化大体可分为两类:

  • 细节信息:ZFNet中使用的deconvolution,改进的guide backpropagation
  • 信息的重要性区分:类激活图(CAM),改进的Grad-CAM

第一类方法只显示了在深层特征中保留了哪些信息,而没有突出显示这些信息的相对重要性。第二类方法则具有一定的解释性,例如在分类任务中,通过CAM能够解释模型究竟是通过重点学习哪些信息来判断类别的。

1. CAM(Class Activation Map)

Network in Network中提出了用全局平均池化(GAP)替代全连接层以加强特征映射与类别之间的联系,更具可解释性。受该思想启发,CAM可视化技术应运而出。生成CAM的流程如下图所示(论文原图):
在这里插入图片描述

可以看出,生成CAM的步骤非常简单,但是对网络结构有要求(网络末端为GAP+FC这样的结构,并且FC只有一层,用于输出类别概率)。假设分类任务采用的是VGG网络,此时生成CAM的步骤为:

  1. 将VGG中的前两个FC替换为GAP,重新训练;
  2. 获取最后一个卷积层输出的特征图 [ f 1 , f 2 , . . . , f n ] [f_1, f_2, ..., f_n] [f1?,f2?,...,fn?],以及全连接层的权重 [ w 1 , w 2 , . . . , w n ] [w_1, w_2, ..., w_n] [w1?,w2?,...,wn?]
  3. 计算 C A M = ∑ i = 1 n w i f i CAM=\sum_{i=1}^{n}w_if_i CAM=i=1n?wi?fi?

不难发现,若网络结构不符合要求,按照上述方法计算CAM需要修改网络结构和重新训练。针对该问题,后续研究中提出了Gard-CAM。

2. Grad-CAM

由上述CAM的计算方法可知,生成CAM的关键是获取特征图的权重。基于对原始CAM的改进,Grad-CAM通过求网络输出的类别置信度对特征图的偏导来获取权重,适用于任意网络,并且能够可视化任意层的类激活图(通常选择最后一个卷积层,因为其包含了丰富的高级语义和空间信息)。
在这里插入图片描述

  • 生成Grad-CAM的步骤如下:
  1. 图片送入网络,前向传播,获取最后一个卷积层的特征图 A k A^k Ak(可选,任意层均可, k k k为通道index);
  2. 反向传播,获取网络输出的类别 c c c 的概率 y c y^c yc关于 A k A^k Ak的梯度 ? y c ? A k \frac{\partial y^c}{\partial A^k} ?Ak?yc?
  3. 计算权重 α k c = 1 Z ∑ i ∑ j ? y c ? A i , j k \alpha^{c}_{k}=\frac{1}{Z}\sum\limits_{i}\sum\limits_{j}\frac{\partial y^c}{\partial A^k_{i,j}} αkc?=Z1?i?j??Ai,jk??yc?
  4. 计算Grad-CAM: L G r a d ? C A M c = R e L U ( ∑ k α k c A k ) L_{Grad-CAM}^{c}=ReLU(\sum\limits_{k}\alpha^{c}_{k}A^k) LGrad?CAMc?=ReLU(k?αkc?Ak)
  • 求偏导的意义:参考知乎中的文章,偏导表示输出关于输入的变化率,也就是特征图上变化一个单位,得到的输出变化多少单位。可以反映出输出 y c y^c yc关于 A i , j k A^k_{i,j} Ai,jk?的敏感程度,如果梯度大,则非常敏感,表示该位置更有可能属于类别 c c c

3. PyTorch中的hook机制

  • PyTorch中设计hook的目的:在不改变网络代码、不在forward中返回某一层的输出的情况下,获取网络中某一层在前向传播或反向传播过程的输入和输出,并对其进行相关操作(例如:特征图可视化,梯度裁剪)。

4. Grad-CAM的PyTorch简洁实现

import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt

class GradCAM():
    '''
    Grad-cam: Visual explanations from deep networks via gradient-based localization
    Selvaraju R R, Cogswell M, Das A, et al. 
    https://openaccess.thecvf.com/content_iccv_2017/html/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.html
    '''
    def __init__(self, model, target_layers, use_cuda=True):
        super(GradCAM).__init__()
        self.use_cuda = use_cuda
        self.model = model
        self.target_layers = target_layers
        
        self.target_layers.register_forward_hook(self.forward_hook)
        self.target_layers.register_full_backward_hook(self.backward_hook)
        
        self.activations = []
        self.grads = []
        
    def forward_hook(self, module, input, output):
        self.activations.append(output[0])
        
    def backward_hook(self, module, grad_input, grad_output):
        self.grads.append(grad_output[0].detach())
        
    def calculate_cam(self, model_input):
        if self.use_cuda:
            device = torch.device('cuda')
            self.model.to(device)                 # Module.to() is in-place method 
            model_input = model_input.to(device)  # Tensor.to() is not a in-place method
        self.model.eval()
        
        # forward
        y_hat = self.model(model_input)
        max_class = np.argmax(y_hat.cpu().data.numpy(), axis=1)
        
        # backward
        model.zero_grad()
        y_c = y_hat[0, max_class]
        y_c.backward()
        
        # get activations and gradients
        activations = self.activations[0].cpu().data.numpy().squeeze()
        grads = self.grads[0].cpu().data.numpy().squeeze()
        
        # calculate weights
        weights = np.mean(grads.reshape(grads.shape[0], -1), axis=1)
        weights = weights.reshape(-1, 1, 1)
        cam = (weights * activations).sum(axis=0)
        cam = np.maximum(cam, 0) # ReLU
        cam = cam / cam.max()
        return max_class, cam
    
    def show_cam_image(self, image, cam):
        # image: [H,W,C]
        h, w = image.shape[:2]
        
        cam = cv2.resize(cam, (h,w))
        cam = cam / cam.max()
        heatmap = cv2.applyColorMap((255*cam).astype(np.uint8), cv2.COLORMAP_JET) # [H,W,C]
        
        image = image / image.max()
        heatmap = heatmap / heatmap.max()
        
        result = 0.4*heatmap + 0.6*image
        result = result / result.max()
        
        plt.figure()
        plt.imshow((result*255).astype(np.uint8))
        plt.colorbar(shrink=0.8)
        plt.tight_layout()
        plt.show()

参考资料

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-07-03 10:48:33  更:2022-07-03 10:50:32 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/29 9:30:18-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计