CV领域常用的attention模块的pytorch实现
进入2021年,CV领域已经发展到了非常非常火的时期了,attention的出现更是让卷积模块的功能变得更加强大;通过attention模块,特征图能够更好地关注到感兴趣区域,提取跟实际任务更加契合的特征。 在这个系列中,将近几年出现的attention模块做一个简要的总结,并且将这些attention模块用pytorch实现出来。所有的attention模块都封装成了一个类,在实际调用的时候即插即用,在需要attention的地方可以直接使用。 刚接触transformer不久,仔细看才发现原来attention有这么多花招,很多博客分享的内容也挺实用的。完整的代码放在了github上,欢迎各位点一下star。 每一种attention模块对应的原文地址也放在github上了,想仔细钻研的童鞋下载来认真看。 下面先将第一个attention模块,SENet,全程为Squeeze-and-Excitation Network,论文地址点击这里获取。SENet使用的attention是通道自注意力机制,实现的原理图如下:首先将输入经过卷积等操作得到特征图,然后将特征图分成两路,其中一路通过全局平均池化得到一个Bx1x1xC的特征图(B–>batch_size, C–>channel),然后让这个特征图经过一个全连接网络,通道数变成原来的1/r,相当于一个挤压通道的过程(squeeze),经过relu激活函数(excatation)之后将通道变回原来的C,将此事的特征图扩展为跟输入特征图相同的shape,相加得到attention模块的输出。 论文中分别展示了inception module类型和residual module类型的attention结构。
pytorch实现起来不难,具体的代码如下所示:
import warnings
warnings.filterwarnings('ignore')
import torch
import torch.nn as nn
import torch.nn.functional as F
class SENet(nn.Module):
def __init__(self,channel,ratio = 16):
super(SENet,self).__init__()
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.sigmoid = nn.Sigmoid()
self.fc_1 = nn.Linear(channel,channel//ratio)
self.relu = nn.ReLU()
self.fc_2 = nn.Linear(channel//ratio,channel)
def forward(self,x):
temp = x
temp = self.global_pool(temp).view(x.size(0),x.size(1))
temp = self.fc_1(temp)
temp = self.relu(temp)
temp = self.fc_2(temp)
temp = self.sigmoid(temp).view(x.size(0),x.size(1),1,1)
temp = temp.expand_as(x)
return x+temp
|