IT数码 购物 网址 头条 软件 日历 阅读 图书馆
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
图片批量下载器
↓批量下载图片,美女图库↓
图片自动播放器
↓图片自动播放器↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁
 
   -> 人工智能 -> TorchScript学习笔记 -> 正文阅读

[人工智能]TorchScript学习笔记

TorchScript学习笔记

TorchScript是一种可从python代码中创建序列化模型的方法。可以从python代码中保存,并在非python环境中加载模型。注:TorchScript 主要实现的是在 PyTorch 中表示神经网络模型所需的 Python 功能,并不适用于所有的Python特性。

torch.jit

TorchScript是Pytorch的JIT实现。
JIT ,全称是 Just In Time Compilation(即时编译)。
JIT 是 Python 和 C++ 的桥梁,我们可以使用 Python 训练模型,然后通过 JIT 将模型转为语言无关的模块,从而可以使用C++把 PyTorch 模型部署到任意平台和设备上:树莓派、iOS、Android 等。(多线程执行和性能原因,一般Python代码并不适合做部署。)
导出模型主要有两种方式,Scripting和Tracing。

1.Tracing方式

tracing是相对简单的方式,输入向量,追踪向量在forward函数的流动来获得模型结构。
必须要有输入。
只适用于比较简单的模型,如果forward函数中有控制流结构,向量一次无法遍历所有的分支,这时就要借助script方式。

torch.jit.trace(func, example_inputs)

torch.jit.trace会将torch::jit::Module 转成 torch::jit::Graph 。
如果trace的是Python 函数,那么返回ScriptFunction , 如果是nn.Module.forward或者nn.Module,返回的就是ScriptModule。 如果trace的时候是eval/train模式,那么返回的ScriptModule就是eval/train模式。

  • 例子:trace一个函数
def sigmoid(z):
    s = 1 / (1 + 1 / np.exp(z))
    return s


x = torch.full((2, 2), 1)
print("x: ",x)
traced_func = torch.jit.trace(sigmoid, x)
print(traced_func.graph)
print(traced_func(x))

输出:

x:  tensor([[1, 1],
        [1, 1]])
graph(%0 : Long(2:2, 2:1, requires_grad=0, device=cpu)):
  %1 : Double(2:2, 2:1, requires_grad=0, device=cpu) = prim::Constant[value= 2.7183  2.7183  2.7183  2.7183 [ CPUDoubleType{2,2} ]]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
  %2 : Double(2:2, 2:1, requires_grad=0, device=cpu) = aten::reciprocal(%1) # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
  %3 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
  %4 : Double(2:2, 2:1, requires_grad=0, device=cpu) = aten::mul(%2, %3) # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
  %5 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # C:/Users/Administrator/Desktop/frcnn_1.7/lt_test.py:28:0
  %6 : int = prim::Constant[value=1]() # C:/Users/Administrator/Desktop/frcnn_1.7/lt_test.py:28:0
  %7 : Double(2:2, 2:1, requires_grad=0, device=cpu) = aten::add(%4, %5, %6) # C:/Users/Administrator/Desktop/frcnn_1.7/lt_test.py:28:0
  %8 : Double(2:2, 2:1, requires_grad=0, device=cpu) = aten::reciprocal(%7) # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
  %9 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
  %10 : Double(2:2, 2:1, requires_grad=0, device=cpu) = aten::mul(%8, %9) # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
  return (%10)

tensor([[0.7311, 0.7311],
        [0.7311, 0.7311]], dtype=torch.float64)
  • 例子:trace一个nn.Module的子类
import torch
from torchvision.models import resnet50


model = resnet50(pretrained=True)
model = model.eval()
resnet = torch.jit.trace(model, torch.rand(1,3,224,224))
print(resnet.graph)

输出:

graph(%self.1 : __torch__.torchvision.models.resnet.ResNet,
      %input.1 : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu)):
  %2664 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="fc"](%self.1)
  %2661 : __torch__.torch.nn.modules.pooling.AdaptiveAvgPool2d = prim::GetAttr[name="avgpool"](%self.1)
  %2660 : __torch__.torch.nn.modules.container.___torch_mangle_141.Sequential = prim::GetAttr[name="layer4"](%self.1)
  %2582 : __torch__.torch.nn.modules.container.___torch_mangle_113.Sequential = prim::GetAttr[name="layer3"](%self.1)
  %2435 : __torch__.torch.nn.modules.container.___torch_mangle_61.Sequential = prim::GetAttr[name="layer2"](%self.1)
  %2334 : __torch__.torch.nn.modules.container.___torch_mangle_25.Sequential = prim::GetAttr[name="layer1"](%self.1)
  %2256 : __torch__.torch.nn.modules.pooling.MaxPool2d = prim::GetAttr[name="maxpool"](%self.1)
  %2255 : __torch__.torch.nn.modules.activation.ReLU = prim::GetAttr[name="relu"](%self.1)
  %2254 : __torch__.torch.nn.modules.batchnorm.BatchNorm2d = prim::GetAttr[name="bn1"](%self.1)
  %2249 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv1"](%self.1)
  %2847 : Tensor = prim::CallMethod[name="forward"](%2249, %input.1)
  %2848 : Tensor = prim::CallMethod[name="forward"](%2254, %2847)
  %2849 : Tensor = prim::CallMethod[name="forward"](%2255, %2848)
  %2850 : Tensor = prim::CallMethod[name="forward"](%2256, %2849)
  %2851 : Tensor = prim::CallMethod[name="forward"](%2334, %2850)
  %2852 : Tensor = prim::CallMethod[name="forward"](%2435, %2851)
  %2853 : Tensor = prim::CallMethod[name="forward"](%2582, %2852)
  %2854 : Tensor = prim::CallMethod[name="forward"](%2660, %2853)
  %2855 : Tensor = prim::CallMethod[name="forward"](%2661, %2854)
  %2059 : int = prim::Constant[value=1]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torchvision\models\resnet.py:214:0
  %2060 : int = prim::Constant[value=-1]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torchvision\models\resnet.py:214:0
  %input : Float(1:2048, 2048:1, requires_grad=1, device=cpu) = aten::flatten(%2855, %2059, %2060) # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torchvision\models\resnet.py:214:0
  %2856 : Tensor = prim::CallMethod[name="forward"](%2664, %input)
  return (%2856)

2. Scripting方式

script方式通过解析AST的方式生成静态图。不需要有输入。Python ast官方文档

torch.jit.script

  • 例子:script一个resnet50+FPN的backbone
from backbone import resnet50_fpn_backbone
# device = torch.device('cpu')
model = resnet50_fpn_backbone()
# model.to(device).eval()
traced_model = torch.jit.script(model)
print(traced_model.graph)
# torch.jit.save(traced_model, 'saved_cpu.pt')

输出:

graph(%self : __torch__.backbone.resnet50_fpn_model.BackboneWithFPN,
      %x.1 : Tensor):
  %2 : __torch__.backbone.resnet50_fpn_model.IntermediateLayerGetter = prim::GetAttr[name="body"](%self)
  %x.3 : Dict(str, Tensor) = prim::CallMethod[name="forward"](%2, %x.1) # C:\Users\Administrator\Desktop\frcnn_1.7\backbone\resnet50_fpn_model.py:228:12
  %5 : __torch__.backbone.feature_pyramid_network.FeaturePyramidNetwork = prim::GetAttr[name="fpn"](%self)
  %x.5 : Dict(str, Tensor) = prim::CallMethod[name="forward"](%5, %x.3) # C:\Users\Administrator\Desktop\frcnn_1.7\backbone\resnet50_fpn_model.py:229:12
  return (%x.5)

TorchScript Type 类型解释

TorchScript 类型系统划分为 TSType 和 TSModuleType

  • TSType :
    Meta Types,如 Any。更像是类型约束,可以表示任何类型的类型。
    Primitive Types,如 int,float, str
    Structural Types,如 TSTuple ,TSNamedTuple,TSList ,TSDict ,TSOptional,TSFuture,TSRRef
    Nominal Types (Python classes),如 MyClass (自定义), torch.tensor (built-in)
  • TSModuleType:
    表示torch.nn.Module及其子类。因为它的定义部分来自对象,部分来自类定义,不属于静态类型,因此不能用作TorchScript type annotation,也不能够和TSType进行组合使用。
  人工智能 最新文章
2022吴恩达机器学习课程——第二课(神经网
第十五章 规则学习
FixMatch: Simplifying Semi-Supervised Le
数据挖掘Java——Kmeans算法的实现
大脑皮层的分割方法
【翻译】GPT-3是如何工作的
论文笔记:TEACHTEXT: CrossModal Generaliz
python从零学(六)
详解Python 3.x 导入(import)
【答读者问27】backtrader不支持最新版本的
上一篇文章      下一篇文章      查看所有文章
加:2021-11-20 18:23:44  更:2021-11-20 18:26:22 
 
开发: C++知识库 Java知识库 JavaScript Python PHP知识库 人工智能 区块链 大数据 移动开发 嵌入式 开发工具 数据结构与算法 开发测试 游戏开发 网络协议 系统运维
教程: HTML教程 CSS教程 JavaScript教程 Go语言教程 JQuery教程 VUE教程 VUE3教程 Bootstrap教程 SQL数据库教程 C语言教程 C++教程 Java教程 Python教程 Python3教程 C#教程
数码: 电脑 笔记本 显卡 显示器 固态硬盘 硬盘 耳机 手机 iphone vivo oppo 小米 华为 单反 装机 图拉丁

360图书馆 购物 三丰科技 阅读网 日历 万年历 2025年1日历 -2025/1/11 4:57:25-

图片自动播放器
↓图片自动播放器↓
TxT小说阅读器
↓语音阅读,小说下载,古典文学↓
一键清除垃圾
↓轻轻一点,清除系统垃圾↓
图片批量下载器
↓批量下载图片,美女图库↓
  网站联系: qq:121756557 email:121756557@qq.com  IT数码