| |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
-> 人工智能 -> MobileVit代码解析 -> 正文阅读 |
|
[人工智能]MobileVit代码解析 |
MobileVit代码逐行解析代码链接:非官方实现 1.1导入所需模块
1.2 MobileNetv2 Block解析
以mobilevit_s()为例,其中channels = [16, 32, 64, 64, 96, 128, 160, 640] MV2Block(channels[0],channels[1],1) 则为 MV2Block(16,32,1) MV2Block(channels[1],channels[2],2) 则为 MV2Block(32,64,2) 1.3 Mobile Vit Block解析
其中首先经过两次卷积y=self.conv2(self.conv1(x))获得局部信息表示,且这两次卷积不会改变特征图尺寸,但将通道映射到了高维空间dim中。对应文章的该段文字。 将形状为 [ b s , d i m , h , w ] [bs,dim,h,w] [bs,dim,h,w]的y进行重组,其中 n h × p h = h nh \times ph=h nh×ph=h 和 n w × p w = w nw \times pw=w nw×pw=w 重组后y的形状为
[
b
s
,
P
,
N
,
d
i
m
]
[bs,P,N,dim]
[bs,P,N,dim],其中
P
=
p
h
×
p
w
P = ph \times pw
P=ph×pw 和
N
=
n
h
×
n
w
N=nh \times nw
N=nh×nw ,这里的P相当于每个patch的所有像素向量集,N相当于Patch数目,对应该片段的前半部分: X G ( p ) = \mathbf{X}_{G}(p)= XG?(p)= Transformer ( X U ( p ) ) , 1 ≤ p ≤ P \left(\mathbf{X}_{U}(p)\right), 1 \leq p \leq P (XU?(p)),1≤p≤P Transformer的结构与代码下节再做分析,只需要知道做完Transformer后,张量的维度仍然是 [ b s , P , N , d i m ] [bs,P,N,dim] [bs,P,N,dim],未改变。 随后将y重整为图片格式,经 再将维度进行重排成 [ b s , d i m , n h ? p h , n w ? p w ] [bs,dim,nh*ph,nw*pw] [bs,dim,nh?ph,nw?pw],其中ph,pw是自定义的patch的高和宽,N=nh*nw, n h ? p h nh*ph nh?ph则为图像的高h, n w ? p w nw*pw nw?pw为图像的宽w。 [ b s , d i m , n h ? p h , n w ? p w ] [bs,dim,nh*ph,nw*pw] [bs,dim,nh?ph,nw?pw]则为 [ b s , d i m , h , w ] [bs,dim,h,w] [bs,dim,h,w] 之所以要把dim放前面,是为了满足pytorch中图像tensor的格式为 [ B , C , H , W ] [B,C,H,W] [B,C,H,W] 之后经y=self.conv3(y),将 [ b s , d i m , h , w ] [bs,dim,h,w] [bs,dim,h,w]映射回指定通道in_channel的特征图 [ b s , i n c h a n n e l , h , w ] [bs,inchannel,h,w] [bs,inchannel,h,w] 之后经y=torch.cat([x,y],1),y=self.conv4(y) 将通道还原到输入x的inchannel数目上。 总的来看MobileViTAttention不会改变图片的大小,也就是不会进行下采样,同时也不会改变通道数。 下采样和通道数的变化发生在MobileNetv2 Block中。 1.4 Transformer解析
Tranformer的相关定义如上,其结构如下图所示,在实现结构上和图的顺序略有不同,图中顺序是先LNorm再做MSA,但是代码顺序是先MSA,再LNorm。
其中query向量,key向量和value向量由下两句产生,先用线性层生成总维度为 h e a d s × h e a d d i m × 3 heads \times head_dim \times 3 heads×headd?im×3 的向量,随后按最后一个维度,切分成3块。
由上述分析 输入x的维度为 [ b s , P , N , d i m ] [bs,P,N,dim] [bs,P,N,dim],其中 P = p h × p w P = ph \times pw P=ph×pw 和 N = n h × n w N=nh \times nw N=nh×nw 经过qkv=self.to_qkv(x).chunk(3,dim=-1)后,qkv是一个包含3个元素的元组,且每个元素的维度为 [ b s , P , N , i n n e r d i m ] [bs,P,N,innerdim] [bs,P,N,innerdim],其中 i n n e r d i m = h e a d s × h e a d d i m innerdim=heads \times headdim innerdim=heads×headdim
Attention ? ( Q , K , V ) = Softmax ? ( Q K T d k ) V \operatorname{Attention}(Q, K, V)=\operatorname{Softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V Attention(Q,K,V)=Softmax(dk??QKT?)V 对应以下几行代码;
其中k.transpose(-1,-2)后的维度为 [ b s , P , h e a d s , h e a d d i m , N ] [bs,P,heads,headdim,N] [bs,P,heads,headdim,N],再与q做矩阵乘法后,dots的维度为 [ b s , P , h e a d s , N , N ] [bs,P,heads,N,N] [bs,P,heads,N,N], 之后再与value向量做矩阵乘法,out维度为 [ b s , P , h e a d s , N , h e a d d i m ] [bs,P,heads,N,headdim] [bs,P,heads,N,headdim], 刚拿到out时,需要将out维度先还原到 [ b s , P , N , i n n e r d i m ] [bs,P,N,innerdim] [bs,P,N,innerdim],对应代码out=rearrange(out,‘b p h n d -> b p n (h d)’) , 之后再通过线性层将out的维度映射回原来的输入维度 [ b s , P , N , d i m ] [bs,P,N,dim] [bs,P,N,dim],用于后续计算与将patch还原成image 。 1.5 总结MobileViT的结构就是通过上述模块的堆叠,最后通过卷积池化全连接层作用到图像分类任务中,也可以不做全连接,用于到其余高阶任务中。 |
|
|
上一篇文章 下一篇文章 查看所有文章 |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 | -2025/1/6 17:09:18- |
|
网站联系: qq:121756557 email:121756557@qq.com IT数码 |