首页 > 其他分享 >预训练Bert模型输出类型为str问题解决

预训练Bert模型输出类型为str问题解决

时间:2023-09-26 11:11:20浏览次数:37  
标签:Bert keras 模型 ids 输出 str type bert

 

input_ids=keras.layers.Input(shape=(MAXLEN,),dtype='int32')
attention_mask=keras.layers.Input(shape=(MAXLEN,),dtype='int32') 
token_type_ids=keras.layers.Input(shape=(MAXLEN,),dtype='int32')
_,x=bert_model([input_ids,attention_mask,token_type_ids])

outputs=keras.layers.Dense(1,activation='sigmoid')(x)
#给模型加一个全连接层,将bert的输出x调整后用sigmoid函数做一个二分类

model=keras.models.Model(inputs=[input_ids,attention_mask,token_type_ids],outputs=outputs)
model.compile(loss='binary_crossentropy',optimizer=keras.optimizers.Adam(lr=LEARNING_RATE),metrics=['accuracy'])

  这是一段用预训练的bert模型来构造一个文本分类器的代码。在代码中,先用设定好的输入,输入模型得到最后一层隐藏层的输出,作为变量x再输入全连接层,再经过sigmiod函数来实现二分类的效果。最后将全连接层加入到模型中,设置好损失函数等参数进行编译。

    _,x=bert_model([input_ids,attention_mask,token_type_ids])    这句的作用是获取bert模型的最后一池化层的输出。

  在运行代码时,每次运行到Dense层就会报错提示格式错误,type(x)查看x的格式,发现它的输出是一个名为pooler_output的字符串,这显然不合要求:模型的输出应该是一个张量而不是字符串,而全连接层需要的也是一个二维的张量。但如果使用x=tf.convert_to_tensor(x)强制将x转为张量,那么得到的是一个丢失了形状信息的空张量。

  在网上查找资料后发现, 这和tranformers库的版本有关系。pip show transformer指令查看版本。如果版本高于4.0,那么输出的确实会是字符串,解决办法是在一开始的模型定义语句里增加一个参数return_dict=flase,让模型正确返回一个元组。

  如果不想考虑这么麻烦,或者加上了return_dict后,解释器报错,那就直接用

x=bert_model([input_ids,attention_mask,token_type_ids])[1]即可。因为模型的输出是包含两个张量的元组。第一部分是所有时刻的输出,第二部分就是最后一层隐藏层的输出。用[1]就能直接得到最后一层张量,避免了格式问题。不得不说这不是什么大问题,但遇到以后真的很糟心。

标签:Bert,keras,模型,ids,输出,str,type,bert
From: https://www.cnblogs.com/namezhyp/p/17729668.html

相关文章

  • using wget utility to download files while keeping path structure
    Frommanwget:-x,--force-directories:[...]createahierarchyofdirectories,evenifonewouldnothavebeencreatedotherwise.E.g.wget-xhttp://fly.srk.fer.hr/robots.txtwillsavethedownloadedfiletofly.srk.fer.hr/robots.txt.  Togetthest......
  • Linux-Stream内存带宽及MLC内存延迟性能测试方法
    1、Stream内存带宽测试  Stream是业界主流的内存带宽测试程序,测试行为相对简单可控。该程序对CPU的计算能力要求很小,对CPU内存带宽压力很大。随着处理器核心数量的增大,而内存带宽并没有随之成线性增长,因此内存带宽对提升多核心的处理能力就越发重要。Stream具有良好的空间局部......
  • thinkphp lang命令执行--struts2 代码执行--(QVD-2022-46174)&&(CVE-2020-17530)&&(CV
    thinkphplang命令执行--struts2代码执行--(QVD-2022-46174)&&(CVE-2020-17530)&&(CVE-2021-31805)thinkphplang命令执行(QVD-2022-46174)影响范围6.0.1<=ThinkPHP<=6.0.13ThinkPHP5.0.xThinkPHP5.1.x漏洞复现POC:?+config-create+/&lang=../../../../......
  • linux 中sed命令输出匹配字符的下一行或者若干行
     001、grep实现(base)[root@pc1test2]#lsa.txt(base)[root@pc1test2]#cata.txt##测试数据12keyword345keyword678(base)[root@pc1test2]#grep"keyword"-A2a.txt##输出匹配字符后面的两行keyword34--keyword67 002、s......
  • Json输出List集合对象和map对象 JSON格式
    Json输出List集合对象和map对象JSON格式//Json输出List集合对象[{"属性1":["值1"],"属性2":"值2"},{"属性3":["值3"],"属性4":"值4"}]importcom.alibaba.fastjson.JSONObject;importjava.util.ArrayList;impor......
  • Spring Boot RestController接口如何输出到终端
    背景公司项目的批处理微服务,一般是在晚上固定时段通过定时任务执行,但为了预防执行失败,我们定义了对应的应急接口,必要时可以通过运维在终端中进行curl操作。然而,部分任务耗时较长,curl命令执行后长时间没有输出,如果不查看日志,无法知道系统当前的状态,因此有必要研究一下如何在curl命......
  • hive string, map, struct类型的建表和导入数据语句
    本文转载于 https://blog.51cto.com/u_14405/6419362,https://blog.csdn.net/tototuzuoquan/article/details/115493697和 https://blog.csdn.net/weixin_43597208/article/details/117450579。今天要用到hive的string相关的数据类型和数据,直接附链接和sql语句Hive的String类......
  • pytest + yaml 框架 -56. 输出日志优化+allure报告优化
    前言v1.4.8版本优化接口请求和响应输出日志,生成的allure报告也按步骤优化request和response详情日志优化日志用例test_log1:-name:log1request:url:http://127.0.0.1:8000/api/test/demomethod:GETvalidate:-eq:[status_code,200]-eq:......
  • strncpy 出core
    core的堆栈是这样子的:(gdb)bt#00x00007ffff4a96a7cinpthread_kill()from/lib/x86_64-linux-gnu/libc.so.6#10x00007ffff4a42476inraise()from/lib/x86_64-linux-gnu/libc.so.6#20x00007ffff4a287f3inabort()from/lib/x86_64-linux-gnu/libc.so.6#30......
  • 入门篇-其之四-字符串String的简单使用
    什么是字符串?在Java编程语言中,字符串用于表示文本数据。字符串(String)属于引用数据类型,根据String的源码,其头部使用class进行修饰,属于类,即引用数据类型。字符串的表示字符串使用双引号""表示,在双引号中你可以写任意字符。和前面定义并初始化基本数据类型的变量一样,定义最简单......