CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows
论文:https://arxiv.org/abs/2107.00652 代码:https://github.com/microsoft/CSWin-Transformer
架构
CSWin Transformer的架构如下图所示。
- 输入图像为
H
×
W
×
3
H×W×3
H×W×3
- 利用重叠卷积 token 嵌入(7 × 7卷积层,步幅为4))得到
H
/
4
×
W
/
4
H/4×W/4
H/4×W/4 patch tokens,每个 token 的维数为 C。
- 为了产生一个层次表示,整个网络由四个阶段组成。
- 构建的特征图在第
i
i
i 阶段具有
H
2
i
+
1
×
W
2
i
+
1
\frac{H}{ 2^{i+1}}× \frac{W}{ 2^{i+1}}
2i+1H?×2i+1W? 个 token,每个阶段由
N
i
N_i
Ni? 个 CSWin Transformer Block 组成。
- 在相邻阶段之间使用一个卷积层(3 × 3,步幅2),以减少 token 数量,并使通道维数加倍。
CSWin Transformer Block 与 Multi-head self-attention(MSA) Transformer Block 的不同之处:
- 使用十字形窗口自注意机制(CSWin SA) 取代了自注意机制(MSA);
- 为了引入局部感应偏置,将 LePE 作为一个并联模块加入到自注意分支。
CSWin Transformer Block 的表达式:
X
^
l
=
CSWin-Attention(LN
(
X
l
?
1
)
)
+
X
l
?
1
X
l
=
MLP(LN
(
X
^
l
)
)
+
X
^
l
\hat X^l= \text{CSWin-Attention(LN}(X^{l?1}))+ X^{l?1} \\ X^l= \text{MLP(LN}(\hat X^l))+\hat X^l
X^l=CSWin-Attention(LN(Xl?1))+Xl?1Xl=MLP(LN(X^l))+X^l
Cross-Shaped Window Self-Attention
尽管全局 MSA 具有较强的远程上下文建模能力,但全局 MSA 的计算复杂度是特征图大小的二次。因此,对于以高分辨率特征图为输入的视觉任务,如目标检测和分割,将会面临巨大的计算成本。为了缓解这一问题,SWin Transformer 等建议在局部注意窗口中进行 self-attention(简称为 局部 MSA),并应用移位窗口(或者halo)来扩大感受野。但是,每个Transformer块中的 token 仍然只有有限的注意区域,需要叠加更多的块来实现全局感受野。为了扩大注意区域,更有效地实现全局自我注意,我们提出了十字形窗口SA(类似于高斯滤波),它通过平行的水平和垂直条纹来实现 SA,形成十字形窗口。
水平和垂直条纹。根据 MSA,输入特征
X
X
X 线性投影到 K 个 head 上,然后每个 head 在水平或垂直条纹内进行局部 MSA。对于水平条纹SA,将输入特征
X
X
X 均匀划分为不重叠等宽的水平条纹
[
X
1
,
.
.
.
,
X
M
]
[X^1, ..., X^M]
[X1,...,XM] ,宽度为 sw (可以调整以平衡学习能力和计算复杂度),每个都包含 sw × W 个 token。形式上,假设
k
t
h
k^{th}
kth head 的投影 Q、K、V 的维度都是
d
k
d_k
dk?,那么
k
t
h
k^{th}
kth head 的水平条纹 SA 输出定义为:
X
=
[
X
1
,
X
2
,
.
.
.
,
X
M
]
,
where
X
i
∈
R
(
s
w
×
W
)
×
C
and
M
=
H
/
s
w
X = [X^1, X^2, . . . , X^M], \quad \text{where} \quad X^i∈\mathbb R^{(sw×W)×C} \quad \text{and} \quad M = H/sw
X=[X1,X2,...,XM],whereXi∈R(sw×W)×CandM=H/sw
Y
k
i
=
Attention
(
X
i
W
k
Q
,
X
i
W
k
K
,
X
i
W
k
V
)
,
where
i
=
1
,
.
.
.
,
M
Y^i_k= \text{Attention}(X^iW^Q_k, X^iW^K_k, X^iW^V_k), \quad \text{where} \quad i = 1, . . . , M
Yki?=Attention(XiWkQ?,XiWkK?,XiWkV?),wherei=1,...,M
H-Attentionk
(
X
)
=
[
Y
k
1
,
Y
k
2
,
.
.
.
,
Y
k
M
]
\text{H-Attentionk}(X) = [Y^1_k, Y^2_k, . . . , Y^M_k]
H-Attentionk(X)=[Yk1?,Yk2?,...,YkM?]
其中,
W
k
Q
∈
R
C
×
d
k
,
W
k
K
∈
R
C
×
d
k
,
W
k
V
∈
R
C
×
d
k
W^Q_k∈\mathbb R^{C×d_k}, W^K_k∈\mathbb R^{C×d_k}, W^V_k∈\mathbb R^{C×d_k}
WkQ?∈RC×dk?,WkK?∈RC×dk?,WkV?∈RC×dk? 分别表示
k
t
h
k^{th}
kth head的 Q、K、V 的投影矩阵,
d
k
d_k
dk? 设为 C/ k。垂直条纹的自注意也可以用类似的方法导出,其对
k
t
h
k^{th}
kth head的输出表示为 V-Attentionk(X)。
假设自然图像没有方向偏差,我们将K个 head 平均分成两个平行组。第一组的 head 表现出水平条纹SA ,第二组的 head 表现出垂直条纹SA。最后,这两个并行组的输出将被连接在一起。
head
k
=
{
H-Attentionk
(
X
)
k
=
1
,
.
.
.
,
K
/
2
V-Attentionk
(
X
)
k
=
K
/
2
+
1
,
.
.
.
,
K
\text{head}_k=\left\{\begin{array}{l} \text{H-Attentionk}(X) & k = 1, . . . , K/2 \\ \text{V-Attentionk}(X) & k = K/2 + 1, . . . , K \end{array}\right.
headk?={H-Attentionk(X)V-Attentionk(X)?k=1,...,K/2k=K/2+1,...,K?
CSWin-Attention
(
X
)
=
Concat
(
head
1
,
.
.
.
,
head
K
)
W
O
\text{CSWin-Attention}(X) = \text{Concat}(\text{head}_1, ...,\text{head}_K)W^O
CSWin-Attention(X)=Concat(head1?,...,headK?)WO
其中
W
O
∈
R
C
×
C
W^O∈\mathbb R^{C×C}
WO∈RC×C 为常用的投影矩阵,将SA的结果投射到目标输出维(默认设为C)。如上所述,SA 设计的一个关键洞察力是将多个 head 分成不同的组,并相应地应用不同的SA操作。换句话说,通过多 head 分组,一个Transformer块中每个token的注意区域被扩大。相比之下,现有的SA在不同的多头上应用相同的SA操作。在实验部分,我们将证明这种设计将带来更好的性能。
计算复杂性分析。CSWin self-attention的计算复杂度为:
?
(
CSWin-Attention
)
=
H
W
C
?
(
4
C
+
s
w
?
H
+
s
w
?
W
)
?(\text{CSWin-Attention}) = HW C ? (4C + sw ? H + sw ? W)
?(CSWin-Attention)=HWC?(4C+sw?H+sw?W)
对于高分辨率输入,考虑到H、W在早期阶段会大于C,在后期阶段会小于C,我们选择早期的小sw,后期的大sw。换句话说,调整sw提供了灵活性,以便在后期有效地扩大每个代币的注意区域。另外,为使224 × 224输入时的中间feature map大小能被sw除,我们默认四个阶段sw分别设为1、2、7、7。
位置编码
由于SA操作是置换不变的,它将忽略二维图像中的重要位置信息。为了添加这些信息,不同的位置编码机制已经在现有的视觉 Transformer 中使用。在图3中,我们展示了一些典型的位置编码机制,并将它们与我们提出的局部增强位置编码进行了比较。具体来说,APE和CPE在输入令牌中添加位置信息,然后再输入到Transformer块中,而RPE和我们的LePE在每个Transformer块中添加位置信息。但是不同于在注意力计算中添加位置信息的RPE(即Softmax(QKT)),我们考虑一种更直接的方式,并将位置信息强加到线性投影值上。设值元素
v
i
v_{i}
vi? 和
v
j
v_{j}
vj? 之间的边由向量
e
i
j
V
∈
E
e^V_{ij}∈E
eijV?∈E 表示,则
Attention
(
Q
,
K
,
V
)
=
SoftMax
(
Q
K
?
d
)
V
+
E
V
\text{Attention}(Q, K, V ) = \text{SoftMax}(\frac{QK^{\top}}{\sqrt d})V + EV
Attention(Q,K,V)=SoftMax(d
?QK??)V+EV
但是,如果我们考虑E中的所有连接,将需要巨大的计算成本。我们假设,对于一个特定的输入元素,最重要的位置信息来自它的局部邻域。因此,我们提出了局部增强位置编码(LePE),并通过对值V应用深度卷积算子[10]来实现:
Attention
(
Q
,
K
,
V
)
=
SoftMax
(
Q
K
?
d
)
V
+
DWConv
(
V
)
\text{Attention}(Q, K, V ) = \text{SoftMax}(\frac{QK^{\top}}{\sqrt d})V + \text{DWConv}(V )
Attention(Q,K,V)=SoftMax(d
?QK??)V+DWConv(V)
通过这种方式,LePE可以友好地应用于以任意输入分辨率作为输入的下游任务。
结果
|