IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【动手学深度学习】之非极大值抑制(NMS)代码实现 -> 正文阅读

[人工智能]【动手学深度学习】之非极大值抑制(NMS)代码实现

import torch
from d2l import torch as d2l

# 更改打印设置
torch.set_printoptions(2)



def show_bboxes(axes, bboxes, labels=None, colors=None):
    """显?所有边界框"""
    def _make_list(obj, default_values=None):
        if obj is None:
            obj = default_values
        elif not isinstance(obj, (list, tuple)):
            obj = [obj]
        return obj

    labels = _make_list(labels)
    colors = _make_list(colors, ['b', 'g', 'r', 'm', 'c'])
    for i, bbox in enumerate(bboxes):
        color = colors[i % len(colors)]
        rect = d2l.bbox_to_rect(bbox.detach().numpy(), color)
        axes.add_patch(rect)
    if labels and len(labels) > i:
        text_color = 'k' if color == 'w' else 'w'
        axes.text(rect.xy[0], rect.xy[1], labels[i],
            va='center', ha='center', fontsize=9, color=text_color,
            bbox=dict(facecolor=color, lw=0))

def box_iou(boxes1, boxes2):
    """计算两个锚框或边界框列表中成对的交并比"""
    box_area = lambda boxes: ((boxes[:, 2] - boxes[:, 0]) *
                              (boxes[:, 3] - boxes[:, 1]))
    # boxes1,boxes2,areas1,areas2的形状:
    # boxes1:(boxes1的数量,4),
    # boxes2:(boxes2的数量,4),
    # areas1:(boxes1的数量,),
    # areas2:(boxes2的数量,)
    # 两个锚框的面积
    areas1 = box_area(boxes1)
    areas2 = box_area(boxes2)

    # inter_upperlefts,inter_lowerrights,inters的形状:
    # (boxes1的数量,boxes2的数量,2)
    inter_upperlefts  = torch.max(boxes1[:, None, :2], boxes2[:, :2])
    inter_lowerrights = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])
    inters = (inter_lowerrights - inter_upperlefts).clamp(min=0)

    # 交集
    inter_areas = inters[:,:,0] * inters[:,:,1]
    # 并集
    union_areas = areas1[:, None] + areas2 - inter_areas
    return inter_areas / union_areas

def offset_inverse(anchors, offset_preds):
    """根据带有预测偏移量的锚框来预测边界框"""
    anc = d2l.box_corner_to_center(anchors)
    pred_bbox_xy = (offset_preds[:, :2] * anc[:, 2:] / 10) + anc[:, :2]
    pred_bbox_wh = torch.exp(offset_preds[:, 2:] / 5) * anc[:, 2:]
    pred_bbox = torch.cat((pred_bbox_xy, pred_bbox_wh), axis=1)
    predicted_bbox = d2l.box_center_to_corner(pred_bbox)
    return predicted_bbox

def nms(boxes, scores, iou_threshold):
    """ 对预测边界框的置信度降序排列"""
    # argsort()函数默认将元素从小到大排列,提取其对应的索引输出。
    B= torch.argsort(scores, dim=-1, descending=True)
    # 保留预测边界框的指标
    keep = []

    while B.numel()>0:
        # 取B的第一个元素,也就是最大的预测概率
        i = B[0]
        keep.append(i)
        # 如果取完所有的类
        if B.numel() == 1: break
        # 计算最大预测概率相应的锚框与其他所有锚框的iou
        iou = box_iou(
            boxes[i,:].reshape(-1,4),
            boxes[B[1:],:].reshape(-1,4)
        ).reshape(-1)
        # 找出iou中小于0.5的索引,其对应元素可能是另一类的物体
        # 大于0.5的iou一般是重复的锚框
        inds = torch.nonzero(iou <= iou_threshold).reshape(-1)
        # 将B截止到第二高预测边界框,(+1是因为boxes[B[1:],:]从第二个框开始)
        B = B[inds + 1]

    return torch.tensor(keep, device=boxes.device)


def multibox_detection(cls_probs, offset_preds, anchors, nms_threshold=0.5,
pos_threshold=0.009999999):
    """NMS预测边界框"""
    # 找到数据位置(device)与概率的batch_size
    device, batch_size = cls_probs.device, cls_probs.shape[0]
    anchors = anchors.squeeze(0)
    # 得到类别数与锚框数
    num_classes, num_anchors = cls_probs.shape[1], cls_probs.shape[2]
    out = []
    for i in range(batch_size):
        # 得到预测概率与偏移值
        cls_prob, offset_pred = cls_probs[i], offset_preds[i].reshape(-1, 4)
        # 找到每一个锚框中最大的预测概率,并返回值与索引
        conf, class_id = torch.max(cls_prob[1:], 0)
        # 根据预测偏移得到预测边界框
        predicted_bb = offset_inverse(anchors, offset_pred)
        # 运用NMS算法
        keep = nms(predicted_bb, conf, nms_threshold)
        
        # 找到所有的non_keep索引,并将类设置为背景
        all_idx = torch.arange(num_anchors, dtype=torch.long, device=device)
        # 将keep与新生成的all_idx组合起来
        combined = torch.cat((keep, all_idx))
        # 返回参数数组中所有不同的值,并从小到大排序
        # return_counts=True: 统计新列表元素中出现过的次数
        uniques, counts = combined.unique(return_counts=True)
        # 找出只出现过一次的元素索引
        non_keep = uniques[counts == 1]
        # 得到所有的锚框ID,前面是最大预测概率,后面是被抑制的锚框
        all_id_sorted = torch.cat((keep, non_keep))
        # 将被抑制的锚框标注在置信度索引中
        class_id[non_keep] = -1
        # 将class_id按照all_id_sorted顺序排序
        class_id = class_id[all_id_sorted]
        # 将conf, predicted_bb按照all_id_sorted顺序排序
        conf, predicted_bb = conf[all_id_sorted], predicted_bb[all_id_sorted]
        # pos_threshold是?个?于?背景预测的阈值
        below_min_idx = (conf < pos_threshold)
        # 将背景的对应锚框ID改为-1
        class_id[below_min_idx] = -1
        # 将背景对应锚框取预测概率相反值
        conf[below_min_idx] = 1 - conf[below_min_idx]
        # 将锚框所有属性拼接起来
        # 第一个索引为预测的类索引
        # 第二个索引是预测边界框的置信度
        # 第三到第六个索引是锚框坐标
        pred_info = torch.cat((class_id.unsqueeze(1),
                                conf.unsqueeze(1),
                                predicted_bb), dim=1)
        # print(pred_info)
        out.append(pred_info)
    return torch.stack(out)

anchors = torch.tensor([[0.1, 0.08, 0.52, 0.92], 
                        [0.08, 0.2, 0.56, 0.95],
                        [0.15, 0.3, 0.62, 0.91], 
                        [0.55, 0.2, 0.9, 0.88]])
offset_preds = torch.tensor([0] * anchors.numel())
cls_probs = torch.tensor([[0] * 4, # 背景的预测概率
                        [0.9, 0.8, 0.7, 0.1], # 狗的预测概率
                        [0.1, 0.2, 0.3, 0.9]]) # 猫的预测概率

output = multibox_detection(cls_probs.unsqueeze(dim=0),
                offset_preds.unsqueeze(dim=0),
                anchors.unsqueeze(dim=0),
                nms_threshold=0.5)

# 导入图片
img = d2l.plt.imread('catdog.jpg')
h, w = img.shape[:2]
d2l.set_figsize()
bbox_scale = torch.tensor((w, h, w, h))

fig = d2l.plt.imshow(img)
show_bboxes(fig.axes, anchors * bbox_scale,['dog=0.9', 'dog=0.8', 'dog=0.7', 'cat=0.9'])

fig = d2l.plt.imshow(img)
for i in output[0].detach().numpy():
    if i[0] == -1:
        continue
    label = ('dog=', 'cat=')[int(i[0])] + str(i[1])
    show_bboxes(fig.axes, [torch.tensor(i[2:]) * bbox_scale], label)

【展示】

?(自学李沐老师《动手学深度学习》使用,仅供参考,侵权删除)

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-12-15 18:17:53  更:2021-12-15 18:19:03 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/27 0:36:01-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码