TorchScript模型
目录TorchScript是PyTorch模型(nn.Module
的子类)的中间表示,可以在高性能环境(例如C ++)中运行
具有一下特点:
1.TorchScript代码可以在其自己的解释器中调用,不被全局解释器锁定,因此可以在同 一实例上同时处理许多请求
2.这种格式使我们可以将整个模型保存,并将其加载到另一个环境
3.TorchScript为我们提供了一种表示形式,可以对代码进行编译器优化
4.TorchScript允许我们与许多后端/设备运行时进行接口
Tracing(跟踪)
PyTorch具有灵活和动态的特性,TorchScript也提供了捕获模型定义的工具
trace 模式顾名思义就是跟踪模型的执行,然后记录执行过程中的路径。在使用 trace 模式时,需要构造一个符合要求的输入,然后使用 TorchScript tracer 运行一遍,记录整个运行过程。在 trace 模式中运行时,每执行一个算子,就会往当前的 graph 加入一个 node。所有代码执行完毕,每一步的操作就会以一个计算图里的某个节点的形式被保存下来。值得一提的是,PyTorch 导出 ONNX 也是使用了这部分代码,所以理论上能够导出 ONNX 的模型也能够使用 trace 模式导出 TorchScript 格式模型。
trace 模式有以下2点限制:
- 不能有 if-else 等控制流, 不支持控制流
- 只支持 Tensor 操作。不支持非Tensor 操作,如List、Tuple、Map 等容器操作
可以检测其.graph
属性的图,这是一个非常低级的表示形式,图中包含的大多数信息对最终用户没有用
使用.code
属性来给出代码的Python语法解释
torch.jit.trace(obj)
# 简单示例
import torch
torch.manual_seed(42)
class TestTraceCell(torch.nn.Module):
def __init__(self):
super(TestTraceCell, self).__init__()
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.linear(x) + h)
return new_h, new_h
test_cell = TestTraceCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(test_cell, (x, h)) # 无控制流,追踪
print(traced_cell)
# TestTraceCell(
# original_name=TestTraceCell
# (linear): Linear(original_name=Linear)
# )
# 代码的 Python 语法解释
print(traced_cell.code)
# def forward(self,
# x: Tensor,
# h: Tensor) -> Tuple[Tensor, Tensor]:
# linear = self.linear
# _0 = torch.tanh(torch.add((linear).forward(x, ), h))
# return (_0, _0)
print(traced_cell(x,h))
# (tensor([[0.9567, 0.6879, 0.2618, 0.7927],
# [0.8227, 0.7464, 0.4165, 0.5366],
# [0.8193, 0.1679, 0.8132, 0.9052]], grad_fn=<TanhBackward0>), tensor([[0.9567, 0.6879, 0.2618, 0.7927],
# [0.8227, 0.7464, 0.4165, 0.5366],
# [0.8193, 0.1679, 0.8132, 0.9052]], grad_fn=<TanhBackward0>))
# TorchScript 在中间表示(或 IR)中记录其定义,在深度学习中通常将其称为图。 我们可以使用.graph属性检查图形
print(traced_cell.graph)
# graph(%self.1 : __torch__.TestTraceCell,
# %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
# %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
# %linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
# %20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)
# %11 : int = prim::Constant[value=1]() # d:\Note\lcodeNoteCards\testcode\pytorch\learntc.py:11:0
# %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # d:\Note\lcodeNoteCards\testcode\pytorch\learntc.py:11:0
# %13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # d:\Note\lcodeNoteCards\testcode\pytorch\learntc.py:11:0
# %14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)
# return (%14)
Scripting(脚本)
import torch
torch.manual_seed(42)
class MyDecisionGate(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x
else:
return -x
class MyCell(torch.nn.Module):
def __init__(self, dg):
super(MyCell, self).__init__()
self.dg = dg
self.linear = torch.nn.Linear(4, 4)
def forward(self, x, h):
new_h = torch.tanh(self.dg(self.linear(x)) + h)
return new_h, new_h
x, h = torch.rand(3, 4), torch.rand(3, 4)
my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell.code)
# TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
# if x.sum() > 0:
# def forward(self,
# x: Tensor,
# h: Tensor) -> Tuple[Tensor, Tensor]:
# dg = self.dg
# linear = self.linear
# _0 = (linear).forward(x, )
# _1 = (dg).forward(_0, )
# _2 = torch.tanh(torch.add(_0, h))
# return (_2, _2)
scripted_gate = torch.jit.script(MyDecisionGate())
my_cell = MyCell(scripted_gate)
traced_cell = torch.jit.script(my_cell)
print(traced_cell.code)
# def forward(self,
# x: Tensor,
# h: Tensor) -> Tuple[Tensor, Tensor]:
# dg = self.dg
# linear = self.linear
# _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
# new_h = torch.tanh(_0)
# return (new_h, new_h)
保存和加载
模型保存
torch.jit.save(m, f, _extra_files=None)
m – A ScriptModule to save.
f – a file name.
_extra_files – 配置额外的映射文件
import torch
import io
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
m = torch.jit.script(MyModule())
# Save to file
torch.jit.save(m, 'scriptmodule.torchscript')
# This line is equivalent to the previous
m.save("scriptmodule.torchscript")
# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.jit.save(m, buffer)
# Save with extra files
extra_files = {'foo.txt': b'bar'} # 将需要的额外映射信息,映射到这里
torch.jit.save(m, 'scriptmodule.torchscript', _extra_files=extra_files)
模型加载
torch.jit.load(f, map_location=None, _extra_files=None, _restore_shapes=False)
f - (Union[str, PathLike, BinaryIO, IO[bytes]])
map_location - torch.device
_extra_files - 输出额外的映射文件
import torch
import io
torch.jit.load('scriptmodule.torchscript')
# Load ScriptModule from io.BytesIO object
with open('scriptmodule.torchscript', 'rb') as f:
buffer = io.BytesIO(f.read())
# Load all tensors to the original device
torch.jit.load(buffer)
# Load all tensors onto CPU, using a device
buffer.seek(0)
torch.jit.load(buffer, map_location=torch.device('cpu'))
# Load all tensors onto CPU, using a string
buffer.seek(0)
torch.jit.load(buffer, map_location='cpu')
# Load with extra files.
extra_files = {'foo.txt': ''} # 对额外的映射信息进行解析和替换
torch.jit.load('scriptmodule.torchscript', _extra_files=extra_files)
print(extra_files['foo.txt'])
特别注意
_extra_files 可以用来对非模型相关信息的保存和传递,例如分类模型和检测模型中的标签信息等
import torch
import io
import json
class MyModule(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.stride=32
self.names= ['a','b','c']
def forward(self, im):
x = im.shape[0]
return x + 10
model = MyModule()
im = torch.randn(32, 3, 224, 224)
d = {"shape": im.shape, "stride": model.stride, "names": model.names}
m = torch.jit.script(MyModule())
# Save with extra files
extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
torch.jit.save(m, 'scriptmodule.torchscript', _extra_files=extra_files)
print(extra_files)
# extra_files = {'config.txt': ""} # torch._C.ExtraFilesMap()
extra_files = {'config.txt':''}
mode=torch.jit.load('scriptmodule.torchscript',_extra_files=extra_files)
print(mode)
print(extra_files)
yolov8模型导出说明
class Exporter:
def __call__(self):
self.metadata = {
'description': description,
'author': 'Ultralytics',
'license': 'AGPL-3.0 https://ultralytics.com/license',
'date': datetime.now().isoformat(),
'version': __version__,
'stride': int(max(model.stride)),
'task': model.task,
'batch': self.args.batch,
'imgsz': self.imgsz,
'names': model.names} # model metadata
def export_torchscript(self, prefix=colorstr('TorchScript:')):
"""YOLOv8 TorchScript model export."""
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
f = self.file.with_suffix('.torchscript')
ts = torch.jit.trace(self.model, self.im, strict=False) # 采用追踪的方式保存为TorchScript
# 添加额外的config包含了相关labels信息
extra_files = {'config.txt': json.dumps(self.metadata)} # torch._C.ExtraFilesMap()
# 针对移动端的优化
if self.args.optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
LOGGER.info(f'{prefix} optimizing for mobile...')
from torch.utils.mobile_optimizer import optimize_for_mobile
optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
else:
ts.save(str(f), _extra_files=extra_files)
return f, None
参考资料
https://www.w3cschool.cn/pytorch/pytorch-ea8n3bsm.html
https://pytorch.panchuang.net/EigthSection/torchScript/
https://zhuanlan.zhihu.com/p/135911580
标签:TorchScript,files,linear,extra,模型,torch,jit,self From: https://www.cnblogs.com/tian777/p/17815011.html