IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> c++ 部署模型 -> 正文阅读

[人工智能]c++ 部署模型

将模型转为onnx

选用pytorch框架,训练resnet18二分类,将二分类模型转为onnx模型

from typing_extensions import dataclass_transform
import torch
import torch.nn as nn
from torchvision import models, transforms

import cv2 as cv
import numpy as np
from PIL import Image

class_name = ['ant', 'bee']
# device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
img = Image.open(r'D:\Projects\PythonProjects\ResnetClassifier\data\hymenoptera_data\val\ants\800px-Meat_eater_ant_qeen_excavating_hole.jpg')
img = img.convert('RGB')
img = img.resize((224, 224))

img = np.array(img, np.float32)
preprocess_img = img/255.0
img_data = np.expand_dims(np.transpose(preprocess_img, (2, 0, 1)), 0)


device = torch.device('cpu')
model = models.resnet18()
model_l = model.fc.in_features
model.fc = nn.Linear(model_l, 2)
model_x = model.to(device)
model_x.eval()

model_x.load_state_dict(torch.load("D:\\Projects\\PythonProjects\\Datasets\\resnet18_bee_and_ant.pth", map_location=torch.device('cpu')))

tensor_img = torch.from_numpy(img_data)
out = model_x(tensor_img)
_, pred = torch.max(out, 1)
pred_name = class_name[pred]
print(pred_name)

input = torch.rand(1, 3, 224, 224)
# torch.save(model_x, 'new_resnet18_v3.pth')
torch.onnx.export(model_x, input, 'new_resnet18_trtpost_v3.onnx', verbose=True)

模型部署

将转换好的onnx模型采用c++进行部署,定义一个ONNXClassifier类,初始化构造函数和分类函数。并将标准化和读标签声明为私有函数,并将重要参数设置为私有变量

class ONNXClassifier {
public:
	ONNXClassifier(const string& model_path, const string& label_path, Size input_size);
	void Classify(const Mat& input_image, string& out_name, double& confidence);
private:
	void preprocess_input(Mat& image);
	bool read_labels(const string& label_path);
private:
	Size input_size;
	cv::dnn::Net net;
	cv::Scalar default_mean;
	cv::Scalar default_std;
	std::vector<string> labels;
};

定义构造函数,并使用参数列表对对标准值、方差以及输入图片大小进行初始化。

ONNXClassifier::ONNXClassifier(const std::string& model_path, const std::string& label_path, cv::Size _input_size) :default_mean(0.485, 0.456, 0.406),
default_std(0.229, 0.224, 0.225), input_size(_input_size)
{
	if (!read_labels(label_path))
	{
		throw std::runtime_error("label read fail!");
	}
	net = cv::dnn::readNet(model_path);
	net.setPreferableBackend(cv::dnn::DNN_BACKEND_OPENCV);
	net.setPreferableTarget(cv::dnn::DNN_TARGET_CPU);
}

定义读标签函数,首先声明ifstream对象用于读文件

bool ONNXClassifier::read_labels(const std::string& label_path) {
	ifstream ifs(label_path);
	string line;
	while (getline(ifs, line)) {
		size_t index = line.find_first_of(":");
		labels.push_back(line.substr(index + 1));
	}
	if (labels.size() > 0)
		return true;
	else
		return false;
}

进行图片标准化,先将图片像素点进行标准化,并转为32位浮点型

void ONNXClassifier::preprocess_input(Mat& image) {
	image.convertTo(image, CV_32F, 1.0 / 255.0); //scale:比例因子
	subtract(image, default_mean, image);  //图像相减
	divide(image, default_std, image);
}

定义分类函数

void ONNXClassifier::Classify(const cv::Mat& input_image, std::string& out_name, double& confidence)
{
	out_name.clear();
	cv::Mat image = input_image.clone();
	preprocess_input(image);
	//图像预处理 1、整体像素减去均值  2、通过放缩系数对图片像素进行放缩
	cv::Mat input_blob = cv::dnn::blobFromImage(image, 1.0, input_size, cv::Scalar(0, 0, 0), true);
	net.setInput(input_blob);
	const std::vector<cv::String>& out_names = net.getUnconnectedOutLayersNames();
	cv::Mat out_tensor = net.forward(out_names[0]);
	cout << out_tensor << endl;
	cv::Point maxLoc;
	double minV;
	cv::minMaxLoc(out_tensor, &minV, &confidence, (cv::Point*)0, &maxLoc); // 寻找矩阵或一维向量中的最大、最小值
	cout << "maxLoc.x:" << maxLoc.x<<"\tmaxLoc.y:"<<maxLoc.y<<endl;
	out_name = labels[maxLoc.x];
}

定义主函数

int main(int argc, char** argv) {
	cv::utils::logging::setLogLevel(cv::utils::logging::LogLevel::LOG_LEVEL_SILENT);
	vector<string> imgVec;
	cv::glob("D:\\Projects\\PythonProjects\\ResnetClassifier\\data\\test_data\\", imgVec);
	string model_path = ("D:\\Projects\\PythonProjects\\ResnetClassifier\\new_resnet18_trtpost_v3.onnx");
	string label_path = ("D:\\Projects\\PythonProjects\\ResnetClassifier\\ant_and_bee.txt");
	Size input_size(300, 300);

	for (size_t i = 0; i < imgVec.size(); i++) {
		Mat test_image = imread(imgVec[i], cv::IMREAD_COLOR);
		ONNXClassifier classifier(model_path, label_path, input_size);
		string result;
		double confidence = 0;
		classifier.Classify(test_image, result, confidence);
		cout << imgVec[i] << "\n"<<"预测结果为:" << result << "\tconfidence:" << confidence << endl;
	}
	return 0;
}
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-04-29 12:08:38  更:2022-04-29 12:09:23 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 8:32:39-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码