首页 > 其他分享 >bert 的输出格式详解

bert 的输出格式详解

时间:2023-02-21 14:24:08浏览次数:42  
标签:bert 格式 torch ids 详解 input output hidden size

输出是一个元组类型的数据 ,包含四部分,

last hidden state shape是(batch_size, sequence_length, hidden_size),hidden_size=768,它是模型最后一层的隐藏状态

pooler_output:shape是(batch_size, hidden_size),这是序列的第一个token (cls) 的最后一层的隐藏状态,它是由线性层和Tanh激活函数进一步处理的,这个输出不是对输入的语义内容的一个很好的总结,对于整个输入序列的隐藏状态序列的平均化或池化可以更好的表示一句话。

hidden_states:这是输出的一个可选项,如果输出,需要指定config.output_hidden_states=True,它是一个元组,含有13个元素,第一个元素可以当做是embedding,其余12个元素是各层隐藏状态的输出,每个元素的形状是(batch_size, sequence_length, hidden_size),

attentions:这也是输出的一个可选项,如果输出,需要指定config.output_attentions=True,它也是一个元组,含有12个元素,包含每的层注意力权重,用于计算self-attention heads的加权平均值

import torch
from torch import tensor
from transformers import BertConfig, BertTokenizer, BertModel
 
model_path = 'model/chinese-roberta-wwm-ext/'#已下载的预训练模型文件路径
config = BertConfig.from_pretrained(model_path, output_hidden_states = True, output_attentions=True)
assert config.output_hidden_states == True
assert config.output_attentions == True
model = BertModel.from_pretrained(model_path, config = config)
tokenizer = BertTokenizer.from_pretrained(model_path)
 
text = '我热爱这个世界'
 
# input = tokenizer(text)
# {'input_ids': [101, 2769, 4178, 4263, 6821, 702, 686, 4518, 102], 
#'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0], 
#'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}
 
# input = tokenizer.encode(text)
# [101, 2769, 4178, 4263, 6821, 702, 686, 4518, 102]
 
# input = tokenizer.encode_plus(text)
# {'input_ids': [101, 2769, 4178, 4263, 6821, 702, 686, 4518, 102], 
#'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0], 
#'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}
 
input_ids = torch.tensor([tokenizer.encode(text)], dtype=torch.long)#一个输入也需要组batch
print(input_ids.shape)
#torch.Size([1, 9])
 
model.eval()
output = model(input_ids)
print(len(output))
print(output[0].shape) #最后一层的隐藏状态 (batch_size, sequence_length, hidden_size)
print(output[1].shape) #第一个token即(cls)最后一层的隐藏状态 (batch_size, hidden_size)
print(len(output[2])) #需要指定 output_hidden_states = True, 包含所有隐藏状态,第一个元素是embedding, 其余元素是各层的输出 (batch_size, sequence_length, hidden_size)
print(len(output[3])) #需要指定output_attentions=True,包含每一层的注意力权重,用于计算self-attention heads的加权平均值(batch_size, layer_nums, sequence_length, sequence_legth)
# 4
# torch.Size([1, 9, 768])
# torch.Size([1, 768])
# 13
# 12
 
all_hidden_state = output[2]
print(all_hidden_state[0].shape)
print(all_hidden_state[1].shape)
print(all_hidden_state[2].shape)
# torch.Size([1, 9, 768])
# torch.Size([1, 9, 768])
# torch.Size([1, 9, 768])
 
attentions = output[3]
print(attentions[0].shape)
print(attentions[1].shape)
print(attentions[2].shape)
# torch.Size([1, 12, 9, 9])
# torch.Size([1, 12, 9, 9])
# torch.Size([1, 12, 9, 9])

后续补充,

text = '我热爱这个世界'
 
input = tokenizer(text)
#input分词后是一个字典
# {'input_ids': [101, 2769, 4178, 4263, 6821, 702, 686, 4518, 102], 
#'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0], 
#'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}
 
 
#input_ids = torch.tensor([tokenizer.encode(text)], dtype=torch.long)
#一个输入也需要组batch
 
input_ids = torch.tensor([input["input_ids"])
token_type_ids = torch.tensor([input["token_type_ids"])
attention_mask = torch.tensor([input["attention_mask"]]
 
output = model(input_ids, token_type_ids, attention_mask)
 
# 可以同时输入input_ids token_type_ids 和 attention_mask得到输出
 
#另一种写法,直接在分词的过程中返回张量
input_tensor = tokenizer(input, return_tensors = "pt")

标签:bert,格式,torch,ids,详解,input,output,hidden,size
From: https://www.cnblogs.com/chaofengya/p/17140877.html

相关文章

  • 日期格式转化
    1字母含义y表示年份yyyy为四位数年份,如2023,yy为两位数年份,如23M表示月份MM为两位数月份,如02,M为以为数月份,如2d表示日期  D表示一年中的第几天h或H表示小时,h是......
  • PHP 错误 系列:编码格式错误解决
    一、Phalcon模型代码日志错误代码错误页面显示:Log日志错误代码:​​PHPmessage:PHPFatalerror:Namespacedeclarationstatementhastobetheveryfirststatemen......
  • PHP系列 | PHPexcel导入xls格式 ,提示错误:iconv(): Wro
    导入xls格式(2003版本)时会报错提示错误信息iconv():Wrongcharset,conversionfrom`CP936'to`UTF-8'isnotallowed[/var/www/web/vendor/phpoffice/phpexcel/Classe......
  • K8SYaml文件详解(云原生)
    一、K8S支持的文件格式kubernetes支持YAML和JSON文件格式管理资源对象。JSON格式:主要用于api接口之间消息的传递YAML格式:用于配置和管理,YAML是一种简洁的非标记性语言,内......
  • 音乐下载器,音乐解析软件,全网音乐免费下载,mp3格式音乐下载,flac格式音乐下载,无损音质音
    在这个音乐版权被三分天下的时代,想必大家也都会有这种的困扰,喜欢的音乐很多,刚好这些音乐的版权还分散在三大主流音乐厂商的手里。这样的话,想要听或者下载自己喜欢的音乐可......
  • Java集合Map接口详解——含源码分析
    前言关于集合中的Collection我们已经讲完了,接下来我们一起来看集合中的另一个大类:MapMap的实现类首先Map是一个接口,是一对键值对来存储信息的,K为key键,V为value值HashMapimpo......
  • 企业微信群聊机器通讯报文详解
    背景对接chatgpt时,需要支持在群聊里@机器人时回复内容@我的收到的请求{"atMe":"true","groupRemark":"","textType":"1","groupName":"吴冠冠......
  • Rust Format 格式
    fnmain(){println!("{}",1);//默认用法,打印Displayprintln!("{:o}",9);//八进制println!("{:x}",255);//十六进制小写println!("{:X}",......
  • 前端Blob数据流转JSON格式
    blobToJson=(blobData)=>{returnnewPromise((resolve,reject)=>{constreader:any=newFileReader()letjsonData:anyreader.readAsText(b......
  • IO流详解及常用方法
    1.1.什么是IO流IO流:Input/OutputStream流:指的是一串流动的数据,在数据在流中按照指定的方向进行流动。实现数据的读取、写入的功能。1.2.IO流的使用场景使用F......