Swin-Transformer
ViT
ViT使用纯Transformer结构来做图像分类任务,它开创了Transformer能够在CV领域有效工作的先河。ViT验证了在大规模数据集上进行预训练,然后迁移到小规模数据集上,Transformer性能要比CNN好。由于缺少CNN自带的归纳偏置(平移不变形和局部性),ViT在ImageNet数据集(中型数据集)上表现没有CNN好,Transformer需要充足的图像数据学习。
我们以ViT的base模型为例来描述ViT的流程。Transformer结构不能直接处理图像,首先需要将2D的图像分块(patch),CV中的patch可以近似看做NLP中的token,每块的大小为
P
?
P
?
C
P*P*C
P?P?C。假设一个大小为
224
?
224
?
3
224*224*3
224?224?3的图像,每块的大小为
16
?
16
?
3
16*16*3
16?16?3,那么此张图片将有
224
16
?
224
16
=
14
?
14
=
196
\frac{224}{16}*\frac{224}{16}=14*14=196
16224??16224?=14?14=196个块。图像预处理将一个2D的
224
?
224
?
3
224*224*3
224?224?3的图像展平为1D的
196
?
768
196*768
196?768大小的向量。接下来,进行图像块嵌入(类似于NLP中的词嵌入),就是ViT论文中的
E
E
E,
E
E
E的维度是
768
?
768
768*768
768?768。映射后的向量维度仍然为
196
?
768
196*768
196?768。类似于BERT中的[class] token,ViT中加入了一个可以学习的嵌入,如下图中的第0位置,它经过Transformer 编码器后的输出作为图像表示
y
y
y,用于分类。就这样,嵌入向量就由
196
?
768
196*768
196?768变为
197
?
768
197*768
197?768。为了保持输入图像块之间的空间位置信息,对映射后的向量添加了一个位置编码信息,如下图一中的0-9数字。位置编码采用的是1-D的可学习嵌入变量,论文中实验验证2-D的位置编码和1-D的位置编码结果近似。
图一:ViT示意图
Swim Transformer
Swim Transformer是特为视觉领域设计的一种分层Transformer结构。Swin 的两大特性是滑动窗口和分层表示。滑动窗口在局部不重叠的窗口中计算自注意力,并允许跨窗口连接。分层结构允许模型适配不同尺度的图片,并且计算复杂度与图像大小呈线性关系。
ViT只能够做分类,Swin Transformer借鉴了CNN的分层结构,如下图二(a),不仅能够做分类,还能够和CNN一样扩展到下游任务,比如检测,分割等。Swim Transformer不同于标准的Transformer结构,它计算不重叠窗口中的自注意力。为了解决窗口和窗口之间无连接的问题,Swin提出了移位窗口分割方法,见下图二(b),W-MSA和SW-MSA在连续的Swin Transformer blocks中交替出现,见下图二?。因此不论哪个Swim Transformer版本,都有偶数个blocks。
下图二(d)展示了Swin Transformer的tiny版本(Swin-T)。首先,它通过一个patch分割模块将输入的RGB图像分割成不重叠的patches,每个patch被看做是一个“token”,在论文中,patch size大小为
4
×
4
4 \times 4
4×4,每个patch的特征维度为
4
×
4
×
3
=
48
4 \times 4 \times 3 = 48
4×4×3=48。对于一个
H
×
W
H \times W
H×W大小的RGB图像,经过patch分割模块之后表示为
H
4
×
W
4
×
48
\frac{H}{4} \times \frac{W}{4} \times 48
4H?×4W?×48。紧接着一个线性嵌入层将此原始值特征映射为一个任意的维度,记为
C
C
C。Swin Transformer block 应用到这些patch token上。线性映射加上Swin Transformer block,被称为“Stage 1”。为了得到分层表示,随着网络层数的加深,token的数量通过patch merging layers减少。第一个patch merging layer层连接每组
2
×
2
2 \times 2
2×2相邻patches的特征,然后在维度为
4
C
4C
4C的连接特征上应用线性层降维到
2
C
2C
2C。“Stage 2”,“Stage 3”和“Stage 4”由patch merging layer和Swin Transformer block组成,因此每个阶段的尺寸减少
2
2
2倍,维度增大
2
2
2倍,以至于“Stage 4”的输出特征为
H
32
×
W
32
×
8
C
\frac{H}{32} \times \frac{W}{32} \times 8C
32H?×32W?×8C 。 图二:Swin Transformer 架构
Patches & Windows
图三:patches 和windows
一张
H
×
W
H \times W
H×W大小的图中,里面包含
H
×
W
H \times W
H×W个像素。一个patch就是图像中的
N
×
N
N \times N
N×N个像素区域;一个window是由
M
×
M
M \times M
M×M个patches组成的。由上图所示,图像被分成
4
4
4个窗口,每个窗口包含
4
×
4
=
16
4 \times 4 =16
4×4=16个patches。假设每个patch的大小为
4
×
4
4 \times 4
4×4,则每个patch的向量维度为
4
×
4
×
3
=
48
4 \times 4 \times 3 = 48
4×4×3=48。每个patch可以看做NLP中的“token”,仿照NLP的词嵌入,将patch映射为维度为
C
C
C的向量。
下述代码展示了如何将图像如何进行patch嵌入。假设一张
224
×
224
×
3
224 \times 224 \times 3
224×224×3的图片,patch size大小为
4
×
4
4 \times 4
4×4,经过一个卷积层(第26行代码)之后的输出shape为
(
B
,
C
,
56
,
56
)
(B, C, 56,56)
(B,C,56,56),展平后两项,并对换后两项的位置,最后嵌入的输出为
(
B
,
56
?
56
,
C
)
(B,56*56,C)
(B,56?56,C)
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
if self.norm is not None:
x = self.norm(x)
return x
patch merging layers
patch merging layers是Swim Transformer分层结构的重要组件。它连接每组
2
×
2
2 \times 2
2×2相邻patches的特征,然后在维度为
4
C
4C
4C的连接特征上应用线性层降维到
2
C
2C
2C。下图四展示了patch merging layers如何将一个
h
×
w
×
1
h \times w \times 1
h×w×1的特征如何转换为
h
2
×
w
2
×
4
\frac{h}{2} \times \frac{w}{2} \times 4
2h?×2w?×4。将$h \times w
特
征
特征
特征x$划分为大小为
2
×
2
2 \times 2
2×2的组,提取每组相同位置的特征得到
x
0
,
x
1
,
x
2
,
x
3
x_{0}, x_{1}, x_{2},x_{3}
x0?,x1?,x2?,x3?(下述代码第28-31行),合并
x
0
,
x
1
,
x
2
,
x
3
x_{0}, x_{1}, x_{2},x_{3}
x0?,x1?,x2?,x3?,通道数量则扩大
4
4
4倍(下述代码第32行),然后再通过线性层降维(下述代码第14和36行)。
图四:patch merger sample
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :]
x1 = x[:, 1::2, 0::2, :]
x2 = x[:, 0::2, 1::2, :]
x3 = x[:, 1::2, 1::2, :]
x = torch.cat([x0, x1, x2, x3], -1)
x = x.view(B, -1, 4 * C)
x = self.norm(x)
x = self.reduction(x)
return x
窗口自注意力
标准的Transformer架构是全局自注意力,它计算某个token和其他token之间的自注意力,计算复杂度和token的数量呈平方关系。视觉图像的token数量要多于语言中的单词token数量,Transformer在视觉中会耗费更多的资源,尤其对于高质量图像,计算复杂度会非常大。基于此种情况,Swim Transformer采用基于窗口的自注意替换标准的全局注意力。
将一张patches数量为
h
×
w
h \times w
h×w的图像拆分成不重叠的窗口,每个窗口包含
M
×
M
M \times M
M×M个patches。我们先回忆一下标准Transformer中的多头自注意力。假设输入为
x
x
x,将
x
x
x进行线性嵌入得到
Q
,
K
,
V
Q, K, V
Q,K,V三个向量,
Q
Q
Q和
K
K
K两个向量相乘计算得到Attention,然后Attention与向量
V
V
V相乘之后再线性映射得到输出。假设patches的数量为
N
N
N,通道数为
C
C
C,那么两次线性计算复杂度为
4
N
C
2
4NC^{2}
4NC2,
Q
,
K
,
V
Q, K, V
Q,K,V的两次矩阵计算的复杂度为
2
N
2
C
2N^{2}C
2N2C。那么对于标准的多头自注意力,它的计算复杂度为
Ω
(
M
S
A
)
=
4
h
w
C
2
+
2
(
h
w
)
2
C
\Omega\left( MSA \right) = 4hwC^{2} + 2\left( hw \right)^{2}C
Ω(MSA)=4hwC2+2(hw)2C;一个窗口的自注意力计算复杂度为
4
M
2
C
2
+
2
M
4
C
4M^{2}C^{2} + 2M^{4}C
4M2C2+2M4C,此张图片总共有
h
M
×
h
M
\frac{h}{M} \times \frac{h}{M}
Mh?×Mh?个窗口,那么总的基于窗口的自注意力的计算复杂度为
Ω
(
W
?
M
S
A
)
=
4
h
w
C
2
+
2
M
2
h
w
C
\Omega\left( W-MSA \right) = 4hwC^{2} + 2M^{2} hw C
Ω(W?MSA)=4hwC2+2M2hwC。
对于
224
×
224
224 \times 224
224×224大小的图片,每一个patches的大小为
4
×
4
4 \times 4
4×4,那么总共有
56
×
56
56 \times 56
56×56个patches。论文中默认
M
=
7
M=7
M=7,
Ω
(
M
S
A
)
\Omega\left( MSA \right)
Ω(MSA)和
Ω
(
W
?
M
S
A
)
\Omega\left( W-MSA \right)
Ω(W?MSA)计算复杂度相差在矩阵相乘的部分,
2
×
56
×
56
h
w
C
2 \times 56 \times 56 hwC
2×56×56hwC是
2
×
7
2
h
w
C
2 \times 7^{2} hwC
2×72hwC的近60倍。随着图片的尺寸越大,这个差距会越大。
论文在计算自注意力时引入了相对位置偏置(relative position bias),论文实验表明,相对位置偏置在ImageNet,CoCo和ADE20k数据集上的表现要优于不加偏置和使用绝对位置偏置。下述代码展示了带有相对位置偏置的窗口多头自注意力的前向过程。它支持窗口自注意力和移动窗口自注意力。窗口自注意力计算包含三个方面,常规多头自注意力,相对位置偏置的计算和移动窗口的掩码计算。常规多头自注意力有Transformer的基础就很好理解,难点在于相对位置偏置的计算和移动窗口的掩码计算。
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
S
o
f
t
M
a
x
(
Q
K
T
/
d
+
B
)
V
Attention(Q,K,V) = SoftMax(QK^{T} / \sqrt{d}+ B) V
Attention(Q,K,V)=SoftMax(QKT/d
?+B)V
def forward(self, x, mask=None):
"""
Attention(Q,K,V) = SoftMax(QK^{T}/sqrt(d) + Bias)V
x: 输入特征 shape: (num_windows*B, N, C)
mask: 掩码
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
绝对位置编码是在进行自注意力计算之前为每个token添加一个可学习的参数,相对位置编码,是在进行自注意力计算时,在计算过程中添加一个可学习的相对位置参数。
相对位置偏置
B
∈
R
M
2
×
M
2
B \in \mathbb{R}^{M^{2} \times M^{2}}
B∈RM2×M2,每一个轴的取值范围是
[
?
M
+
1
,
M
?
1
]
[-M+1, M-1]
[?M+1,M?1]。计算自注意力时,每个token都要与其他位置上的token计算
Q
K
QK
QK值。对于一个大小为
2
×
2
2\times2
2×2的窗口,位置1上的patch要与位置1,2,3,4的patch计算
Q
K
QK
QK值,位置2上的patch要与位置1,2,3,4上的patch计算
Q
K
QK
QK值,… ,那么其他位置相对于当前位置都有一个偏移量。下图5中展示了relative_coords(下述代码第8行)其他位置相当于当前位置的偏移量(按列看),为了便于后续的计算,对每个元素都加上偏移量,使其从零开始,如下述代码第9和第10行。由于(0,1)和(1,0),(-1,0)和(0,-1)它们取和后的总偏移量结果一样,因为对某一列坐标进行乘法变换,如下述代码第11行,最后再取和得到总的偏移量relative_position_index。至此,相对位置的下标取值范围为
[
0
,
8
]
[0,8]
[0,8],可由一个
(
2
M
?
1
)
?
(
2
M
?
1
)
(2M-1)*(2M-1)
(2M?1)?(2M?1)大小的矩阵表示,参数化这个更小尺寸的偏置矩阵
B
^
∈
R
(
2
M
?
1
)
×
(
2
M
?
1
)
\hat{B} \in \mathbb{R}^{\left( 2M-1\right) \times \left( 2M-1\right)}
B^∈R(2M?1)×(2M?1),那么
B
B
B的值就可以从
B
^
\hat{B}
B^中提取。
self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
trunc_normal_(self.relative_position_bias_table, std=.02)
'''self.relative_position_index是计算出不可学习的量 第17行
self.relative_position_index.shape=(Wh*Ww, Wh*Ww) 第15行
self.relative_position_bias_table.shape=(2*Wh-1 * 2*Ww-1, nH) 第2行
self.relative_position_index矩阵中的所有值都是从self.relative_position_bias_table中取的
'''
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
Swim 计算每个窗口中的自注意,窗口与窗口之间无计算,失去了Transformer处理全局信息的特征。为此,Swim Transformer提出了移位窗口分割方法。假设有
8
×
8
8 \times 8
8×8 patches的图片,其中一个窗口包含
4
×
4
4 \times 4
4×4个patches,那么将有
2
×
2
2 \times 2
2×2个窗口。现在将左上角的窗口向左下移位
2
2
2个patches,四个窗口将被重新划分为
9
9
9个大小不一的窗口,如图六所示,只有标号为4的窗口和原窗口大小一致。
最直接的想法是对小窗口进行padding,并在计算的时候屏蔽掉填充的值。但是,自注意力计算将由四个被扩展到九个,计算多了2.25倍。为了不增加计算量,论文中提出了循环移位(cyclic-shift)算法,如图六所示,将编号3,6的窗口移位到编号5,8的窗口下面,将编号0,1的窗口移位到编号6,7的窗口左面,将编号为0的窗口,从左上角移位到右下角。这样就可以重新拼凑出
2
×
2
2 \times 2
2×2 (4,(7,1),(3,5),(0,2,6,8))个窗口。拼凑出的窗口在原图中属于不同的位置,不相连,以标号为0,2,6,8窗口组成的大窗口为例,这四个小窗口分别位于原图的四个顶点,关联性极低,因此,在计算窗口注意力时,需要掩码机制,只能计算相同子窗口的自注意力,不同窗口的自注意力结果要为0。标号为0,2,6,8窗口,在计算窗口自注意力时,窗口0中的每一个patch分别需要和窗口0,2,6,8中的每一patch进行自注意力计算,那么窗口0中的patch与窗口0中的patch的自注意力是有用的,但是窗口0中的patch与窗口2,6,8中的patch的自注意力需要设为0。我们回忆一下Attention的计算公式,
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
S
o
f
t
M
a
x
(
Q
K
T
/
d
+
B
)
V
Attention(Q,K,V) = SoftMax(QK^{T} / \sqrt{d}+ B) V
Attention(Q,K,V)=SoftMax(QKT/d
?+B)V,自注意力计算最后需要Softmax函数。在不同窗口的自注意力值上添加
?
100
-100
?100(下图代码第20行,mask赋值-100;第27行,将mask添加到自注意力值上,然后再进行softmax计算),在softmax计算过程中,
?
100
-100
?100会无限趋近于0,从而达到归0的效果。
图六:窗口移动
if self.shift_size > 0:
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1))
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
下图展示了循环移位之后的窗口组成和mask值的分布情况。
图七:自注意力mask
Architecture Variants
Swin Transformer有四种形式,分别命名为Swin-T,Swin-S,Swin-B和Swin-L。以Swin-T为基础模型版本,Swin-T,Swin-S,和Swin-L分别是基础模型的
0.25
×
0.25\times
0.25×,
0.5
×
0.5\times
0.5×和
2
×
2\times
2×倍。这四种模型的架构如图8所示。
图8:architecture
ViT模型将Transformer结构应用到视觉领域,但是仍然还受限于图片的尺寸大小。Swin引入移动窗口和分层结构,使得自注意力在视觉领域的计算复杂度能与图片大小成线性关系。Swin吸取了CNN和Transformer的优点,在ImageNet-1k的数据集上也能取得SOTA效果,相比于ViT模型,降低了数据的需求量。
参考
- Swin-Transformer code
- Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
- The Question about the mask of window attention
- An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
|