IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 0927锚框(Anchor box) -> 正文阅读

[人工智能]0927锚框(Anchor box)

锚框(Anchor box)

  • 目标检测算法中,通常会在输入图像中采样大量的区域,然后判断这些区域中是否包含所感兴趣的目标,并调整区域边界从而更加准确地预测目标的真实边界框(ground-truth bounding box)

基于锚框的目标检测算法

  • 也有不基于锚框的目标检测算法,但是基于锚框的目标检测算法占主流
  • 以每个像素为中心,生成多个缩放比(scale)宽高比(aspect ratio)不同的边界框(这里的以每个像素点为中心指的是锚框中心点的像素,当中心位置给定时,已知宽和高的锚框是确定的)

关于生成多个锚框:

  • 锚框的宽度和高度分别是 w * s * sqrt(r) 和 h * s / sqrt(r),可得锚框的面积是 w * h * s ^2,因为 s ∈(0,1],可以得到锚框的最大面积是 w * h,也就是输入图像的面积(锚框的宽度和高度的表达式好像有错误,欢迎小伙伴指出)
  • 这里的 n + m - 1 的意思是:在实践中,只考虑包含 s1 或者 r1 的组合,s1 和 m 个宽高比共有 m 个组合,r1 和 n 个缩放比共有 n 个组合,这两种情况中(s1,r1)重复算了一次,所以最终以同一个像素为中心的锚框数量是 n + m - 1,因为输入图像的高度为 h ,宽度为 w,所以输入图像总共有 m * n 个像素,因此对于整个输入图像,总共生成了 w * h * (n + m - 1) 个锚框

  • 锚框和边缘框的区别:边缘框bounding box指的是所标号的真实物体的位置;锚框anchor box指的是算法对边缘框,也就是物体真实位置的猜测
  • 提出多个被称为锚框的区域
  • 预测每个锚框里是否含有关注的物体
  • 如果是,预测从这个锚框到真实边缘框的偏移
  • (因为算法本身并不知道边缘框,也就是标号物体真实的位置,如果直接对位置进行预测,预测边缘框的四个坐标值的话相对来讲比较困难,所以一般算法的操作是先提出一些框,然后首先判断这些框中是否包含目标物体,如果确定包含目标物体的话,接下来预测基于该锚框相对于边缘框的偏移,也就是说基于锚框的目标检测不是直接对边缘框的四个坐标值进行预测,而是先提出一些锚框,然后等到差不多包含目标物体的时候,再对锚框进行调整到边缘框的位置)

整个过程中包含两次预测:

  • 类别:预测锚框中所包含物体的类别
  • 位置:预测锚框到边缘框的位置偏移

IoU-交并比

  • 用于衡量锚框和真实边缘框之间的相似度,两个框之间的交集与两个框的并集的比值
  • 取值范围[0,1]:0表示没有重叠1表示完全重合(越接近1,两个框的相似度越高)
  • 它是Jacquard指数的特殊情况(给定两个集合,Jacquard指数表示两个集合的交集和两个集合的并集之间的比值)

  • 如果将任何边界框的像素区域中的像素看成是集合中的元素,每个框就可以看成是像素的集合,IoU就等价于Jacquard指数

在训练数据中标注锚框

在训练集中,将每个锚框视为一个训练样本,为了训练目标检测模型,需要每个锚框的类别(class,与锚框相关的对象的类别)偏移量(offset,真实边缘框相对于锚框的偏移量)标签

在预测的时候,首先为每个图像生成多个锚框,预测所有锚框的类别和偏移量,根据预测的偏移量调整它们的位置以获得预测的边缘框,最后只输出符合特定条件的预测边缘框

  • 基于锚框的目标检测是首先提出多个锚框,然后对锚框是否包含所感兴趣的物体以及锚框相对于边缘框的偏移进行预测
  • 所以在训练的时候,每一个锚框是一个训练样本
  • 对于每一个锚框来说,要么被标注成背景(不包含所感兴趣的物体,只包含背景),要么关联上一个真实的边缘框(锚框所框住的物体的标号与所关联的边缘框所包含的物体的标号相同;锚框相对于边缘框的偏移就是相对于它所关联的边缘框的偏移,这个偏移量根据锚框和真实边缘框中心坐标的相对位置以及这两个框的相对大小进行标记)
  • 鉴于数据集内不同的框的位置和大小不同,可以对那些相对位置和大小应用变换,使其获得分布更均匀且易于拟合的偏移量
  • 一般来讲算法会生成大量的锚框,而只有少量的边缘框,绝大部分锚框都是背景,背景类别的锚框通常被称为“负类”锚框,其余的被称为“正类”锚框

如何赋予锚框标号?

  • 目标检测的训练集中带有真实边界框的位置以及其所包围物体类别的标签,所以如果要标记所生成的锚框,可以参考分配到的最接近此锚框的真实边界框的位置和类别标签

步骤:

关于锚框的生成

  • 固定生成
  • 根据图片生成锚框

使用非极大值抑制(non-maximum suppression,NMS)输出

  • 在预测时会为图像生成多个锚框,然后再为这些锚框逐个预测类别和偏移量,一个预测好的边界框是根据其中某个带有预测偏移量的锚框而生成的。所以最终会得到很多相似的具有明显重叠的预测边缘框,而且都是围绕着同一个目标,因此需要对这些相似的框进行剔除,最终保留下来比较干净的预测输出结果
  • NMS也是剔除方法之一,首先选中所有预测框中非背景类的最大预测值(对类的预测的softmax值,越接近于1置信度越高),然后去掉所有其它和它IoU值大于θ的预测值(也就是去掉和最大预测值相似度比较高的其它锚框),重复这个过程,直到所有的预测框要么被选中,要么被去掉,最终得到一个比较干净的输出(NMS的输出)

  • 在执行非极大值抑制前,可以将置信度较低的预测边缘框移除,从而减少算法中的计算量;也可以对非极大抑制的输出结果进行后处理,比如只保留置信度更高的结果作为最终输出

总结

  • 目前主流的目标检测算法都是基于锚框来预测的
  • 首先以图像的每个像素为中心生成大量不同形状的锚框(不同的算法生成锚框的方法不同),并对每个锚框赋予标号(对锚框赋予标号的方法也有很多种),这样每个锚框就会有一个类别标号相对与边缘框的偏移,因此每个锚框可以作为一个样本进行训练
  • 交并比用于衡量两个边界框的相似性,它等于两个边界框像素区域的交集并集之间的比值
  • 在训练集中,需要给每个锚框两种类型的标签:1、锚框中目标检测的类别;2、锚框相对于真实边缘框的偏移量
  • 在预测的时候会对每个锚框进行预测,会生成大量冗余的预测,可以使用NMS来去掉冗余的预测,从而简化输出

代码:

%matplotlib inline
import torch
from d2l import torch as d2l

torch.set_printoptions(2)

def multibox_prior(data,sizes,ratios):
    """生成以每个像素为中心具有不同高宽度的锚框"""
    in_height, in_width = data.shape[-2:] # data.shape的最后两个元素为宽和高,第一个元素为通道数  
    device, num_sizes, num_ratios = data.device, len(sizes), len(ratios)
    boxes_per_pixel = (num_sizes + num_ratios - 1) # 按照上面组合对应的所有锚框数量
    size_tensor = torch.tensor(sizes, device=device)
    ratio_tensor = torch.tensor(ratios, device=device)
    
    offset_h, offset_w = 0.5, 0.5
    steps_h = 1.0 / in_height
    steps_w = 1.0 / in_width
    
    # torch.arange(in_height, device=device)获得每一行像素
    # (torch.arange(in_height, device=device) + offset_h) 获得每一行像素的中心
    # (torch.arange(in_height, device=device) + offset_h) * steps_h 对每一行像素的中心坐标作归一化处理  
    center_h = (torch.arange(in_height, device=device) + offset_h) * steps_h
    center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w
    shift_y, shift_x = torch.meshgrid(center_h, center_w)
    shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)
    
    w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]),
                  sizes[0] * torch.sqrt(ratio_tensor[1:]))) \
                    * in_height / in_width
    h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]),
                  sizes[0] / torch.sqrt(ratio_tensor[1:])))
    
    anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(in_height * in_width, 1) / 2
    out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y], dim=1).repeat_interleave(boxes_per_pixel, dim=0)   
    output = out_grid + anchor_manipulations
    return output.unsqueeze(0)

# 返回锚框变量Y的形状
img = d2l.plt.imread('01_Data/03_catdog.jpg')
print("img.shape:",img.shape) # 高561,宽72,3通道
h, w = img.shape[:2]
print(h,w)

X = torch.rand(size=(1,3,h,w)) # 批量大小为1,3通道
Y = multibox_prior(X, sizes=[0.75,0.5,0.25], ratios=[1,2,0.5]) # 占图片sizes尺寸的大小、高宽比ratios尺寸大小的锚框   
print(Y.shape) # 1 是批量大小,2042040是一张图片生成的锚框数量,4个元素时每个锚框对应的位置 

# 访问以(250,250)为中心的第一个锚框
boxes = Y.reshape(h,w,5,4)  # 上面的sizes×sizes=3×3,3+3-1=5,故每个像素为中心生成五个锚框    
boxes[250,250,0,:] # 以250×250为中心的第一个锚框的坐标

# 显示以图像中一个像素为中心的所有锚框
def show_bboxes(axes, bboxes, labels=None, colors=None):
    """显示所有边界框"""
    def _make_list(obj, default_values=None):
        if obj is None:
            obj = default_values
        elif not isinstance(obj, (list, tuple)):
            obj = [obj]
        return obj
    
    labels = _make_list(labels)
    colors = _make_list(colors, ['b','g','r','m','c'])
    for i, bbox in enumerate(bboxes):
        color = colors[i % len(colors)]
        rect = d2l.bbox_to_rect(bbox.detach().numpy(),color)
        axes.add_patch(rect)
        if labels and len(labels) > i:
            text_color = 'k' if color == 'w' else 'w'
            axes.text(rect.xy[0], rect.xy[1], labels[i], va='center',
                     ha='center', fontsize=9, color=text_color,
                     bbox=dict(facecolor=color, lw=0))
            
            
d2l.set_figsize()
bbox_scale = torch.tensor((w,h,w,h)) # 高宽
fig = d2l.plt.imshow(img)
print("fig.axes:",fig.axes)
# 在生成锚框的时候是0-1的值,进行缩放的话就可以省点乘法运算,因为最后输出并不需要显示所有锚框,所以可能会更快一点
print("boxes[250,250,:,:]:",boxes[250,250,:,:])
print("bbox_scale:", bbox_scale)
print("boxes[250,250,:,:] * bbox_scale:",boxes[250,250,:,:] * bbox_scale)
show_bboxes(fig.axes, boxes[250,250,:,:] * bbox_scale, ['s=0.75, r=1','s=0.5, r=1','s=0.25, r=1','s=0.75,r=2','s=0.75,r=0.5']) # 画出以250×250像素为中心的不同高宽比的五个锚框                                  
# 交并比(IoU)
def box_iou(boxes1,boxes2):
    """计算两个锚框或边界框列表中成对的交并比"""
    box_area = lambda boxes: ((boxes[:,2] - boxes[:,0]) *
                             (boxes[:,3] - boxes[:,1]))
    areas1 = box_area(boxes1) # 锚框1的面积
    areas2 = box_area(boxes2) # 锚框2的面积
    inter_upperlefts = torch.max(boxes1[:,None,:2],boxes2[:,:2]) 
    inter_lowerrights = torch.min(boxes1[:,None,2:],boxes2[:,2:])
    inters = (inter_lowerrights - inter_upperlefts).clamp(min=0)
    inter_areas = inters[:,:,0] * inters[:,:,1] # 交集的面积
    union_areas = areas1[:,None] + areas2 - inter_areas # 并集的面积
    return inter_areas / union_areas

# 将真实边界框分配给锚框
def assign_anchor_to_bbox(ground_truth,anchors,device,iou_threshold=0.5):
    """将最接近的真实边界框分配给锚框"""
    num_anchors, num_gt_boxes = anchors.shape[0], ground_truth.shape[0]
    jaccard = box_iou(anchors,ground_truth) # 计算所有的锚框和真实边缘框的IOU
    anchors_bbox_map = torch.full((num_anchors,), -1, dtype=torch.long, device=device)    
    max_ious, indices = torch.max(jaccard, dim=1)
    anc_i = torch.nonzero(max_ious >= 0.5).reshape(-1)
    box_j = indices[max_ious >= 0.5]
    anchors_bbox_map[anc_i] = box_j
    col_discard = torch.full((num_anchors,),-1)
    row_discard = torch.full((num_gt_boxes,),-1)
    for _ in range(num_gt_boxes):
        max_idx = torch.argmax(jaccard) # 找IOU最大的锚框
        box_idx = (max_idx % num_gt_boxes).long()
        anc_idx = (max_idx / num_gt_boxes).long()
        anchors_bbox_map[anc_idx] = box_idx
        jaccard[:,box_idx] = col_discard # 把最大Iou对应的锚框在 锚框-类别 矩阵中的一列删掉
        jaccard[anc_idx,:] = row_discard # 把最大Iou对应的锚框在 锚框-类别 矩阵中的一行删掉
    return anchors_bbox_map



?

def offset_boxes(anchors,assigned_bb,eps=1e-6):
    """对锚框偏移量的转换"""
    c_anc = d2l.box_corner_to_center(anchors)
    c_assigned_bb = d2l.box_corner_to_center(assigned_bb)
    offset_xy = 10 * (c_assigned_bb[:,:2] - c_anc[:,:2] / c_anc[:,2:])
    offset_wh = 5 * torch.log(eps + c_assigned_bb[:,2:] / c_anc[:,2:])         
    offset = torch.cat([offset_xy, offset_wh], axis=1)
    return offset # 尽量使得 offset 让 machine learning 算法好预测

# 标记锚框的类和偏移量
def multibox_target(anchors, labels):
    """使用真实边界框标记锚框"""
    batch_size, anchors = labels.shape[0], anchors.squeeze(0)
    batch_offset, batch_mask, batch_class_labels = [], [], []
    device, num_anchors = anchors.device, anchors.shape[0]
    for i in range(batch_size):
        label = labels[i,:,:]
        anchors_bbox_map = assign_anchor_to_bbox(label[:,1:],anchors,device)   
        bbox_mask = ((anchors_bbox_map >= 0).float().unsqueeze(-1)).repeat(1,4)   
        class_labels = torch.zeros(num_anchors, dtype=torch.long,device=device)  
        assigned_bb = torch.zeros((num_anchors,4), dtype=torch.float32,device=device)   
        indices_true =torch.nonzero(anchors_bbox_map >= 0)
        bb_idx = anchors_bbox_map[indices_true]
        class_labels[indices_true] = label[bb_idx,0].long() + 1
        assigned_bb[indices_true] = label[bb_idx, 1:]
        offset = offset_boxes(anchors, assigned_bb) * bbox_mask
        batch_offset.append(offset.reshape(-1))
        batch_mask.append(bbox_mask.reshape(-1))
        batch_class_labels.append(class_labels)
    bbox_offset = torch.stack(batch_offset)
    bbox_mask = torch.stack(batch_mask)
    class_labels = torch.stack(batch_class_labels)
    # 返回每一个锚框到真实标注框的offset偏移
    # bbox_mask为0表示背景锚框,就不用了,为1表示对应真实的物体
    # class_labels为锚框对应类的编号
    return (bbox_offset, bbox_mask, class_labels)

# 两个真实边缘框的位置信息
ground_truth = torch.tensor([[0,0.1,0.08,0.52,0.92],
                            [1,0.55,0.2,0.9,0.88]])

# 五个锚框的位置信息
anchors = torch.tensor([[0,0.1,0.2,0.3],[0.15,0.2,0.4,0.4],
                       [0.63,0.05,0.88,0.98],[0.66,0.45,0.8,0.8],
                       [0.57,0.3,0.92,0.9]])

fig = d2l.plt.imshow(img)
show_bboxes(fig.axes,ground_truth[:,1:] * bbox_scale, ['dog','cat'],'k')   
show_bboxes(fig.axes,anchors * bbox_scale, ['0','1','2','3','4'])

# anchors.unsqueeze(dim=0)在0号位置加了一个批量维度,该批量维度大小为1
labels = multibox_target(anchors.unsqueeze(dim=0),ground_truth.unsqueeze(dim=0))  
print(len(labels)) # labels 对应 multibox_target 函数返回的  (bbox_offset, bbox_mask, class_labels)
print(labels[2]) # labels[2]有五个锚框 0表示背景、1表示狗、2表示猫 这里3号框被表示为背景是因为被2号框和四号框非极大值抑制了  
print(labels[1]) # 锚框是不是对应是真实物体
print(labels[0]) # 每一个锚框有四个值,0表示不需要预测,

# 输出由非极大值抑制保存的最终预测边界框
fig = d2l.plt.imshow(img)
print("output[0]:", output[0])
for i in output[0].detach().numpy(): 
    print(i)
    if i[0] == -1: # 值-1表示背景或在非极大值抑制中被移除了
        continue
    print("int(i[0]):", int(i[0]))  # i[0]=0表示狗,i[0]=1表示猫,即i的第一个元素表示框对应的类别   
    print("str(i[1]):", str(i[1]))  # i的第二元素表示该类别的置信度
    label = ('dog=', 'cat=')[int(i[0])] + str(i[1]) # 取('dog=', 'cat=')元组的第int(i[0]位置与str(i[1])进行拼接             
    print("label:",label)
    show_bboxes(fig.axes, [torch.tensor(i[2:]) * bbox_scale], label)

?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-09-30 00:52:59  更:2022-09-30 00:55:38 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/28 3:20:21-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计