首页 > 其他分享 >pytorch中bin模型文件转onnx遇到的问题

pytorch中bin模型文件转onnx遇到的问题

时间:2023-04-04 18:00:34浏览次数:23  
标签:bin tensor onnx torch pytorch input example size

pytorch中bin模型文件转onnx遇到的问题

1 常规做法

import os

import numpy as np
from transformers import GPT2LMHeadModel
import torch
localfile = r"C:\Users\min_ppl_model"
model = GPT2LMHeadModel.from_pretrained(localfile)
# 输入shape 为 1,50 其中1 为bs  50为固定输入长度,前面做量化时规定的。
# 确定输入shape
batch_size = 1
seq = 50
example_input = torch.randn(batch_size, seq)
print(example_input)
# 保存文件
save_onnx_dir = r"layer20_gpt2"
os.makedirs(save_onnx_dir, exist_ok=True)
save_onnx_model = r"layer20_gpt2\layer20_gpt2.onnx"
# 转onnx
with torch.no_grad():
    torch.onnx.export(model,
                      example_input,
                      save_onnx_model,
                      opset_version=11,

                      input_names=["inp"], output_names=["opt"])
print("-----end-------")

  其中

example_input = torch.randn(batch_size, seq, dtype=torch.float32)

  最为关键一步

类型为:

  

(Pdb) type(example_input)
<class 'torch.Tensor'>

打印为:

tensor([[-2.0319, -0.4021, 0.0092, 1.4100, -0.2214, 0.6954, 0.1764, 0.2111,
-0.4725, -0.7527, -0.0766, -1.4510, -0.2528, -1.4077, -0.9340, 0.2309,
0.5564, -0.0498, -0.7499, -1.8176, -0.8981, 1.3997, 0.2904, -0.5024,
0.8392, -0.2341, -0.5459, -1.0992, 0.0211, 0.3346, -0.5087, -0.6159,
1.3256, -0.0423, -0.0764, 1.0469, -1.7328, -1.3470, -0.3346, 0.2129,
1.1073, -0.7503, 0.3968, -0.1374, 0.6514, -0.8763, -1.1972, -1.7750,
-0.9977, -2.2836]])

报错为:RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)

原因为 embedding 为int。

直接更换为整型:

example_input = torch.randint(1,5, size=(batch_size, seq))

2 先用np 再转torch.tensor

example_input = torch.from_numpy(np.random.randint(low=6,high=10,size=(1,50),dtype=np.int32))

  可以稍微指定一个整数范围

(Pdb) type(example_input)
<class 'torch.Tensor'>

打印信息:

tensor([[6, 9, 9, 9, 8, 6, 6, 6, 8, 8, 9, 9, 9, 7, 9, 9, 7, 8, 8, 7, 9, 9, 9, 6,
7, 6, 8, 8, 9, 9, 8, 7, 7, 8, 8, 9, 6, 9, 7, 9, 8, 6, 9, 7, 7, 7, 9, 6,
7, 8]], dtype=torch.int32)

则可以成功执行。

 

总结:

1 能先用int先用int, 要求是float就先用float, 随机生成数据。

 2 如果定义数据是dtype=torch.int32 如果带了,则打印数据时也会带, 否则不显示。

tensor([[6, 9, 9, 9, 8, 6, 6, 6, 8, 8, 9, 9, 9, 7, 9, 9, 7, 8, 8, 7, 9, 9, 9, 6,
7, 6, 8, 8, 9, 9, 8, 7, 7, 8, 8, 9, 6, 9, 7, 9, 8, 6, 9, 7, 7, 7, 9, 6,
7, 8]], dtype=torch.int32)

标签:bin,tensor,onnx,torch,pytorch,input,example,size
From: https://www.cnblogs.com/lx63blog/p/17287122.html

相关文章

  • Address already in use: JVM_Bind:80
    今天maven工程中配置pom.xml时Tomcat配了80端口号,结果提示Addressalreadyinuse:JVM_Bind:80,也就是80端口被占用,于是就按照往常的惯例去查看是什么进程占用了80端口1、打开cmd命令窗口,输入netstat-ano|findstr80,回车就会看见被占用端口号8080所对应的PID为4然后去任务管......
  • MySQL 关闭 binlog 日志
    【关闭binlog日志】1、vim/etc/my.cnf注释如下内容:#log-bin=mysql-bin#binlog_format=mixed#server-id=1#expire_logs_days=102、重启mysql服务/etc/init.d/mysqlrestart......
  • 又是华为!对标 TensorFlow、PyTorch,深度学习框架 MindSpore已开源!附入手公开课
     Datawhale 作者:Datawhale摘要:随着深度学习框架等工具的全面开源,最强AI处理芯片昇腾的落地,华为的AI布局已经逐渐清晰起来,MindSpore将为各行业带来哪些新变化,让我们拭目以待。今年的华为开发者大会HDC2020上,除了昇腾、鲲鹏等自研芯片硬件平台之外,最令人期待的就是深度学习......
  • Pytorch 备忘清单_开发速查表分享
    Pytorch备忘清单Pytorch是一种开源机器学习框架,可加速从研究原型设计到生产部署的过程,备忘单是由IT宝库整理的Pytorch开发速查备忘清单为您提供了Pytorch基本语法和初步应用参考入门,为开发人员分享快速参考备忘单。开发速查表大纲入门介绍认识Pytorch创建一个......
  • windows下使用pytorch进行单机多卡分布式训练
    现在有四张卡,但是部署在windows10系统上,想尝试下在windows上使用单机多卡进行分布式训练,网上找了一圈硬是没找到相关的文章。以下是踩坑过程。首先,pytorch的版本必须是大于1.7,这里使用的环境是:pytorch==1.12+cu11.6四张4090显卡python==3.7.6使用nn.DataParallel进行分布式......
  • C++11新特性之std::function和bind绑定器
    在C++中,存在可调用对象这一个概念,可调用对象有以下几种定义:(1).是一个函数指针(2).是一个具有operator()成员函数的类对象(仿函数)(3).是一个可被转换为函数指针的类对象(4).是一个类成员(函数指针)一、可调用对象包装器----std::functionstd::function是可调用对象的包装......
  • junit单元测试报错:java.lang.NoClassDefFoundError: org/hamcrest/SelfDescribing
    今天在复习的时候对对一些知识点进行巩固,用到了junit-4.12.jar,手动导入jar包,然后运行然后报错:java.lang.NoClassDefFoundError:org/hamcrest/SelfDescribing。刚开始我以为代码错了,看了看发现不是代码的问题,是导包的问题。然后查询了百度,发现了是版本的问题:然后说换个低版本的就......
  • BinarySearch
    packageBisectionMethod;/***704.二分查找*给定一个n个元素有序的(升序)整型数组nums和一个目标值target,写一个函数搜索nums中的target,*如果目标值存在返回下标,否则返回-1。*//***思路:设定左右指针*找出中间位置,并判断该位置值是否等于target*nums[......
  • 38、K8S-安全机制-RBAC之(cluster)role、(cluster)rolebinding
    1、RBAC1.1、基本简介1.1.1、简介RBAC使用rbac.authorization.k8s.ioAPIGroup来实现授权决策,允许管理员通过KubernetesAPI动态配置策略。Kubernetes的基本特性就是它的所有资源对象都是模型化的API对象,我们可以基于api-server对各种资源进行增、删、改、查等操作,但......
  • Java(TM) Platform SE binary 打开jar文件报错
    问题描述双击jar包,使用Java(TM)PlatformSEbinary直接运行java代码,报错Error:AJNIerrorhasoccurred,pleasecheckyourinstallationandtryagain和AJavaExceptionhasoccurred.选择更多应用,进目录里选中javaw.exe也不行问题原因jdk版本与jar包的编译时的版本......