IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> 【TorchScript】PyTorch模型转换为C++支持的模型 -> 正文阅读

[人工智能]【TorchScript】PyTorch模型转换为C++支持的模型

任务简介:

使用PyTorch训练的模型只能在Python环境中使用,在自动驾驶场景中,模型推理过程通常是在硬件设备上进行。TorchScript可以将PyTorch训练的模型转换为C++环境支持的模型,推理速度比Python环境更快。本文对整体转换流程做一个简单的记录,后续需要补充TorchScript的支持的各种语法规则以及注意点。


TorchScript

TorchScript是一种从PyTorch代码创建可序列化和可优化模型的方法。任何TorchScript程序都可以从Python进程中保存,并加载到没有Python依赖的进程中。

1. 两种TorchScript模型创建方式

TorchScript模型生成有torch.jit.trace和torch.jit.script两种方法。

1.1 torch.jit.trace

传入Module和符合的示例输入。它会调用Moduel并将操作记录下来,当Module运行时记录下操作,然后创建torch.jit.ScriptModule的实例。对于有控制流的模型,直接使用torch.jit.trace()并不能跟踪到控制流,因为它只是对操作进行了记录,对于没有运行到的操作并不会记录,trace方式生成模型的示例如下:

class MyDecisionGate(torch.nn.Module):
    def forward(self, x: Tensor) -> Tensor:
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x: Tensor, h: Tensor) -> Tuple[Tensor, Tensor]:
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell(MyDecisionGate())
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))  # trace方式

print(traced_cell.dg.code)
print(traced_cell.code)

输出:

def forward(self,
    argument_1: Tensor) -> None:
  return None

def forward(self,
    input: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = self.dg
  _1 = (self.linear).forward(input, )
  _2 = (_0).forward(_1, )
  _3 = torch.tanh(torch.add(_1, h, alpha=1))
  return (_3, _3)

可以看到.code的输出,if-else的分支没有了,控制流会被擦除。

1.2 torch.jit.script

前面提到的问题,可以使用script compiler来解决,可以直接分析Python源代码来把它转化为TrochScript。如下:

scripted_gate = torch.jit.script(MyDecisionGate())  # script方式
my_cell = MyCell(scripted_gate)
scripted_cell = torch.jit.script(my_cell)  # script方式

print(scripted_gate.code)
print(scripted_cell.code)

输出:

def forward(self,
    x: Tensor) -> Tensor:
  _0 = bool(torch.gt(torch.sum(x, dtype=None), 0))
  if _0:
    _1 = x
  else:
    _1 = torch.neg(x)
  return _1

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  _0 = (self.dg).forward((self.linear).forward(x, ), )
  new_h = torch.tanh(torch.add(_0, h, alpha=1))
  return (new_h, new_h)

可以看到控制流保存下来了。

1.2 DenseTNT的TorchScript模型生成

torch.jit.script()会转换传入Module的所有代码,在实际转换模型的过程中会增加修改代码的工作量,因此通常将torch.jit.trace()和torch.jit.script()进行混合使用,比较灵活。

在需要使用控制流,如不定长的for循环、if-else分支时,在该函数上方输入@torch.jit.script 即可,如:

@torch.jit.script
def get_goal_2D(topk_lane_vector: Tensor, topk_points_mask: Tensor) -> Tensor:
    points = torch.zeros([1,2],device=topk_lane_vector.device)
    visit: Dict[int,bool]= {}
    for index_lane, lane_vector in enumerate(topk_lane_vector):
        for i, point in enumerate(lane_vector):
            if topk_points_mask[index_lane][i]:
                hash: int = int(torch.round((point[0] + 500) * 100) * 1000000 + torch.round((point[1] + 500) * 100))
                if hash not in visit:
                    visit[hash] = True
                    points = torch.cat([points,point.unsqueeze(0)],dim=0)
        point_num, divide_num = _get_subdivide_num(lane_vector, topk_points_mask[index_lane]) 
        if divide_num > 1:
            subdivide_points = _get_subdivide_points(lane_vector, point_num, divide_num)
            points = torch.cat([points,subdivide_points],dim=0)
    return points[1:]

再使用torch.jit.trace()将实例化的model和输入传入,即可生成TorchScript模型。

示例代码:

model.eval()
with torch.no_grad():
    traced_script_model = torch.jit.trace(model, script_inputs, strict=False)
    traced_script_model.save("models.densetnt.1/model_save/model.16_script.bin")
print(traced_script_model.code)
print('Finish converting model!!!')
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2022-07-03 10:48:34  更:2022-07-03 10:52:30 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2024年11日历 -2024/11/26 0:56:13-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码