TorchScript学习笔记
TorchScript是一种可从python代码中创建序列化模型的方法。可以从python代码中保存,并在非python环境中加载模型。注:TorchScript 主要实现的是在 PyTorch 中表示神经网络模型所需的 Python 功能,并不适用于所有的Python特性。
torch.jit
TorchScript是Pytorch的JIT实现。 JIT ,全称是 Just In Time Compilation(即时编译)。 JIT 是 Python 和 C++ 的桥梁,我们可以使用 Python 训练模型,然后通过 JIT 将模型转为语言无关的模块,从而可以使用C++把 PyTorch 模型部署到任意平台和设备上:树莓派、iOS、Android 等。(多线程执行和性能原因,一般Python代码并不适合做部署。) 导出模型主要有两种方式,Scripting和Tracing。
1.Tracing方式
tracing是相对简单的方式,输入向量,追踪向量在forward函数的流动来获得模型结构。 必须要有输入。 只适用于比较简单的模型,如果forward函数中有控制流结构,向量一次无法遍历所有的分支,这时就要借助script方式。
torch.jit.trace(func, example_inputs)
torch.jit.trace会将torch::jit::Module 转成 torch::jit::Graph 。 如果trace的是Python 函数,那么返回ScriptFunction , 如果是nn.Module.forward或者nn.Module,返回的就是ScriptModule。 如果trace的时候是eval/train模式,那么返回的ScriptModule就是eval/train模式。
def sigmoid(z):
s = 1 / (1 + 1 / np.exp(z))
return s
x = torch.full((2, 2), 1)
print("x: ",x)
traced_func = torch.jit.trace(sigmoid, x)
print(traced_func.graph)
print(traced_func(x))
输出:
x: tensor([[1, 1],
[1, 1]])
graph(%0 : Long(2:2, 2:1, requires_grad=0, device=cpu)):
%1 : Double(2:2, 2:1, requires_grad=0, device=cpu) = prim::Constant[value= 2.7183 2.7183 2.7183 2.7183 [ CPUDoubleType{2,2} ]]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
%2 : Double(2:2, 2:1, requires_grad=0, device=cpu) = aten::reciprocal(%1) # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
%3 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
%4 : Double(2:2, 2:1, requires_grad=0, device=cpu) = aten::mul(%2, %3) # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
%5 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # C:/Users/Administrator/Desktop/frcnn_1.7/lt_test.py:28:0
%6 : int = prim::Constant[value=1]() # C:/Users/Administrator/Desktop/frcnn_1.7/lt_test.py:28:0
%7 : Double(2:2, 2:1, requires_grad=0, device=cpu) = aten::add(%4, %5, %6) # C:/Users/Administrator/Desktop/frcnn_1.7/lt_test.py:28:0
%8 : Double(2:2, 2:1, requires_grad=0, device=cpu) = aten::reciprocal(%7) # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
%9 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
%10 : Double(2:2, 2:1, requires_grad=0, device=cpu) = aten::mul(%8, %9) # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torch\tensor.py:519:0
return (%10)
tensor([[0.7311, 0.7311],
[0.7311, 0.7311]], dtype=torch.float64)
import torch
from torchvision.models import resnet50
model = resnet50(pretrained=True)
model = model.eval()
resnet = torch.jit.trace(model, torch.rand(1,3,224,224))
print(resnet.graph)
输出:
graph(%self.1 : __torch__.torchvision.models.resnet.ResNet,
%input.1 : Float(1:150528, 3:50176, 224:224, 224:1, requires_grad=0, device=cpu)):
%2664 : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="fc"](%self.1)
%2661 : __torch__.torch.nn.modules.pooling.AdaptiveAvgPool2d = prim::GetAttr[name="avgpool"](%self.1)
%2660 : __torch__.torch.nn.modules.container.___torch_mangle_141.Sequential = prim::GetAttr[name="layer4"](%self.1)
%2582 : __torch__.torch.nn.modules.container.___torch_mangle_113.Sequential = prim::GetAttr[name="layer3"](%self.1)
%2435 : __torch__.torch.nn.modules.container.___torch_mangle_61.Sequential = prim::GetAttr[name="layer2"](%self.1)
%2334 : __torch__.torch.nn.modules.container.___torch_mangle_25.Sequential = prim::GetAttr[name="layer1"](%self.1)
%2256 : __torch__.torch.nn.modules.pooling.MaxPool2d = prim::GetAttr[name="maxpool"](%self.1)
%2255 : __torch__.torch.nn.modules.activation.ReLU = prim::GetAttr[name="relu"](%self.1)
%2254 : __torch__.torch.nn.modules.batchnorm.BatchNorm2d = prim::GetAttr[name="bn1"](%self.1)
%2249 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv1"](%self.1)
%2847 : Tensor = prim::CallMethod[name="forward"](%2249, %input.1)
%2848 : Tensor = prim::CallMethod[name="forward"](%2254, %2847)
%2849 : Tensor = prim::CallMethod[name="forward"](%2255, %2848)
%2850 : Tensor = prim::CallMethod[name="forward"](%2256, %2849)
%2851 : Tensor = prim::CallMethod[name="forward"](%2334, %2850)
%2852 : Tensor = prim::CallMethod[name="forward"](%2435, %2851)
%2853 : Tensor = prim::CallMethod[name="forward"](%2582, %2852)
%2854 : Tensor = prim::CallMethod[name="forward"](%2660, %2853)
%2855 : Tensor = prim::CallMethod[name="forward"](%2661, %2854)
%2059 : int = prim::Constant[value=1]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torchvision\models\resnet.py:214:0
%2060 : int = prim::Constant[value=-1]() # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torchvision\models\resnet.py:214:0
%input : Float(1:2048, 2048:1, requires_grad=1, device=cpu) = aten::flatten(%2855, %2059, %2060) # E:\Anaconda2021\envs\torch1.7\lib\site-packages\torchvision\models\resnet.py:214:0
%2856 : Tensor = prim::CallMethod[name="forward"](%2664, %input)
return (%2856)
2. Scripting方式
script方式通过解析AST的方式生成静态图。不需要有输入。Python ast官方文档
torch.jit.script
- 例子:script一个resnet50+FPN的backbone
from backbone import resnet50_fpn_backbone
# device = torch.device('cpu')
model = resnet50_fpn_backbone()
# model.to(device).eval()
traced_model = torch.jit.script(model)
print(traced_model.graph)
# torch.jit.save(traced_model, 'saved_cpu.pt')
输出:
graph(%self : __torch__.backbone.resnet50_fpn_model.BackboneWithFPN,
%x.1 : Tensor):
%2 : __torch__.backbone.resnet50_fpn_model.IntermediateLayerGetter = prim::GetAttr[name="body"](%self)
%x.3 : Dict(str, Tensor) = prim::CallMethod[name="forward"](%2, %x.1) # C:\Users\Administrator\Desktop\frcnn_1.7\backbone\resnet50_fpn_model.py:228:12
%5 : __torch__.backbone.feature_pyramid_network.FeaturePyramidNetwork = prim::GetAttr[name="fpn"](%self)
%x.5 : Dict(str, Tensor) = prim::CallMethod[name="forward"](%5, %x.3) # C:\Users\Administrator\Desktop\frcnn_1.7\backbone\resnet50_fpn_model.py:229:12
return (%x.5)
TorchScript Type 类型解释
TorchScript 类型系统划分为 TSType 和 TSModuleType
- TSType :
Meta Types,如 Any。更像是类型约束,可以表示任何类型的类型。 Primitive Types,如 int,float, str Structural Types,如 TSTuple ,TSNamedTuple,TSList ,TSDict ,TSOptional,TSFuture,TSRRef Nominal Types (Python classes),如 MyClass (自定义), torch.tensor (built-in) - TSModuleType:
表示torch.nn.Module及其子类。因为它的定义部分来自对象,部分来自类定义,不属于静态类型,因此不能用作TorchScript type annotation,也不能够和TSType进行组合使用。
|