IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 学习笔记-SENet -> 正文阅读

[人工智能]学习笔记-SENet

目录

一、前言

1、什么是注意力机制

2、CV中的注意力机制有哪些

二、Squeeze-and-Excitation Networks

1、引言

2、总述

3、过程

4、应用

5、代码(SE_Resnet)

6、思考


一、前言

1、什么是注意力机制

注意力机制能够灵活的捕捉全局信息局部信息之间的联系。

其目的主要是让模型获得需要重点关注的目标区域,并对该区域投入更大的权重,突出显著地有用的特征,抑制或者忽略无用的特征。

2、CV中的注意力机制有哪些

1、通道域注意力机制

2、空间域注意力机制

3、混合域注意力机制

二、Squeeze-and-Excitation Networks

1、引言

CNN的主要作用:提取特征

对CNN的要求:具有非常强的特征表征能力,但是局部操作有限

感受野:卷积神经网络每一层输出的特征图上的像素点对应输入图片上区域大小。似乎要的是全局信息

增大感受野:堆叠卷积层、增大卷积核、下采样、空洞卷积等

2、总述

一个目的:得到一个权重矩阵,对特征进行重构

? ? ? ? ? ? ? ? ? (它是一个可以用来衡量通道重要性的数值,上图中用不同颜色展示)

两个重要操作:Squeeze 和 Excitation

四步走战略:Transformation、Squeeze 、Excitation、Scale

3、过程

第一步,Transformation (Ftr):

? ? ? ? ? ? ? ?给定一个input : 𝑋 ∈ ?𝐻′ ?𝑊′ ?𝐶 ′ ,让其经过 Ftr 做一次映射

? ? ? ? ? ? ? ?得到一个output:𝑈 ∈ ?𝐻 ? 𝑊 ? 𝐶 ,即: 𝑈 = Ftr(𝑋)

? ? ? ? ? ? ? ?注释:在传统的CNN中,这一步其实就是一个普通的卷积操作。

? ? ? ? ? ? ? ?这里边之所以单独 把这一步提出来,是因为现在不同的网络可能有着不同的操作变换。

第二步,Squeeze (Fsq):

? ? ? ? ? ? ? ?给定一个input : 𝑈 ∈ ?𝐻 ? 𝑊 ? 𝐶 ,让其经过 Fsq 做一次 变换

? ? ? ? ? ? ? ?得到一个output:𝑍 ∈ ?𝟏 ?𝟏 ? 𝐶 ,即:Zc=Fsq(Uc) = \frac{1}{H*W}\sum_{i=1}^{H} \sum_{j=1}^{W}Uc(i,j)

? ? ? ? ? ? ? ?具体操作:采用全局平均池化(GAP),将每个通道上对应的空间信息(H*W)压缩到对应通? ? ? ? ? ? ? ? ? ?道中变为1个数值,此时1个像素表示一个通道,最终维度变为1*1*C,成了一个向量。

? ? ? ? ? ? ? ?注释: 对𝑈实现全局低维嵌入——将空间进行挤压,相当于拥有了全局感受野。

第三步,Excitation (Fex):

? ? ? ? ? ? ??给定一个input : 𝑍 ∈ ?𝟏 ?𝟏 ? 𝐶 ,让其经过 Fex 得到 一个output:𝑆 ∈ ?𝟏 ?𝟏 ? 𝐶 ,即:

? ? ? ? ? ? ??s = F_{ex}(z,W) = \sigma (g(z,W)) = \sigma(W_{2}\delta (W_{1}z))

? ? ? ? ? ? ??具体操作:将上一步得到的𝑍 经过两个全连接层(bottleNeck),对应上面的𝑤1和𝑤2

? ? ? ? ? ? ??其中 𝛿是激活函数Relu,𝜎为激活函数sigmoid,最终得到的𝑺 就是我们想得到的权重值。? ? ? ? ? ? ? ? ? 注释: 这一步操作存在一个超参reduction ratio:r ,影响参数量 。

第四步,Scale (Fscale):

? ? ? ? ? ? ? 利用上一步得到的权重值,对原始的feature map进 行操作,

? ? ? ? ? ? ? 得到一个output:\tilde{X}∈ ?𝟏 ?𝟏 ? 𝐶 ,即: \tilde{X}= Fscale(𝑈, 𝑆) = 𝑆 ? 𝑈? ? ? ? ??

? ? ? ? ? ? ? 具体操作:将得到的权重施加到𝑈 上面的每一个通道上。

? ? ? ? ? ? ?其实也就是对于U每个位 置上的所有H*W上的值都乘上对应通道的权值而已,完成了特征? ? ? ? ? ? ? 图的重构

4、应用

1、设计一个新的有效的CNN很困难,本模块的基本原则是不改变原来 CNN结构,做到即插即用

2、要知道上图的transformation在哪里

5、代码(SE_Resnet)

class BasicBlock(nn.Module):
    """
    basic building block for ResNet-18, ResNet-34
    """
    message = "basic"

    def __init__(self, in_channels, out_channels, strides, is_se=False):
        super(BasicBlock, self).__init__()
        self.is_se = is_se
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=strides, padding=1, bias=False)  # same padding
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        if self.is_se:
            self.se = SE(out_channels, 16)

        # fit input with residual output
        self.short_cut = nn.Sequential()
        if strides is not 1:
            self.short_cut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=strides, padding=0, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn(out)
        if self.is_se:
            coefficient = self.se(out)
            out *= coefficient
        out += self.short_cut(x)
        return F.relu(out)


class BottleNeck(nn.Module):
    """
    BottleNeck block for RestNet-50, ResNet-101, ResNet-152
    """
    message = "bottleneck"

    def __init__(self, in_channels, out_channels, strides, is_se=False):
        super(BottleNeck, self).__init__()
        self.is_se = is_se
        self.conv1 = nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False)  # same padding
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=strides, padding=1, bias=False)
        self.conv3 = nn.Conv2d(out_channels, out_channels * 4, 1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels * 4)
        if self.is_se:
            self.se = SE(out_channels * 4, 16)

        # fit input with residual output
        self.shortcut = nn.Sequential(
            nn.Conv2d(in_channels, out_channels * 4, 1, stride=strides, padding=0, bias=False),
            nn.BatchNorm2d(out_channels*4)
        )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn1(out)
        out = F.relu(out)
        out = self.conv3(out)
        out = self.bn2(out)
        if self.is_se:
            coefficient = self.se(out)
            out *= coefficient
        out += self.shortcut(x)
        return F.relu(out)


class ResNet(nn.Module):
    """
    building ResNet_34
    """

    def __init__(self, block: object, groups: object, num_classes, is_se=False) -> object:
        super(ResNet, self).__init__()
        self.channels = 64  # out channels from the first convolutional layer
        self.block = block
        self.is_se = is_se

        self.conv1 = nn.Conv2d(3, self.channels, 7, stride=2, padding=3, bias=False)
        self.bn = nn.BatchNorm2d(self.channels)
        self.pool1 = nn.MaxPool2d(3, 2, 1)
        self.conv2_x = self._make_conv_x(channels=64, blocks=groups[0], strides=1, index=2)
        self.conv3_x = self._make_conv_x(channels=128, blocks=groups[1], strides=2, index=3)
        self.conv4_x = self._make_conv_x(channels=256, blocks=groups[2], strides=2, index=4)
        self.conv5_x = self._make_conv_x(channels=512, blocks=groups[3], strides=2, index=5)
        self.pool2 = nn.AvgPool2d(7)
        patches = 512 if self.block.message == "basic" else 512 * 4
        self.fc = nn.Linear(patches, num_classes)  # for 224 * 224 input size

    def _make_conv_x(self, channels, blocks, strides, index):
        """
        making convolutional group
        :param channels: output channels of the conv-group
        :param blocks: number of blocks in the conv-group
        :param strides: strides
        :return: conv-group
        """
        list_strides = [strides] + [1] * (blocks - 1)  # In conv_x groups, the first strides is 2, the others are ones.
        conv_x = nn.Sequential()
        for i in range(len(list_strides)):
            layer_name = str("block_%d_%d" % (index, i))  # when use add_module, the name should be difference.
            conv_x.add_module(layer_name, self.block(self.channels, channels, list_strides[i], self.is_se))
            self.channels = channels if self.block.message == "basic" else channels * 4
        return conv_x

    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(self.bn(out))
        out = self.pool1(out)
        out = self.conv2_x(out)
        out = self.conv3_x(out)
        out = self.conv4_x(out)
        out = self.conv5_x(out)
        out = self.pool2(out)
        out = out.view(out.size(0), -1)
        out = F.softmax(self.fc(out))
        return out


def ResNet_50_SE(num_classes=1000):
    return ResNet(block=BottleNeck, groups=[3, 4, 6, 3], num_classes=num_classes, is_se=True)

6、思考

1、按照本文的思路将空间信息全部压缩到通道中是否合理?

2、能否对本文的通道注意力进行其他的改进?

3、按照本文的通道注意力思想设计空间注意力?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-02-04 11:02:36  更:2022-02-04 11:02:46 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/10 12:04:09-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码