首页 > 其他分享 >onnx导出BERT示例

onnx导出BERT示例

时间:2024-03-13 09:12:10浏览次数:31  
标签:__ BERT inputs onnx 示例 output hidden config

onnx导出BERT示例

目录

BERT模型导出

1.基于transformers载入PyTorch模型
2.创建伪输入(dummy inputs),并利用伪输入在模型中前向inference,推理网络并在这个过程中追踪记录操作集合
3.在输入和输出tensors上定义动态轴
4.保存graph和网络参数

nlp与cv的区别不大,主要是注意一下不输入序列定长的问题,也就是export方法中的dynamic_axes参数

BERT-Large, Uncased.(Whole Word Masking): 24-layer, 1024-hidden, 16-heads, 340M parameters
BERT-Large, Cased(Whole Word Masking) : 24-layer, 1024-hidden, 16-heads, 340M parameters
BERT-Base, 	Uncased: 12-layer, 768-hidden, 12-heads, 110M parameters
BERT-Large, Uncased: 24-layer, 1024-hidden, 16-heads, 340M parameters
BERT-Base, 	Cased: 12-layer, 768-hidden, 12-heads , 110M parameters
BERT-Large, Cased: 24-layer, 1024-hidden, 16-heads, 340M parameters
BERT-Base, 	Multilingual Cased (New, recommended): 104 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
BERT-Base, 	Multilingual Uncased (Orig, not recommended) (Not recommended, use Multilingual Cased instead): 102 languages, 12-layer, 768-hidden, 12-heads, 110M parameters
BERT-Base, 	Chinese: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters


前6个为英文模型,Multilingual代表多语言模型,最后一个是中文模型 (字级别)
Uncased 代表将字母全部转换成小写,而Cased代表保留了大小写

加载模型

from pathlib import Path
from transformers import BertConfig,BertModel, BertTokenizer
from transformers.convert_graph_to_onnx import convert

from transformers import AutoTokenizer
import torch

config = BertConfig.from_pretrained("bert-base-uncased")
print(config)
bert_model  = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, config=config)
print(bert_model.config)
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# BertConfig {
#   "_name_or_path": "bert-base-uncased",
#   "architectures": [
#     "BertForMaskedLM"
#   ],
#   "attention_probs_dropout_prob": 0.1,
#   "classifier_dropout": null,
#   "gradient_checkpointing": false,
#   "hidden_act": "gelu",
#   "hidden_dropout_prob": 0.1,
#   "hidden_size": 768,
#   "initializer_range": 0.02,
#   "intermediate_size": 3072,
#   "layer_norm_eps": 1e-12,
#   "max_position_embeddings": 512,
#   "model_type": "bert",
#   "num_attention_heads": 12,
#   "num_hidden_layers": 12,
#   "pad_token_id": 0,
#   "position_embedding_type": "absolute",
#   "transformers_version": "4.36.2",
#   "type_vocab_size": 2,
#   "use_cache": true,
#   "vocab_size": 30522
# }

pt模型推理

from transformers import AutoModel, AutoConfig, AutoTokenizer
import torch
from itertools import chain

# 加载model,token,config
model = AutoModel.from_pretrained('bert-base-uncased')
config = AutoConfig.from_pretrained('bert-base-uncased')
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')


# 定义句子
# 分词器分词
sentence = 'here is some text to encode'
inputs_pt = tokenizer(sentence, return_tensors='pt')
print(inputs_pt["input_ids"].shape)

outputs = model(**inputs_pt)
print(dir(outputs))
last_hidden_state = outputs.last_hidden_state
pooler_output = outputs.pooler_output
print("Token wise output: {}, Pooled output: {}".format(last_hidden_state.shape, pooler_output.shape))
print(last_hidden_state)

print("---" * 20)
torch.Size([1, 9])
['__annotations__', '__class__', '__contains__', '__dataclass_fields__', '__dataclass_params__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__post_init__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__setattr__', '__setitem__', '__sizeof__', '__str__', '__subclasshook__', 'attentions', 'clear', 'copy', 'cross_attentions', 'fromkeys', 'get', 'hidden_states', 'items', 'keys', 'last_hidden_state', 'move_to_end', 'past_key_values', 'pooler_output', 'pop', 'popitem', 'setdefault', 'to_tuple', 'update', 'values']
Token wise output: torch.Size([1, 9, 768]), Pooled output: torch.Size([1, 768])
tensor([[[-0.0549,  0.1053, -0.1065,  ..., -0.3551,  0.0686,  0.6506],
         [-0.5759, -0.3650, -0.1383,  ..., -0.6782,  0.2092, -0.1639],
         [-0.1641, -0.5597,  0.0150,  ..., -0.1603, -0.1346,  0.6216],
         ...,
         [ 0.2448,  0.1254,  0.1587,  ..., -0.2749, -0.1163,  0.8809],
         [ 0.0481,  0.4950, -0.2827,  ..., -0.6097, -0.1212,  0.2527],
         [ 0.9046,  0.2137, -0.5897,  ...,  0.3040, -0.6172, -0.1950]]],
       grad_fn=<NativeLayerNormBackward0>)
------------------------------------------------------------

重新导出config

# 利用config生成一个onnx的config

from transformers.onnx.features import FeaturesManager
onnx_config = FeaturesManager._SUPPORTED_MODEL_TYPE['bert']['sequence-classification'](config)

print(onnx_config.inputs.items())
print(onnx_config.outputs.items())

odict_items([('input_ids', {0: 'batch', 1: 'sequence'}), ('attention_mask', {0: 'batch', 1: 'sequence'}), ('token_type_ids', {0: 'batch', 1: 'sequence'})])
odict_items([('logits', {0: 'batch'})])

# dummy_inputs的计算需要利用到tokenizer
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')

print(dummy_inputs)
# dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

导出onnx模型

import torch
from transformers import AutoModel, AutoConfig, AutoTokenizer
from transformers.onnx.features import FeaturesManager
from transformers.convert_graph_to_onnx import convert


# 加载model,token,config
model = AutoModel.from_pretrained('bert-base-uncased')
config = AutoConfig.from_pretrained('bert-base-uncased')
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model.eval()

# 定义句子
sentence = 'here is some text to encode'


# 利用config生成一个onnx的config
# dummy_inputs的计算需要利用到tokenizer
onnx_config = FeaturesManager._SUPPORTED_MODEL_TYPE['bert']['sequence-classification'](config)
dummy_inputs = onnx_config.generate_dummy_inputs(tokenizer, framework='pt')

output_onnx_path = "assets/bert_uncased.onnx"

print("onnx  input",onnx_config.inputs.items())
print("onnx output",onnx_config.outputs.items())

input_ids = dummy_inputs['input_ids']
attention_masks = dummy_inputs['attention_mask']
token_type_ids = dummy_inputs['token_type_ids']
   

input_names = ["input_ids", "attention_masks", "token_type_ids"]
output_names = ["output"]

torch.onnx.export(bert_model, 
                  (input_ids, attention_masks, token_type_ids),    #  或者 (dummy_inputs,)              
				f=output_onnx_path, 
				verbose=True,
                 input_names=list(onnx_config.inputs.keys()),
                 output_names=list(onnx_config.outputs.keys()),
                 dynamic_axes={name: axes for name, axes in chain(onnx_config.inputs.items(), onnx_config.outputs.items())
    					},    
                  opset_version=onnx_config.default_onnx_opset)

print("转换完成")
onnx  input odict_items([('input_ids', {0: 'batch', 1: 'sequence'}), ('attention_mask', {0: 'batch', 1: 'sequence'}), ('token_type_ids', {0: 'batch', 1: 'sequence'})])
onnx output odict_items([('logits', {0: 'batch'})])
转换完成

加载onnx测试

import onnxruntime as ort

# 定义句子
sentence = 'here is some text to encode'


options = ort.SessionOptions() 	# initialize session options
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

# 这里的路径传上一节保存的onnx模型地址
session = ort.InferenceSession(
    "assets/bert_uncased.onnx", sess_options=options, providers=["CUDAExecutionProvider","CPUExecutionProvider"]
)

# disable session.run() fallback mechanism, it prevents for a reset of the execution provider
session.disable_fallback() 

inputs = tokenizer(sentence, return_tensors='pt')
inputs = {k: v.detach().cpu().numpy() for k, v in inputs.items()}


print(inputs.keys())
# dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

# 运行
# 这里的logits要有export的时候output_names相对应
output = session.run(output_names=['logits'], input_feed=inputs)

print(output)

print(output)
print(output[0].shape)

onnx output odict_items([('logits', {0: 'batch'})])
dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
[array([[[-0.05490887,  0.10528212, -0.10649522, ..., -0.3550497 ,
          0.06862388,  0.650573  ],
        [-0.5759427 , -0.36500782, -0.13834022, ..., -0.6781805 ,
          0.20923868, -0.16394015],
        [-0.16414754, -0.55971897,  0.01500742, ..., -0.16027743,
         -0.13455114,  0.62159723],
        ...,
        [ 0.2447815 ,  0.125429  ,  0.15869957, ..., -0.27489156,
         -0.11634777,  0.88089377],
        [ 0.0481048 ,  0.4950128 , -0.28274378, ..., -0.6097362 ,
         -0.12124838,  0.2527281 ],
        [ 0.9046008 ,  0.21367389, -0.5896968 , ...,  0.30398968,
         -0.61721766, -0.19498175]]], dtype=float32)]
(1, 9, 768)

参考资料

模型推理加速系列|如何用ONNX加速BERT特征抽取(附代码)

实践演练BERT Pytorch模型转ONNX模型及预测

Bert模型导出为onnx和pb格式

NLP实践——Bert转onnx格式简介与踩坑记录

实践演练BERT Pytorch模型转ONNX模型及预测

标签:__,BERT,inputs,onnx,示例,output,hidden,config
From: https://www.cnblogs.com/tian777/p/18069801

相关文章

  • Seatunnel系列之:Apache Iceberg sink connector和往Iceberg同步数据任务示例
    Seatunnel系列之:ApacheIcebergsinkconnector和往Iceberg同步数据任务示例一、支持的Iceberg版本二、支持的引擎三、描述四、支持的数据源信息五、数据库依赖六、数据类型映射七、Sink选项八、往Iceberg同步数据任务示例一、支持的Iceberg版本1.4.2二......
  • ER图如何画?如何根据ER图转换成表结构?示例:图书管理系统ER图
    原文链接:https://blog.csdn.net/m0_61902905/article/details/131472145   ......
  • 三、MyBatis基础配置之映射文件Mapper.xml(均为单表查询示例)
    一)动态if需求:多条件组合查询。  二)动态foreach需求:多值查询。  三)动态抽取......
  • 无线表格识别模型LORE转换库:ConvertLOREToONNX
    引言总有小伙伴问到阿里的无线表格识别模型是如何转换为ONNX格式的。这个说来有些惭愧,现有的ONNX模型是很久之前转换的了,转换环境已经丢失,且没有做任何笔记。今天下定决心再次尝试转换,庆幸的是转换成功了。于是有了转换笔记:ConvertLOREToONNX。这次吸取教训,环境文件采用Anacond......
  • elasticsearch常用请求接口Rest API示例
    创建shopping索引PUT/shopping查看全部索引GET/_cat/indices查看指定索引GET/shopping删除指定索引DELETE/shopping索引的映射字段属性,是否关键字和加入索引PUT/shopping/_mapping{"properties":{"title":{"type":"text"},&qu......
  • netcat 命令介绍及使用示例
    netcat命令介绍及使用示例nc(netcat)是一个强大的网络工具,它可以用于读取和写入数据流,支持TCP和UDP协议。它常被用于网络调试和网络服务的创建。一、安装方法centos中,执行yuminstallncprocps-ng-y二、功能介绍1.IP端口监控使用nc进行端口扫描可以检查指定主机的......
  • [AIAgent]白菜GPT支撑AutoGen开发示例
    AutoGen示例说明AIAgnet仅限GPT4支持,请完成一次请求Token不可预估,请留意费用消耗。免费会员GPT4配额,很可能不足以支撑完成DEMO演示,建议付费会员测试体验。部分示例代码参考AutoGen官方文档,重点验证白菜GPT对AutoGen的支撑能力,详细代码说明,请参考官方文档。所有示例代码均在......
  • Java 日期和时间 API:实用技巧与示例 - 轻松处理日期和时间
    Java用户输入(Scanner)简介Scanner类用于获取用户输入,它位于java.util包中。使用Scanner类要使用Scanner类,请执行以下步骤:导入java.util.Scanner包。创建一个Scanner对象,并将其初始化为System.in。使用Scanner对象的方法读取用户输入。示例importjava.ut......
  • YAML 语法简介与 C# 操作示例
    〇、简介YAML(YetAnotherMarkupLanguage)另一种标记语言。YAML是一种较为人性化的数据序列化语言,可以配合目前大多数编程语言使用。YAML的语法比较简洁直观,特点是使用空格来表达层次结构,其最大优势在于数据结构方面的表达,所以YAML更多应用于编写配置文件,其文件一般以.yml......
  • C++语言代码开发示例-身份证实名认证接口
    实名认证API实时联网核验个人身份信息与所持身份证件人员的一致性,可满足不同应用场景的不同实名认证需求,如身份证号+姓名核验、身份证号+姓名+人像核验、身份证号+姓名+证件人像+现场人像核验等,有助于推动诚信网络环境建设,降低利用网络进行违法犯罪的可能性,防止纠纷和诈骗,确保消费......