IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 逐步理解Swin-Transformer源码 -> 正文阅读

[人工智能]逐步理解Swin-Transformer源码

逐步解释Swin-Transformer源码

这一行没什么好说的,就是设置参数值,并读取yaml里的模型参数。

是否采用混合精度训练,混合精度训练可以给合适的参数设置合适的位数,而不是全部为float32

分布式训练。

这一段就是什么梯度累加,学习率缩放等一些优化策略。跳到main函数中

数据预处理,logger写日志的模块,暂时不用管。

跳到构建模型中,详细讲解这一部分。

再跳到SwinTransformer中,再看模型的构建代码时,先整体看一下整体一个架构

这个图想必大家都看到过很多次了,首先输入图片(H*W*3),输入到一个Patch Partition中,这个模块SwinTrasfomer中与VIT相同但又略微不同,他将图片按照4*4划分为一个token,送到Linear Embedding中,进行维度变化,利用一个卷积,将通道数变为初始设置的C,源码中为96.

? ? ? ? 再送到Swintransformer Block中运算,需要注意的是,Block并不会改变输入图像的维度。每一个模块下面的*2,*2,*6,*2就是叠加了多少个Block,patch Merging就是将4*4变为8*8,再将通道数变为2C,以此类推。

? ? ? ? Block中主要是一个W-MSA和一个SW-MSA,这就是Swintransformer的创新之处,之后代码中会讲到。

? ? ? ? 接下来看代码

这里就是一初始化,源码中都有注释,这里用到再详细解释。

跳入第一个模块

初始化,patch_size就是我们想要将图片按照几乘几去划分,都很好理解,patches_resolution就是图像被划分为了几块,proj就是前面说的,通过一个卷积将通道数映射到我们设置的值上。

绝对位置编码,也就是transformer中的位置embedding

?接下来到了Basic?Block中,根据每一个stage的层数来循环相应次数。

需要注意的是,这里与论文中不同,这里的Basic?Block是包括Swintransformer Block和Pach merging的。

跳入Swintransformer Block

跳入WindowAttention中

WindowAttention就是创新之处,由于VIT采取全局各个Token之间算attention,这样会导致,计算量非常大,所以这里采用窗的机制,我们将一幅图片划分为不同的窗,每个窗内自己的attention,即W-MSA

建立相对位置的索引表,具体什么意思。

我们知道,Transformer是利用绝对位置编码的,而这里采用相对位置编码。

就是公式中的B,包含着他的一个位置信息。那B到底如何计算,首先B的维度一定是与QKT的维度相同,因为两者要相加嘛,而QKT的维度如何计算,不用计算,QKT其实就是每两个像素点之间求attention,那么肯定是要遍历所有的像素点,并且包含顺序,也就是说是len*len,如下图

比如,我喜欢你,他的相对位置有如图四种情况,那应该如何将相对位置加到QKT中呢?

拉直,每一行以该行的值当作原点叠加

当然,这个加并不是直接加上去,而是我们有一个索引表,也就是刚刚建立的索引表,0找到0代表的值,-1找到-1的,以此类推。这就是相对位置编码的原理。

那在Swintransformer中是怎样设计的呢?继续看代码

这一部分就是计算相对位置的代码, 每一步对应下面的一张图,注意这里使用了广播机制来计算相对位置,也就是coords_flatten(:,:,None),加M-1是因为防止有负数,乘以2M-1是因为,每一行不同元素之间的attention值不能一样,就比如(1,2)和(2,1)

这就是相对位置编码。同时将索引注册为不用学习的参数。

计算Q,K,V同时再利用卷积完成通道数的变化,最后再给索引表初始化。

这就是一个WindowAttention

这就是一个MLP,也没啥好说的

接下来又到了非常重要的地方,前面我们说到,将图片划分为不同的窗,每个窗内计算attention,这样的话就会导致不同的窗之间没有关联,而实际上应该也是有关联的,所以就出现了SW-MSA

如何实现的呢?就是将窗整体右移window_size/2,也就是移动3,如图

如图,移动之后,我们可以看到左上角多出来两条,将其补刀右下角就如图所展示。

移完后,我们还是要去窗内计算attention,对于0中,完全可以实现,并且完成了对于W-MSA中不同窗之间计算attention,但对于1,2,3,4,5,6,7,8来说,却不可以,因为有我们补上去的,实际的这些像素并不在这里,我们不能计算他们之间的attention。

所以就需要mask,来防止两者之间计算mask

这里就是通过slice对图片进行划分并且标号,如我自己画的图那样

如图

?这里之所以加-100,就是为了加到atten矩阵中时,使得不相关的块之间的值很小,这样经过softmax之后就很小了,也就是相关性非常低。

到这里一个Swintransformer模块介绍完了,就是窗不移和窗移连在一起。然后重复[2,2,6,2]次

然后就到

通过一个下采样,使得原本的4*4,变为8*8,相当于卷积中的池化,增大了感受野。

这就是一个stage,然后一共有四个stage循环四次,再经过

数据集采用多少分类,这里就输出多少。再对权值初始化,至此模型搭建完成。

之后,设置模型的优化函数,损失函数,分布式训练等等。

是否继续上次的训练,这个感觉一般都不会用到,暂且不用管。

开始训练,记录开始时间,跳到train_one_epoch中

然后其实就没什么可说的了,提取bach数据,送到模型出,输出预测值,计算损失和梯度,进行梯度累积,达到累计次数,梯度更新并梯度清零。然后就计算准确率啊,损失值啊,时间啊等等。

然后就,看看迭代次数是否达到设置的值,达到保存模型,dist.get_rank()==0是分布式训练中的,这里对分布式训练还不太了解,所以这里也不清楚。

之后就计算top1,top5准确率,计算总体时间,将想要的值通过logger写入日志。

至此,Swin-transformer中的原理及代码介绍完毕。

如有错误,欢迎批评指正!!

?

?

?

?

?

?

?

?

?

?

?

?

?

?

?

?

?

?

?

?

?

?

?

?

?

?

?

?

?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-11-09 19:29:17  更:2021-11-09 19:30:07 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 7:42:32-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码