首页 > 其他分享 >关于BERT输出的一点记录

关于BERT输出的一点记录

时间:2022-11-10 14:22:12浏览次数:45  
标签:BERT None return 记录 torch 输出 output hidden Optional

关于Bert的 [cls] 的输出

在抱抱脸团队发布的Pytorch版的Bert中,要想取到每句话的第一个cls特征是一件容易的事情。直接使用Bert的输出,然后.pooler_output 就可以了。

BERT的最后一层的输出是一个[batch, seq_length,dim]的东西,dim通常为768。seq_length 是句子被填充后的长度,论文中说最长不能超过512。那么如何取得[cls]所对应的768维度的向量呢?

其实就是最后一层的输出的句子长度的第一个。翻译成python

last_hidden[:,0]就这样简单。这种切片操作返回的是[batch , dim]。

那么第二个词就是last_hidden[:,1]。

在抱抱脸实现的版本中,可以通过 output.pooler_output 获得经过加工的[cls]对应的向量。

在源码中,通过了dense和activation。


使用BERT做 Token 级别的分类

最近任务中用到了NER,在网上发现直接使用BERT的准确率就非常好了,再往后面加东西,准确率反而会下降。但是再具体实现过程中遇到了些问题。

抱抱脸有一个 Token 级别的BERT实现,BertForTokenClassification 。

def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

源码中output是bert的输出,但是bert的输出是一个BaseModelOutputWithPoolingAndCrossAttentions 对象,它是一个dataclass(我第一次听说这个词)

    @dataclass
class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
    last_hidden_state: torch.FloatTensor = None
    pooler_output: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    cross_attentions: Optional[Tuple[torch.FloatTensor]] = None

下面的output[0]是 last_hidden_state(我第一次见这种写法)。那么它的shape应该是[batch , max_length , 768]。对输出经过一词 dropout 后,送入线性模型,那么输出logits是[batch ,max_length , classes]。

labels 的shape 是[batch ,max_length]看样子应该是每个字符所对应的类别id。接下来定义了一个 CrossEntropyLoss 函数,它的输入input 被置换为了[batch*max_length , classes],输出被拉平(1维)[batch *max_length ]。

举个例子,一个批次60句话,每句话被填充为126个字,每个字对应3种类别。那么sequence_output=【60,126,768】,logits为【60,126,3】,logits.view(-1, self.num_labels)为 【60* 126,3】,labels.view(-1)为【60* 126】

标签:BERT,None,return,记录,torch,输出,output,hidden,Optional
From: https://www.cnblogs.com/aerith/p/16876873.html

相关文章

  • 记录一下Stream流的一个坑
    List<String>list=newArrayList<>();booleana=list.stream().anyMatch("a"::equals);//Ifthestreamisemptythenfalseisreturnedandthepredi......
  • Redis对于字符串(String)知识点理解和实操过程例子的详解记录
    一.Redis字符串1.1基本操作如果字符串内容为整数的时候。1.1.1set、mset、get、mget存和取Redis的Set是String类型的无序集合。集合成员是唯一的,这就意味......
  • 学习记录24常用API
    Math\System\Runtime\Object\BigInteger\BigDecima\正则表达式(爬虫、捕获)主要记忆类名和作用MathString时间原点:1970年1月1日08:00:001秒=1000毫秒1毫秒=1000......
  • Cmake 相关语法记录
    CMake说明cmake的定义是什么?-----高级编译配置工具当多个人用不同的语言或者编译器开发一个项目,最终要输出一个可执行文件或者共享库(dll,so等等)这时候神器就出现了-----......
  • 删除文件后,磁盘空间没有释放的处理记录
    问题说明:一台服务器的/分区使用率爆满了!已达到100%!经查看发现有个文件过大(50G),于是在跟有关同事确认后rm-f果断删除该文件。但是发现删除该文件后,/分区的磁盘空间压根没有......
  • 记录一次springboot 集成 openfeign 实现模块间调用异常
    记录一次springboot集成openfeign实现模块间调用异常 问题背景product 服务作为服务端,提供了一个对外通信Fegin接口ProductClient,放在了com.imooc.product.clie......
  • 2022 icpc 沈阳站 记录(非题解)
    赛前大概是赛前三周才突然知道拥有了比赛机会。赛前训练和vp频率很高,有一段时间cf上都是绿的。比赛的那一周只有一天没在vp,到了周六热身赛我人都有点麻木。(可能正赛也是......
  • Java输出SSL握手日志和查看cacerts路径
    在JAVA启动时添加下面的VM参数就可以启动握手日志了!!!-Djavax.net.debug=all另外,在debug日志中,有一个trustStoreis关键字,根据这个可以找到使用的是哪个truststor......
  • 将hex printf输出存储到变量
    Ihavetoroundoffafloattodecimal.Afterroundingoff,Ishouldconvertthisnumbertohexadecimal.IthinkIgottheroundoffpartokaywith round()我必......
  • 输入与输出
    输入/输出流read和write方法在执行时都将阻塞,直至字节确实被读入或写出。完成操作后要通过close方法将资源关闭,输出流在关闭时会冲刷缓冲区。完整的流家族任何实......