IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> ONNX构建并运行模型 -> 正文阅读

[人工智能]ONNX构建并运行模型

????????ONNX是开放式神经网络(Open Neural Network Exchange)的简称,主要由微软和合作伙伴社区创建和维护。很多深度学习训练框架(如Tensorflow, PyTorch, Scikit-learn, MXNet等)的模型都可以导出或转换为标准的ONNX格式,采用ONNX格式作为统一的界面,各种嵌入式平台就可以只需要解析ONNX格式的模型而不用支持多种多样的训练框架,本文主要介绍如何通过代码或JSON文件的形式来构造一个ONNX单算子模型或者整个graph,以及使用ONNX Runtime进行推理得到算子或模型的计算结果。

一. ONNX文件格式

????????ONNX文件是基于Protobuf进行序列化。了解Protobuf协议的同学应该知道,Protobuf都会有一个*.proto的文件定义协议,ONNX的该协议定义在https://github.com/onnx/onnx/blob/master/onnx/onnx.proto3 文件中。

????????从onnx.proto3协议中我们需要重点知道的数据结构如下:

  • ModelProto:模型的定义,包含版本信息,生产者和GraphProto。
  • GraphProto: 包含很多重复的NodeProto, initializer, ValueInfoProto等,这些元素共同构成一个计算图,在GraphProto中,这些元素都是以列表的方式存储,连接关系是通过Node之间的输入输出进行表达的。
  • NodeProto: onnx的计算图是一个有向无环图(DAG),NodeProto定义算子类型,节点的输入输出,还包含属性。
  • ValueInforProto: 定义输入输出这类变量的类型。
  • TensorProto: 序列化的权重数据,包含数据的数据类型,shape等。
  • AttributeProto: 具有名字的属性,可以存储基本的数据类型(int, float, string, vector等)也可以存储onnx定义的数据结构(TENSOR, GRAPH等)。

二. Python API

2.1 搭建ONNX模型

????????ONNX是用DAG来描述网络结构的,也就是一个网络(Graph)由节点(Node)和边(Tensor)组成,ONNX提供的helper类中有很多API可以用来构建一个ONNX网络模型,比如make_node, make_graph, make_tensor等,下面是一个单个Conv2d的网络构造示例:

import onnx
from onnx import helper
from onnx import TensorProto
import numpy as np

weight = np.random.randn(36)
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 2, 4, 4])
W = helper.make_tensor('W', TensorProto.FLOAT, [2, 2, 3, 3], weight)
B = helper.make_tensor('B', TensorProto.FLOAT, [2], [1.0, 2.0])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 2, 2, 2])

node_def = helper.make_node(
    'Conv', # node name
    ['X', 'W', 'B'],
    ['Y'], # outputs
    # attributes
    strides=[2,2],
    )

graph_def = helper.make_graph(
    [node_def],
    'test_conv_mode',
    [X], # graph inputs
    [Y], # graph outputs
    initializer=[W, B],
)

mode_def = helper.make_model(graph_def, producer_name='onnx-example')
onnx.checker.check_model(mode_def)
onnx.save(mode_def, "./Conv.onnx")

????????搭建的这个Conv算子模型使用netron可视化如下图所示:

?????????这个示例演示了如何使用helper的make_tensor_value_info, make_mode, make_graph, make_model等方法来搭建一个onnx模型。

????????相比于PyTorch或其它框架,这些API看起来仍然显得比较繁琐,一般我们也不会用ONNX来搭建一个大型的网络模型,而是通过其它框架转换得到一个ONNX模型。

2.2 Shape Inference

????????很多时候我们从pytorch, tensorflow或其他框架转换过来的onnx模型中间节点并没有shape信息,如下图所示:

?????????我们经常希望能直接看到网络中某些node的shape信息,shape_inference模块可以推导出所有node的shape信息,这样可视化模型时将会更友好:

import onnx
from onnx import shape_inference

onnx_model = onnx.load("./test_data/mobilenetv2-1.0.onnx")
onnx_model = shape_inference.infer_shapes(onnx_model)
onnx.save(onnx_model, "./test_data/mobilenetv2-1.0_shaped.onnx")

????????可视化经过shape_inference之后的模型如下图:

2.3 ONNX Optimizer

????????ONNX的optimizer模块提供部分图优化的功能,例如最常用的:fuse_bn_into_conv,fuse_pad_into_conv等等。

????????查看onnx支持的优化方法:

from onnx import optimizer
all_passes = optimizer.get_available_passes()
print("Available optimization passes:")
for p in all_passes:
    print(p)
print()

????????应用图优化到onnx模型上进行变换:

passes = ['fuse_bn_into_conv']
# Apply the optimization on the original model
optimized_model = optimizer.optimize(onnx_model, passes)

????????将mobile net v2应用fuse_bn_into_conv之后,BatchNormalization的参数合并到了Conv的weight和bias参数中,如下图所示:

三. ONNX Runtime计算ONNX模型

????????onnx本身只是一个协议,定义算子与模型结构等,不涉及具体的计算。onnx runtime是类似JVM一样将ONNX格式的模型运行起来的解释器,包括对模型的解析、图优化、后端运行等。

????????安装onnx runtime:

python3 -m pip install onnxruntime

????????推理:

import onnx
import onnxruntime as ort
import numpy as np
import cv2

def preprocess(img_data):
    mean_vec = np.array([0.485, 0.456, 0.406])
    stddev_vec = np.array([0.229, 0.224, 0.225])
    norm_img_data = np.zeros(img_data.shape).astype('float32')
    for i in range(img_data.shape[0]):
         # for each pixel in each channel, divide the value by 255 to get value between [0, 1] and then normalize
        norm_img_data[i,:,:] = (img_data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]
    return norm_img_data

img = cv2.imread("test_data/dog.jpeg")
img = cv2.resize(img, (224,224), interpolation=cv2.INTER_AREA)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
input_data = np.transpose(img, (2, 0, 1))
input_data = preprocess(input_data)
input_data = input_data.reshape([1, 3, 224, 224])
sess = ort.InferenceSession("test_data/mobilenetv2-1.0.onnx")
input_name = sess.get_inputs()[0].name
result = sess.run([], {input_name: input_data})
result = np.reshape(result, [1, -1])
index = np.argmax(result)
print("max index:", index)

?

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-08-08 11:20:38  更:2021-08-08 11:21:56 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/12 3:54:10-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码