IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> 使用Faster—RCNN训练数据集流程(学习记录) -> 正文阅读

[Python知识库]使用Faster—RCNN训练数据集流程(学习记录)

关于理论部分我看的是b站“霹雳吧啦Wz”的RCNN理论讲解,作为入门小白表示能听懂,需要的同学可以自行观看

目录

1.环境准备

2.训练步骤

3.测试过程?

4.计算map


1.环境准备

我是用的是在colab+tensorflow1.14.0上进行训练,其他linux系统训练等同

(windows+pytorch的我也有尝试,但是在配置环境时下载pycocotools时windows的很麻烦,而且还得借助Microsoft visual c++14,我电脑装的是vscode,跟网上教程不一样,也找不到相关教程,所以放弃了)

colab的使用教程如果不会可以自行查阅这篇文章colab使用方法记录_道人兄的博客-CSDN博客

2.训练步骤

(1)下载源码

!git clone https://github.com/dBeker/Faster-RCNN-TensorFlow-Python3 

(2)下载拓展文件

!pip install -r requirements.txt

(3)下载并添加预训练模型

? ? ??源码中预训练模型使用的是VGG16,VGG16模型可直接下载:

!wget -P ./data/imagenet_weights/ http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz

? ? ??下载的模型名字应该是vgg_16.ckpt,重命名为vgg16.ckpt 后,把模型保存在data\imagenet_weights\文件夹下。

? ? ? 也可以使用其他的模型替代VGG16,其他模型在下方链接中下载:

? ? ??https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models

(4)修改训练参数

打开源码的lib\config文件夹下的config.py文件,修改其中一些重要参数,如:

①network参数

? ? ? ?该参数定义了预训练模型网络,源码中默认使用了vgg16模型,我们使用vgg16就不需修改,如果在上一步中使用其他模型就要修改。

②learning_rate

? ? ? ?这个参数是学习率,如果设定太大就可能产生振荡,如果设定太小就会使收敛速度很慢。所以我们可以先默认为源码的0.001进行实验,后期再取0.01或0.0001等多次实验,找到运行后的相对最优值。

③batch_size

? ? ? ?该参数表示梯度下降时数据批量大小,一般可以取16、32、64、128、256等。我个人的理解是,batch_size设定越大,训练时梯度下降的速率更快,也具有更高的方向准确度,但更加消耗内存;batch_size设定越小,虽然节省内存,但训练的速率比较慢,收敛效果也可能不是很好。所以在内存允许的情况下,尽量设定大一些。

④max_iters

? ? ? ?max_iters参数表示训练最大迭代的步数。源码中是40000,我实验了4000和40000的步数,发现后来的测试结果中mAP值相差不大,以后会再继续研究。这个参数可以先按照源码的40000进行(要跑好几天。。。)

⑤snapshot_iterations

? ? ? ?这个参数表示间隔多少迭代次数生成一次结果模型。

⑥roi_bg_threshold_low 和 roi_bg_threshold_high

? ? ? ?这个参数表示在背景中被设定为ROI(感兴趣区域,region of interest)的阈值。如果后面出现Exception: image invalid, skipping 这样的报错,将roi_bg_threshold_low参数修改为0.0会解决问题。

(5)准备数据集

?下载voc2007数据集至data/VOCdevkit2007

!wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
!wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
!wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCdevkit_08-Jun-2007.tar

解压

!tar xvf VOCtrainval_06-Nov-2007.tar
!tar xvf VOCtest_06-Nov-2007.tar
!tar xvf VOCdevkit_08-Jun-2007.tar

如果使用自己的训练数据集,自行标记为voc格式,标注流程就不多赘述,需要可以自行查阅

yolo数据集标注软件安装+使用流程_道人兄的博客-CSDN博客_yolo数据集标注工具

将数据集整理成以下格式即可

其中main所放至为训练及验证txt文件,划分程序如下:

# 数据集划分集类
import os
from sklearn.model_selection import train_test_split

image_path = r'F:/111/data/VOCDevkit2007/VOC2007/JPEGImages'
image_list = os.listdir(image_path)
names = []

for i in image_list:
    names.append(i.split('.')[0])     # 获取图片名
trainval,test = train_test_split(names,test_size=0.5,shuffle=446)   # shuffle()中是图片总数目
validation,train = train_test_split(trainval,test_size=0.5,shuffle=446)

with open('F:/111/data/VOCDevkit2007/VOC2007/ImageSets/Main/trainval.txt','w') as f:
    for i in trainval:
        f.write(i+'\n')
with open('F:/111/data/VOCDevkit2007/VOC2007/ImageSets/Main/test.txt','w') as f:
    for i in test:
        f.write(i+'\n')
with open('F:/111/data/VOCDevkit2007/VOC2007/ImageSets/Main/validation.txt','w') as f:
    for i in validation:
        f.write(i+'\n')
with open('F:/111/data/VOCDevkit2007/VOC2007/ImageSets/Main/train.txt','w') as f:
    for i in train:
        f.write(i+'\n')

print('完成!')

(6)生成对应文件

进入?./data/coco/PythonAPI文件夹路径,分别运行下面两条命令:

!python setup.py build_ext --inplace
!python setup.py build_ext install

进入 ./lib/utils文件夹路径,运行下面一条命令:

!python setup.py build_ext --inplace

这一步其实我报错了,如下,但是好像后面也能正常训练没什么影响,如果有影响麻烦大佬说声

!python setup.py build_ext --inplace
running build_ext
building 'lib.utils.cython_bbox' extension
x86_64-linux-gnu-gcc -pthread -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -I/usr/local/lib/python3.7/dist-packages/numpy/core/include -I/lib/utils -I/usr/include/python3.7m -c ../../../lib/utils/bbox.c -o build/temp.linux-x86_64-3.7/../../../lib/utils/bbox.o
In file included from /usr/local/lib/python3.7/dist-packages/numpy/core/include/numpy/ndarraytypes.h:1969:0,
                 from /usr/local/lib/python3.7/dist-packages/numpy/core/include/numpy/ndarrayobject.h:12,
                 from /usr/local/lib/python3.7/dist-packages/numpy/core/include/numpy/arrayobject.h:4,
                 from ../../../lib/utils/bbox.c:770:
/usr/local/lib/python3.7/dist-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h:17:2: warning: #warning "Using deprecated NumPy API, disable it with " "#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION" [-Wcpp]
 #warning "Using deprecated NumPy API, disable it with " \
  ^~~~~~~
x86_64-linux-gnu-gcc -pthread -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -I/usr/local/lib/python3.7/dist-packages/numpy/core/include -I/lib/utils -I/usr/include/python3.7m -c ../../../lib/utils/bbox.c -o build/temp.linux-x86_64-3.7/../../../lib/utils/bbox.o
In file included from /usr/local/lib/python3.7/dist-packages/numpy/core/include/numpy/ndarraytypes.h:1969:0,
                 from /usr/local/lib/python3.7/dist-packages/numpy/core/include/numpy/ndarrayobject.h:12,
                 from /usr/local/lib/python3.7/dist-packages/numpy/core/include/numpy/arrayobject.h:4,
                 from ../../../lib/utils/bbox.c:770:
/usr/local/lib/python3.7/dist-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h:17:2: warning: #warning "Using deprecated NumPy API, disable it with " "#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION" [-Wcpp]
 #warning "Using deprecated NumPy API, disable it with " \
  ^~~~~~~
x86_64-linux-gnu-gcc -pthread -shared -Wl,-O1 -Wl,-Bsymbolic-functions -Wl,-Bsymbolic-functions -g -fwrapv -O2 -Wl,-Bsymbolic-functions -g -fwrapv -O2 -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 build/temp.linux-x86_64-3.7/../../../lib/utils/bbox.o build/temp.linux-x86_64-3.7/../../../lib/utils/bbox.o -o /content/drive/MyDrive/colab/RCNN/Faster-RCNN/data/coco/PythonAPI/lib/utils/cython_bbox.cpython-37m-x86_64-linux-gnu.so
build/temp.linux-x86_64-3.7/../../../lib/utils/bbox.o:(.bss+0x0): multiple definition of `__pyx_module_is_main_lib__utils__cython_bbox'
build/temp.linux-x86_64-3.7/../../../lib/utils/bbox.o:(.bss+0x0): first defined here
build/temp.linux-x86_64-3.7/../../../lib/utils/bbox.o: In function `PyInit_cython_bbox':
/content/drive/MyDrive/colab/RCNN/Faster-RCNN/data/coco/PythonAPI/../../../lib/utils/bbox.c:4470: multiple definition of `PyInit_cython_bbox'
build/temp.linux-x86_64-3.7/../../../lib/utils/bbox.o:/content/drive/MyDrive/colab/RCNN/Faster-RCNN/data/coco/PythonAPI/../../../lib/utils/bbox.c:4470: first defined here
collect2: error: ld returned 1 exit status
error: command 'x86_64-linux-gnu-gcc' failed with exit status 1

(7)修改类别

打开lib/datasets目录中的pascal_voc.py文件,第34行self._classes表示目标检测的类别,将其修改为自己数据集的类别。注意不能修改 “_background_”,它表示图片的背景。

(8)删除缓存文件

?打开源码中data/cache目录,删掉上一次训练生成的.pkl缓存文件。打开default/voc_2007_trainval/default目录,删掉上次训练生成的模型。

注意以后每次训练都要删掉上述两个文件夹中的缓存文件和模型,不删会报错的。

(9)运行

每次生成的模型都会保存在default/voc_2007_trainval/default目录下

!python train.py

报错:

第一次运行时会出现以下错误

AttributeError: module 'tensorflow' has no attribute 'app'

解决方法:

因为colab自动预装最新的TensorFlow 2.X,而源码所使用的TensorFlow是1.x的,所以我们需要将新版的卸载,并安装旧版本

!pip uninstall tensorflow
!pip install tensorflow-gpu==1.14.0

3.测试过程?

(1)添加训练模型

?新建Faster-RCNN-TensorFlow-Python3-master/output/vgg16/voc_2007_trainval/default目录。把训练生成的模型(default/voc_2007_trainval/default目录下的四个文件)复制到新建目录下,并重命名为如下图:

(2)修改demo.py文件

①修改目标类别

? ? ? ? ?修改demo.py文件中line32,CLASSES中的类别要修改为之前步骤中相同的类别。注意 “_background_”不要修改。

②修改网络模型

? ? ? ? ?找到demo.py文件中line35、line36,将其修改为如下图所示:

③修改预训练模型

?找到demo.py文件中line104,将其修改为'vgg16',如下图:

找到demo.py文件中的line148,改为自己测试用的几张图片名称。注意和data/demo目录下存放的测试图片名字一致。

(3)运行demo.py文件

4.计算map

mAP(mean Average Precision), 即各类别AP的平均值,反映出一个目标检测模型性能的总体精确度。

(1)修改pascal_voc.py文件

? ? ? ?打开pascal_voc.py文件,找到line189,将"filename"内容修改为下图:

(2)修改demo.py文件

? ? ? ?打开demo.py文件,找到line31,添加两个模块:

# 添加这两个import
from lib.utils.test import test_net
from lib.datasets.factory import get_imdb

? ? ? ?添加后如图所示:

? ? ?? 然后,找到最后一行plt.show(),在它上面添加两行代码:

# 添加这两行代码
imdb = get_imdb("voc_2007_trainval")
test_net(sess, net, imdb, 'default')

? ? ? ?添加后如图所示:

(3)运行demo.py文件

? ? ? ?新建data/VOCDevkit2007/results/VOC2007/Main目录,然后运行demo.py文件,等待运行结束就能看到mAP指标的计算结果啦!贴出我自己模型的计算结果吧!

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-09-15 01:58:19  更:2022-09-15 01:58:28 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/15 10:45:56-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码