首页 > 其他分享 >pytorch中自定义onnx新算子并导出为onnx

pytorch中自定义onnx新算子并导出为onnx

时间:2024-07-28 17:50:33浏览次数:10  
标签:return 自定义 onnx torch add custom pytorch input

import torch
from torch.autograd import Function
import torch.onnx

# Step 1: Define custom PyTorch operator
class MyCustomOp(Function):
    @staticmethod
    def forward(ctx, input):
        return input + 1

    @staticmethod
    def symbolic(g, input):
        return g.op("CustomAddOne", input)#注意此处的input参数要和后面trt中的插件层一样

def custom_add_one(input):
    return MyCustomOp.apply(input)

# Step 2: Register custom ONNX operator
def custom_add_one_symbolic(g, input):
    return g.op("CustomAddOne", input)

torch.onnx.register_custom_op_symbolic("::custom_add_one", custom_add_one_symbolic, 9)

# Step 3: Export to ONNX
class MyModel(torch.nn.Module):
    def forward(self, x):
        return custom_add_one(x)

model = MyModel()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "custom_model.onnx", opset_version=9, custom_opsets={"": 9})

print("ONNX model with custom operator exported successfully.")

 

标签:return,自定义,onnx,torch,add,custom,pytorch,input
From: https://www.cnblogs.com/chentiao/p/18328600

相关文章

  • Telegram API 禁止更改用户的自定义标题
    问题:我使用了FSMTelegramBot,最终机器人将聊天用户更改为聊天管理员并更改了管理员的自定义标题。但我在更改自定义标题时遇到问题,这是控制台输出:ERROR:aiogram.event:Causeexceptionwhileprocessupdateid=XXXXXXXXbybotid=XXXXXXXXTelegramBadRequest:Telegramse......
  • 20、flask-进阶-自定义静态文件static和模板文件templates的路径配置
    自定义static目录和templates目录的路径原本flask默认的static和templates目录是在App目录下的:如下图如果想把这两个目录更改位置,如放在根目录下:代码如下:__init__.pyfromflaskimportFlaskfrom.viewsimportbluefrom.extsimportinit_extsimportos#获......
  • Java 自定义注解
    一、Java 自定义注解的用途、 1、可以记录在特殊方法进行日志记录     2、可以进行 特殊鉴权 如@ValidateRole(“admin") 只有当前用户拥有指定角色时才放行 否则抛自定义异常 无权限    3、可以用于参数 如Controller 方法中的参数进行 参数......
  • 如何在事件wxWidgets中传递自定义数据
    情况我目前正在使用wxPython(wxWidgetsforPython)编写一个应用程序。在此应用程序中显示了一个对话列表,每行末尾都有一个“打开对话”按钮。我们将此窗口称为“所有对话”。单击任何一个按钮都会调用函数“open_conversation(self,event)”,该函数会显示完整的对话。......
  • 使用 python 支持构建自定义 vim 二进制文件
    背景Debian11vim软件包不包含python3支持。请参阅标题为“Debian11vim中不支持python-证据”的部分下面我需要vim支持python3YouCompleteMevim插件为了构建一个新的,我将vim9.0tarball下载到v......
  • Coggle数据科学 | Kaggle干货:自定义transformers数据集
    本文来源公众号“Coggle数据科学”,仅用于学术分享,侵权删,干货满满。原文链接:Kaggle干货:自定义transformers数据集transformers是现在NLP同学必备的库,但在使用的过程中主要的代码是需要自定义数据集,那么如何舒服的读取数据,并使用transformers进行训练模型呢?本文的内容如下:自......
  • 构造中心损失----pytorch详解
    当输入数据X维度为[num_classes,feat_dim]时,参考链接:Centerloss-pytorch代码详解.对于输入数据X类型为[batch_size,seq_len,feat_dim],对参考链接代码进行调整,整个代码如下:classCenterLoss_seq(nn.Module):"""Centerloss.Reference:Wenetal.ADisc......
  • STM32自定义协议串口接收解析指令程序
    1、在使用串口接收自定义协议指令时,需要串口解析收到的是什么指令,举例通信报文为上位机->单片机名称长度备注帧头1Byte0x5A0x5A帧长度1Byte数据包的长度0x00-0xFF数据包命令字1Byte功能标识数据可以为空校验1Byte数据包所有字节按位异......
  • YOLOv8-seg——基于自定义数据集训练图像分割模型
    目录一、制作分割数据集1标注2json文件转txt文件3数据集划分二、训练图像分割模型1环境搭建2训练网络3预测三、训练结果解读一.制作分割数据集1标注运用labelme软件进行手动标注,得到数据的json格式标注文件。*注意区别于labelimg软件,labelimg软件对每个......
  • 渲染三角形(自定义数据)并平移的关键代码/(OpenGL)
    废话不多说,先上结果:图1 渲染一个三角形并移动 图2打印坐标关键代码:(1)glBegin、glEnd这两个函数之间的代码用于定义要绘制的图形;glColor3f:设置顶点颜色;glVertex3f:设置顶点位置因为涉及需要打印移动前后的三角形顶点的坐标矩阵,所以在绘制三角形的时候,三角形顶点可以......