IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> 训练TFlite模型 -> 正文阅读

[Python知识库]训练TFlite模型

注意:不是所有的模型都可以转换成TFlite模型,本文采用SSD网络模型
由于VC编译器的原因,无法在windows上完成所有步骤,所以本文分为上中下两部分,第一部分windows篇介绍Tensorflow模型的训练
第二部分Linux篇介绍从Tensorflow模型转化成Tflite模型,第三那部分树梅派篇是介绍在数梅派上运行训练得到的TFlite模型

参考文章

  1. TensorFlow-Lite-Object-Detection-on-Android-and-Raspberry-Pi
  2. How To Train an Object Detection Classifier for Multiple Objects Using TensorFlow (GPU) on Windows 10

Windows篇

1. Anaconda

2.下载Tensorflow model、 SSD预训练模型和文件库

链接1
解压到

C:\tensorflow1

并重命名为models
链接2
下载解压到

C:\tensorflow1\models\research\object_detection

链接3
下载解压到

C:\tensorflow1\models\research\object_detection

3. 设置conda虚拟环境

C:\>  conda create -n tensorflow1 pip python=3.5
C:\>  activate tensorflow1
(tensorflow1) C:\> conda install tensorflow-gpu=1.13
(tensorflow1) C:\> conda install -c anaconda protobuf
(tensorflow1) C:\> pip install pillow
(tensorflow1) C:\> pip install lxml
(tensorflow1) C:\> pip install Cython
(tensorflow1) C:\> pip install contextlib2
(tensorflow1) C:\> pip install jupyter
(tensorflow1) C:\> pip install matplotlib
(tensorflow1) C:\> pip install pandas
(tensorflow1) C:\> pip install opencv-python

设置python环境变量

(tensorflow1) C:\> set PYTHONPATH=C:\tensorflow1\models;C:\tensorflow1\models\research;C:\tensorflow1\models\research\slim

4. 编译Protobuf

(tensorflow1) C:\> cd C:\tensorflow1\models\research
(tensorflow1) C:\tensorflow1\models\research> protoc --python_out=. .\object_detection\protos\anchor_generator.proto .\object_detection\protos\argmax_matcher.proto .\object_detection\protos\bipartite_matcher.proto .\object_detection\protos\box_coder.proto .\object_detection\protos\box_predictor.proto .\object_detection\protos\eval.proto .\object_detection\protos\faster_rcnn.proto .\object_detection\protos\faster_rcnn_box_coder.proto .\object_detection\protos\grid_anchor_generator.proto .\object_detection\protos\hyperparams.proto .\object_detection\protos\image_resizer.proto .\object_detection\protos\input_reader.proto .\object_detection\protos\losses.proto .\object_detection\protos\matcher.proto .\object_detection\protos\mean_stddev_box_coder.proto .\object_detection\protos\model.proto .\object_detection\protos\optimizer.proto .\object_detection\protos\pipeline.proto .\object_detection\protos\post_processing.proto .\object_detection\protos\preprocessor.proto .\object_detection\protos\region_similarity_calculator.proto .\object_detection\protos\square_box_coder.proto .\object_detection\protos\ssd.proto .\object_detection\protos\ssd_anchor_generator.proto .\object_detection\protos\string_int_label_map.proto .\object_detection\protos\train.proto .\object_detection\protos\keypoint_box_coder.proto .\object_detection\protos\multiscale_anchor_generator.proto .\object_detection\protos\graph_rewriter.proto .\object_detection\protos\calibration.proto .\object_detection\protos\flexible_grid_anchor_generator.proto

在object_detection目录下会生成一个name_pb2.py文件
然后运行

(tensorflow1) C:\tensorflow1\models\research> python setup.py build
(tensorflow1) C:\tensorflow1\models\research> python setup.py install

5. 标注图片

将训练图片和测试图片分别放在

C:\tensorflow1\models\research\object_detection\images\

中的train和test文件夹下

标注步骤请参阅 我之前写的训练yolov5模型的教程

6. 生成训练数据

运行

(tensorflow1) C:\tensorflow1\models\research\object_detection> python xml_to_csv.py

会在 \object_detection\images中生成两个文件
train_labels.csv 和 test_labels.csv

修改object_detection\目录下的generate_tfrecord.py
要修改的部分是

# TO-DO replace this with label map
def class_text_to_int(row_label):
    if row_label == 'nine':
        return 1
    elif row_label == 'ten':
        return 2
    elif row_label == 'jack':
        return 3
    elif row_label == 'queen':
        return 4
    elif row_label == 'king':
        return 5
    elif row_label == 'ace':
        return 6
    else:
        None

将这部分内容中的row_label==’ ’ 修改为要训练的类别,请根据自己的情况适当增删
比如

# TO-DO replace this with label map
def class_text_to_int(row_label):
    if row_label == 'basketball':
        return 1
    elif row_label == 'shirt':
        return 2
    elif row_label == 'shoe':
        return 3
    else:
        None

然后生成TFRecord 文件

python generate_tfrecord.py --csv_input=images\train_labels.csv --image_dir=images\train --output_path=train.record
python generate_tfrecord.py --csv_input=images\test_labels.csv --image_dir=images\test --output_path=test.record

会在object_detection\目录下生成train.record和test.record文件

7. 创建Label Map

label map 文件为object_detection\training\labelmap.pbtxt
修改其内容,原内容是

item {
  id: 1
  name: 'nine'
}

item {
  id: 2
  name: 'ten'
}

item {
  id: 3
  name: 'jack'
}

item {
  id: 4
  name: 'queen'
}

item {
  id: 5
  name: 'king'
}

item {
  id: 6
  name: 'ace'
}

修改为自己的训练类别和对应ID,ID就是第7步中每个类别返回的数字
比如

    if row_label == 'basketball':
        return 1

则在labelmap.pbtxt中修改为

item {
  id: 1
  name: 'basketball'
}

8. 配置训练文件

C:\tensorflow1\models\research\object_detection\samples\configs中的ssd_mobilenet_v2_quantized_300x300_coco.config复制到object_detection\training文件夹中

然后修改该文件

1. num_classes : 修改为自己的训练类别
2. fine_tune_checkpoint : "C:/tensorflow1/models/research/object_detection/ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/model.ckpt"
## in train_input_reader section
3. input_path : "C:/tensorflow1/models/research/object_detection/train.record"       
4. label_map_path: "C:/tensorflow1/models/research/object_detection/training/labelmap.pbtxt"
5. num_examples:修改为\images\test 中的图片数量
## in eval_input_reader section
6. input_path : "C:/tensorflow1/models/research/object_detection/test.record"
7. label_map_path: "C:/tensorflow1/models/research/object_detection/training/labelmap.pbtxt"

9. 开始训练

重新训练时需要清除training\中除了labelmap和ssd_mobilenet。。。以外的所有文件,否则会报错

(tensorflow1) C:\tensorflow1\models\research\object_detection> python train.py --logtostderr --train_dir=training/ --pipeline_config_path=training/ssd_mobilenet_v2_quantized_300x300_coco.config

当loss降到2以下的时候,就可以停止训练 , 直接 Ctrl + C即可

10. 生成.pb文件

python export_inference_graph.py --input_type image_tensor --pipeline_config_path training/ssd_mobilenet_v2_quantized_300x300_coco.config --trained_checkpoint_prefix training/model.ckpt-XXXX --output_directory inference_graph

将model.ckpt-XXXX替换为training\目录中数值最大的一个文件,这步之后会在\object_detection\inference_graph中生成一个frozen_inference_graph.pb文件

11. 测试Tensorflow模型

修改Object_detection_image.py,将NUM_CLASSES的数值修改为训练的类别数量
将IMAGE_NAME修改为测试图片的路径
然后运行

(tensorflow1) C:\tensorflow1\models\research\object_detection> python Object_detection_image.py

应该可以看到检测结果

12. 导出 frozen inference graph

(tensorflow1) C:\tensorflow1\models\research\object_detection>  mkdir TFLite_model
(tensorflow1) C:\tensorflow1\models\research\object_detection> set CONFIG_FILE=C:\\tensorflow1\models\research\object_detection\training\ssd_mobilenet_v2_quantized_300x300_coco.config
(tensorflow1) C:\tensorflow1\models\research\object_detection> set CHECKPOINT_PATH=C:\\tensorflow1\models\research\object_detection\training\model.ckpt-XXXX
(tensorflow1) C:\tensorflow1\models\research\object_detection> set OUTPUT_DIR=C:\\tensorflow1\models\research\object_detection\TFLite_model

注意将上述步骤中的model.ckpt-XXXX替换为training\目录中数值最大的一个文件

(tensorflow1) C:\tensorflow1\models\research\object_detection> python export_tflite_ssd_graph.py --pipeline_config_path=%CONFIG_FILE% --trained_checkpoint_prefix=%CHECKPOINT_PATH% --output_directory=%OUTPUT_DIR% --add_postprocessing_op=true

之后会在 \object_detection\TFLite_model中生成tflite_graph.pbtflite_graph.pbtxt

Linux篇 (Ubuntu 18.04)

1. 下载TensorFlow 源码

2.3.1版本 链接
解压,然后

cd tensorflow-2.3.1

2. 设置conda虚拟环境

conda create -n tensorflow-build pip python=3.6
conda activate tensorflow-build

3. 下载Bazel和Python依赖

pip install six numpy==1.18.0 wheel
pip install keras_applications==1.0.6 --no-deps
pip install keras_preprocessing==1.0.5 --no-deps
conda install -c conda-forge bazel=3.1.0

注意 TF 2.3.1 需要搭配 Bazel 3.1.0 以及 numpy 1.18.0 ,否则编译TF源码时会报错

4. 配置build

./configure

注意修改python的路径和lib路径为虚拟环境中的路径,会自动列举检测到的python路径,只需要复制然后粘贴回车即可

其余y/N的选项一律选择N
非y/N的选项除了python的路径和lib路径都默认即可,直接回车

也无需添加CUDA支持,仅编译CPU版本即可

5. 编译TF

bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package 

编译时间较长,请耐心等待。只要不报错就好

6. 构建、安装软件包

构建,即生成wheel文件

./bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pkg

安装

pip install /tmp/tensorflow_pkg/tensorflow-version-tags.wheel

安装完成之后离开tensorflow-2.3.1目录

cd ..

然后进行后续操作

注意下次训练生成TFLite 模型还是得在tensorflow-2.3.1目录下执行下面的命令,不然会报错

7. 生成TFlite 模型

bazel run --config=opt tensorflow/lite/toco:toco -- --input_file=/media/liuxuwei/Windows/tensorflow1/models/research/object_detection/TFLite_model/tflite_graph.pb --output_file=/media/liuxuwei/Windows/tensorflow1/models/research/object_detection/TFLite_model/detect.tflite --input_shapes=1,300,300,3 --input_arrays=normalized_input_image_tensor --output_arrays=TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3 --inference_type=QUANTIZED_UINT8 --mean_values=128 --std_values=128 --change_concat_input_ranges=false --allow_custom_ops

这一步执行完毕之后会生成detect.tflite文件
请根据自己的情况,修改
input_file=/media/liuxuwei/Windows/tensorflow1/models/research/object_detection/TFLite_model/tflite_graph.pb
output_file=/media/liuxuwei/Windows/tensorflow1/models/research/object_detection/TFLite_model/detect.tflite

/media/liuxuwei/Windows/tensorflow1/models/research/object_detection/TFLite_model下新建txt文件
命名为labelmap.txt,这是TFlite所需要的label map文件。根据windows篇第6步中的训练类别,填写labelmap.txt
比如

basketball
shirt
shoe

到这里Tflite模型就完成了

树梅派篇

前期配置请参阅我之前写的树梅派部署深度学习的文章

将TFLite_model文件夹复制到树梅派中的tflite文件夹中,然后将一张测试图片也复制到tflite文件夹中,然后运行

python3 TFLite_detection_image.py --modeldir=TFLite_model --image=xxx
  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2021-07-30 22:42:56  更:2021-07-30 22:43:15 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年12日历 -2024/12/25 14:56:44-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码
数据统计