首页 > 其他分享 >将onnx的静态batch改为动态batch及修改输入输出层的名称

将onnx的静态batch改为动态batch及修改输入输出层的名称

时间:2023-07-01 13:36:17浏览次数:47  
标签:dim name onnx 输入输出 batch input output model

目录

背景

在模型的部署中,为了高效利用硬件算力,常常会需要将多个输入组成一个batch同时输入网络进行推理,这个batch的大小根据系统的负载或者摄像头的路数时刻在变化,因此网络的输入batch是在动态变化的。对于pytorch等框架来说,我们并不会感受到这个问题,因为整个网络在pytorch中都是动态的。而在实际的工程化部署中,为了运行效率,却并不能有这样的灵活性。可能会有人说,那我就把batch固定在一个最大值,然后输入实际的batch,这样实际上网络是以最大batch在推理的,浪费了算力。所以我们需要能支持动态的batch,能够根据输入的batch数来运行。

一个常见的训练到部署的路径是:pytorch→onnx→tensorrt。在pytorch导出onnx时,我们可以指定输出为动态的输入:

torch_out = torch.onnx.export(model, inp,
                              save_path,input_names=["data"],output_names=["fc1"],dynamic_axes={
        "data":{0:'batch_size'},"fc1":{0:'batch_size'}
    })

而另一些时候,我们部署的模型来源于他人或开源模型,已经失去了原始的pytorch模型,此时如果onnx是静态batch的,在移植到tensorrt时,其输入就为静态输入了。想要动态输入,就需要对onnx模型本身进行修改了。另一方面,算法工程师在导模型的时候,如果没有指定输入层输出层的名称,导出的模型的层名有时候可读性比较差,比如输出是batchnorm_274这类名称,为了方便维护,也有需要对onnx的输入输出层名称进行修改。

操作

修改输入输出层

def change_input_output_dim(model):
    # Use some symbolic name not used for any other dimension
    sym_batch_dim = "batch"

    # The following code changes the first dimension of every input to be batch-dim
    # Modify as appropriate ... note that this requires all inputs to
    # have the same batch_dim 
    inputs = model.graph.input
    for input in inputs:
        # Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
        # Add checks as needed.
        dim1 = input.type.tensor_type.shape.dim[0]
        # update dim to be a symbolic value
        dim1.dim_param = sym_batch_dim
        # or update it to be an actual value:
        # dim1.dim_value = actual_batch_dim
    
    outputs = model.graph.output
    for output in outputs:
        # Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
        # Add checks as needed.
        dim1 = output.type.tensor_type.shape.dim[0]
        # update dim to be a symbolic value
        dim1.dim_param = sym_batch_dim

model = onnx.load(onnx_path)
change_input_output_dim(model)

通过将输入层和输出层的shape的第一维修改为非数字,就可以将onnx模型改为动态batch。

修改输入输出层名称

def change_input_node_name(model, input_names):
    for i,input in enumerate(model.graph.input):
        input_name = input_names[i]
        for node in model.graph.node:
            for i, name in enumerate(node.input):
                if name == input.name:
                    node.input[i] = input_name
        input.name = input_name
        

def change_output_node_name(model, output_names):
    for i,output in enumerate(model.graph.output):
        output_name = output_names[i]
        for node in model.graph.node:
            for i, name in enumerate(node.output):
                if name == output.name:
                    node.output[i] = output_name
        output.name = output_name

代码中input_names和output_names是我们希望改到的名称,做法是遍历网络,若有node的输入层名与要修改的输入层名称相同,则改成新的输入层名。输出层类似。

完整代码

import onnx
def change_input_output_dim(model):
    # Use some symbolic name not used for any other dimension
    sym_batch_dim = "batch"

    # The following code changes the first dimension of every input to be batch-dim
    # Modify as appropriate ... note that this requires all inputs to
    # have the same batch_dim 
    inputs = model.graph.input
    for input in inputs:
        # Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
        # Add checks as needed.
        dim1 = input.type.tensor_type.shape.dim[0]
        # update dim to be a symbolic value
        dim1.dim_param = sym_batch_dim
        # or update it to be an actual value:
        # dim1.dim_value = actual_batch_dim
    
    outputs = model.graph.output
    for output in outputs:
        # Checks omitted.This assumes that all inputs are tensors and have a shape with first dim.
        # Add checks as needed.
        dim1 = output.type.tensor_type.shape.dim[0]
        # update dim to be a symbolic value
        dim1.dim_param = sym_batch_dim

def change_input_node_name(model, input_names):
    for i,input in enumerate(model.graph.input):
        input_name = input_names[i]
        for node in model.graph.node:
            for i, name in enumerate(node.input):
                if name == input.name:
                    node.input[i] = input_name
        input.name = input_name
        

def change_output_node_name(model, output_names):
    for i,output in enumerate(model.graph.output):
        output_name = output_names[i]
        for node in model.graph.node:
            for i, name in enumerate(node.output):
                if name == output.name:
                    node.output[i] = output_name
        output.name = output_name


onnx_path = ""
save_path = ""
model = onnx.load(onnx_path)
change_input_output_dim(model)
change_input_node_name(model, ["data"])
change_output_node_name(model, ["fc1"])

onnx.save(model, save_path)

经过修改后的onnx模型输入输出将成为动态batch,可以方便的移植到tensorrt等框架以支持高效推理。

在这里插入图片描述

标签:dim,name,onnx,输入输出,batch,input,output,model
From: https://www.cnblogs.com/haoliuhust/p/17519161.html

相关文章

  • C++输入输出,设置精度setprecision、域宽setw、填充setfill
    本文的三个函数均需要引入头文件:#include<iomanip>设置输出精度setprecision(intn)参考:C语言中文网:c++setprecision用法详解//写法1cout<<setprecision(10)<<a<<endl;//写法2:a、b、c都将以10位有效位输出cout<<setprecision(10);cout<<a<<endl;cout......
  • Pytorch | 输入的形状为[seq_len, batch_size, d_model]和 [batch_size, seq_len, d_m
    首先导入依赖的torch包。importtorch我们设:seq_len(序列的最大长度):5batch_size(批量大小):2d_model(每个单词被映射为的向量的维度):10heads(多头注意力机制的头数):5d_k(每个头的特征数):21、输入形状为:[seq_len,batch_size,d_model]input_tensor=torch.randn(5,2,10)inp......
  • TensorFlow10.4 卷积神经网络-batchnorm
    我们发现这个sigmoid函数在小于-4或者大于4的时候他的导数趋近于0。然后我们送进去的input的值在[-100,100]之间,这样很容易引起梯度弥散的现象。所以我们一般情况下使用ReLU函数,但是我们有时候又不得不使用sigmoid函数。这个时候我们在送到下一层的时候我们应该先经过Normalizatio......
  • Java-写一下输入输出
    首先写一下输入把,用的是java自带的Scanner包,但是要引用一下importjava.util.Scanner;然后介绍一下输入,如果你确定了只需要输入一个数,那么可以这么写:inta=newScanner(Systemin).nextInt;缺点是每输入一次,就要重新写一遍,所以还是更推荐下面这种输入方法:Scanner s=newSca......
  • PMML-ONNX-AI Serving等深度学习模型上线-部署实战经验分享
    AI的广泛应用是由AI在开源技术的进步推动的,利用功能强大的开源模型库,数据科学家们可以很容易的训练一个性能不错的模型。但是因为模型生产环境和开发环境的不同,涉及到不同角色人员:模型训练是数据科学家和数据分析师的工作,但是模型部署是开发和运维工程师的事情,导致模型上线部署......
  • Mybatis Plus 批量插入方法效率低问题优化方案 BatchExcutor
    1、问题描述项目用的是MybatisPlus框架操作数据库,在使用batchSave批量插入方法的时候发现效率极低,插入2w数据花了6分钟,太恐怖了。看了源码发现,项目的批量插入方法调用的是MybatisPlus的BatchExcutor,用这个本意是将多次更新sql语句集合为一条更新语句,复用同一个sql连接更新数据。......
  • IS220PAICH2A 336A4940CSP11通用电气模拟输入输出模块
    IS220PAICH2A336A4940CSP11通用电气模拟输入输出模块IS220PAICH2A336A4940CSP11通用电气模拟输入输出模块  但是传统的以太网是一种商用网络,要应用到工业控制中还存在一些问题,主要有以下几个方面。1、存在实时性差,不确定性的问题传统的以太网采用了CSMA/CD的介质......
  • C++输入输出流
    一、输入输出流三种流:istream、ostream、iostream标准输入输出流ifstream、ofstream、ftream文件输入输出流istringstream、ostringstream、stringstream字符串输入输出流三种流的关系:流的状态iostate:1.badbit:表示发生系统级的错误,如不可恢复的读写错误。......
  • Spring Batch:将数据从Web服务处理到MongoDB
    概观在这篇文章中,我们将介绍如何创建一个使用Web服务数据并将其插入MongoDB数据库的SpringBatch应用程序。要求阅读本文的开发人员必须熟悉SpringBatch(示例)和MongoDB。环境Mongo数据库部署在MLab中。请按照本快速入门中的步骤操作。批处理应用程序部署在Heroku PaaS中。详情  ......
  • SpringBatch从入门到实战(一):简介和环境搭建
    一:简介SpringBatch是一个轻量级的批处理框架,适合处理大批量的数据(如百万级别)。功能就是从一个地方读数据写到另一个地方去。一般都是系统之间不能直接访问同一个数据库,需要通过文件来交换数据。二:从文件中读然后写到数据库这代码谁都会写,那么为什么还要使用框架?try(BufferedReader......