Title | Masked-attention Mask Transformer for Universal Image Segmentation |
---|
Paper | https://arxiv.org/pdf/2112.01527.pdf | Code | https://github.com/facebookresearch/Mask2Former |
1. 整体框架
Mask2former的整体结构包含三个部分:
- Backbone:从图像中提取多尺度特征;
- Pixel decoder:类似FPN的功能,进行多尺度特征交互融合;
- Transformer Decoder:从pixel feature中迭代更新query feature,并利用query feature预测class,利用query feature和1/4高分辨率特征预测mask;
1.1 Pixel decoder
- 输入1/32,1/16,1/8,1/4四个分辨率的特征,采用Deformable DETR提出的Deformable Transformer进行多尺度特征交互。
- 具体来讲,对于1/32,1/16,1/8三个特征图上的每一个像素,在三个特征图上各预测K个点,最终使用3K个点的特征加权来更新当前像素点的特征。
- 关于每个分辨率K个点的选取,是通过以当前像素所在的位置为参考点,预测K个(x, y)偏移量得到。
- 这样,有效避免了Transformer采用全局特征来更新单个像素特征导致的极高的计算复杂度,将计算量从
H
W
×
H
W
HW \times HW
HW×HW降低到了
H
W
×
K
HW \times K
HW×K。
- 除此之外,最终还将1/8分辨率的特征上采样到1/4并与输入1/4特征sum融合用于作为mask预测的高分辨率特征。
1.2 Transformer Decoder
Transformer Decoder部分采用与DETR相似的query方式进行目标检测,即首先初始化N*256的query feature,进而采用1/32,1/16,1/8三个分辨率的特征对query feature进行迭代更新并进行最终的class和mask预测。本文的Transformer Decoder与DETR中使用的Decoder有两个区别,分别是(1)masked attention;(2)调换self-attention和cross-attention的顺序。除此之外,与DETR的不同的是,本文采用多尺度特征更新query,避免了DETR仅仅使用1/32特征图造成的小目标的检测效果较差。
1.2.1 DETR Transformer Decoder
DETR的Decoder部分主要包含self-attention,cross-attention以及FFN三部分;
- self-attention:
- Q:
N
?
256
N*256
N?256
- K:
N
?
256
N*256
N?256
- V:
N
?
256
N*256
N?256
- Cross-attention
- Q:
N
?
256
N*256
N?256
- K:
H
?
W
?
256
H*W*256
H?W?256
- V:
H
?
W
?
256
H*W*256
H?W?256
- FFN
- bbox:
N
?
256
→
N
?
4
N*256 \rightarrow N*4
N?256→N?4
- cls:
N
?
256
→
N
?
c
l
a
s
s
N*256 \rightarrow N*class
N?256→N?class
1.2.2 masked attention
Masked attention是针对DETR transformer Decoder中的cross-attention部分的改进。有研究表明,DETR类的模型收敛速度慢的部分原因是cross-attention中的全局上下文需要经过较长的训练时间才能使得注意力每次集中在目标附近,即局部上下文。本文考虑到在cross-attention中使用局部上下文能够加速模型的收敛,设计了masked attention模块,即每次仅使用特征图的前景区域的特征更新query。 两者的对比如下:
Cross attention | |
---|
Masked attention |
|
M
l
?
1
M_{l-1}
Ml?1?采用的是前一个transformer decoder层预测的mask,并使用阈值0.5进行截断得到的二值图。当某个像素的前一层被预测为背景,即
M
l
?
1
=
0
M_{l-1}=0
Ml?1?=0,映射后的
M
l
?
1
=
?
∞
\mathcal{M}_{l-1}=-\infty
Ml?1?=?∞,则经过softmax映射后,其该像素点的注意力便会下降为0。最终,便只有前景区域的像素点特征会影响query
X
l
X_{l}
Xl?的更新。
1.2.3 调换self-attention和cross-attention的顺序
由于query feature是zero初始化的,第一层transformer Decoder直接上来就进行self-attention不会产生任何有意义的特征,因此先试用cross attention对query feature进行更新后再执行query feature内部的self-attention反而是一种更佳的做法。
1.3 采样点损失函数
不同于语义分割,其最终是预测
C
l
a
s
s
?
H
?
W
Class*H*W
Class?H?W的特征图,并对其进行监督。然而,maskformer则需要预测
N
u
m
_
q
u
e
r
i
e
s
?
H
?
W
Num\_queries*H*W
Num_queries?H?W的特征图。由于query数目通常远大于类别数目,因此,这一类方法会导致巨大的训练内存占用。本文参考PointRend,在每个
H
?
W
H*W
H?W的特征图依据当前像素的不确定度选取K个采样点而不是整个特征图来计算损失函数。其中,不确定度的衡量是依据像素预测置信度得到的。
u
n
c
e
r
t
a
i
n
t
y
=
?
(
t
o
r
c
h
.
a
b
s
(
g
t
_
c
l
a
s
s
_
l
o
g
i
t
s
)
)
uncertainty=-(torch.abs(gt\_class\_logits))
uncertainty=?(torch.abs(gt_class_logits)) 除了在计算损失函数的过程中运用了采样点,本文还在matcher匈牙利匹配时才用了采样点的策略。不同之处在于,在matcher过程中,还不清楚每个query和哪个gt匹配,因此无法同样采用不确定度进行选点无法保证能够选取到gt的前景区域,因此,在matcher过程是采用均匀分布随机了K个点,同一个batch所有的query和gt均是采样这K个点。
2. 实验部分
2.1 实验设置
- 训练尺寸:1024*1024
- Epoch:50
- Batch size:16
- 数据增强:LSJ(large scale jitting)从0.1 到 2.0 范围内随机缩放
2.2 pixel decoder
Pixel decoder充当的就是FPN多尺度特征融合的功能,此处的各个模块均是前人论文所提出的。从最终效果来看,Deformable DETR提出的Multi-scale Deformable Attention是效果最佳的。
2.3 Masked attention
综合两个表格来看,单独的masked attention相比于标准的cross-attention能够产生5.9个点的AP提升。但是相比于同类型的采用局部上下文特征更新query的方式相比:
- mask pooling averages features within mask regions;
- QueryInst采用RoIAlign将box 区域的特征缩放到相同的尺寸并拉伸进而通过卷积映射来更新query feature;
从表格来看,masked attention与mask pooling相比只有0.6个点的AP提升。并且,masked attention虽然只使用了mask区域内的特征,但是其并不会降低计算量,因为其只是在cross-attention计算得到的attention map上屏蔽掉了背景区域的权重。相反,QueryInst采用的局部特征虽然没有细化到mask级别,只是bbox级别,但是其计算量却是更低的。
2.4 调换self-attention和cross-attention顺序
调换self-attention和cross-attention的顺序会产生0.5个点的AP影响。
2.5 采样点损失函数
- 对于使用不确定度的采样点损失函数,其相比于mask损失函数能够将显存占用从18G降低为6G,精度还有小幅提升。
- 对于使用均匀分布随机点采样匈牙利匹配,其竟能够产生2.7个点的AP收益,文中没有对此进行详细解释。因此我进行了一下复现,与文中的数据是一致的。详细一点来讲,使用mask匹配是在1/4分辨率进行的,将gt mask采用最近邻插值进行下采样。point匹配会随机K个点(不一定是整数位置),然后从1/4的预测mask和1/1的gt mask均采用双线性插值取值,不清楚为什么会产生如此大的差异。
2.6 query数目
关于query数目的实验,之前也在Deformable DETR上做过实验,随着query数目的增多,效果是越来越好的。此处的query越多效果却又变差,猜测可能是mask相比于bbox需要更多的训练epoch。
Method | queries | epoch | mAP | APs | APm | APl |
---|
Deformable DETR | 100 | 50 | 41.1 | 23.2 | 44.2 | 55.6 | Deformable DETR | 300 | 50 | 44.6 | 27.0 | 47.3 | 59.2 | Deformable DETR | 500 | 50 | 44.8 | 27.3 | 48.1 | 59.2 | Deformable DETR | 700 | 50 | 45.2 | 27.8 | 48.7 | 59.4 |
2.7 LSJ数据扩充
LSJ数据扩充是本文使用的一个竞赛Trick,从epoch50的曲线来看,大致有2个点的AP影响,因此mask2former效果好还也不全是模型的功劳,毕竟LSJ就能提升2个点,其他的文中主要创新点的提点都还在1个点一下。
2.8 Benchmark对比
Benchmark对比来看,mask2former在APl上完胜其他所有的实例分割网络,唯独APs效果比QueryInst和HTC++差一点。这也是接下来的文章改进的主要方向。
|