IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 详细步骤:pytorch pth转wts转tensorrt(自定义模型,不用parser) -> 正文阅读

[人工智能]详细步骤:pytorch pth转wts转tensorrt(自定义模型,不用parser)

分两步转是考虑了如下应用场景:
需要把模型部署到内存有限的嵌入式板

  • 自己的电脑上安装的有anaconda, pytorch等,但是在电脑上转的不能直接在板子上用
  • 板子的内存有限,不能安装anaconda, pytorch这些,但是需要部署模型上去。
    这时就可以现在电脑上把pth转成wts,再把wts传到板子上,在板子上转成tensorrt

(1) pth转wts
参考如下代码

import torch
from torch import nn
#load你的模型
import os
import struct

def main():
    net = torch.load('XXX.pth') #loadpth文件
    net = net.to('cuda:0')
    net.eval()
    
    f = open("XXX.wts", 'w') #自己命名wts文件
    f.write("{}\n".format(len(net.state_dict().keys())))  #保存所有keys的数量
    for k,v in net.state_dict().items():
        #print('key: ', k)
        #print('value: ', v.shape)
        vr = v.reshape(-1).cpu().numpy()
        f.write("{} {}".format(k, len(vr)))  #保存每一层名称和参数长度
        for vv in vr:
            f.write(" ")
            f.write(struct.pack(">f", float(vv)).hex())  #使用struct把权重封装成字符串
        f.write("\n")

if __name__ == '__main__':
    main()

wts文件是如下格式
第一行:数字
下面分别是网络每一层的名称,参数个数和对应的参数

(2) wts转tensorrt
常用的模型转换可参考链接
但是如果是自定义模型,不在常用模型范围,比如我建了一个几层的小语义分割网络,但是现有的parser不支持,那只能手撕tensorrt的API了

tensorRT的API使用方法可参考链接

简要说下wts转tensorrt的原理

  1. 从wts文件把weight给load出来,存到一个map里,key是网络每层的名称,value就是对应的权重
  2. 利用tensorrt的API把网络重建出来,同时导入key对应的value,也就是weightMap的形式
  3. 定义网络的输出,设置内存空间
  4. build engine
    最后出来的是一个engine文件

具体的做法:

  1. load weight
std::map<std::string, Weights> loadWeights(const std::string file)
{
    std::cout << "Loading weights: " << file << std::endl;
    std::map<std::string, Weights> weightMap;

    // Open weights file
    std::ifstream input(file);
    assert(input.is_open() && "Unable to load weight file.");

    // Read number of weight blobs
    int32_t count;
    input >> count;
    assert(count > 0 && "Invalid weight map file.");

    while (count--)
    {
        Weights wt{DataType::kFLOAT, nullptr, 0};
        uint32_t size;

        // Read name and type of blob
        std::string name;
        input >> name >> std::dec >> size;
        wt.type = DataType::kFLOAT;

        // Load blob
        uint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));
        for (uint32_t x = 0, y = size; x < y; ++x)
        {
            input >> std::hex >> val[x];
        }
        wt.values = val;
        
        wt.count = size;
        weightMap[name] = wt;
    }

    return weightMap;
}
  1. 重建网络
    具体参考tensorRT的API
    举个例子:
    比如是3x3的conv,stride=2, padding=2, kernel size=5x5,output channel=8
    首先要定义一个network
INetworkDefinition* network = builder->createNetworkV2(0U); 

输入层要定义出来,比如输入size是3xHxW

ITensor* data = network->addInput(INPUT_BLOB_NAME, DataType::kFLOAT, Dims3{3, INPUT_H, INPUT_W}); 

然后把输入传给conv层

IConvolutionLayer* conv1 = network->addConvolutionNd(*data, 8, DimsHW{5, 5}, weightMap["conv1.weight"], weightMap["conv1.bias"]); //map名称要改

一层一层传下去,直到指定output

  1. 定义输出,指定内存空间
deconv3->getOutput(0)->setName(OUTPUT_BLOB_NAME); //设置output名称
network->markOutput(*deconv3->getOutput(0)); //指定output

还要设置一个engine内存空间,太小的话会报错,尤其是报could not find implementation for node的error

builder->setMaxBatchSize(maxBatchSize);
config->setMaxWorkspaceSize(128*(1 << 20)); 
  1. build engine
ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);

这里说明一下反卷积层的写法,torch的deconvolution有padding和output_padding,如何设置
padding对应的是setPrePadding
output_padding对应的是setPostPadding

举个例子,反卷积层,output channel=16,kernel_size=5x5,
stride=2, padding=2, output_padding=1

IDeconvolutionLayer* deconv1 = network->addDeconvolutionNd(*relu3->getOutput(0), 16, DimsHW{5,5}, weightMap["deconv1.weight"], weightMap["deconv1.bias"]);
deconv1->setStrideNd(DimsHW{2, 2});
deconv1->setPrePadding(DimsHW{2, 2});
deconv1->setPostPadding(DimsHW{1, 1});

具体cpp和CMakeList.txt放在github上

  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-11-26 08:51:52  更:2021-11-26 08:52:57 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 4:00:33-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码