| |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
-> 人工智能 -> 逐步理解Swin-Transformer源码 -> 正文阅读 |
|
[人工智能]逐步理解Swin-Transformer源码 |
逐步解释Swin-Transformer源码 这一行没什么好说的,就是设置参数值,并读取yaml里的模型参数。 是否采用混合精度训练,混合精度训练可以给合适的参数设置合适的位数,而不是全部为float32 分布式训练。 这一段就是什么梯度累加,学习率缩放等一些优化策略。跳到main函数中 数据预处理,logger写日志的模块,暂时不用管。 跳到构建模型中,详细讲解这一部分。 再跳到SwinTransformer中,再看模型的构建代码时,先整体看一下整体一个架构 这个图想必大家都看到过很多次了,首先输入图片(H*W*3),输入到一个Patch Partition中,这个模块SwinTrasfomer中与VIT相同但又略微不同,他将图片按照4*4划分为一个token,送到Linear Embedding中,进行维度变化,利用一个卷积,将通道数变为初始设置的C,源码中为96. ? ? ? ? 再送到Swintransformer Block中运算,需要注意的是,Block并不会改变输入图像的维度。每一个模块下面的*2,*2,*6,*2就是叠加了多少个Block,patch Merging就是将4*4变为8*8,再将通道数变为2C,以此类推。 ? ? ? ? Block中主要是一个W-MSA和一个SW-MSA,这就是Swintransformer的创新之处,之后代码中会讲到。 ? ? ? ? 接下来看代码 这里就是一初始化,源码中都有注释,这里用到再详细解释。 跳入第一个模块 初始化,patch_size就是我们想要将图片按照几乘几去划分,都很好理解,patches_resolution就是图像被划分为了几块,proj就是前面说的,通过一个卷积将通道数映射到我们设置的值上。 绝对位置编码,也就是transformer中的位置embedding ?接下来到了Basic?Block中,根据每一个stage的层数来循环相应次数。 需要注意的是,这里与论文中不同,这里的Basic?Block是包括Swintransformer Block和Pach merging的。 跳入Swintransformer Block 跳入WindowAttention中 WindowAttention就是创新之处,由于VIT采取全局各个Token之间算attention,这样会导致,计算量非常大,所以这里采用窗的机制,我们将一幅图片划分为不同的窗,每个窗内自己的attention,即W-MSA 建立相对位置的索引表,具体什么意思。 我们知道,Transformer是利用绝对位置编码的,而这里采用相对位置编码。 就是公式中的B,包含着他的一个位置信息。那B到底如何计算,首先B的维度一定是与QKT的维度相同,因为两者要相加嘛,而QKT的维度如何计算,不用计算,QKT其实就是每两个像素点之间求attention,那么肯定是要遍历所有的像素点,并且包含顺序,也就是说是len*len,如下图 比如,我喜欢你,他的相对位置有如图四种情况,那应该如何将相对位置加到QKT中呢? 拉直,每一行以该行的值当作原点叠加 当然,这个加并不是直接加上去,而是我们有一个索引表,也就是刚刚建立的索引表,0找到0代表的值,-1找到-1的,以此类推。这就是相对位置编码的原理。 那在Swintransformer中是怎样设计的呢?继续看代码 这一部分就是计算相对位置的代码, 每一步对应下面的一张图,注意这里使用了广播机制来计算相对位置,也就是coords_flatten(:,:,None),加M-1是因为防止有负数,乘以2M-1是因为,每一行不同元素之间的attention值不能一样,就比如(1,2)和(2,1) 这就是相对位置编码。同时将索引注册为不用学习的参数。 计算Q,K,V同时再利用卷积完成通道数的变化,最后再给索引表初始化。 这就是一个WindowAttention 这就是一个MLP,也没啥好说的 接下来又到了非常重要的地方,前面我们说到,将图片划分为不同的窗,每个窗内计算attention,这样的话就会导致不同的窗之间没有关联,而实际上应该也是有关联的,所以就出现了SW-MSA 如何实现的呢?就是将窗整体右移window_size/2,也就是移动3,如图 如图,移动之后,我们可以看到左上角多出来两条,将其补刀右下角就如图所展示。 移完后,我们还是要去窗内计算attention,对于0中,完全可以实现,并且完成了对于W-MSA中不同窗之间计算attention,但对于1,2,3,4,5,6,7,8来说,却不可以,因为有我们补上去的,实际的这些像素并不在这里,我们不能计算他们之间的attention。 所以就需要mask,来防止两者之间计算mask 这里就是通过slice对图片进行划分并且标号,如我自己画的图那样 如图 ?这里之所以加-100,就是为了加到atten矩阵中时,使得不相关的块之间的值很小,这样经过softmax之后就很小了,也就是相关性非常低。 到这里一个Swintransformer模块介绍完了,就是窗不移和窗移连在一起。然后重复[2,2,6,2]次 然后就到 通过一个下采样,使得原本的4*4,变为8*8,相当于卷积中的池化,增大了感受野。 这就是一个stage,然后一共有四个stage循环四次,再经过 数据集采用多少分类,这里就输出多少。再对权值初始化,至此模型搭建完成。 之后,设置模型的优化函数,损失函数,分布式训练等等。 是否继续上次的训练,这个感觉一般都不会用到,暂且不用管。 开始训练,记录开始时间,跳到train_one_epoch中 然后其实就没什么可说的了,提取bach数据,送到模型出,输出预测值,计算损失和梯度,进行梯度累积,达到累计次数,梯度更新并梯度清零。然后就计算准确率啊,损失值啊,时间啊等等。 然后就,看看迭代次数是否达到设置的值,达到保存模型,dist.get_rank()==0是分布式训练中的,这里对分布式训练还不太了解,所以这里也不清楚。 之后就计算top1,top5准确率,计算总体时间,将想要的值通过logger写入日志。 至此,Swin-transformer中的原理及代码介绍完毕。 如有错误,欢迎批评指正!! ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? |
|
|
上一篇文章 下一篇文章 查看所有文章 |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 | -2024/11/27 6:23:36- |
|
网站联系: qq:121756557 email:121756557@qq.com IT数码 |