build_targets作用
build_targets函数用于网络训练时计算loss所需要的目标框,即正样本。
注意
- 与yolov3/yolov4不同,yolv5支持跨网格预测。即每一个bbox,正对于任何一个输出层,都可能有anchor与之匹配。
- 该函数输出的正样本框比传入的GT数目要多。
- 当前解读版本为6.1
可视化结果
过程
- 首先通过bbox与当前层anchor做一遍过滤。对于任何一层计算当前bbox与当前层anchor的匹配程度,不采用IoU,而采用shape比例。如果anchor与bbox的宽高比差距大于4,则认为不匹配,保留下匹配的bbox。
r = t[..., 4:6] / anchors[:, None]
j = torch.max(r, 1 / r).max(2)[0] < self.hyp['anchor_t']
t = t[j]
- 最后根据留下的bbox,在上下左右四个网格四个方向扩增采样。
gxy = t[:, 2:4]
gxi = gain[[2, 3]] - gxy
j, k = ((gxy % 1 < g) & (gxy > 1)).T
l, m = ((gxi % 1 < g) & (gxi > 1)).T
j = torch.stack((torch.ones_like(j), j, k, l, m))
t = t.repeat((5, 1, 1))[j]
详细代码解读
准备
def build_targets(self, p, targets):
P是网络预测的输出。 p的shape为 :(batch_size,anchor_num,grid_cell,grid_cell,xywh+obj_confidence+classes_num) data:image/s3,"s3://crabby-images/47618/47618554bc93a632086c6352887f1aed38d72dff" alt="在这里插入图片描述" P[0]的shape data:image/s3,"s3://crabby-images/6f5f3/6f5f31d9daa93505b2201ee6a08db56098e3c8f7" alt="在这里插入图片描述" P[1]的shape data:image/s3,"s3://crabby-images/24a34/24a347e66a2e4c82a8c7b5f0dd26332c12f97fa0" alt="在这里插入图片描述" P[2]的shape data:image/s3,"s3://crabby-images/5b250/5b250daa34152cbf263f45a156aebb29c8f2ae2d" alt="在这里插入图片描述"
targets是经过数据增强(mosaic等)后总的bbox。 targets的shape为 :[num_obj, 6] , that number 6 means -> (img_index, obj_index, x, y, w, h) data:image/s3,"s3://crabby-images/4c76b/4c76b37e389bf69efc53ad7da9272aac57d56d4c" alt="在这里插入图片描述" data:image/s3,"s3://crabby-images/d6023/d6023b51fe15244025cff53cafef16ce2d31e31d" alt="在这里插入图片描述"
na, nt = self.na, targets.shape[0]
data:image/s3,"s3://crabby-images/6b959/6b959413f4844123b31099729b4ee8fad81323ad" alt="在这里插入图片描述"
tcls, tbox, indices, anch = [], [], [], []
tcls:用来存储类别。 tbox:用来存储bbox indices:用来存储第几张图片,当前层的第几个anchor,以及当前层grid的下标。
gain = torch.ones(7, device=self.device)
初始化为1,用来还原bbox为当前层的尺度大小。
ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt)
扩充anchor数量和当前bbox一样多。 ai是anchor的下标 data:image/s3,"s3://crabby-images/5007a/5007aca87e6cf02201e1dba17a7ea469aa42b699" alt="在这里插入图片描述" data:image/s3,"s3://crabby-images/683fa/683fa54a78dc9f9c2fb62b47081ac31b8b1971c7" alt="在这里插入图片描述"
targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None]), 2)
targets的shape变为(3,101,7)。 targets[0] 对应第一个anchor对应的(image_id, cls, center_x,center_y, w, h,第一个anchor) data:image/s3,"s3://crabby-images/9a49c/9a49ca958339eee69cad9ae957575274c013b52b" alt="在这里插入图片描述" targets[1] 对应第一个anchor对应的(image_id, cls, center_x,center_y, w, h,第二个anchor)data:image/s3,"s3://crabby-images/a0737/a0737c49430c33f3ef366299b659c3905ac96246" alt="在这里插入图片描述" targets[2] 对应第一个anchor对应的(image_id, cls, center_x,center_y, w, h,第三个anchor)data:image/s3,"s3://crabby-images/0af44/0af4466f27d72a398b8154b465f961c1a15de423" alt="在这里插入图片描述"
g = 0.5
off = torch.tensor(
[
[0, 0],
[1, 0],
[0, 1],
[-1, 0],
[0, -1],
],
device=self.device).float() * g
for i in range(self.nl):
anchors = self.anchors[i]
self.anchors data:image/s3,"s3://crabby-images/a585f/a585f2acd25a308eab8a44ad6bf16a9e74eca278" alt="在这里插入图片描述" self.anchors[0] 得到第一层归一化后的anchor data:image/s3,"s3://crabby-images/db6df/db6dfd4747a7babcd98e85f533322a4f68b5e00a" alt="在这里插入图片描述" 乘8得到的 data:image/s3,"s3://crabby-images/f80bc/f80bc2db145deab8e458982d8773a52801d11f1f" alt="在这里插入图片描述" self.anchors[1] 得到第二层归一化后的anchor data:image/s3,"s3://crabby-images/57cac/57cac20efb18b525fb36add45886ba0117a38b90" alt="在这里插入图片描述" 乘16得到的 data:image/s3,"s3://crabby-images/7d2d9/7d2d9cc993e39453e8d2a00ac6fe9031c656928a" alt="在这里插入图片描述" self.anchors[2] 得到第三层归一化后的anchor data:image/s3,"s3://crabby-images/48171/4817102e6a083c297f3d8a17713834df0ff8fa5e" alt="在这里插入图片描述" 乘以32得到的 data:image/s3,"s3://crabby-images/f3ad3/f3ad36a0903a21d25a4bee1f4c35fc96b65589b6" alt="在这里插入图片描述"
gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]]
生成一个当前层的方格大小。 如果i=0 data:image/s3,"s3://crabby-images/20e40/20e4084745ce30941cdcd32bc1a19ef907a49458" alt="在这里插入图片描述" 如果i=1, data:image/s3,"s3://crabby-images/1bae9/1bae97045a2f0d800448afc7ce130774ded1bac8" alt="在这里插入图片描述" 如果i=2 data:image/s3,"s3://crabby-images/9dd53/9dd5377acfe43bb29739db031078a3615e2f51aa" alt="在这里插入图片描述"
t = targets * gain
将targets的大小映射到当前层,第六列是当前层的第几个anchor,第0列是位于哪张图片,第1列代表的是类别,2-5列是目标在当前层x,y,w,h。 下采样八倍的层 data:image/s3,"s3://crabby-images/a6c9b/a6c9b0f876ab0eb8f834688721dc98e7d0d64cf0" alt="在这里插入图片描述"
第一遍筛选
if nt:
r = t[..., 4:6] / anchors[:, None]
r是指bbox与当前层三个anchor的高宽的比值。 data:image/s3,"s3://crabby-images/4080e/4080ef13c877cd55db9a5de46640e43848e0c72e" alt="在这里插入图片描述" r[0] data:image/s3,"s3://crabby-images/ced92/ced92ca7fe517ff4a1ec0d662395a964f13419b5" alt="在这里插入图片描述" r[1] data:image/s3,"s3://crabby-images/f3174/f31749559c20f6815128e19c84aba788b26ab735" alt="在这里插入图片描述" r[2] data:image/s3,"s3://crabby-images/0966c/0966c62e047981798671045b0e6a9a38e3d697f3" alt="在这里插入图片描述"
j = torch.max(r, 1 / r).max(2)[0] < self.hyp['anchor_t']
torch.max(r, 1 / r).max(2)[0] 为什么是[0] 不是[1] .[0]代表的是value,[1]代表的index。
data:image/s3,"s3://crabby-images/e2a90/e2a9034c3e4a5458831b97a9f068a063c747f55c" alt="在这里插入图片描述"
torch.max(r, 1 / r).max(2)[1]
data:image/s3,"s3://crabby-images/1c237/1c237463ec36114fc5da0b5d1fa9ef182dae4181" alt="在这里插入图片描述"
torch.max(r, 1 / r).max(1)[0]
按行获取最大值。 data:image/s3,"s3://crabby-images/46df5/46df5b0b45de8f14443f0a9ac2d4d1529e3bcd24" alt="在这里插入图片描述"
torch.max(r, 1 / r).max(1)[1]
按行获取最大值,返回索引。 data:image/s3,"s3://crabby-images/27039/27039f3aeab8fc1c3e7ee91e43a47248544768f1" alt="在这里插入图片描述"
t = t[j]
经过过滤后,全部汇总到来了一起。按照第六列anchor的顺序排列。 data:image/s3,"s3://crabby-images/7be45/7be454a17dda4360cf2598cf79c0ddba4907023d" alt="在这里插入图片描述"
扩增正样本
接下来是扩增正样本
gxy = t[:, 2:4]
gxi = gain[[2, 3]] - gxy
假设最后的特征图大小是8x8,有a-h8个目标边框如下。 data:image/s3,"s3://crabby-images/9108f/9108f2c94e39f25ccc117db167a6c32e521b455f" alt="在这里插入图片描述" 下图中深灰色的表示满足条件的。 data:image/s3,"s3://crabby-images/690ac/690ac35d5118840a5a642048f5855baf9af662c6" alt="在这里插入图片描述"
j, k = ((gxy % 1 < g) & (gxy > 1)).T
l, m = ((gxi % 1 < g) & (gxi > 1)).T
gxy % 1 < g 和gxi % 1 < g 包含两个方向,x和y方向。 data:image/s3,"s3://crabby-images/e938f/e938fb7da032b973c46c384978a544ca2d13b953" alt="在这里插入图片描述"
((gxy % 1 < g) & (gxy > 1))
data:image/s3,"s3://crabby-images/e1b6e/e1b6e4e283b2b3c7fd1c4ee27a12162d130faed6" alt="在这里插入图片描述"
(gxi % 1 < g) & (gxi > 1)
data:image/s3,"s3://crabby-images/8d3ed/8d3eddbf09d2880fd4fe6bec333331f31a6d6ba5" alt="在这里插入图片描述"
j = torch.stack((torch.ones_like(j), j, k, l, m))
t = t.repeat((5, 1, 1))[j]
|----------------------------------------------------------------------|
| 这里将t复制5个,然后使用j来过滤 |
| 第一个t是保留经过第一步过滤留下的gtbox,因为上一步里面增加了一个全为true的维度|
| 第二个t保留了靠近方格左边的gtbox, |
| 第三个t保留了靠近方格上方的gtbox, |
| 第四个t保留了靠近方格右边的gtbox, |
| 第五个t保留了靠近方格下边的gtbox, |
|----------------------------------------------------------------------|
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
j的第一行全为1,意思是指经过第一步保留下的bbox所在的grid_cell为1. data:image/s3,"s3://crabby-images/080ee/080ee562ecd46675b6d0dfcde39ba082e6f72a95" alt="在这里插入图片描述"
else:
t = targets[0]
offsets = 0
bc, gxy, gwh, a = t.chunk(4, 1)
a, (b, c) = a.long().view(-1), bc.long().T
gij = (gxy - offsets).long()
gi, gj = gij.T
下面的四张图展示了gij = (gxy - offsets).long() 做了啥。 data:image/s3,"s3://crabby-images/5f391/5f391de5a0d01aad76279144e473e69c9ae16f0e" alt="在这里插入图片描述" data:image/s3,"s3://crabby-images/1f923/1f92376ddab54b5b39356e07221ef3a2bb6ba2c1" alt="在这里插入图片描述" data:image/s3,"s3://crabby-images/00738/0073854eac46442a1b4e67188d4f972c968b8ea8" alt="在这里插入图片描述" data:image/s3,"s3://crabby-images/53bf9/53bf9f7fdd5eb96fe5a0a6919e753b91868d105d" alt="在这里插入图片描述"
**最终得到的结果如下**
data:image/s3,"s3://crabby-images/ae31c/ae31ca9953ef90711e5238c0b02ad45e94b02509" alt="在这里插入图片描述"
indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1)))
tbox.append(torch.cat((gxy - gij, gwh), 1))
anch.append(anchors[a])
tcls.append(c)
tbox.append(torch.cat((gxy - gij, gwh), 1)) # box 这句话做的如下: data:image/s3,"s3://crabby-images/ad59d/ad59d9e3f080b9ac0be945b1430e3e9f30526490" alt="在这里插入图片描述"
Reference
- 感谢这位UP主的详细解释,本文的正样本采样细节参考了此UP主的PPT。yolo v5 解读,训练,复现
|