IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> Python知识库 -> C++:Windows平台下利用LibTorch调用PyTorch模型 -> 正文阅读

[Python知识库]C++:Windows平台下利用LibTorch调用PyTorch模型

参考:C++调用PyTorch模型:LibTorch

环境

Windows10
VS2017
CPU

OpenCV3.0.0

Pytorch1.10.2  torchvision0.11.3
Libtorch1.10.2

Libtorch下载

Pytorch官网

在这里插入图片描述
解压后:注意红框文件夹路径,之后需要添加到项目属性配置中。
在这里插入图片描述

Pytorch将.pth转为.pt文件

所使用的模型为基于AlexNet的分类模型:AlexNet:论文阅读及pytorch网络搭建

python环境下的预测

输出结果:rose

在这里插入图片描述
在这里插入图片描述

新建pt模型生成文件

# tmp.py

import os
import torch
from PIL import Image
from torchvision import transforms
from model import AlexNet

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    # create model
    model = AlexNet(num_classes=5).to(device)

    image = Image.open("rose2.jpg").convert('RGB')
    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    img = data_transform(image)
    img = img.unsqueeze(dim=0)
    print(img.shape)

    # load model weights
    weights_path = "AlexNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)

    testsize = 224

    if torch.cuda.is_available():
        modelState = torch.load(weights_path, map_location='cuda')
        model.load_state_dict(modelState, strict=False)
        model = model.cuda()
        model = model.eval()
        # An example input you would normally provide to your model's forward() method.
        example = torch.rand(1, 3, testsize, testsize)
        example = example.cuda()
        traced_script_module = torch.jit.trace(model, example)

        output = traced_script_module(img.cuda())
        print(output.shape)
        pred = torch.argmax(output, dim=1)
        print(pred)

        traced_script_module.save('model_cuda.pt')
    else:
        modelState = torch.load(weights_path, map_location='cpu')
        model.load_state_dict(modelState, strict=False)
        example = torch.rand(1, 3, testsize, testsize)
        example = example.cpu()
        traced_script_module = torch.jit.trace(model, example)

        output = traced_script_module(img.cpu())
        print(output.shape)
        pred = torch.argmax(output, dim=1)
        print(pred)

        traced_script_module.save('model.pt')

if __name__ == '__main__':
    main()

输出结果:rose

在这里插入图片描述

在这里插入图片描述

C++调用pytorch模型

新建空项目pt_alex

在这里插入图片描述

项目属性配置

修改配置管理器

Release/x64
在这里插入图片描述

属性>VC++目录>包含目录

添加:(libtorch解压位置)
在这里插入图片描述

注意还应有opencv目录:(继承值修改可参考
在这里插入图片描述

属性>VC++目录>库目录

添加:
在这里插入图片描述

属性>链接器>输入>附加依赖项

添加:
在这里插入图片描述

注意:
如果后续出现error:找不到c10.dll,
直接把该目录下的相应dll复制到项目pt_alex/x64/Release文件夹下。

注意还应有opencv目录:(Debug下为lib*d.lib)
在这里插入图片描述

注意CUDA下的情况

链接器>命令行,添加:

/INCLUDE:?warp_size@cuda@at@@YAHXZ

属性>C/C++

常规>SDL检查:选择否
语言>符合模式:选择否

项目下新建test.cpp

// test.cpp

#include <torch/script.h> // One-stop header.
#include "torch/torch.h"
#include <opencv2/opencv.hpp>
#include "opencv2/core.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/imgcodecs.hpp"
#include <vector>
#include <chrono>
#include <string>
#include <vector>
#include <iostream>
#include <memory>

// class_list
/*
	"0": "daisy",
	"1": "dandelion",
	"2": "roses",
	"3": "sunflowers",
	"4": "tulips"
*/

std::string classList[5] = { "daisy", "dandelion", "rose", "sunflower", "tulip" };

std::string image_path = "rose2.jpg";

int main(int argc, const char* argv[]) {

	// Deserialize the ScriptModule from a file using torch::jit::load().
	//std::shared_ptr<torch::jit::script::Module> module = torch::jit::load("../../model_resnet_jit.pt");
	using torch::jit::script::Module;
	Module module = torch::jit::load("model.pt");

	std::cout << image_path << std::endl;

	std::cout << "cuda support:" << (torch::cuda::is_available() ? "ture" : "false") << std::endl;
	std::cout << "CUDNN:  " << torch::cuda::cudnn_is_available() << std::endl;
	std::cout << "GPU(s): " << torch::cuda::device_count() << std::endl;

	// module.to(at::kCUDA); //cpu下会在(auto image = cv::imread(image_path, cv::IMREAD_COLOR))行引起c10:error,未经处理的异常
	module.to(at::kCPU);

	//assert(module != nullptr);
	//std::cout << "ok\n";

	//输入图像
	auto image = cv::imread(image_path, cv::IMREAD_COLOR);
	cv::cvtColor(image, image, cv::COLOR_BGR2RGB);
	cv::Mat image_transfomed;
	cv::resize(image, image_transfomed, cv::Size(224, 224));

	// 转换为Tensor
	torch::Tensor tensor_image = torch::from_blob(image_transfomed.data,
		{ image_transfomed.rows, image_transfomed.cols,3 }, torch::kByte);
	tensor_image = tensor_image.permute({ 2,0,1 });
	tensor_image = tensor_image.toType(torch::kFloat);
	tensor_image = tensor_image.div(255);
	tensor_image = tensor_image.unsqueeze(0);
	// tensor_image = tensor_image.to(at::kCUDA);
	tensor_image = tensor_image.to(at::kCPU);

	// 网络前向计算
	at::Tensor output = module.forward({ tensor_image }).toTensor();
	std::cout << "output:" << output << std::endl;

	auto prediction = output.argmax(1);
	std::cout << "prediction:" << prediction << std::endl;

	int maxk = 5;
	auto top3 = std::get<1>(output.topk(maxk, 1, true, true));

	std::cout << "top3: " << top3 << '\n';

	std::vector<int> res;
	for (auto i = 0; i < maxk; i++) {
		res.push_back(top3[0][i].item().toInt());
	}
	// for (auto i : res) {
	// 	std::cout << i << " ";
	// }
	// std::cout << "\n";

	int pre = torch::Tensor(prediction).item<int>();
	std::string result = classList[pre];
	std::cout << "This is:" << result << std::endl;

	cvWaitKey();

	return 0;
	// system("pause");
}

出现以下报错不影响项目生成:
在这里插入图片描述

输出结果:rose

在这里插入图片描述

  Python知识库 最新文章
Python中String模块
【Python】 14-CVS文件操作
python的panda库读写文件
使用Nordic的nrf52840实现蓝牙DFU过程
【Python学习记录】numpy数组用法整理
Python学习笔记
python字符串和列表
python如何从txt文件中解析出有效的数据
Python编程从入门到实践自学/3.1-3.2
python变量
上一篇文章      下一篇文章      查看所有文章
加:2022-03-06 12:58:23  更:2022-03-06 13:01:29 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/15 22:24:35-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码