IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【神经网络】(22) ConvMixer 代码复现,网络解析,附TensorFlow完整代码 -> 正文阅读

[人工智能]【神经网络】(22) ConvMixer 代码复现,网络解析,附TensorFlow完整代码

大家好,今天和各位分享一下如何使用 TensorFlow 构建 ConvMixer 卷积神经网络模型.

我偶然间找到了这个网络,这是一个实现起来非常简单的模型,但是能够实现较好的精度表现,超过了 Vision Transformer 模型,有种大道至简的感觉。

论文地址:https://openreview.net/forum?id=TVHS5Y4dNvM


1. 引言

近年来 Transformer?模型在 CV 领域中不断挑战卷积神经网络的统治地位,出现了能和 CNN 扳手腕的 VisionTransformer 以及划时代的 SwinTransformer。这篇文章作者主要针对的是 VIT 模型,他提出了一个问题:ViT的性能是由于其强大的Transformer结构产生的,还是由于使用patch作为输入表示产生的

在论文中,作者证明了PatchEmbedding对VIT的精度影响更大,并提出了一个非常简单的模型ConvMixer,在思想上类似于ViT和MLP-Mixer。模型直接将patch作为输入,分离空间和通道尺寸的混合建模并在整个网络中保持相同大小的分辨率

尽管ConvMixer的设计很简单,但是实验证明了ConvMixer在相似的参数计数和数据集大小方面优于ViT、MLP-Mixer及其一些变体,以及经典的视觉模型,如ResNet。


2. 模型构建

我们先导入需要用到的工具包

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

2.1 Patch Embedding

patchembedding 的主要功能是对原始输入图像(h, w)划分图像块。首先指定每个图像块的size为(patch_size, patch_size)将每张图像划分出(h//patch_size, w//patch_size)个图像块

它的实现方法就是通过一个 kernel_size 和 stride 都等于 patch_size 的卷积层来划分图像块

代码如下:

# ---------------------------------------------- #
#(1)patchembedding层
'''out_channel代表输出通道数, patch_size代表每个图像块的宽高'''
# ---------------------------------------------- #
def patchembed(inputs, out_channel, patch_size):
    
    # 卷积核大小为patch_size*patch_size,步长为patch_size的标准卷积划分图像块
    x = layers.Conv2D(filters = out_channel,   # 输出通道数
                      kernel_size = patch_size,  # 卷积核尺寸
                      strides = patch_size,  # 卷积步长
                      padding = 'same',  # 
                      use_bias = False)(inputs)

    # GELU激活函数、BN标准化
    x = layers.Activation('gelu')(x)
    x = layers.BatchNormalization()(x)

    return x

2.2 特征提取层

这里的特征提取层由三部分组成,深度卷积(depthwise conv)、逐点卷积(pointwise conv)、残差连接(shortcut)。如下图ConvMixer Layer所示。

关于深度可分离卷积的原理,看我这篇博文:https://blog.csdn.net/dgvv4/article/details/123476899

首先输入特征图,经过深度卷积提取特征图长宽方向的信息,其中卷积核的个数和输入特征图的通道数相同,且输入和输出特征图的shape相同;接着残差连接输入和输出;然后经过1*1逐点卷积融合通道方向的信息,其中卷积核的个数和输出特征图的个数相同

代码如下:

# ---------------------------------------------- #
#(2)单个特征提取模块
'''out_channel代表逐点卷积的输出通道数, kernel_size代表深度卷积的卷积核大小'''
# ---------------------------------------------- #
def layer(inputs, out_channel, kernel_size):

    # 9*9深度卷积提取特征
    x = layers.DepthwiseConv2D(kernel_size = kernel_size,  # 卷积核大小
                               strides = 1,  # 不经过下采样
                               padding = 'same',  # 卷积前后size不变
                               use_bias = False)(inputs)
    # GELU激活、BN标准化
    x = layers.Activation('gelu')(x)
    x = layers.BatchNormalization()(x)

    # 残差连接
    x = x + inputs
    
    # 1*1逐点卷积
    x = layers.Conv2D(filters = out_channel,  # 输出通道数
                      kernel_size = 1,  # 1*1卷积
                      strides = 1)(x)
    # GELU激活、BN标准化
    x = layers.Activation('gelu')(x)
    x = layers.BatchNormalization()(x)

    return x

# ---------------------------------------------- #
#(3)堆叠多个特征提取模块
'''depth代表堆叠的次数'''
# ---------------------------------------------- #
def blocks(x, depth, out_channel, kernel_size):
    # 堆叠多个特征提取模块
    for _ in range(depth):
        x = layer(x, out_channel, kernel_size)
    
    return x

2.3 主干网络

ConvMixer的网络结构非常简单。首先图像经过 PatchEmbedding 划分图像块,然后经过12个特征提取模块,最后经过一个全连接层得到输出结果。

这里构建 ConvMixer-1536/20 网络模型,其中 1536 代表patchembedding 层的输出通道数20 代表堆叠20个特征提取模块每个图像块patch_size的大小为7*7特征提取模块中深度卷积的卷积核尺寸为 9*9

代码如下:

# ---------------------------------------------- #
#(4)主干网络
'''input_shape代表输入图像的尺寸(不包含batch维度), num_classes代表分类数'''
# ---------------------------------------------- #
def convmixer(input_shape, num_classes):

    # 构造输入层[b,224,224,3]
    inputs = keras.Input(shape=input_shape)
    # patchembedding层[b,224//7,224//7,1536]
    x = patchembed(inputs, out_channel=1536, patch_size=7)
    # 经过20个特征提取层[b,224//7,224//7,1536]
    x = blocks(x, depth=20, out_channel=1536, kernel_size=9)

    # 全局平均池化[b,1536]
    x = layers.GlobalAveragePooling2D()(x)
    # 全连接分类[b,num_classes]
    outputs = layers.Dense(num_classes)(x)

    # 构造网络
    model = keras.Model(inputs, outputs)

    return model

2.4 查看网络架构

以1000分类为例查看网络结构

# ---------------------------------------------- #
#(5)查看网络结构
# ---------------------------------------------- #
if __name__ == '__main__':
    # 接受模型
    model = convmixer(input_shape=[224,224,3],num_classes=1000)
    # 查看网络结构
    model.summary()

网络结构如下:

 conv2d_20 (Conv2D)             (None, 32, 32, 1536  2360832     ['tf.__operators__.add_19[0][0]']
                                )

 activation_40 (Activation)     (None, 32, 32, 1536  0           ['conv2d_20[0][0]']
                                )

 batch_normalization_40 (BatchN  (None, 32, 32, 1536  6144       ['activation_40[0][0]']
 ormalization)                  )

 global_average_pooling2d (Glob  (None, 1536)        0           ['batch_normalization_40[0][0]']
 alAveragePooling2D)

 dense (Dense)                  (None, 1000)         1537000     ['global_average_pooling2d[0][0]'
                                                                 ]
==================================================================================================
Total params: 51,719,656
Trainable params: 51,593,704
Non-trainable params: 125,952
__________________________________________________________________________________________________
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-06-14 22:35:35  更:2022-06-14 22:36:11 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/30 0:59:32-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码