首页 > 其他分享 >onnx导出-多输入+动态维度

onnx导出-多输入+动态维度

时间:2024-01-29 21:22:53浏览次数:30  
标签:onnx ort torch 导出 input 维度 output model

onnx导出-多输入+动态维度

目录

常见问题

多参数输入

import numpy as np
import cv2
import torch
import torch.nn as nn
import torchvision
import onnx
import onnxruntime

print(torch.__version__)
print(torchvision.__version__)
# 1.13.0+cu116
# 0.14.0+cu116


class SRNNnet(nn.Module):
    def __init__(self, upscale_factor=3):
        super().__init__()
        self.upscale_factor = upscale_factor
        self.img_upsampler = nn.Upsample(
            scale_factor=self.upscale_factor,
            mode='bicubic',
            align_corners=False)

        self.conv1 = nn.Conv2d(3,64,kernel_size=9,padding=4)
        self.conv2 = nn.Conv2d(64,32,kernel_size=1,padding=0)
        self.conv3 = nn.Conv2d(32,3,kernel_size=5,padding=2)

        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.img_upsampler(x)
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.conv3(out)
        return out

def print_state_dict(state_dict):    
    print(len(state_dict))
    for layer in state_dict:
        print(layer, '\t', state_dict[layer].shape)


def init_torch_model():
    torch_model = SRNNnet(upscale_factor=3)

    state_dict = torch.load('assets/srcnn.pth')['state_dict']
    print_state_dict(state_dict)
    
    # Adapt the checkpoint
    for old_key in list(state_dict.keys()):
        new_key = '.'.join(old_key.split('.')[1:])
        state_dict[new_key] = state_dict.pop(old_key)

    torch_model.load_state_dict(state_dict)
    torch_model.eval()
    print("init_torch_model success")
    return torch_model


def test_mode():
    
    torch_model = init_torch_model()

    input_img = cv2.imread('assets/dog.jpg').astype(np.float32)
    #  # 固定图像大小为256x256
    input_img = cv2.resize(input_img,(256,256))  
    # HWC to NCHW
    input_img = np.transpose(input_img, [2, 0, 1])
    input_img = np.expand_dims(input_img, 0)

    print(input_img.shape)
    torch_output = torch_model(torch.from_numpy(input_img)).detach().numpy()
    # NCHW to HWC
    torch_output = np.squeeze(torch_output, 0)
    torch_output = np.clip(torch_output, 0, 255)
    torch_output = np.transpose(torch_output, [1, 2, 0]).astype(np.uint8)
    cv2.imwrite("assets/out.jpg", torch_output)

test_mode()

# ------------------------------------------------------------
6
generator.conv1.weight   torch.Size([64, 3, 9, 9])
generator.conv1.bias     torch.Size([64])
generator.conv2.weight   torch.Size([32, 64, 1, 1])
generator.conv2.bias     torch.Size([32])
generator.conv3.weight   torch.Size([3, 32, 5, 5])
generator.conv3.bias     torch.Size([3])
init_torch_model success
(1, 3, 256, 256)

动态输入

模型的动态化。出于性能的考虑,各推理框架都默认模型的输入形状、输出形状、结构是静态的。

而为了让模型的泛用性更强,部署时需要在尽可能不影响原有逻辑的前提下,让模型的输入输出或是结构动态化。

上面模型固定了 输入图像维度为 256x256
输入张量维度为      (1, 3, 256, 256)
如何使得模型适配任何图像维度的输入?

导出动态输入

问题-无法修改维度

通过 torch.onnx.export() 的 dynamic_axes 参数来指定动态输入和静态输入
dynamic_axes 的默认值为 None,即默认为静态输入。静态参数无法修改输入数据的维度
如下示例
def mode_export_onnx():
    model=init_torch_model()
    x = torch.randn(1, 3, 256, 256)

    input_names = ["input"]        # 定义onnx 输入节点名称
    output_names = ["output"]      # 定义onnx 输出节点名称

    with torch.no_grad():
        torch.onnx.export(
            model,
            x,
            "assets/srcnn_cpu.onnx",
            input_names=input_names,
            output_names=output_names,
            opset_version=11,
            )
    # 导出模型-验证和测试模型
def test_onnx_inter_onnx():
    onnx_model = onnx.load("assets/srcnn_cpu.onnx")
    try:
        onnx.checker.check_model(onnx_model)
    except Exception:
        print("onnx incorrect")
    else:
        print("onnx check_model success")

    input_img = cv2.imread('assets/images/dog.jpg').astype(np.float32)
    input_img = cv2.resize(input_img,(256,320))
    # 设置导出维度大小为256,320
    # HWC to NCHW
    input_img = np.transpose(input_img, [2, 0, 1])
    input_img = np.expand_dims(input_img, 0)
    print(input_img.shape)
    # # 输入维度为(1, 3, 320, 256)
    ort_session = onnxruntime.InferenceSession("assets/srcnn_cpu.onnx",
                                               providers=['CPUExecutionProvider']
                                               )

    ort_inputs = {'input': input_img}
    ort_output = ort_session.run(['output'], ort_inputs)[0]

    ort_output = np.squeeze(ort_output, 0)
    ort_output = np.clip(ort_output, 0, 255)
    ort_output = np.transpose(ort_output, [1, 2, 0]).astype(np.uint8)
    print(ort_output.shape)
#  报错 报错信息如下
Traceback (most recent call last):
  File "e:\ept_exp_onnx.py", line 136, in <module>
    test_onnx_inter_onnx()
  File "e:\ept_exp_onnx.py", line 127, in test_onnx_inter_onnx
    ort_output = ort_session.run(['output'], ort_inputs)[0]
  File "D:\X_Software\Code\miniconda3\envs\py38\lib\site-packages\onnxruntime\capi\onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Got invalid dimensions for input: input for the following indices
 index: 2 Got: 320 Expected: 256
 Please fix either the inputs or the model.

重新定义onnx输出

with torch.no_grad():
    torch.onnx.export(
        model,
        x,
        "assets/dynamic_srcnn_cpu.onnx",
        input_names=input_names,
        output_names=output_names,
        opset_version=11,
        dynamic_axes = {'input':  {0: 'batch_size', 1: 'channel', 2: "height", 3: 'width'}, 
                        'output': {0: 'batch_size', 1: 'channel', 2: 'height', 3: 'width'}}
        )
# 设置 dynamic_axes
# dynamic_axes 中的 0、2、3 表示相应的维度设置为动态值
def mode_export_dynamic_onnx():
	model=init_torch_model()
	batch_size = 1
    height = 256
    width = 256
    
    x = torch.randn(batch_size, 3,height, width)

    input_names = ["input"]        # 定义onnx 输入节点名称
    output_names = ["output"]      # 定义onnx 输出节点名称

    with torch.no_grad():
        torch.onnx.export(
            model,
            x,
            "assets/dynamic_srcnn_cpu.onnx",
            input_names=input_names,
            output_names=output_names,
            opset_version=11,
            dynamic_axes = {'input':  {0: 'batch_size', 1: 'channel', 2: "height", 3: 'width'}, 
                            'output': {0: 'batch_size', 1: 'channel', 2: 'height', 3: 'width'}}
            )
image.png

将导出的模型放入到https://netron.app/ 进行可视化

从onnx模型可视化参数来看,input 和 output 都改成了动态维度,支持实时修改输入参数维度

验证导出和测试

def test_dynamic_inter_onnx()
    onnx_model = onnx.load("assets\dynamic_srcnn_cpu.onnx")
    try:
        onnx.checker.check_model(onnx_model)

        graph = onnx_model.graph 
        # print(onnx.helper.printable_graph(onnx_model.graph))
        # print(graph.input)
        # print(graph.output)
    except Exception:
        print("onnx incorrect")
    else:
        print("onnx check_model success")

    input_img = cv2.imread('assets/images/dog.jpg').astype(np.float32)
    input_img = cv2.resize(input_img,(512,460))
    print("input_img transpose pre:", input_img.shape)
    # HWC to NCHW
    input_img = np.transpose(input_img, [2, 0, 1])
    print("input_img transpose pos:", input_img.shape)
    input_img = np.expand_dims(input_img, 0)
    print("input_img shape:", input_img.shape)
    ort_session = onnxruntime.InferenceSession("assets\dynamic_srcnn_cpu.onnx",
                                               providers=['CPUExecutionProvider']
                                               )

    ort_inputs = {'input': input_img}
    ort_output = ort_session.run(['output'], ort_inputs)[0]

    ort_output = np.squeeze(ort_output, 0)
    ort_output = np.clip(ort_output, 0, 255)
    ort_output = np.transpose(ort_output, [1, 2, 0]).astype(np.uint8)
    print("ort_output shape:", ort_output.shape)
    cv2.imwrite("assets/out.jpg", ort_output)
#  打印结果
# input_img transpose pre: (460, 512, 3)
# input_img transpose pos: (3, 460, 512)
# input_img shape: (1, 3, 460, 512)
# ort_output shape: (1380, 1536, 3)

多头输入

先来一个简单的案例,新增一个常数输入

import torch.nn as nn
import torch
import numpy as np
import onnx
import onnxruntime

#   定义一个简单的多输入网络   
# -----------------------------------#
class MyNet(nn.Module):
    def __init__(self, num_classes=10):
        super(MyNet, self).__init__()
        self.features = nn.Sequential(
            # input[3, 28, 28]  
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),   
            # output[32, 28, 28]          
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  
            # output[64, 14, 14]
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2)                             
            # output[64, 7, 7]
        )

        self.fc = nn.Linear(64 * 7 * 7, num_classes)

    def forward(self, x, ratio):
        # 输入是两个  x,ratio
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = ratio * self.fc(x)
        return x

#   导出ONNX模型函数
# -----------------------------------#
def torch2onnx():
    model = MyNet()
    model.eval() #使用测试模式
    x = torch.randn(1, 3, 28, 28)
    ratio = torch.randn(1, 1)				# 输入的常数是张量
    input_names = ["input1",'input2']       # 配置输入参数,输出参数
    output_names = ["output1"]  
    output_path = 'assets/MyNet.onnx'

    torch.onnx.export(
        model,
        (x,ratio),
        output_path,
        verbose=False,
        opset_version=11,
        input_names=input_names,
        output_names=output_names,
    )
if __name__ == '__main__':

    torch2onnx()

输入的第一个参数是张量,传入的第二个参数也必须是张量,要符合pytorch的相关要求。

保证输入的所有参数都是 torch.Tensor 类型

def test_onnx_inter_onnx():
    onnx_model = onnx.load("assets/MyNet.onnx")
    try:
        onnx.checker.check_model(onnx_model)
    except Exception:
        print("onnx incorrect")
    else:
        print("onnx check_model success")

    x = np.random.randn(1, 3, 28, 28).astype(np.float32)
    ratio = np.random.randn(1, 1).astype(np.float32)
    print(x.shape)
    ort_session = onnxruntime.InferenceSession("assets/MyNet.onnx",
                                               providers=['CPUExecutionProvider']
                                               )

    ort_inputs = {"input1": x,"input2": ratio}    # 配置输入参数,输出参数
    ort_output = ort_session.run(['output1'], ort_inputs)[0]
    print(ort_output.shape)
    
test_onnx_inter_onnx()

多头输出

同样的多头输出,也可以通过定义输出参数名称 实现多头输出

import torch.nn as nn
import torch
import numpy as np
import onnx
import onnxruntime

class Model(torch.nn.Module):
    def __init__(self, in_features, out_features, weights1, weights2, bias=False):
        super().__init__()
        self.linear1 = nn.Linear(in_features, out_features, bias)
        self.linear2 = nn.Linear(in_features, out_features, bias)
        with torch.no_grad():
            self.linear1.weight.copy_(weights1)
            self.linear2.weight.copy_(weights2)

    def forward(self, x):
        x1 = self.linear1(x)
        x2 = self.linear2(x)
        return x1, x2
    
def export_onnx():
    input    = torch.zeros(1, 1, 1, 4)
    weights1 = torch.tensor([
        [1, 2, 3, 4],
        [2, 3, 4, 5],
        [3, 4, 5, 6]
    ],dtype=torch.float32)
    weights2 = torch.tensor([
        [2, 3, 4, 5],
        [3, 4, 5, 6],
        [4, 5, 6, 7]
    ],dtype=torch.float32)
    model   = Model(4, 3, weights1, weights2)
    model.eval() #添加eval防止权重继续更新

    # pytorch导出onnx的方式,参数有很多,也可以支持动态size
    torch.onnx.export(
        model         = model, 
        args          = (input,),
        f             = "assets/two_out.onnx",
        input_names   = ["input0"],
        output_names  = ["output0", "output1"],
        opset_version = 12)
    print("Finished onnx export")


# export_onnx()
    
def test_onnx_inter_onnx():
    onnx_model = onnx.load("assets/two_out.onnx")
    try:
        onnx.checker.check_model(onnx_model)
    except Exception:
        print("onnx incorrect")
    else:
        print("onnx check_model success")
        
    input0  = np.random.randn(1, 1, 1, 4).astype(np.float32)
    ort_session = onnxruntime.InferenceSession("assets/two_out.onnx",
                                               providers=['CPUExecutionProvider']
                                               )

    ort_inputs = {"input0": input0}    # 配置输入参数,输出参数
    ort_output = ort_session.run(['output0','output1'], ort_inputs)
    print(ort_output)

test_onnx_inter_onnx()

# -----------------------------------------------------------------------------------
onnx check_model success
[array([[[[ 8.880694, 12.145373, 15.41005 ]]]], dtype=float32),
 array([[[[12.145373, 15.41005 , 18.674728]]]], dtype=float32)]

参考资料

知乎-OpenMMLab-模型部署入门教程(一):模型部署简介

知乎OpenMMLab-模型部署入门教程(二):解决模型部署中的难题

知乎-OpenMMLab-模型部署入门教程(三):PyTorch 转 ONNX 详解

正确导出onnx|onnx结构|编辑onnx各类节点|onnx算子编写|复杂后处理的添加|onnx形状推理

万字长文,一文搞懂Torch转换ONNX详细流程

标签:onnx,ort,torch,导出,input,维度,output,model
From: https://www.cnblogs.com/tian777/p/17995361

相关文章

  • onnx模型导出
    onnx模型导出目录onnx模型导出环境准备简介介绍torch.onnx.export参数解析onnx导出步骤单输入导出示例定义并准备模型加载权重并测试onnx导出和验证Netron可视化在线使用代码可视化onnx模型推理补充细节添加自定义标签读取自定义标签导出注意参考资料环境准备#环境依赖torch......
  • plsql 导出表结构和表数据并导入
    1.情景展示如何完成oracle表结构和表数据的导入和导出?2.导出表结构和表数据点击“工具”-->选择“导出表”;单击选中要导出的表;因为我要建表,所以需要勾选上“创建表”选项;其余的选项对我来说没用,那就全部取消勾选;点击右侧的文件夹按钮,选择要导出SQL文件的位置以及文件名......
  • 词向量的维度应该怎么选择
    在之前的文章《最小熵原理(六):词向量的维度应该怎么选择?》中,我们基于最小熵思想推导出了一个词向量维度公式“n>8.33logN�>8.33log⁡�”,然后在《让人惊叹的Johnson-Lindenstrauss引理:应用篇》中我们进一步指出,该结果与JL引理所给出的O(logN)�(log⁡�)是吻合的。既然理论上看上去很完美,......
  • 名企测试管理大咖解析沟通管理,多维度经验分享
    沟通管理在测试开发中扮演着至关重要的角色,它不仅是团队协作的基石,也是项目成功的关键因素之一。有效的沟通管理能够促进信息传递、问题解决以及团队协同工作,为测试开发的顺利进行提供坚实支持。但在实际工作中却有很多的问题,你是否在工作中遇到过以下问题呢?在团队会议上,需要分享自......
  • es从线上库导出数据并导入开发环境
    背景来了个需求,需要从某个线上es库查询一些数据出来并进行大屏展示。问需求方有没有开发环境的es库,答:没有,说要不直连他们的线上库。后面想想也行吧,业务方都这么说了,结果开网络的流程被打回了,理由是网络隔离。于是,只能采用从线上es库导出文件,然后在开发环境原样搭建这么一个es库......
  • Python_numpy-增加以及修改维度
    gradio组件输入组件-输出组件输入输出组件 多输入和多输出组件gr.State是一个不可见的组件,目的是在后台存储一些变量方便访问和交互BlockcomponentsTextbox:interactiveinteractive=TrueEventlistenerchange()e......
  • jmeter 将response body内容全部导出并保存到文件
    1.使用正则表达式,获取response内容2.使用beanshell后置处理器处理并保存数据脚本:importjava.io.File;importjava.io.FileWriter;importjava.io.IOException;importorg.apache.jmeter.samplers.SampleResult;Stringseq=vars.get("seq");//使用变量获取正则......
  • 封装Excel读取,导出(实体类集合List、DataTable、DataGridView、实体类集合和DataTable
     1、引入使用 #region读取excel///<summary>///根据Excel和Sheet返回DataTable///</summary>///<paramname="filePath">Excel文件地址</param>///<paramname="sheetIndex">She......
  • c# 数据放入excel导出,卡顿
    前言:导出数据到Excel,导出的时候特别卡顿原代码:usingSystem;usingSystem.Collections.Generic;usingSystem.Data;usingSystem.Linq;usingSystem.Text;usingSystem.Threading.Tasks;namespaceIPC.Helper{classExcelExportUtil{publicstatic......
  • 毕设4:导出文件对话框
    一般弹出窗口会用Dialog,但JavaFX的Dialog很难自定义,不如直接弹出一个Stage。分割文件、导出图片和文字三个功能共用一个对话框。于是新增了一个enum区别三个功能。enum可以加方法,还挺好玩的:packagecom.pdfTool.defination;publicenumExportType{SPLIT{@Ove......