IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> YOLOv5——NMS -> 正文阅读

[人工智能]YOLOv5——NMS

def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
                        labels=()):
    """Runs Non-Maximum Suppression (NMS) on inference results

    Returns:
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
    """

    nc = prediction.shape[2] - 5  # 分类数
    # 第四个值框置信度大于conf_thres的为True,否则为False
    xc = prediction[..., 4] > conf_thres  # candidates 

    # Settings
    min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height
    max_det = 300  # maximum number of detections per image
    max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
    time_limit = 10.0  # seconds to quit after
    redundant = True  # require redundant detections
    # 是否属于多分类问题
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    # 默认是关闭的,使用的话需要修改为True
    merge = False  # use merge-NMS

    t = time.time()
    # 创建一个存储容器
    output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
    # xi表示某张图片的索引
    # x表示某张图片的tensor张量数据,x中包含数条预测框数据
    for xi, x in enumerate(prediction):  # image index, image inference
        # xc[xi]:表示筛选某张图片框置信度大于阈值的所有数据,这里的表示形式是True或False,而不是数据
        # x表示所有的candidates 为True的数据,这里表示筛选出所有置信度大于conf_thres的框的数据
        x = x[xc[xi]]  # confidence

        # Cat apriori labels if autolabelling
        # 好像没用到,暂时不管了
        if labels and len(labels[xi]):
            l = labels[xi]
            v = torch.zeros((len(l), nc + 5), device=x.device)
            v[:, :4] = l[:, 1:5]  # box
            v[:, 4] = 1.0  # conf
            v[range(len(l)), l[:, 0].long() + 5] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Compute conf
        # 见2讲解
        # 计算最后六个分类置信度信息,将他们分别乘上第四个数据(框置信度)
        x[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf

        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
        box = xywh2xyxy(x[:, :4])

        # Detections matrix nx6 (xyxy, conf, cls)
        # 见3讲解
        if multi_label:
        	# 如果是多分类问题,得到每一条数据的索引值i,以及每一条数据中分类置信度大于阈值的类别的索引
            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
            # 得到一个最终x,x是所有符合条件的框信息
            # 某条数据的框坐标信息和这个类别的得分和这个类别的类别索引放在一条数据
            x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
        else:  # best class only
            conf, j = x[:, 5:].max(1, keepdim=True)
            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Apply finite constraint
        # if not torch.isfinite(x).all():
        #     x = x[torch.isfinite(x).all(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        # 如果大于最大容量的框数量,进行排序,取前max_nms个
        elif n > max_nms:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence

        # Batched NMS
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        # 每一个框以及他的分数
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        # NMS过滤后的bouding boxes索引(降序排列)
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        if i.shape[0] > max_det:  # limit detections
            i = i[:max_det]
        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
            weights = iou * scores[None]  # box weights
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
            if redundant:
                i = i[iou.sum(1) > 1]  # require redundancy

        output[xi] = x[i]
        if (time.time() - t) > time_limit:
            print(f'WARNING: NMS time limit {time_limit}s exceeded')
            break  # time limit exceeded

    return output

1

  • 输入prediction是一个tensor张量;
  • 每一条数据代表每一个预测框;
  • 前面四个数据代表预测框的坐标信息;
  • 第五个数据代表框的置信度;
  • 后面有六个数据,是因为我这里有六个类别的物体,因此每个数据代表每一类的分类置信度。
tensor([[[3.8242e+00, 3.9844e+00, 9.9609e+00, 1.0477e+01, 2.0862e-06,
          9.3384e-03, 3.2898e-02, 2.3975e-03, 6.5369e-02, 8.0322e-01,
          4.5258e-02],
         [1.1172e+01, 3.0117e+00, 2.1844e+01, 7.0781e+00, 4.4703e-06,
          9.1629e-03, 3.9490e-02, 2.4433e-03, 6.0974e-02, 7.5830e-01,
          5.0812e-02],
          ...
          ...
         [5.8500e+02, 6.4350e+02, 1.4875e+02, 9.3812e+01, 1.7881e-06,
          2.4658e-02, 2.3331e-02, 9.5978e-03, 5.4492e-01, 1.3086e-01,
          5.5939e-02],
         [6.2100e+02, 6.5300e+02, 1.1762e+02, 1.0444e+02, 1.4901e-06,
          3.0045e-02, 3.2715e-02, 1.2238e-02, 3.6963e-01, 2.3193e-01,
          5.3101e-02]]], device='cuda:0', dtype=torch.float16)
  • prediction.shape[0] = 1:一张图片就是1

  • prediction.shape[1] = 26460:这里可能表示预测框的数量

  • prediction.shape[2] = 11:每一条数据中数据的个数(4个位置信息+1个框置信度+6个分类置信度)

  • 我们看第一条数据

    [3.8242e+00, 3.9844e+00, 9.9609e+00, 1.0477e+01, 2.0862e-06, 9.3384e-03, 3.2898e-02, 2.3975e-03, 6.5369e-02, 8.0322e-01, 4.5258e-02]
    
    [center_x, center_y, width, height, cls_conf, obj_conf0, obj_conf1, obj_conf2, obj_conf3, obj_conf4, obj_conf5]
    

2

import torch

x = torch.tensor([[2.9950e+02, 2.9025e+02, 1.6962e+02, 3.0675e+02,
                   2.3785e-03,
                   2.0580e-03, 6.3610e-04, 9.6226e-04, 9.9756e-01, 3.0270e-03, 3.2864e-03]],
                   dtype=torch.float16)
print(x[:, 5:])
print(x[:, 4:5])
x[:, 5:] *= x[:, 4:5]
print(x[:, 5:])

输出

tensor([[2.0580e-03, 6.3610e-04, 9.6226e-04, 9.9756e-01, 3.0270e-03, 3.2864e-03]], dtype=torch.float16)
tensor([[0.0024]], dtype=torch.float16)
tensor([[4.8876e-06, 1.4901e-06, 2.2650e-06, 2.3727e-03, 7.2122e-06, 7.8082e-06]], dtype=torch.float16)

print(x[:, 5:]):表示第五个数据后面的数
print(x[:, 4:5]):表示第四个数据的近似值

3

# Detections matrix nx6 (xyxy, conf, cls)
        if multi_label:
            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
        else:  # best class only
            conf, j = x[:, 5:].max(1, keepdim=True)
            x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]

x[:, 5:] > conf_thres:每个分类置信度大于conf_thres为True,否则False

tensor([[False, True, False,  False, False, False],
        [False, False, False,  False, False, False],
        [False, False, False,  True, False, False],
        [False, False, False,  True, False, False],
        [False, False, False,  True, False, True],
        [False, False, False,  True, False, False]], device='cuda:0')

i:

tensor([0, 2, 3, 4, 4, 5], device='cuda:0')
tensor([0条数据有1True,2条数据有1True,3条数据有1True,4条数据有2_1个True,4条数据有2_2个True,5条数据有1True], device='cuda:0')

j:

tensor([1, 3, 3, 3, 5, 3], device='cuda:0')
tensor([0条数据第1个位置为True,2条数据第3个位置为True,3条数据第3个位置为True,4条数据第3个位置为True,4条数据第5个位置为True,5条数据第3个位置为True, device='cuda:0')

j其实就是类别索引

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-11-09 19:29:17  更:2021-11-09 19:34:19 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 7:50:30-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码