IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> GradCAM神经网络可视化解释(原理和实现) -> 正文阅读

[人工智能]GradCAM神经网络可视化解释(原理和实现)

GradCAM是经典的特征图可视化工具,在CV任务中,能用于分析CNN学到了什么东西。先看一张图:
dog
这就是GradCAM做出的效果,它直观地表示出咱们模型认为图片是Dog的是依据哪些地方。
GradCAM借用梯度来进行注意力表示,发表于ICCV2017,如今依然活跃在学术和工程界。
论文链接:https://arxiv.org/abs/1610.02391


GradCAM原理

在这里插入图片描述
对于视觉任务,包括图像分类、目标检测等,通常都是backbone+head的形式。如图1所示。所以,GradCAM可以无差别地对各种视觉任务进行可视化
在操作上,GradCAM拿到backbone的输出梯度,一般是4维张量,将这一层梯度进行平均化作为权重,然后跟这一层的输出张量做一个加权平均(先乘再加),然后过一层relu去掉负值,最后等比例投影在调整过的原图上。

我以图像分类为例进行剖析:(假设我们做5分类)
第一步,前向传播得到特征图。需要进行一次前向计算,得到backbone的特征图输出。
第二步,反向传播得到梯度。模型的输出为一个5-d的向量res,假设我们要看类别1的可视化,咱们就把res[1]当作loss进行反向传播。这么做的原理:咱们需要知道模型识别出类别1会认为哪些特征图是重要的,而梯度直接表达了参数要调整的方向,假设参数调整方向为正向,那么这些特征图就应该是重要的。所以,在这一层的对应位置的平均梯度可以表示该特征图的重要性,即权重。再结合第一步得到的特征图,进行加权平均就可以了~

咱们看一下代码:(来自https://github.com/leftthomas/GradCAM

import cv2
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.autograd import Variable


class GradCam:
    def __init__(self, model):
        self.model = model.eval()
        self.feature = None
        self.gradient = None

    def save_gradient(self, grad):
        self.gradient = grad

    def __call__(self, x):
        image_size = (x.size(-1), x.size(-2))
        datas = Variable(x)
        heat_maps = []
        for i in range(datas.size(0)):
            img = datas[i].data.cpu().numpy()
            img = img - np.min(img)
            if np.max(img) != 0:
                img = img / np.max(img)

            feature = datas[i].unsqueeze(0)
            for name, module in self.model.named_children():
                print(name)
                if name == 'classifier' or name == 'fc':
                    feature = feature.view(feature.size(0), -1)
                feature = module(feature)
                if name == 'features' or name == 'backbone':
                    feature.register_hook(self.save_gradient)  # get backbone gradients
                    self.feature = feature
            classes = torch.sigmoid(feature)
            print(torch.argmax(F.softmax(classes), dim=-1))
            one_hot, _ = classes.max(dim=-1)
            self.model.zero_grad()
            one_hot.backward()

            weight = self.gradient.mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True)
            mask = F.relu((weight * self.feature).sum(dim=1)).squeeze(0)
            mask = cv2.resize(mask.data.cpu().numpy(), image_size)
            mask = mask - np.min(mask)
            if np.max(mask) != 0:
                mask = mask / np.max(mask)
            heat_map = np.float32(cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET))
            cam = heat_map + np.float32((np.uint8(img.transpose((1, 2, 0)) * 255)))
            cam = cam - np.min(cam)
            if np.max(cam) != 0:
                cam = cam / np.max(cam)
            heat_maps.append(transforms.ToTensor()(cv2.cvtColor(np.uint8(255 * cam), cv2.COLOR_BGR2RGB)))
        heat_maps = torch.stack(heat_maps)
        return heat_maps

代码中通过register_hook来获取backbone的梯度,想进一步了解hook的可戳《python中的register_hook

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-05-09 12:39:40  更:2022-05-09 12:44:40 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/4 16:00:20-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码