IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> MMDet逐行解读之DeltaXYWHBBoxCoder -> 正文阅读

[人工智能]MMDet逐行解读之DeltaXYWHBBoxCoder

前言

? 本篇是MMdet逐行解读第三篇,代码地址:mmdet/core/bbox/coder/delta_xywh_bbox_coder.py。历史文章如下:
?AnchorGenerator解读
?MaxIOUAssigner解读

1、BaseBBoxCoder父类

? 该类是所有bbox编解码类的父类,代码比较容易理解,即所有继承该类的子类均需要实现encode和decode两个方法。

class BaseBBoxCoder(metaclass=ABCMeta):
    """Base bounding box coder"""

    def __init__(self, **kwargs):
        pass

    @abstractmethod
    def encode(self, bboxes, gt_bboxes):
        """Encode deltas between bboxes and ground truth boxes"""
        pass

    @abstractmethod
    def decode(self, bboxes, bboxes_pred):
        """
        Decode the predicted bboxes according to prediction and base boxes
        """
        pass

2、DeltaXYWHBBoxCoder类

2.1. 理论基础

? 大多数基于anchor的目标检测算法均使用的该类。在目标检测算法中,为了利于网络的收敛,借助了anchor并回归的是anchor和gtbbox之间的偏差。所以,网络预测的是偏差,因此,在训练过程中,需要计算gtbbox和anchor之间的偏差真值t*。真实t*的计算方式如下:
在这里插入图片描述
? 其中 [x,y,w,h] 表示gtbbox的中心宽和高;[xa,ya,wa,ha] 表示anchor的中心宽和高。简单来说,tx* ,ty* 表示二者做差除以宽高做了归一化;tw*,th*就是取了个对数。

2.2、初始化部分

? 我们首先构造一个对象:

import torch
from mmdet.core.bbox import build_bbox_coder
if __name__ == '__main__':
    bbox_coder = dict(
        type='DeltaXYWHBBoxCoder',
        target_means=[.0, .0, .0, .0],
        target_stds=[1.0, 1.0, 1.0, 1.0])
    coder = build_bbox_coder(bbox_coder)
    # 构造两个预测向量和真值
    proposals = torch.tensor([[1,1,3,3],[4,4,6,6]])
    gt = torch.tensor([[2,2,3,3],[2,2,5,5]])
    target_t = coder.encode(proposals,gt)    # 调用解码方法

其中target_means和target_stds就是对上述的t*减均值除以标准差。

2.3、编码过程

? 这部分代码我已经做了注释,总体思路就是首先将proposals和gtbbox由[xmin, ymin, xmax, ymax]变成[cx, cy, w,h],之后计算t*,然后将t*减均值除以标准差。

    # 候选框数量和gt数量必须一致
    assert proposals.size() == gt.size() # [N,4]
    proposals = proposals.float()
    gt = gt.float()
    # proposals: [xmin, ymin, xmax, ymax] --> [cx, cy, w, h]
    px = (proposals[..., 0] + proposals[..., 2]) * 0.5   # [N]
    py = (proposals[..., 1] + proposals[..., 3]) * 0.5
    pw = proposals[..., 2] - proposals[..., 0]
    ph = proposals[..., 3] - proposals[..., 1]
    # gt: [xmin, ymin, xmax, ymax] --> [cx, cy, w, h]
    gx = (gt[..., 0] + gt[..., 2]) * 0.5
    gy = (gt[..., 1] + gt[..., 3]) * 0.5
    gw = gt[..., 2] - gt[..., 0]
    gh = gt[..., 3] - gt[..., 1]
	# 计算t*
    dx = (gx - px) / pw
    dy = (gy - py) / ph
    dw = torch.log(gw / pw)
    dh = torch.log(gh / ph)
    deltas = torch.stack([dx, dy, dw, dh], dim=-1)   # [N] --> [N,4]
	# 减均值除以标准差
    means = deltas.new_tensor(means).unsqueeze(0)    # [1,4]
    stds = deltas.new_tensor(stds).unsqueeze(0)      # [1,4]
    deltas = deltas.sub_(means).div_(stds)           # [N,4]

2.4、解码过程

? 解码过程常发生在测试阶段。将网络预测出的偏差t加到anchor上得到proposal(一阶算法)或者roi(二阶)用的。该过程就是编码过程相反操作,首先乘标准差在加上均值得到t,之后将t加到anchor上即可。

    # 均值和标准差: [4] --> [1,4]
    means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4)
    stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4)
    denorm_deltas = deltas * stds + means   # [N,4]
    # 得到dx,dy,dw,dh
    dx = denorm_deltas[:, 0::4]    # [N,1]  
    dy = denorm_deltas[:, 1::4]
    dw = denorm_deltas[:, 2::4]
    dh = denorm_deltas[:, 3::4]
    max_ratio = np.abs(np.log(wh_ratio_clip))
    dw = dw.clamp(min=-max_ratio, max=max_ratio) # 裁减下
    dh = dh.clamp(min=-max_ratio, max=max_ratio)
    # 将rois/proposal转成[cx,cy,w,h]格式:[N,] --> [N,1] --> [N,1]
    px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx) 
    py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy)
    # Compute width/height of each roi
    pw = (rois[:, 2] - rois[:, 0]).unsqueeze(1).expand_as(dw)
    ph = (rois[:, 3] - rois[:, 1]).unsqueeze(1).expand_as(dh)
    # 解码过程
    gw = pw * dw.exp()
    gh = ph * dh.exp()
    gx = px + pw * dx
    gy = py + ph * dy
    # 将[cx,cy,w,h] --> [xmin, ymin, xmax, ymax]格式
    x1 = gx - gw * 0.5
    y1 = gy - gh * 0.5
    x2 = gx + gw * 0.5
    y2 = gy + gh * 0.5
    # 裁减一下bbox,过大则裁减掉
    if max_shape is not None:
        x1 = x1.clamp(min=0, max=max_shape[1])
        y1 = y1.clamp(min=0, max=max_shape[0])
        x2 = x2.clamp(min=0, max=max_shape[1])
        y2 = y2.clamp(min=0, max=max_shape[0])
    # 返回修正过大预测框
    bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas)

总结

? 总体来说该部分源码比较简单,下篇介绍anchor的sampler部分。

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-01-24 10:50:21  更:2022-01-24 10:54:03 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 21:21:22-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码