参考代码:stereo-transformer
1. 概述
导读:这篇文章通过transformer机制实现了一种立体匹配算法(STTR),在该方法中将立体匹配问题转换为序列上的响应问题,使用未知信息编码与注意力机制替换了传统匹配方法中的cost volume策略。由于替换了cost volume解除了预定max-disparity假设的限制,增强了网络的泛化表达能力。在估计视差图的同时显示地估计遮挡区域的概率结果。此外,为了寻找右视图到左视图的最佳匹配,文中对其中的匹配矩阵添加熵约束,从而实现对匹配过程的添加唯一性约束。
将文章的方法(STTR)与correlation-based和3D convolution-based方法进行比较,可以归纳为:
- 1)STTR与correlation-based方法比较:STTR在进行左右视图匹配的时候通过self-attention和cross-attention建立相关性,并且对像素点匹配的结果进行唯一性约束;
- 2)STTR与3D convolution-based方法比较:STTR通过attention建立像素之间的关联,而不是通过设定max-disparity的形式建立cost volume;
文章的方法在下面几个数据集下的结果:
2. 方法设计
2.1 pipline
文章提出的方法pipeline见下图所示: 在上图中可以看到左右两视图经过一个共享backbone和tokenizer进行特征抽取,这部分的实现可以参考:
class SppBackbone(nn.Module):
…
class Tokenizer(nn.Module):
…
经过上面两个过程对特征进行抽取得到的是channel为
C
e
C_e
Ce?,空间分辨率与原输入尺度
(
I
h
,
I
w
)
(I_h,I_w)
(Ih?,Iw?)一致的特征图。之后这些特征图便与位置编码组合经过
N
N
N层的attention操作预测得到粗预测结果,之后再改结果的基础上进行refine得到最后的结果。
2.2 Transformer操作
文章提出的transformer结构可见下图: 在上图中可以看到其中首先会经过几个attention层(带position encoding),之后 特征经过带mask的cross-attention得到最后优化的特征。
2.2.1 attention操作
这里采用的attention操作是multi-head attention,可以参考pytorch的实现nn.MultiheadAttention 。这里会将特征图在channel维度进行分组操作,对不同的分组进行运算从而增强特征的表达的能力,对于组的划分可以描述为:
C
h
=
C
e
N
h
C_h=\frac{C_e}{N_h}
Ch?=Nh?Ce??,其中分母是划分的组的数量。在每个分组中会产生回应的query、key和value向量,其分别表示为:
Q
h
=
W
Q
h
e
I
+
b
Q
h
Q_h=W_{Q_h}e_I+b_{Q_h}
Qh?=WQh??eI?+bQh??
K
h
=
W
K
h
e
I
+
b
K
h
K_h=W_{K_h}e_I+b_{K_h}
Kh?=WKh??eI?+bKh??
V
h
=
W
V
h
e
I
+
b
V
h
V_h=W_{V_h}e_I+b_{V_h}
Vh?=WVh??eI?+bVh?? 有了计算权重的所需的变量之后,结下来就是计算加权组合因子了:
α
h
=
s
o
f
t
m
a
x
(
Q
h
T
K
h
C
h
)
\alpha_h=softmax(\frac{Q_h^TK_h}{\sqrt{C_h}})
αh?=softmax(Ch?
?QhT?Kh??) 接下来就是对之前划分出来的组进行组合:
V
o
=
W
o
C
o
n
c
a
t
(
α
1
V
1
,
…
,
α
h
V
h
)
+
b
o
V_o=W_oConcat(\alpha_1V_1,\dots,\alpha_hV_h)+b_o
Vo?=Wo?Concat(α1?V1?,…,αh?Vh?)+bo? 最后与原特征进行相加得到增强之后的特征:
e
I
=
e
I
+
V
o
e_I=e_I+V_o
eI?=eI?+Vo? 在文中计算上述attention是采用两种方式:
- 1)self-attention:在该计算过程中所有attention操作所需的
Q
h
,
K
h
,
V
h
Q_h,K_h,V_h
Qh?,Kh?,Vh?都是来自于同一视图产生的特征,这部分可以参考:
class TransformerSelfAttnLayer(nn.Module):
…
- 2)cross-attention:在这个操作中
Q
h
Q_h
Qh?是来自于source图像产生的特征,而
K
h
,
V
h
K_h,V_h
Kh?,Vh?是来自于target图像产生的特征。需要注意的是这里的source和target是相对的,相当于在计算cross-attention的过程中是会进行交换的,实现双向计算。其实现可以参考:
class TransformerCrossAttnLayer(nn.Module):
…
2.2.2 position encoding
在上面的多层attention过程中描述了像素与像素之间的关系,但是对于那些弱纹理甚至是无纹理区域的处理就变得比较困难了。对此文章为这些点通过建立相邻点(特别是那些诸如边缘点的显著性特征)的联系,优化对于弱纹理区域的适应能力,因而这里就使用到了用于相对位置建模的position encoding,其实现可以参考:
class PositionEncodingSine1DRelative(nn.Module):
…
则上一节中讲到的attention权重经过position encoding的重新编码可以得到下面的权值组合形式: 对于这部分的实现可以参考:
2.2.3 attention mask
在经过多层attention操作之后,已经可以构架出左视图和右视图上每个像素的对应关系了,但为了排除一些无关干扰,文章通过建立下三角mask的形式去约束对应点的位置,这部分的计算描述为:
def _generate_square_subsequent_mask(self, sz: int):
…
2.2.4 Optimal Transport
在进行特征匹配的时候为了右视图中的像素能被对应到左视图中最匹配的像素,文章对匹配矩阵
T
\mathcal{T}
T添加了约束,也就是上文中提到的唯一性约束。其是在匹配矩阵的基础上添加熵正则化,可以描述为:
T
=
arg?min
?
T
∈
R
+
I
w
?
I
h
∑
i
,
j
=
1
I
w
,
I
h
T
i
j
M
i
j
?
γ
E
(
T
)
\mathcal{T}=\argmin_{\mathcal{T}\in R_{+}^{I_w*I_h}}\sum_{i,j=1}^{I_w,I_h}\mathcal{T}_{ij}M_{ij}-\gamma E(\mathcal{T})
T=T∈R+Iw??Ih??argmin?i,j=1∑Iw?,Ih??Tij?Mij??γE(T)
s
.
t
.
?
T
1
I
w
=
a
,
T
1
I
h
=
b
s.t.\ \mathcal{T}1I_w=a,\mathcal{T}1I_h=b
s.t.?T1Iw?=a,T1Ih?=b 其中,上式子中
M
M
M是边缘分布
a
,
b
a,b
a,b的代价矩阵,其长度为
I
w
I_w
Iw?。这部分的实现可以参考:
def _optimal_transport(self, attn: Tensor, iters: int):
…
2.3 视差和遮挡mask预测
2.3.1 第一阶段raw预测
在上述内容中得到了左右视图之间的像素匹配矩阵
T
\mathcal{T}
T,那么去寻找视图间最佳匹配的方式可以是硬性直接argmax的,也可以是在一定窗口内软性操作的。文章中采取的就是第二种方式,在匹配到的最佳位置
k
k
k处采用一个大小为3的窗口
N
3
(
k
)
N_3(k)
N3?(k),之后使用这个窗口内的归一化加权值作为最后的预测结果:
t
l
ˉ
=
t
l
∑
l
∈
N
3
(
k
)
t
l
,
f
o
r
?
l
∈
N
3
(
k
)
\bar{t_l}=\frac{t_l}{\sum_{l\in N_3(k)}t_l},for\ l\in N_3(k)
tl?ˉ?=∑l∈N3?(k)?tl?tl??,for?l∈N3?(k)
d
ˉ
r
a
w
(
k
)
=
∑
l
∈
N
3
(
k
)
d
l
t
l
ˉ
\bar{d}_{raw}(k)=\sum_{l\in N_3(k)}d_l\bar{t_l}
dˉraw?(k)=l∈N3?(k)∑?dl?tl?ˉ? 在匹配矩阵的情况下对于遮挡区域的像素的概率描述为:
p
o
c
c
(
k
)
=
1
?
∑
l
∈
N
3
(
k
)
t
l
p_{occ}(k)=1-\sum_{l\in N_3(k)}t_l
pocc?(k)=1?l∈N3?(k)∑?tl?
2.3.2 第二阶段预测
在上一个阶段中已经在一个极线(epipolar line)上预测得到初始视差和遮挡图,但是却缺少跨越多个极线的上下文信息(来自于多个实例),因而在第二阶段预测中使用CNN网络对于这些信息进行编码,来使得整体pipeline能够基于输入的图像和网络在多个极线的编码信息生成对应的预测结果,其中对于遮挡区域的预测网络为: 对于视差部分的预测网络为: 上述的几个结构对于最后性能的影响:
2.4 损失函数
在得到匹配矩阵
T
\mathcal{T}
T之后,经过与GT进行比较就可以得到匹配的像素
M
\mathcal{M}
M和不匹配的像素
U
\mathcal{U}
U。则raw预测阶段匹配部分的损失函数为:
L
r
r
=
1
N
M
∑
i
∈
M
?
l
o
g
(
t
i
?
)
+
1
N
U
∑
i
∈
U
?
l
o
g
(
t
i
,
?
)
L_{rr}=\frac{1}{N_{\mathcal{M}}}\sum_{i\in \mathcal{M}}-log(t_i^{*})+\frac{1}{N_{\mathcal{U}}}\sum_{i\in\mathcal{U}}-log(t_{i,\phi})
Lrr?=NM?1?i∈M∑??log(ti??)+NU?1?i∈U∑??log(ti,??) 其中,
t
i
?
=
i
n
t
e
r
p
(
T
i
,
p
i
?
d
g
t
,
i
)
t_i^{*}=interp(\mathcal{T}_{i,p_i}-d_{gt,i})
ti??=interp(Ti,pi???dgt,i?)。则整个损失函数描述为:
L
=
w
1
L
r
r
,
r
a
w
+
w
2
L
d
1
,
r
a
w
+
w
3
L
d
1
,
f
i
n
a
l
+
w
4
L
b
e
,
f
i
n
a
l
L=w_1L_{rr,raw}+w_2L_{d1,raw}+w_3L_{d1,final}+w_4L_{be,final}
L=w1?Lrr,raw?+w2?Ld1,raw?+w3?Ld1,final?+w4?Lbe,final? 其中,
d
1
d1
d1代表的是smooth L1损失,
c
e
ce
ce是二值较差熵损失。
3. 实验结果
|