首页 > 其他分享 >9 PyTorch的模型部署

9 PyTorch的模型部署

时间:2022-08-28 16:46:43浏览次数:55  
标签:部署 onnx 模型 PyTorch ONNX input model ort

9.1 ONNX( Open Neural Network Exchange) 简介

  ONNX( Open Neural Network Exchange) 通过定义一组与环境和平台无关的标准格式,使AI模型可以在不同框架和环境下交互使用,可以看作深度学习框架和部署端的桥梁。

  PyTorch部署流水线:PyTorch --> ONNX --> ONNX Runtime ,只需要将模型转换为 .onnx 文件,并在 ONNX Runtime 上运行模型即可。

  安装:
# 激活虚拟环境
conda activate env_name # env_name换成环境名称
# 安装onnx
pip install onnx 
# 安装onnx runtime
pip install onnxruntime # 使用CPU进行推理
# pip install onnxruntime-gpu # 使用GPU进行推理

  除此之外,我们还需要注意ONNX和ONNX Runtime之间的适配关系。我们可以访问ONNX Runtime的Github进行查看,链接地址如下:

  ONNX和ONNX Runtime的适配关系:https://github.com/microsoft/onnxruntime/blob/master/docs/Versioning.md

  当我们想使用GPU进行推理时,我们需要先将安装的onnxruntime卸载,再安装onnxruntime-gpu,同时我们还需要考虑ONNX Runtime与CUDA之间的适配关系,我们可以参考以下链接进行查看:

ONNX Runtime和CUDA之间的适配关系:https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html

9.2 模型导出为ONNX

9.2.1 模型转换为ONNX格式

  使用torch.onnx.export()把模型转换成 ONNX 格式的函数。模型导成onnx格式前,我们必须调用model.eval()或者model.train(False)以确保我们的模型处在推理模式下,避免因为dropoutbatchnorm等运算符在推理和训练模式下的不同产生错误。

import torch.onnx 
# 转换的onnx格式的名称,文件后缀需为.onnx
onnx_file_name = "xxxxxx.onnx"
# 我们需要转换的模型,将torch_model设置为自己的模型
model = torch_model
# 加载权重,将model.pth转换为自己的模型权重
# 如果模型的权重是使用多卡训练出来,我们需要去除权重中多的module. 具体操作可以见5.4节
model = model.load_state_dict(torch.load("model.pth"))
# 导出模型前,必须调用model.eval()或者model.train(False)
model.eval()
# dummy_input就是一个输入的实例,仅提供输入shape、type等信息 
batch_size = 1 # 随机的取值,当设置dynamic_axes后影响不大
dummy_input = torch.randn(batch_size, 1, 224, 224, requires_grad=True) 
# 这组输入对应的模型输出
output = model(dummy_input)
# 导出模型
torch.onnx.export(model,        # 模型的名称
                  dummy_input,   # 一组实例化输入
                  onnx_file_name,   # 文件保存路径/名称
                  export_params=True,        #  如果指定为True或默认, 参数也会被导出. 如果你要导出一个没训练过的就设为 False.
                  opset_version=10,          # ONNX 算子集的版本,当前已更新到15
                  do_constant_folding=True,  # 是否执行常量折叠优化
                  input_names = ['input'],   # 输入模型的张量的名称
                  output_names = ['output'], # 输出模型的张量的名称
                  # dynamic_axes将batch_size的维度指定为动态,
                  # 后续进行推理的数据可以与导出的dummy_input的batch_size不同
                  dynamic_axes={'input' : {0 : 'batch_size'},    
                                'output' : {0 : 'batch_size'}})

9.2.2 ONNX模型的检验

  使用onnx.checker.check_model()对导出的模型进行检验

import onnx
# 我们可以使用异常处理的方法进行检验
try:
    # 当我们的模型不可用时,将会报出异常
    onnx.checker.check_model(self.onnx_model)
except onnx.checker.ValidationError as e:
    print("The model is invalid: %s"%e)
else:
    # 模型可用时,将不会报出异常,并会输出“The model is valid!”
    print("The model is valid!")

9.2.3 ONNX可视化

  使用Netron实现可视化操作,Netron下载网址:https://github.com/lutzroeder/netron。

9.3 使用ONNX Runtime进行推理

# 导入onnxruntime
import onnxruntime
# 需要进行推理的onnx模型文件名称
onnx_file_name = "xxxxxx.onnx"

# onnxruntime.InferenceSession用于获取一个 ONNX Runtime 推理器
ort_session = onnxruntime.InferenceSession(onnx_file_name)  

# 构建字典的输入数据,字典的key需要与我们构建onnx模型时的input_names相同
# 输入的input_img 也需要改变为ndarray格式
ort_inputs = {'input': input_img} 
# 我们更建议使用下面这种方法,因为避免了手动输入key
# ort_inputs = {ort_session.get_inputs()[0].name:input_img}

# run是进行模型的推理,第一个参数为输出张量名的列表,一般情况可以设置为None
# 第二个参数为构建的输入值的字典
# 由于返回的结果被列表嵌套,因此我们需要进行[0]的索引
ort_output = ort_session.run(None,ort_inputs)[0]
# output = {ort_session.get_outputs()[0].name}
# ort_output = ort_session.run([output], ort_inputs)[0]

  在上述的步骤中,我们有几个需要注意的点:

  • PyTorch模型的输入为tensor,而ONNX的输入为array,因此我们需要对张量进行变换或者直接将数据读取为array格式,我们可以实现下面的方式进行张量到array的转化。
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
  • 输入的array的shape应该和我们导出模型的dummy_input的shape相同,如果图片大小不一样,我们应该先进行resize操作。
  • run的结果是一个列表,我们需要进行索引操作才能获得array格式的结果。
  • 在构建输入的字典时,我们需要注意字典的key应与导出ONNX格式设置的input_name相同,因此我们更建议使用上述的第二种方法构建输入的字典。
 

标签:部署,onnx,模型,PyTorch,ONNX,input,model,ort
From: https://www.cnblogs.com/5466a/p/16633033.html

相关文章

  • EvaluationSystem:数据库模型建立
    1、用户table(./models/user.js)用户字段:useraccount:账号(主键)nickname:昵称password:密码evalnum:已参与测评数量2、数据table(./models/data.js)数据字段:name:数据名字......
  • 8. PyTorch生态简介
    由于本人未接触过也并未打算从事图像相关工作,所以只介绍了torchtext生态。有关torchvision和PytorchViseo只是了解了一下并未进行笔记输出。torchtext简介torch......
  • PyTorch Geometric(pyg)学习
    参考2个链接: 第十六课.Pytorch-geometric入门(一)_tzc_fly的博客-CSDN博客_pytorch-geometric 第十七课.Pytorch-geometric入门(二)_tzc_fly的博客-CSDN博客......
  • jenkins部署执行完成提示:Finished: UNSTABLE
    执行完提示:Finished:UNSTABLE原因:我遇到的这个提示因为测试时间超时解决方法:在配置的“build”中wvn命令中将命令:cleaninstall,修改为添加跳过测试时间:cleaninstall-......
  • Linux多节点部署KubeSphere
    官网参考1.使用KubeKey创建集群(master节点)#下载KubeKeyexportKKZONE=cncurl-sfLhttps://get-kk.kubesphere.io|VERSION=v2.2.1sh-chmod+xkk#创建集群配置......
  • Nginx分布式框架详解46-56nginx静态资源部署02
    error_page指令error_page指令是设置网站的错误页面。语法默认值位置error_page......[=[response]];—http、server、location......code是响应......
  • 部署若依微服务全流程(前置条件))
    (第一步)安装jdk(1)查看是否安装jdk java-version如果显示jdk版本则表示已安装jdk,显示其他则说明未安装(2)下载jdk,这里下载的jdk1.8下面是官网下载地址 https:......
  • pytorch转为mindspore模型
    MindConverter将PyTorch(ONNX)模型快速迁移到MindSpore框架下使用。第一步:pytorch模型转onnx:importtorch#根据实际情况替换以下类路径fromcustomized.path.to.py......
  • 深入理解“字符编码模型”
    深入理解“字符编码模型”作者:哲思时间:2022.8.28邮箱:[email protected]:zhe-si(哲思)(github.com)前言最近踩坑了后端的文档生成,本想写篇相关的实践总结,忽然......
  • Asp.Net Core 项目部署Centos中,httpClient 请求Https报证书错误的系列问题
    参考自https://www.cnblogs.com/leoxjy/p/10201046.html#5095270Centos报这个问题,Asp.NetCore3.1HttpClient请求Https报错的SSL证书异常的问题,请使用以下方法解决......