IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 目标检测——paddleYOLOv3 -> 正文阅读

[人工智能]目标检测——paddleYOLOv3

1.?导入包,配置参数

import time
import os
import paddle

ANCHORS = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 373, 326]

ANCHOR_MASKS = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]

IGNORE_THRESH = .7
NUM_CLASSES = 7

def get_lr(base_lr = 0.0001, lr_decay = 0.1):
    bd = [10000, 20000]
    lr = [base_lr, base_lr * lr_decay, base_lr * lr_decay * lr_decay]
    learning_rate = paddle.optimizer.lr.PiecewiseDecay(boundaries=bd, values=lr)
    return learning_rate

2. TrainDataset类设置

初始的数据集数据组成:安全帽数据集共有5000张图片和5000个标注文件xml,每个xml文件对应一张图片。xml文件中含有图片路径,图片高宽,标注框位置信息(x1,y1,x2,y2)和类别(['helmet',?'head',?'person']?共3类)

该类属于自定义类,主要完成如下工作:

  1. 解析xml文件得到字典列表存储图片和标注信息。通过公式x?=?(x1?+?x2)/2, y?=?(y1?+?y2)/2,?w?=?x2?-?x1?+1,?h?=?y2?-?y1?+?1将(x,y,x,y)格式转化为(x,y,w,h)格式
  2. 根据字典列表的字典返回一张图片数据img,及标注?gt_boxes,?gt_labels和图片高宽(h,?w)。注意这里的真实框(x,y,w,h)使用的是相对值,所以要返回图片的高宽(h,w)用于恢复。
  3. 数据增强,包括明亮变化、缩放、归一化等。注意缩放后处于统一大小,便于批量读取。经过图像增广后,img的shape被缩放了大小,但(h, w)存储的还是原来的大小。

具体实现:目标检测--数据集处理?

 TRAINDIR = '/home/aistudio/work/insects/train'
 TESTDIR = '/home/aistudio/work/insects/test'
 VALIDDIR = '/home/aistudio/work/insects/val'
 paddle.set_device("gpu:0")
 # 创建数据读取类
 train_dataset = TrainDataset(TRAINDIR, mode='train')
 valid_dataset = TrainDataset(VALIDDIR, mode='valid')
 test_dataset = TrainDataset(VALIDDIR, mode='valid')
使用Dataloader加载数据,返回的格式为:img(batch, channel, w, h), 真实框gt_boxs(batch, num_boxs, location), 类别gt_labels(batch, num_boxs), 高宽w_h(batch, 2)
例如:([2, 3, 400, 400], [2, 10, 4], [2, 10], [2, 2])
 # 使用paddle.io.DataLoader创建数据读取器,并设置batchsize,进程数量num_workers等参数
train_loader = paddle.io.DataLoader(train_dataset, batch_size=10, shuffle=True,num_workers=0, drop_last=True, use_shared_memory=False)

valid_loader = paddle.io.DataLoader(valid_dataset, batch_size=10, shuffle=False,num_workers=0, drop_last=False, use_shared_memory=False)

3. 网络

YOLOV3由darknet53做骨干网络,输出3个层级的特征图。

model = YOLOv3(num_classes = NUM_CLASSES)  #创建模型
learning_rate = get_lr()
opt = paddle.optimizer.Momentum(
                 learning_rate=learning_rate,
                 momentum=0.9,
                 weight_decay=paddle.regularizer.L2Decay(0.0005),
                 parameters=model.parameters())  #创建优化器
# opt = paddle.optimizer.Adam(learning_rate=learning_rate, weight_decay=paddle.regularizer.L2Decay(0.0005), parameters=model.parameters())

4. 实现训练函数

重点是get_loss的实现,目标检测的损失,首先要根据真实框计算出预测框,预测框标注了锚框与真实框的中心点和高宽的偏差,物体的类别。模型预测出的是这些偏差值,损失由这些偏差值来建立。?

单尺度损失的计算步骤:目标检测YOLOv3的loss计算

虽然这里使用的是多尺度的目标检测,但损失的计算是在单层计算的基础上得到的。?

############# 这段代码在本地机器上运行请慎重,容易造成死机#######################   
    MAX_EPOCH = 200
    for epoch in range(MAX_EPOCH):
        for i, data in enumerate(train_loader()):
            img, gt_boxes, gt_labels, img_scale = data
            gt_scores = np.ones(gt_labels.shape).astype('float32')
            gt_scores = paddle.to_tensor(gt_scores)
            img = paddle.to_tensor(img)
            gt_boxes = paddle.to_tensor(gt_boxes)
            gt_labels = paddle.to_tensor(gt_labels)
            outputs = model(img)  #前向传播,输出[P0, P1, P2]
            loss = model.get_loss(outputs, gt_boxes, gt_labels, gtscore=gt_scores,
                                  anchors = ANCHORS,
                                  anchor_masks = ANCHOR_MASKS,
                                  ignore_thresh=IGNORE_THRESH,
                                  use_label_smooth=False)        # 计算损失函数

            loss.backward()    # 反向传播计算梯度
            opt.step()  # 更新参数
            opt.clear_grad()
            if i % 10 == 0:
                timestring = time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(time.time()))
                print('{}[TRAIN]epoch {}, iter {}, output loss: {}'.format(timestring, epoch, i, loss.numpy()))

        # save params of model
        if (epoch % 5 == 0) or (epoch == MAX_EPOCH -1):
            paddle.save(model.state_dict(), 'yolo_epoch{}'.format(epoch))

        # 每个epoch结束之后在验证集上进行测试
        model.eval()
        for i, data in enumerate(valid_loader()):
            img, gt_boxes, gt_labels, img_scale = data
            gt_scores = np.ones(gt_labels.shape).astype('float32')
            gt_scores = paddle.to_tensor(gt_scores)
            img = paddle.to_tensor(img)
            gt_boxes = paddle.to_tensor(gt_boxes)
            gt_labels = paddle.to_tensor(gt_labels)
            outputs = model(img)
            loss = model.get_loss(outputs, gt_boxes, gt_labels, gtscore=gt_scores,
                                  anchors = ANCHORS,
                                  anchor_masks = ANCHOR_MASKS,
                                  ignore_thresh=IGNORE_THRESH,
                                  use_label_smooth=False)
            if i % 1 == 0:
                timestring = time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(time.time()))
                print('{}[VALID]epoch {}, iter {}, output loss: {}'.format(timestring, epoch, i, loss.numpy()))
        model.train()

部分截图如下:

5.总结

5.1数据到真实框

这部分主要是数据处理,由标号文件和图片得到图片数据、真实框位置、类别

重点是图片数据要经过增强处理,而真实框的位置和类别不能直接用于计算损失,要经过与锚框的偏差计算,得到的偏差才是标号,放在后面一步。

在YOLOv3中该步骤会将真实框位置处理为xywh格式的相对值[0~1]。

5.2 YOLOv3模型

各种框架下的该模型的开源实现代码网上都有,重点是要调整各层级的输出形状,要与划分锚框的网格大小要一致,这样模型的输出才能与锚框的位置相对应。

5.3损失的计算

实现由真实框和锚框计算预测框是难点之一。

首先要设定下采样率(图片划分为多少网格),得到的锚框要与真实框做计算得到预测框,预测框的类别由真实框给出,位置由真实框和锚框的偏差公式给出,该公式的实现比较复杂。经由网络预测得到的值与预测框做损失计算。飞桨的API:paddle.vision.ops.yolo_loss实现了上述过程,直接调用即可。

主体程序

数据增强

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-02-26 11:31:25  更:2022-02-26 11:35:11 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 19:25:16-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码