IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 跟李沐学AI—pytorch锚框代码解析——3 -> 正文阅读

[人工智能]跟李沐学AI—pytorch锚框代码解析——3

跟李沐学AI–锚框代码解析–3

非极大值抑制预测边界框

  • 当存在许多锚框时,可能会输出许多相似的具有明显重带你的预测边界框,围绕同一目标,为了简化输出,使用给非极大值抑制(non-maximum suppression, NMS)合并对应目标为同一类的类似的预测边界框
  • 其工作原理如下:
    • 基础概念:对于一个预测边界框B,目标检测模型会计算每个类的预测概率,最大预测概率 p p p 所对应的类别,就是边框 B B B 的类别,这里 p p p B B B置信度,对于同一张图像,所有非背景预测边框按照置信都降序排序,生成列表 L L L
    • 操作过程:
    • L L L 中选取置信度最高的预测边界框 B 1 B_1 B1?作为基准,然后将所有与 B 1 B_1 B1? I o U IoU IoU 超过预定阈值 ? \epsilon ? 的非基准预测边接框从 L L L 中移除。此时在 L L L 中,对于 B 1 B_1 B1? 只剩下一个可用边界框,其他的相似锚框均基于上述标准被删除
      • 注意这里求IoU值,是以 B 1 B_1 B1? 为基准,使用其他预测锚框与 B 1 B_1 B1? 进行计算IoU,而不是真实边框
    • L L L 中选取置信度第二稿的预测边框 B 2 B_2 B2? 作为有一个基准,然后将所有与 B 2 B_2 B2?的IoU大于 ? \epsilon ?的非基准预测边界框从 L L L 中移除;
    • 重复上述过程,历遍 L L L中所有的锚框,直到 L L L中所有的预测边界框都曾被用作基准;此时 L L L中任意一对预测边界框的IoU都小于阈值 ? \epsilon ?,没有一对锚框相似
    • 代码如下:
    •   def nms(boxes, scores, iou_threshold):
            """对预测边界框的置信度进行排序
            args: 
                boxes: 预测边框
                    [anchors_num, 4]
                scores: 置信度
                    [anchors_num]
                iou_threshold: iou阈值
            """
            B = torch.argsort(scores, dim=-1, descending=True)
            '''返回scores排序后的下标
               B --> tensor([0, 3, 1, 2])
            '''
            keep = []  # 保留预测边界框的指标
            '''B.numel() 返回的tensor中的元素个数'''
            while B.numel() > 0:
                i = B[0]
                keep.append(i)
                if B.numel() == 1: break
                iou = box_iou(boxes[i, :].reshape(-1, 4),
                              boxes[B[1:], :].reshape(-1, 4)).reshape(-1)
                '''iou 计算的为 B1与 B2, B3,...的iou一维矩阵
                   iou --> tensor([0.00, 0.74, 0.55])'''
                inds = torch.nonzero(iou <= iou_threshold).reshape(-1)
                ''' inds 返回的为所有iou小于阈值的下标 '''
                B = B[inds + 1]
                '''由于iou矩阵长度为 anchors_num-1, 
                   剔除了最大的数值,因此在这里需要加1'''
            return torch.tensor(keep, device=boxes.device)
      

非极大抑制方法的应用:

  • 该部分由一个函数实现,主要步骤简述如下:
    • a. 根据锚框与类的置信度矩阵,求取每个锚框的最大置信度和其最大置信度所对应的类
    • b. 利用转换函数,将带偏移量的锚框转为预测锚框,并基于预测锚框使用非极大值抑制方法进行筛选,并将keep和non_keep的下标进行排序(使用的为torch.cat直接拼接),其中keep对应的是物种种类,non_keep对应的为背景,利用合并排序后的下标对 锚框最大置信度预测边框的顺序进行重排
    • c. 对置信度小于置信度阈值的锚框进行处理,设置为背景锚框,在锚框类的预测概率中储存的为 1- p p p
    • d. 最后对上述结果进行合并,最内层维度中的六个元素提供了同一预测边界框的输出信息。 第一个元素是预测的类索引,从 0 开始(0代表狗,1代表猫),值 -1 表示背景或在非极大值抑制中被移除了。 第二个元素是预测的边界框的置信度。 其余四个元素分别是预测边界框左上角和右下角的 (x,y) 轴坐标(范围介于 0 和 1 之间)。
  • 代码:
def multibox_detection(cls_probs, offset_preds, anchors, nms_threshold=0.5,
                       pos_threshold=0.009999999):
    """使用非极大值抑制来预测边界框
    args:
        cls_probs: 锚框对于不同类别的概率
            [batch_size, 1+class_num, anchors_num]
        offset_preds: 不同锚框的偏移量
            [anchors_num * 4]
        anchors: 锚框矩阵
            [anchors_num]
    """
    device, batch_size = cls_probs.device, cls_probs.shape[0]
    anchors = anchors.squeeze(0)
    num_classes, num_anchors = cls_probs.shape[1], cls_probs.shape[2]
    out = []
    for i in range(batch_size):
        cls_prob, offset_pred = cls_probs[i], offset_preds[i].reshape(-1, 4)
        '''得到最大置信度,及所对应的种类
            cls_prob 每一列代表的为单一锚框对应的不同类的置信度
            conf: 每个锚框对于不同类的最大置信度 --> [anchors_num]
            class_id: 每个锚框最大置信度对应的种类 --> [anchors_num]'''
        conf, class_id = torch.max(cls_prob[1:], 0)
        '''将带偏移量的边框转变为预测边框'''
        predicted_bb = offset_inverse(anchors, offset_pred)
        keep = nms(predicted_bb, conf, nms_threshold)
        # 找到所有的 non_keep 索引,并将类设置为背景,就是设置为-1
        all_idx = torch.arange(num_anchors, dtype=torch.long, device=device)
        '''找到没有非极大值边框的编号,并排序,keep在前,non_keep在后'''
        combined = torch.cat((keep, all_idx))
        uniques, counts = combined.unique(return_counts=True)
        non_keep = uniques[counts == 1]
        ''' all_id_sorted作为之后置信度和预测边框的索引 '''
        all_id_sorted = torch.cat((keep, non_keep))
        '''对于没有保留的锚框,认为是背景锚框,基于non_keep将class_id变成 -1
           并按all_id_sorted 对class_id 进行重排'''
        class_id[non_keep] = -1 
        class_id = class_id[all_id_sorted]
        ''' 将各个锚框的最大置信度和各个预测框,按照NMS值进行排序 '''
        conf, predicted_bb = conf[all_id_sorted], predicted_bb[all_id_sorted]
        # `pos_threshold` 是一个用于非背景预测的阈值
        '''将置信度小于阈值的预测框的id设置为 -1, 抛弃'''
        below_min_idx = (conf < pos_threshold)
        class_id[below_min_idx] = -1
        '''对小于阈值的预测锚框的概率值,进行计算处理'''
        conf[below_min_idx] = 1 - conf[below_min_idx]
        ''' 将类别信息、置信度和预测边框按列合并
           pred_info --> [anchors_num, 6] '''
        pred_info = torch.cat((class_id.unsqueeze(1),
                               conf.unsqueeze(1),
                               predicted_bb), dim=1)
        out.append(pred_info)
    return torch.stack(out)
  • 带偏移量的锚框转换函数的代码如下:
def offset_inverse(anchors, offset_preds):
    """根据带有预测偏移量的锚框来计算预测边界框。"""
    anc = d2l.box_corner_to_center(anchors)
    pred_bbox_xy = (offset_preds[:, :2] * anc[:, 2:] / 10) + anc[:, :2]
    pred_bbox_wh = torch.exp(offset_preds[:, 2:] / 5) * anc[:, 2:]
    pred_bbox = torch.cat((pred_bbox_xy, pred_bbox_wh), axis=1)
    predicted_bbox = d2l.box_center_to_corner(pred_bbox)
    return predicted_bbox
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-10-13 11:27:11  更:2021-10-13 11:27:53 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 13:01:46-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码