欢迎访问个人网络日志🌹🌹知行空间🌹🌹
1.简介
论文Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection 代码https://github.com/sfzhang15/ATSS
ATSS 是中科院自动化研究所的Shifeng Zhang 等最早于2019年12月份提交的论文中提出的方法,发表在CVPR2020会议上。
文中分析了Anchor Based 和Anchor Free 的检测方法,性能差异的主要原因在于正负训练样本的定义方式不同,而和回归目标是基于**点式(point)还是盒式(box)**关系不大。Anchor Free 检测常用的有两种方法,一种是keypoint_based ,另一种是center_based 。keypoint_based 的Anchor Free 目标检测算法同标准的keypoint estimation pipeline ,和anchor based 的目标检测算法差异较大。但center_based 的Anchor Free 目标检测算法与Anchor Based 的方法比较相近,center_based 方法将point 作为预设样本(如FCOS),Anchor Based 方法是将anchor 作为预设样本(如RetinaNet)。Anchor Based 的RetinatNet 与Center Based 的FCOS 的主要区别是:
- 1)
feature map 中每个位置的anchor 数量不同,RetinaNet 每个点生成多个anchor boxes ,FCOS 每个点生成一个anchor point - 2)正负样本的定义方式不同,
RetinaNet 使用IoU 来判定正负样本,FCOS 使用patial and scale constraints 来判断。 - 3)回归起始状态不同。
RetinaNet 是基于Anchor Box 的
(
t
x
,
t
y
.
t
ω
,
t
h
)
(t_x,t_y.t_\omega,t_h)
(tx?,ty?.tω?,th?),FCOS 是基于Anchor Point 的
(
l
l
,
l
t
,
l
r
,
l
o
)
(l_l,l_t,l_r,l_o)
(ll?,lt?,lr?,lo?)。
ATSS 分析了Anchor Based 和Anchor Free 检测算法实现上的差异,得出的结论是正负样本定义方式的不同影响了两种方法检测效果的差异。基于此论文提出了Adaptive Training Sample Selection(ATSS) 算法以基于目标特征自动的计算正负样本。本文还基于实验得出了在同个位置没必要使用多个anchor box 做检测的结论。
2.目标检测相关
ObjectDetector
AnchorBased
AnchorFree
OneStage:SSD
fusing context information from different layers
training from scratch
introducing new loss function
anchor refinement and matching
architecture redesign
feature enrichment and alignment
TwoStage:FasterR-CNN
architecture redesign and reform
context and attention mechanism
multiscale training and testing
training strategy and loss function
feature fusion and enhancement
better proposal and balance
KeyPointBased:CornerNet/GridR-CNN/ExtremeNet/CenterNet/RepPoints
CenterBased:YOLO/DenseBox/FCOS/CSP/FoveaBox
3.Anchor Based 与Anchor Free 目标检测算法的差异分析
Anchor Based 选择RetinaNet 作为代表,Anchor Free 选择FCOS 作为代表,从以下三方面进行分析:
- 1)正负样本定义
- 2)初始回归状态,是回归
t
x
,
t
y
,
t
w
,
t
h
t_x,t_y,t_w,t_h
tx?,ty?,tw?,th?还是
l
l
,
l
t
,
l
r
,
l
o
l_l,l_t,l_r,l_o
ll?,lt?,lr?,lo?
- 3)每个位置的
anchor 数量
3.1 RetinaNet 与FCOS 的对比
设置RetinaNet 的Anchor box 数量为1 。对FCOS 的改进:
- 1)将
centernerss 移到regression 分支 - 2)使用
GIoU Loss - 3)将回归目标使用对应level的stride来归一化
这些提升了FCOS 的检测效果,coco minival 上的map 从37.1 提升到了37.8 ,进一步拉开了Anchor=1 的RetinaNet 与FCOS 的差距。
FCOS 中使用的一些trick 在Anchor=1 的RetinaNet 中也能使用,如检测头中使用的Group Normlization , GIoU ,限制ground truth box 中的正样本,对特征金字塔的每层加上一个中心度分支和可训练参数。将这些trick 逐一加到RetinaNet 上的对比结果为:
从上图可以看出,将所有的通用trick 都应用到RetinaNet 上后,MAP 依然有0.8的差距。除了以上指出的通用性差异后,还有两点不同,一个是正负样本的定义方式,另一个是回归任务本身,RetinaNet 是基于Anchor Box 回归,FCOS 是基于Anchor Point 回归。
3.1 正负样本定义的区别
如上图,RetinaNet 根据ground truth box 与anchor box 之间的IoU 的值来判断是正样本还是负样本,通常设置两个超参数
(
I
o
U
n
e
g
,
I
o
U
p
o
s
)
(IoU_{neg}, IoU_{pos})
(IoUneg?,IoUpos?),小于
I
o
U
n
e
g
IoU_{neg}
IoUneg?的是负样本,大于
I
o
U
p
o
s
IoU_{pos}
IoUpos?的是正样本,在两者之间的Anchor Box 被忽略,不参与训练,RPN 生产的Proposal Box 基于FPN 论文中提出的方程式2赋值给某个feature 层。FCOS 则先根据Anchor Point 的空间位置是否落在ground truth box 中找出可能为正的Anchor Point ,再根据Anchor Point 对应feature map 上的回归范围regression scale 来近一步确认是否为正样本,参考见博客FCOSNet。基于Spatial and Scale 的正样本判定方式决定了检测器的优秀性能,如下表,使用Spatial and Scale 后,Anchor=1的RetinaNet 的MAP 也提升到了37.8 ,换用IoU 的FCOS 的MAP 降到了36.9 :
3.2 回归起始位置的差异
如下图,Anchor=1的RetinaNet 回归的是AnchorBox 相对于ground truth box 的平移缩放
(
t
x
,
t
y
,
t
w
,
t
h
)
(t_x,t_y,t_w,t_h)
(tx?,ty?,tw?,th?)即基于box 的回归,而FCOS 回归的是中心点距离ground truth box 四边的距离
l
l
,
l
t
,
l
r
,
l
b
l_l,l_t,l_r,l_b
ll?,lt?,lr?,lb?,即基于点的回归。从上图中按行方向比较可以发现,使用box 或point 的回归方式对最终的结果影响不大,37->36.9 ,‵37.8->37.8`。
综合3.1和3.2的分析,可以得出结论:是正负样本的定义方式不同影响了Anchor Based 和Anchor Free 算法的性能。
4.自适应训练样本选择
从前面作者得出的结论,How to define positive and negative samples极大影响了检测器的性能,基于此作者提出了新的samples 分类算法,自适应训练样本选择(Adaptive Training Sample Selection, ATSS)。
Anchor Based 基于IoU 和Anchor Free 基于Scale Range 的正样本定义方法都依赖预先定义好的超参数,ATSS 提出了一种自适应取阈值的方法,减少了sample definition 所需的超参数。
以一张输入图像为例说明上图ATSS 算法的工作流程:
- 1)对于1个
ground truth box ,分别在每个金字塔特征层上取中心
L
2
L_2
L2?距离最近的
k
k
k个anchor boxes 作为候选positive sample ,对于有
L
\mathcal{L}
L个金字塔特征层的网络,共得到
k
L
k\mathcal{L}
kL个candidate positive anchor boxes - 2)计算
candidates 与ground truth boxes
g
∈
D
g
g\in \mathcal{D}_{g}
g∈Dg?之间的IoU - 3)计算2)中
IoU 的均值
m
g
m_{\mathcal{g}}
mg?和标准差
v
g
\mathcal{v}_{\mathcal{g}}
vg? - 4)取
t
g
=
m
g
+
v
g
t_g=m_{\mathcal{g}}+\mathcal{v}_{\mathcal{g}}
tg?=mg?+vg?作为阈值,大于
t
g
t_g
tg?的是
positive ,其余的Anchor Boxes 都是negative
作者指出,当一个anchor box 同时落入两个ground truth box 中时,会将其分配给IoU 比较大的ground truth box 。
从上图可以看出ATSS 的作用,对于某个ground truth box ,图a中标准差较大,意味着有某个金字塔特征层比较适合预测该box ,因此阈值
t
g
t_g
tg?也比较大。图b中标准差不大,意味者可能有多个特征层适合预测当前box ,因此选取的阈值
t
g
t_g
tg?也较小。
作者还指出使用ATSS ,可以使得对于不同大小的目标对象得到相同比例的正负训练样本。对于标准正态分布有16% 的样本落在
[
v
+
σ
,
1
]
[v+\sigma,1]
[v+σ,1]之间,虽然IoU of candidates 不是正态分布,正样本的比例依然保持在了20% of
k
L
k\mathcal{L}
kL 左右,和目标
s
c
a
l
e
/
a
s
p
e
c
t
r
a
t
i
o
/
l
o
c
a
t
i
o
n
scale/aspect ratio/location
scale/aspectratio/location无关。而RetinaNet 和FCOS 都会倾向于对大目标生成更多的正样本。
ATSS 使用的超参数很少,只有k 一个,且算法效果对k 不敏感。实验证明k 取[3, 5, 7, 9, 11, 13, 15, 17, 19] 时map 变化不大:
5.代码实现
mmdetection 中ATSS 算法的实现在ATSSAssigner 类中,assign 的部分代码如下:
candidate_idxs = []
start_idx = 0
for level, bboxes_per_level in enumerate(num_level_bboxes):
end_idx = start_idx + bboxes_per_level
distances_per_level = distances[start_idx:end_idx, :]
selectable_k = min(self.topk, bboxes_per_level)
_, topk_idxs_per_level = distances_per_level.topk(
selectable_k, dim=0, largest=False)
candidate_idxs.append(topk_idxs_per_level + start_idx)
start_idx = end_idx
candidate_idxs = torch.cat(candidate_idxs, dim=0)
candidate_overlaps = overlaps[candidate_idxs, torch.arange(num_gt)]
overlaps_mean_per_gt = candidate_overlaps.mean(0)
overlaps_std_per_gt = candidate_overlaps.std(0)
overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :]
for gt_idx in range(num_gt):
candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
ep_bboxes_cx = bboxes_cx.view(1, -1).expand(
num_gt, num_bboxes).contiguous().view(-1)
ep_bboxes_cy = bboxes_cy.view(1, -1).expand(
num_gt, num_bboxes).contiguous().view(-1)
candidate_idxs = candidate_idxs.view(-1)
l_ = ep_bboxes_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0]
t_ = ep_bboxes_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1]
r_ = gt_bboxes[:, 2] - ep_bboxes_cx[candidate_idxs].view(-1, num_gt)
b_ = gt_bboxes[:, 3] - ep_bboxes_cy[candidate_idxs].view(-1, num_gt)
is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01
is_pos = is_pos & is_in_gts
欢迎访问个人网络日志🌹🌹知行空间🌹🌹
参考资料
|