前言
CNN中的特征可视化大体可分为两类:
- 细节信息:ZFNet中使用的deconvolution,改进的guide backpropagation
- 信息的重要性区分:类激活图(CAM),改进的Grad-CAM
第一类方法只显示了在深层特征中保留了哪些信息,而没有突出显示这些信息的相对重要性。第二类方法则具有一定的解释性,例如在分类任务中,通过CAM能够解释模型究竟是通过重点学习哪些信息来判断类别的。
1. CAM(Class Activation Map)
Network in Network中提出了用全局平均池化(GAP)替代全连接层以加强特征映射与类别之间的联系,更具可解释性。受该思想启发,CAM可视化技术应运而出。生成CAM的流程如下图所示(论文原图):
可以看出,生成CAM的步骤非常简单,但是对网络结构有要求(网络末端为GAP+FC这样的结构,并且FC只有一层,用于输出类别概率)。假设分类任务采用的是VGG网络,此时生成CAM的步骤为:
- 将VGG中的前两个FC替换为GAP,重新训练;
- 获取最后一个卷积层输出的特征图
[
f
1
,
f
2
,
.
.
.
,
f
n
]
[f_1, f_2, ..., f_n]
[f1?,f2?,...,fn?],以及全连接层的权重
[
w
1
,
w
2
,
.
.
.
,
w
n
]
[w_1, w_2, ..., w_n]
[w1?,w2?,...,wn?];
- 计算
C
A
M
=
∑
i
=
1
n
w
i
f
i
CAM=\sum_{i=1}^{n}w_if_i
CAM=∑i=1n?wi?fi?
不难发现,若网络结构不符合要求,按照上述方法计算CAM需要修改网络结构和重新训练。针对该问题,后续研究中提出了Gard-CAM。
2. Grad-CAM
由上述CAM的计算方法可知,生成CAM的关键是获取特征图的权重。基于对原始CAM的改进,Grad-CAM通过求网络输出的类别置信度对特征图的偏导来获取权重,适用于任意网络,并且能够可视化任意层的类激活图(通常选择最后一个卷积层,因为其包含了丰富的高级语义和空间信息)。
- 图片送入网络,前向传播,获取最后一个卷积层的特征图
A
k
A^k
Ak(可选,任意层均可,
k
k
k为通道index);
- 反向传播,获取网络输出的类别
c
c
c 的概率
y
c
y^c
yc关于
A
k
A^k
Ak的梯度
?
y
c
?
A
k
\frac{\partial y^c}{\partial A^k}
?Ak?yc?;
- 计算权重
α
k
c
=
1
Z
∑
i
∑
j
?
y
c
?
A
i
,
j
k
\alpha^{c}_{k}=\frac{1}{Z}\sum\limits_{i}\sum\limits_{j}\frac{\partial y^c}{\partial A^k_{i,j}}
αkc?=Z1?i∑?j∑??Ai,jk??yc?
- 计算Grad-CAM:
L
G
r
a
d
?
C
A
M
c
=
R
e
L
U
(
∑
k
α
k
c
A
k
)
L_{Grad-CAM}^{c}=ReLU(\sum\limits_{k}\alpha^{c}_{k}A^k)
LGrad?CAMc?=ReLU(k∑?αkc?Ak)
- 求偏导的意义:参考知乎中的文章,偏导表示输出关于输入的变化率,也就是特征图上变化一个单位,得到的输出变化多少单位。可以反映出输出
y
c
y^c
yc关于
A
i
,
j
k
A^k_{i,j}
Ai,jk?的敏感程度,如果梯度大,则非常敏感,表示该位置更有可能属于类别
c
c
c。
3. PyTorch中的hook机制
- PyTorch中设计hook的目的:在不改变网络代码、不在forward中返回某一层的输出的情况下,获取网络中某一层在前向传播或反向传播过程的输入和输出,并对其进行相关操作(例如:特征图可视化,梯度裁剪)。
4. Grad-CAM的PyTorch简洁实现
import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt
class GradCAM():
'''
Grad-cam: Visual explanations from deep networks via gradient-based localization
Selvaraju R R, Cogswell M, Das A, et al.
https://openaccess.thecvf.com/content_iccv_2017/html/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.html
'''
def __init__(self, model, target_layers, use_cuda=True):
super(GradCAM).__init__()
self.use_cuda = use_cuda
self.model = model
self.target_layers = target_layers
self.target_layers.register_forward_hook(self.forward_hook)
self.target_layers.register_full_backward_hook(self.backward_hook)
self.activations = []
self.grads = []
def forward_hook(self, module, input, output):
self.activations.append(output[0])
def backward_hook(self, module, grad_input, grad_output):
self.grads.append(grad_output[0].detach())
def calculate_cam(self, model_input):
if self.use_cuda:
device = torch.device('cuda')
self.model.to(device)
model_input = model_input.to(device)
self.model.eval()
y_hat = self.model(model_input)
max_class = np.argmax(y_hat.cpu().data.numpy(), axis=1)
model.zero_grad()
y_c = y_hat[0, max_class]
y_c.backward()
activations = self.activations[0].cpu().data.numpy().squeeze()
grads = self.grads[0].cpu().data.numpy().squeeze()
weights = np.mean(grads.reshape(grads.shape[0], -1), axis=1)
weights = weights.reshape(-1, 1, 1)
cam = (weights * activations).sum(axis=0)
cam = np.maximum(cam, 0)
cam = cam / cam.max()
return max_class, cam
def show_cam_image(self, image, cam):
h, w = image.shape[:2]
cam = cv2.resize(cam, (h,w))
cam = cam / cam.max()
heatmap = cv2.applyColorMap((255*cam).astype(np.uint8), cv2.COLORMAP_JET)
image = image / image.max()
heatmap = heatmap / heatmap.max()
result = 0.4*heatmap + 0.6*image
result = result / result.max()
plt.figure()
plt.imshow((result*255).astype(np.uint8))
plt.colorbar(shrink=0.8)
plt.tight_layout()
plt.show()
参考资料
|