| |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
-> 人工智能 -> 双目深度算法——基于Transformer的方法(STTR) -> 正文阅读 |
|
[人工智能]双目深度算法——基于Transformer的方法(STTR) |
双目深度算法——基于Transformer的方法(STTR)双目深度算法——基于Transformer的方法(STTR)STTR是STereo TranformeR的缩写,原论文名为《Revisiting Stereo Depth Estimation From a Sequence-to-Sequence Perspectivewith Transformers》,发表于2021年,据我了解这应该是第一篇使用Transformer进行双目视差估计的方法,打破了基于Correlation或者Cost Volume进行视差估计的方法论,在论文的摘要中,作者提到该方法主要有三大优势:(1)解放了视差的限制;(2)明确定义了遮挡区域;(3)保证了匹配的唯一性,这篇文章实验做得非常充分,开源代码也些得很好,下面我们结合代码和文中的实验来详细学习下这篇论文。 1. 网络构架网络整体结果如下图所示,主要包括三部分:Feature Extractor、Transformer和Context Adjustment Layer,其中Feature Extractor主要用于特征提取,Transformer通过Attention计算视差,Context Adjustment Layer用于后处理。
其中backbone为encoder部分,tokenizer为decoder部分。 1.1 Feature ExtractorFeature Extractor主要分为Encoder和Decoder两部分,其中Encoder部分使用的是类似Hourglass的结构,在Decoder部分使用的是转置卷积和Dense Block,特征提取的网络结构就不在此展开,其主要作用就是从原始的图像输入中提取图像特征,特征图大小和原始图像大小相同,但是每个像素变成了一个长为 C e C_{e} Ce?的特征向量。 尽管论文中是讲,基于Transformer的网络结构没有视差的限制,但是由于特征提取使用的CNN网络,因此计算Self Attention和Cross Attention使用的特征向量还是从图像的一个局部区域(感受野)抽象出来的。 1.2 TransformerTransformer部分结构如下图所示: 本论文使用的是带残差的多头注意力机制,公式如下: Q h = W Q h e I + b Q h \mathcal{Q}_{h}=W_{\mathcal{Q}_{h}} e_{I}+b_{\mathcal{Q}_{h}} Qh?=WQh??eI?+bQh?? K h = W K h e I + b K h \mathcal{K}_{h}=W_{\mathcal{K}_{h}} e_{I}+b_{\mathcal{K}_{h}} Kh?=WKh??eI?+bKh?? V h = W V h e I + b V h \mathcal{V}_{h}=W_{\mathcal{V}_{h}} e_{I}+b_{\mathcal{V}_{h}} Vh?=WVh??eI?+bVh?? α h = softmax ? ( Q h T K h C h ) \alpha_{h}=\operatorname{softmax}\left(\frac{\mathcal{Q}_{h}^{T} K_{h}}{\sqrt{C_{h}}}\right) αh?=softmax(Ch??QhT?Kh??) V O = W O ?Concat? ( α 1 V 1 , … , α N h V N h ) + b O V_{\mathcal{O}}=W_{\mathcal{O}} \text { Concat }\left(\alpha_{1} \mathcal{V}_{1}, \ldots, \alpha_{N_{h}} \mathcal{V}_{N_{h}}\right)+b_{\mathcal{O}} VO?=WO??Concat?(α1?V1?,…,αNh??VNh??)+bO? e I = e I + V O e_{I}=e_{I}+\mathcal{V}_{\mathcal{O}} eI?=eI?+VO?其中 W Q h , W K h , W V h ∈ R C h × C h , b Q h , b K h , b V h ∈ R C h W_{\mathcal{Q}_{h}}, W_{\mathcal{K}_{h}}, W_{\mathcal{V}_{h}} \in \mathbb{R}^{C_{h} \times C_{h}}, b_{\mathcal{Q}_{h}}, b_{\mathcal{K}_{h}}, b_{\mathcal{V}_{h}} \in \mathbb{R}^{C_{h}} WQh??,WKh??,WVh??∈RCh?×Ch?,bQh??,bKh??,bVh??∈RCh?以及 W O ∈ R C e × C e , b O ∈ R C e W_{\mathcal{O}} \in \mathbb{R}^{C_{e} \times C_{e}}, b_{\mathcal{O}} \in \mathbb{R}^{C_{e}} WO?∈RCe?×Ce?,bO?∈RCe?,这就是普通的Attention计算公式,我们就不在此赘述,细节不清楚的同学可以参考计算机视觉算法——Transformer学习笔记,作者在实现Attention机制时是因为加入了相对位置编码和注意力掩膜,所以是继承原始pytorch中的MultiheadAttention类重新实现了下,但是这一部分基本的操作是保持不变的:
作者在补充材料中可视化了不同层的注意力分布结果,如下图所示: 1.2.1 Relative Positional Encoding为了保证算法在大范围无纹理区域也能够有合理的视差估计,作者提到需要在输入中加入与数据无关位置信息,即 e = e I + e p e=e_{I}+e_{p} e=eI?+ep?上式展开后可以获得 α i , j = e I , i T W Q T W K e I , j ? ( 1 ) ?data-data? + e I , i T W Q T W K e p , j ? ( 2 ) ?data-position? + e p , i T W Q T W K e I , j ? (3)?position-data? + e p , i T W Q T W K e p , j ? (4)?position-position? . \begin{gathered} \alpha_{i, j}=\underbrace{e_{I, i}^{T} W_{\mathcal{Q}}^{T} W_{\mathcal{K}} e_{I, j}}_{(1) \text { data-data }}+\underbrace{e_{I, i}^{T} W_{\mathcal{Q}}^{T} W_{K} e_{p, j}}_{(2) \text { data-position }}+ \\ \underbrace{e_{p, i}^{T} W_{\mathcal{Q}}^{T} W_{\mathcal{K}} e_{I, j}}_{\text {(3) position-data }}+\underbrace{e_{p, i}^{T} W_{\mathcal{Q}}^{T} W_{\mathcal{K}} e_{p, j}}_{\text {(4) position-position }} . \end{gathered} αi,j?=(1)?data-data? eI,iT?WQT?WK?eI,j???+(2)?data-position? eI,iT?WQT?WK?ep,j???+(3)?position-data? ep,iT?WQT?WK?eI,j???+(4)?position-position? ep,iT?WQT?WK?ep,j???.?我们注意到上式中第四项仅仅取决于像素点位置,而和图像信息完全无关了,是不必存在的一项,我们将第四项移除后得到最后的结果: α i , j = e I , i T W Q T W K e I , j ? ( 1 ) ?data-data? + e I , i T W Q T W K e p , i ? j ? (2)?data-position? + e p , i ? j T W Q T W K e I , j ? (3)?position-data? , \begin{aligned} \alpha_{i, j}=\underbrace{e_{I, i}^{T} W_{\mathcal{Q}}^{T} W_{\mathcal{K}} e_{I, j}}_{(1) \text { data-data }}+\\ \underbrace{e_{I, i}^{T} W_{\mathcal{Q}}^{T} W_{K} e_{p, i-j}}_{\text {(2) data-position }}+\underbrace{e_{p, i-j}^{T} W_{\mathcal{Q}}^{T} W_{\mathcal{K}} e_{I, j}}_{\text {(3) position-data }}, \end{aligned} αi,j?=(1)?data-data? eI,iT?WQT?WK?eI,j???+(2)?data-position? eI,iT?WQT?WK?ep,i?j???+(3)?position-data? ep,i?jT?WQT?WK?eI,j???,?在具体实现时,就是在计算完 α h \alpha_{h} αh?后加上Positional Encoding
在论文的Ablation实验中,作者对比了有Positional Encoding和没有Positional Encoding的特征图的区别: 但在这里我有疑惑的一点的是,在ViT论文的实验中,发现Absolute Positional Encoding和Relative Positional Encoding最后的结果是差别不大,在本文论中,也并没有对比作者提出的Relative Positional Encoding相对简单的Absolute Positional Encoding的区别,这一点值得考究。 1.2.2 Optimal TransportOptimal Transport算法是一种可微分的二分图匹配算法,这一部分的内容作者是完全参考SuperGlue的做法,包括Dustbin的设置(本文在匹配到Dustbin的数据上加了一个可学习参数 ? \phi ?),这里感兴趣的同学可以参考视觉SLAM总结——SuperPoint / SuperGlue 1.2.3 Attention MaskAttention Mask只在最后一层Cross Attention中用到了,其基本原理如下图所示:
然后在最后一层Cross Attention上将生成的mask加到原始的Attention结果上
其中就是在不需要计算的区域加上一个负无穷的值,这样在Softmax后这些区域的权重就自然变成了0,在论文的Ablation实验中,这个模块也可以带来精度的提高。 1.2.4 Raw Disparity and Occlusion Regression在基于Correlation和Cost Volume的方法中,求视差通常采用的是加权平均,而本文由于采用的是匹配的方法,因此最终的结果是通过Modified Winner-Take-All的方式获得的,如何Modified的呢?作者先从分配矩阵 T \mathcal{T} T中取得一个 3 × 3 3 \times 3 3×3的Patch进行归一化 t ~ l = t l ∑ l ∈ N 3 ( k ) t l , ?for? l ∈ N 3 ( k ) \tilde{t}_{l}=\frac{t_{l}}{\sum_{l \in \mathcal{N}_{3}(k)} t_{l}}, \text { for } l \in \mathcal{N}_{3}(k) t~l?=∑l∈N3?(k)?tl?tl??,?for?l∈N3?(k)然后使用归一化的结果对视差进行加权平均 d ~ r a w ( k ) = ∑ l ∈ N 3 ( k ) d l t ~ l \tilde{d}_{r a w}(k)=\sum_{l \in \mathcal{N}_{3}(k)} d_{l} \tilde{t}_{l} d~raw?(k)=l∈N3?(k)∑?dl?t~l?这个 3 × 3 3 \times 3 3×3的Patch也正好代表着网络对于当前匹配结果的一个确定程度,因此对其求反就可以作为一个被遮挡的概率 p o c c ( k ) = 1 ? ∑ l ∈ N 3 ( k ) t l . p_{o c c}(k)=1-\sum_{l \in \mathcal{N}_{3}(k)} t_{l} . pocc?(k)=1?l∈N3?(k)∑?tl?.但是这里求得的视差和遮挡概率都是最原始的结果,后面还需要通过Context Adjustment Layer进行Refine。 1.3 Context Adjustment LayerContext Adjustment Layer的主要目的是对输出的视差和遮挡不确定度进行Refine,其结构如下图所示: 2. 损失函数STTR算法中损失函数构建得相对复杂,一共由四部分组成,如下:
L
=
w
1
L
r
r
+
w
2
L
d
1
,
r
+
w
3
L
d
1
,
f
+
w
4
L
b
e
,
f
L=w_{1} L_{r r}+w_{2} L_{d 1, r}+w_{3} L_{d 1, f}+w_{4} L_{b e, f}
L=w1?Lrr?+w2?Ld1,r?+w3?Ld1,f?+w4?Lbe,f?其中 第二部分 L d 1 , r L_{d 1, r} Ld1,r?和 L d 1 , f L_{d 1, f} Ld1,f?分别为原始视差和Refine后的视差的L1损失。 第三部分 L b e , f L_{\mathrm{be}, f} Lbe,f?为判断是否为遮挡区域的二分类交叉熵损失。 3. 实验及实验结果为了减小内存占用,作者采用的是Gradient Checkpointing技术,该技术是使得一些临时变量在前向推导时不进行存储,而时在反向导数传递时重新计算,这样就是用时间换空间,在训练大型模型时是一种常见的处理手段,具体实现其实就是使用了torch.utils.checkpoint这个库,如下所示:
在算法对比部分,作者显示对比了模型的泛化性,作者将所有模型在Sense Flow数据集上训练,然后直接在KITTI数据集上推理,结果如下图所示: |
|
|
上一篇文章 下一篇文章 查看所有文章 |
|
开发:
C++知识库
Java知识库
JavaScript
Python
PHP知识库
人工智能
区块链
大数据
移动开发
嵌入式
开发工具
数据结构与算法
开发测试
游戏开发
网络协议
系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程 数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁 |
360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 | -2024/12/29 8:34:49- |
|
网站联系: qq:121756557 email:121756557@qq.com IT数码 |
数据统计 |