def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=()):
"""Runs Non-Maximum Suppression (NMS) on inference results
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
nc = prediction.shape[2] - 5
xc = prediction[..., 4] > conf_thres
min_wh, max_wh = 2, 4096
max_det = 300
max_nms = 30000
time_limit = 10.0
redundant = True
multi_label &= nc > 1
merge = False
t = time.time()
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
for xi, x in enumerate(prediction):
x = x[xc[xi]]
if labels and len(labels[xi]):
l = labels[xi]
v = torch.zeros((len(l), nc + 5), device=x.device)
v[:, :4] = l[:, 1:5]
v[:, 4] = 1.0
v[range(len(l)), l[:, 0].long() + 5] = 1.0
x = torch.cat((x, v), 0)
if not x.shape[0]:
continue
x[:, 5:] *= x[:, 4:5]
box = xywh2xyxy(x[:, :4])
if multi_label:
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
else:
conf, j = x[:, 5:].max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
n = x.shape[0]
if not n:
continue
elif n > max_nms:
x = x[x[:, 4].argsort(descending=True)[:max_nms]]
c = x[:, 5:6] * (0 if agnostic else max_wh)
boxes, scores = x[:, :4] + c, x[:, 4]
i = torchvision.ops.nms(boxes, scores, iou_thres)
if i.shape[0] > max_det:
i = i[:max_det]
if merge and (1 < n < 3E3):
iou = box_iou(boxes[i], boxes) > iou_thres
weights = iou * scores[None]
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)
if redundant:
i = i[iou.sum(1) > 1]
output[xi] = x[i]
if (time.time() - t) > time_limit:
print(f'WARNING: NMS time limit {time_limit}s exceeded')
break
return output
1
- 输入
prediction 是一个tensor张量; - 每一条数据代表每一个预测框;
- 前面四个数据代表预测框的坐标信息;
- 第五个数据代表框的置信度;
- 后面有六个数据,是因为我这里有六个类别的物体,因此每个数据代表每一类的分类置信度。
tensor([[[3.8242e+00, 3.9844e+00, 9.9609e+00, 1.0477e+01, 2.0862e-06,
9.3384e-03, 3.2898e-02, 2.3975e-03, 6.5369e-02, 8.0322e-01,
4.5258e-02],
[1.1172e+01, 3.0117e+00, 2.1844e+01, 7.0781e+00, 4.4703e-06,
9.1629e-03, 3.9490e-02, 2.4433e-03, 6.0974e-02, 7.5830e-01,
5.0812e-02],
...
...
[5.8500e+02, 6.4350e+02, 1.4875e+02, 9.3812e+01, 1.7881e-06,
2.4658e-02, 2.3331e-02, 9.5978e-03, 5.4492e-01, 1.3086e-01,
5.5939e-02],
[6.2100e+02, 6.5300e+02, 1.1762e+02, 1.0444e+02, 1.4901e-06,
3.0045e-02, 3.2715e-02, 1.2238e-02, 3.6963e-01, 2.3193e-01,
5.3101e-02]]], device='cuda:0', dtype=torch.float16)
-
prediction.shape[0] = 1 :一张图片就是1 -
prediction.shape[1] = 26460 :这里可能表示预测框的数量 -
prediction.shape[2] = 11 :每一条数据中数据的个数(4个位置信息+1个框置信度+6个分类置信度) -
我们看第一条数据 [3.8242e+00, 3.9844e+00, 9.9609e+00, 1.0477e+01, 2.0862e-06, 9.3384e-03, 3.2898e-02, 2.3975e-03, 6.5369e-02, 8.0322e-01, 4.5258e-02]
[center_x, center_y, width, height, cls_conf, obj_conf0, obj_conf1, obj_conf2, obj_conf3, obj_conf4, obj_conf5]
2
import torch
x = torch.tensor([[2.9950e+02, 2.9025e+02, 1.6962e+02, 3.0675e+02,
2.3785e-03,
2.0580e-03, 6.3610e-04, 9.6226e-04, 9.9756e-01, 3.0270e-03, 3.2864e-03]],
dtype=torch.float16)
print(x[:, 5:])
print(x[:, 4:5])
x[:, 5:] *= x[:, 4:5]
print(x[:, 5:])
输出
tensor([[2.0580e-03, 6.3610e-04, 9.6226e-04, 9.9756e-01, 3.0270e-03, 3.2864e-03]], dtype=torch.float16)
tensor([[0.0024]], dtype=torch.float16)
tensor([[4.8876e-06, 1.4901e-06, 2.2650e-06, 2.3727e-03, 7.2122e-06, 7.8082e-06]], dtype=torch.float16)
print(x[:, 5:]) :表示第五个数据后面的数 print(x[:, 4:5]) :表示第四个数据的近似值
3
if multi_label:
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
else:
conf, j = x[:, 5:].max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
x[:, 5:] > conf_thres :每个分类置信度大于conf_thres为True,否则False
tensor([[False, True, False, False, False, False],
[False, False, False, False, False, False],
[False, False, False, True, False, False],
[False, False, False, True, False, False],
[False, False, False, True, False, True],
[False, False, False, True, False, False]], device='cuda:0')
i:
tensor([0, 2, 3, 4, 4, 5], device='cuda:0')
tensor([第0条数据有1个True, 第2条数据有1个True, 第3条数据有1个True, 第4条数据有2_1个True, 第4条数据有2_2个True, 第5条数据有1个True], device='cuda:0')
j:
tensor([1, 3, 3, 3, 5, 3], device='cuda:0')
tensor([第0条数据第1个位置为True, 第2条数据第3个位置为True, 第3条数据第3个位置为True, 第4条数据第3个位置为True, 第4条数据第5个位置为True, 第5条数据第3个位置为True, device='cuda:0')
j其实就是类别索引
|