首页 > 其他分享 >onnx模型导出

onnx模型导出

时间:2024-01-29 21:22:18浏览次数:25  
标签:onnx 模型 torch 导出 output input model

onnx模型导出

目录

环境准备

# 环境依赖
torch                      1.13.0+cu116
torchvision                0.14.0+cu116
onnx                       1.13.1
onnxruntime-gpu            1.15.0

简介介绍

ONNX(Open Neural Network Exchange)是 Facebook 和微软在 2017 年共同发布的,用于标准描述计算图的一种格式。

ONNX 已经对接了多种深度学习框架和多种推理引擎。因此,ONNX 被当成了深度学习框架到推理引擎的桥梁,就像编译器的中间语言一样。目前官方支持加载ONNX模型并进行推理的深度学习框架有: Caffe2, PyTorch, MXNet,ML.NET,TensorRT

onnx定义了一种可扩展的计算图模型\一系列内置的运算(op)和标准数据类型.每一个计算流图都定义为由节点组成的列表,并构建有向无环图

torch.onnx.export

def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
           input_names=None, output_names=None, operator_export_type=None,
           opset_version=None, _retain_param_name=None, do_constant_folding=True,
           example_outputs=None, strip_doc_string=None, dynamic_axes=None,
           keep_initializers_as_inputs=None, custom_opsets=None, enable_onnx_checker=None,
           use_external_data_format=None)

参数解析

可选 参数 说明
必填 model 需要转换的模型
必填 args 模型的输入,torch.Tensor
必填 f onnx模型导出的路径
必填 input_names 按顺序定义onnx模型输入张量名称,不设置的话,自动分配
可选 output_names 按顺序定义onnx模型输出张量名称,不设置的话,自动分配
可选 export_params=True 模型中是否存储模型权重,onnx是用同一个文件表示记录模型结构和权重,默认为True
可选 opset_version onnx 的 opset版本
可选 dynamic_axes 动态维度设置,指定输入输出张量的哪些维度是动态
可选 verbose=False 是否打印导出过程中的详细信息
dynamic_axes 为了追求效率,ONNX 默认所有参与运算的张量都是静态的(张量的形状不发生改变)。但在实际应用中,我们又希望模型的输入张量是动态的,尤其是本来就没有形状限制的全卷积模型。因此,我们需要显式地指明输入输出张量的哪几个维度的大小是可变的。

onnx导出步骤

1. 定义创建模型
2. 加载模型权重
3. 定义模型输入参数
4. 定义模型输入名称和输出名称 (输入节点-输出节点)
5. 使用torch.onnx.export()函数导出onnx
6. 自定义标签

单输入导出示例

定义并准备模型

import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision
import onnx
import onnxruntime

print(torch.__version__)
print(torchvision.__version__)
# 1.13.0+cu116
# 0.14.0+cu116

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SRNNnet(nn.Module):
    def __init__(self, upscale_factor=3):
        super().__init__()
        self.upscale_factor = upscale_factor
        self.img_upsampler = nn.Upsample(
            scale_factor=self.upscale_factor,
            mode='bicubic',
            align_corners=False)

        self.conv1 = nn.Conv2d(3,64,kernel_size=9,padding=4)
        self.conv2 = nn.Conv2d(64,32,kernel_size=1,padding=0)
        self.conv3 = nn.Conv2d(32,3,kernel_size=5,padding=2)

        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.img_upsampler(x)
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.conv3(out)
        return out
print(SRNNnet())
# ------------------------------------------------------------------------
SRNNnet(
  (img_upsampler): Upsample(scale_factor=3.0, mode=bicubic)
  (conv1): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (conv2): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
  (conv3): Conv2d(32, 3, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (relu): ReLU()
)

加载权重并测试

def print_state_dict(state_dict):    
    print(len(state_dict))
    for layer in state_dict:
        print(layer, '\t', state_dict[layer].shape)


def init_torch_model():
    torch_model = SRNNnet(upscale_factor=3)

    state_dict = torch.load('assets/srcnn.pth')['state_dict']
    print_state_dict(state_dict)
    
    # Adapt the checkpoint
    for old_key in list(state_dict.keys()):
        new_key = '.'.join(old_key.split('.')[1:])
        state_dict[new_key] = state_dict.pop(old_key)

    torch_model.load_state_dict(state_dict)
    torch_model.eval()
    print("init_torch_model success")
    return torch_model


def test_mode():
    
    torch_model = init_torch_model()

    input_img = cv2.imread('assets/dog.jpg').astype(np.float32)
    input_img = cv2.resize(input_img,(256,256))   
    # 固定图像大小为256x256
    # HWC to NCHW
    input_img = np.transpose(input_img, [2, 0, 1])
    input_img = np.expand_dims(input_img, 0)

    print(input_img.shape)
    torch_output = torch_model(torch.from_numpy(input_img)).detach().numpy()
    # NCHW to HWC
    torch_output = np.squeeze(torch_output, 0)
    torch_output = np.clip(torch_output, 0, 255)
    torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8)
    cv2.imwrite("assets/out.jpg", torch_output)

test_mode()

# ------------------------------------------------------------
6
generator.conv1.weight   torch.Size([64, 3, 9, 9])
generator.conv1.bias     torch.Size([64])
generator.conv2.weight   torch.Size([32, 64, 1, 1])
generator.conv2.bias     torch.Size([32])
generator.conv3.weight   torch.Size([3, 32, 5, 5])
generator.conv3.bias     torch.Size([3])
init_torch_model success
(1, 3, 256, 256)

onnx导出和验证

onnx导出后,需要进行检查,
检查onnx模型节点,
如果onnx算子不支持转engine时,方便定位节点,找到不支持的算子进行修改
def mode_export_onnx():

    model=init_torch_model()
    x = torch.randn(1, 3, 256, 256)

    input_names = ["input"]        # 定义onnx 输入节点名称
    output_names = ["output"]      # 定义onnx 输出节点名称

    with torch.no_grad():
        torch.onnx.export(
            model,
            x,
            "assets/srcnn.onnx",
            input_names=input_names,
            output_names=output_names,
            opset_version=11
            )
        
    print("mode_export_onnx success")

def test_onnx():

    onnx_model = onnx.load("assets/srcnn.onnx")
    try:
        onnx.checker.check_model(onnx_model)
     	print(onnx.helper.printable_graph(onnx_model.graph))
        graph = onnx_model.graph 
        print(graph.input)
        print(graph.output)
    except Exception:
        print("Model incorrect")
    else:
        print("Model correct")

# ----------------------------------------------------------------------0=--
[name: "input"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 3
      }
      dim {
        dim_value: 256
      }
      dim {
        dim_value: 256
      }
    }
  }
}
]
[name: "output"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 3
      }
      dim {
        dim_value: 768
      }
      dim {
        dim_value: 768
      }
    }
  }
}
]
onnx check_model success

Netron可视化

Netron 是一个开源的模型可视化工具,用于可视化深度学习模型的结构和参数。它可以加载和显示多种框架和模型格式,包括ONNX(Open Neural Network Exchange)、TensorFlow、Keras、Caffe、Core ML 等。通过图形界面,用户可以直观地查看模型的网络结构、层级关系、参数等信息

在线使用

https://netron.app/

image.png
之前定义的模型输入为 256x256
模型的输入 input=[1,3,256,256]

代码可视化

pip install netron
# 针对有网络模型,但还没有训练保存 .pth 文件的情况
import netron
import torch.onnx

netron.start(onnx_path)  # 输出网络结构

# http://localhost:8080

onnx模型推理

推理onnx模型,查看输出是否一致
def inter_onnx():
    input_img = cv2.imread('assets/dog.jpg').astype(np.float32)
    input_img = cv2.resize(input_img,(256,256))
    # HWC to NCHW
    input_img = np.transpose(input_img, [2, 0, 1])
    input_img = np.expand_dims(input_img, 0)
    
    ort_session = onnxruntime.InferenceSession("assets/srcnn.onnx",
                                               providers=['CPUExecutionProvider']
                                               )

    ort_inputs = {'input': input_img}
    ort_output = ort_session.run(['output'], ort_inputs)[0]

    ort_output = np.squeeze(ort_output, 0)
    ort_output = np.clip(ort_output, 0, 255)
    ort_output = np.transpose(ort_output, [1, 2, 0]).astype(np.uint8)
    cv2.imwrite("assets/out.jpg", ort_output)


mode_export_onnx()
test_onnx()
inter_onnx()

补充细节

添加自定义标签

model_onnx = onnx.load(f)  # load onnx model
onnx.checker.check_model(model_onnx)  # check onnx model

d={1:"person",2:"car",3:"dog"}
for k, v in d.items():
    meta = model_onnx.metadata_props.add()
    meta.key, meta.value = k, str(v)
onnx.save(model_onnx, f)

读取自定义标签

onnxmodel = onnx.load(f)  # load onnx model
meta = onnxmodel.get_modelmeta().custom_metadata_map
print( meta)
{1:"person",2:"car",3:"dog"}

导出注意

Pytorch模型在执行时是动态推导的,在运行之前并不知道整个推理的流程,ONNX模型是静态的,在推理时整个图已经构建完成。

动态的模型是数据边走边计算,静态的模型是在推理时先构建了一个图,然后数据从输入节点开始,按照拓扑关系一直流向输出节点。

这就导致在采用jit.trace(jit.script模式不讨论)方法进行模型导出时,遇到分支语句,Pytorch只会记录走过的路径,其他的路径将会直接被丢弃,

遇到while循环语句,Pytorch只会记录当前转模型的固定循环次数。换句话说,如果构成网络结构的某个循环次数是依赖与输入变量的,则循环的次数不可预期。

比如RNN网络,输入序列是不一样的,在解码的过程中,不知道要经过多少次循环,这时只能将RNN拆成一个个的小的单元,在外部根据实际情况对单元模块进行循环调用。

参考资料

OpenMMLab-模型部署简介

OpenMMLab-解决模型部署常见难题-动态多输入

OpenMMLab-PyTorch 转 ONNX 详解

知乎-OpenMMLab-模型部署入门教程(一):模型部署简介

知乎-OpenMMLab-模型部署入门教程(三):PyTorch 转 ONNX 详解

onnxsim

https://www.python100.com/html/89RQ4H08DH6S.html

https://www.python100.com/html/D0Q71A1IQ25I.html

标签:onnx,模型,torch,导出,output,input,model
From: https://www.cnblogs.com/tian777/p/17995359

相关文章

  • 美国宣布启动 NAIRR 计划打造 AI 帝国;Siri 将获大语言模型支持丨 RTE 开发者日报 Vol.
      开发者朋友们大家好: 这里是「RTE开发者日报」,每天和大家一起看新闻、聊八卦。我们的社区编辑团队会整理分享RTE(RealTimeEngagement)领域内「有话题的新闻」、「有态度的观点」、「有意思的数据」、「有思考的文章」、「有看点的会议」,但内容仅代表编辑......
  • 「效果图渲染」怎么用VRay渲染逼真的物理模型
    使用V-Ray渲染出逼真的物理模型首先要注重材质和光照的真实性。精细调整材质属性,如反射、透明度和质感,确保它们与现实世界中物质的特性相一致。接下来,布置合适的光源,模拟自然光线的行为,创建真实的光影效果。通过这两个基本步骤,即可开始打造高度逼真的三维渲染作品。VRay渲染室内......
  • plsql 导出表结构和表数据并导入
    1.情景展示如何完成oracle表结构和表数据的导入和导出?2.导出表结构和表数据点击“工具”-->选择“导出表”;单击选中要导出的表;因为我要建表,所以需要勾选上“创建表”选项;其余的选项对我来说没用,那就全部取消勾选;点击右侧的文件夹按钮,选择要导出SQL文件的位置以及文件名......
  • 产品解读 | 新一代湖仓集存储,多模型统一架构,高效挖掘数据价值
    星环科技TDH一直致力于给用户带来高性能、高可靠的一站式大数据基础平台,满足对海量数据的存储和复杂业务的处理需求。同时在易用性方面持续深耕,降低用户开发和运维成本,让数据处理平民化,助力用户以更便捷、高效的方式去挖掘数据价值。基于这样的宗旨,星环科技TDH正式发布了9.3......
  • 鸿蒙Stage模型--概述
    Stage模型:HarmonyOS3.1DevelperPreview版本开始新增的模型,是目前主推且会长期演进的模型。在该模型中,由于提供了AbilityStage、WindowStage等类作为应用组件和Window窗口的“舞台”,因此称这种应用模型为Stage模型。设计思想Stage模型之所以成为主推模型,源于其设计思想。Stage模......
  • LangChain大模型应用开发指南:从基础链式结构到ReAct对话解构
    在自然语言处理领域,大模型的应用已经成为了一种趋势。LangChain是一个基于深度学习的自然语言处理框架,它通过使用链式结构和ReAct对话模型,为开发者提供了一种高效、灵活的方式来进行大模型应用开发。本指南将介绍如何从基础链式结构开始,逐步构建ReAct对话解构,以实现自然语言处理应......
  • 生物科学大模型
    随着大数据和人工智能技术的迅速发展,生物科学大模型已经成为生物医学领域的研究热点。这些大规模的模型能够模拟生物系统的复杂行为,为药物研发、疾病诊断和治疗提供了全新的视角。本文将介绍生物科学大模型的原理、应用案例,以及如何应对其中的挑战,希望能为相关领域的读者提供一些启......
  • 帕金森早期诊断准确率提高至 90.2%,深圳先进院联合中山一院提出 GSP-GCNs 模型
    中山大学附属第一医院&中科大先进院等研究团队,提出了一种深度学习模型——图信号处理-图卷积网络(GSP-GCNs),利用从涉及声调调节的特定任务中获得的事件相关脑电图数据来诊断帕金森病。震颤、动作迟缓、表情僵硬……提起帕金森病,多数人会率先想到「手抖」,殊不知,在患病中晚期,患者甚......
  • 大模型技术与我们的生活
    随着科技的飞速发展,大模型技术已经成为我们生活中不可或缺的一部分。它以其强大的处理能力和深度学习能力,在语音识别、自然语言处理、图像识别等领域取得了显著成果。本文将深入探讨大模型技术如何影响我们的生活,以及我们如何更好地利用这一技术。首先,让我们了解一下什么是大模型技......
  • 使用核模型高斯过程(KMGPs)进行数据建模
    核模型高斯过程(KMGPs)作为一种复杂的工具可以处理各种数据集的复杂性。他通过核函数来扩展高斯过程的传统概念。本文将深入探讨kmgp的理论基础、实际应用以及它们所面临的挑战。核模型高斯过程是机器学习和统计学中对传统高斯过程的一种扩展。要理解kmgp,首先掌握高斯过程的基础......