首页 > 其他分享 >transformer中解码器的实现细节

transformer中解码器的实现细节

时间:2023-07-26 18:45:30浏览次数:42  
标签:transformer num self attention 细节 解码器 key decoder size

1. 前言

17年google团队发表l了论文《Attention Is All You Need》,transformer横空出世,并引领了AI学术圈的研发风向,以Transformer为基础模型的新模型层出不穷,无论是NLP还是CV或者是多模态,attention遍地开花。

这篇文章遵循encoder-decoder架构,并在其中使用了self-attention和cross-attention,如下图所示:

transformer架构图

其中,encoder的行为还是非常好理解的,至于decoder,则相关细节在原文中都只草草提过,令人留下很多疑问,譬如,

decoder第一个attention为什么需要使用masked?

decoder在训练阶段和测试阶段有什么区别?

decoder在测试阶段,decoder的query输入是将目前所有的预测输入,还是只输入上一次decoder的输出?

2. 问题探讨

decoder第一个attention为什么需要使用masked?

Transformer模型属于自回归模型,也就是说后面的token的推断是基于前面的token的。Decoder端的Mask的功能是为了保证训练阶段和推理阶段的一致性。
在推理阶段,token是按照从左往右的顺序推理的。也就是说,在推理timestep=T的token时,decoder只能“看到”timestep < T的 T-1 个Token, 不能和timestep大于它自身的token做attention(因为根本还不知道后面的token是什么)。为了保证训练时和推理时的一致性,所以,训练时要同样防止token与它之后的token去做attention。

 

decoder在训练阶段和测试阶段有什么区别?

在训练阶段,预测序列是直接全部喂到decoder的输入的,只是在算self-attention的时候加了一个mask,前面时间步的不能看到后面时间步的词,decoder的预测也是一次就全部出来了,也就是Teacher Forcing机制,如下图所示,在训练的时候,需要预测一段语音,decoder端的input,就直接把gt喂进去了,当然加进去前还需要shift right,在序列最左边增加一个Begin的特殊字符(为了和预测阶段保持一致),然后这些gt作为query,进行进入第一层mask multi-head attention层(根据时间步增加mask,以免在self-attention阶段前面的词可以看到后面的),然后以这层的输出为query,来自encoder的输出为key-value pair输入第二个子层multi-head attention,输出作为下层的输入,继续前面的过程,重复N次。

下载 (1)

如果是测试阶段,则就不一样,首先decoder会先输入Begin,预测出下一个词,然后再以已经预测的词作为输入,再进入decoder预测下一个词,直到遇到预测出的词是表示结束的特殊次元,才结束这个过程,参考以下视频:

https://www.zhihu.com/zvideo/1330559583777939456

 

decoder在测试阶段,decoder的query输入是将目前所有的预测输入,还是只输入上一次decoder的输出?

两种实现都有,具体来说,分别是:

a. 每次都将当前预测全部输入,在self-attention和cross-attention中均进行全量计算,优点是实现简单,缺点是计算量大,如下面的代码实现:

class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask):
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

可以看到每次的做self-attention的时候query,key,value都是目前所有的词(query 做了mask操作)。

完全版可以查看:https://zhuanlan.zhihu.com/p/398039366

b. 还有另外一种实现就是增量进行计算,李沐在《动手学深度学习》中就用了这种实现,优点是每次只需要计算一个query,但是因为在self-attention中需要与其他的词进行attention操作,因此需要在每层中保存之前的词作为key和value,如下面代码所示:

class DecoderBlock(nn.Module):
    """解码器中第i个块"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, i, **kwargs):
        super(DecoderBlock, self).__init__(**kwargs)
        self.i = i
        self.attention1 = d2l.MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.attention2 = d2l.MultiHeadAttention(
            key_size, query_size, value_size, num_hiddens, num_heads, dropout)
        self.addnorm2 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,
                                   num_hiddens)
        self.addnorm3 = AddNorm(norm_shape, dropout)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        # 训练阶段,输出序列的所有词元都在同一时间处理,
        # 因此state[2][self.i]初始化为None。
        # 预测阶段,输出序列是通过词元一个接着一个解码的,
        # 因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示
        if state[2][self.i] is None:
            key_values = X
        else:
            key_values = torch.cat((state[2][self.i], X), axis=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape
            # dec_valid_lens的开头:(batch_size,num_steps),
            # 其中每一行是[1,2,...,num_steps]
            dec_valid_lens = torch.arange(
                1, num_steps + 1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None

        # 自注意力
        X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
        Y = self.addnorm1(X, X2)
        # 编码器-解码器注意力。
        # enc_outputs的开头:(batch_size,num_steps,num_hiddens)
        Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
        Z = self.addnorm2(Y, Y2)
        return self.addnorm3(Z, self.ffn(Z)), state

其中state[2][self.i]就存储了目前为止所有预测到的词。

完整版可以查看:https://zh-v2.d2l.ai/chapter_attention-mechanisms/transformer.html

3. 参考

[1] Transformer源码详解(Pytorch版本)

[2] 10.7. Transformer

(完)

标签:transformer,num,self,attention,细节,解码器,key,decoder,size
From: https://www.cnblogs.com/harrymore/p/17583295.html

相关文章

  • CSS2.1规范笔记——10 视觉格式化模型细节
    视觉格式化模型细节包含块的定义元素(生成的)盒的位置有时候是根据一个特定的矩形计算的,叫做元素的包含块(containingblock)。元素包含块的定义如下:元素包含块其为根元素。其包含块是一个被称为初始包含块的矩形。对连续媒体,尺寸取自视口的尺寸,并且被固定在画布开......
  • java_方法使用细节
    java_方法使用细节1.一个方法想要返回多个值思考?一个方法如何返回多个值返回数组classAA{publicint[]getSumAndSub(intn1,intn2){//.........int[]res=newint[2];//创建一个数组res[0]=n1+n2;res[1]=n1-n2;......
  • Transformer(转换器)
    SequenceToSequence(序列对序列)输入一个序列,输出一个序列输出序列的长度由机器自己决定,例如:语音辨识、机器翻译、语音翻译 SequenceToSequence一般分成两部分:Encoder:传入一个序列,由Encoder处理后传给DecoderDecoder:决定输出什么样的序列Encoder Encoder中分......
  • 斯坦福博士一己之力让Attention提速9倍!FlashAttention燃爆显存,Transformer上下文长度
    前言 FlashAttention新升级!斯坦福博士一人重写算法,第二代实现了最高9倍速提升。本文转载自新智元仅用于学术分享,若侵权请联系删除欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。CV各大方向专栏与各个部署框架最全教程整理......
  • Transformer取代者登场!微软、清华刚推出RetNet:成本低、速度快、性能强
    前言 Transformer的训练并行性是以低效推理为代价的:每一步的复杂度为O(N)且键值缓存受内存限制,让Transformer不适合部署。不断增长的序列长度会增加GPU内存消耗和延迟,并降低推理速度。研究者们一直在努力开发下一代架构,希望保留训练并行性和Transformer的性能,同时实现......
  • 从RNN到Transformer
    1.RNN循环神经网络的内容可参考https://www.youtube.com/watch?v=UNmqTiOnRfg。RNN建模的对象是具有时间上前后依赖关系的对象。以youtube上的这个视频为例,一个厨师如果只根据天气来决定今天他做什么菜,那么就是一个普通的神经网络;但如果他第i天所做的菜不仅和第i天的天气有关,还......
  • 大语言模型的预训练[1]:基本概念原理、神经网络的语言模型、Transformer模型原理详解
    大语言模型的预训练[1]:基本概念原理、神经网络的语言模型、Transformer模型原理详解、Bert模型原理介绍1.大语言模型的预训练1.LLM预训练的基本概念预训练属于迁移学习的范畴。现有的神经网络在进行训练时,一般基于反向传播(BackPropagation,BP)算法,先对网络中的参数进行随机初始......
  • Scrapy爬虫文件代码基本认识和细节解释
    importscrapyfromscrapy.http.requestimportRequestfromscrapy.http.response.htmlimportHtmlResponsefromscrapy_demo.itemsimportForumItemclassBaiduSpider(scrapy.Spider):#name必须是唯一的,因为运行一个爬虫是通过name来选择的。#你需要运行命......
  • 一些细节记录
    (125条消息)linux源码中__asm____volatile__作用_asmvolatile_liu-yonggang的博客-CSDN博客Makefile中.PHONY的作用-veli-博客园(cnblogs.com)......
  • @Async组件的细节说明
    使用方式启动类里面使用@EnableAsync注解开启功能,自动扫描定义异步任务类并使用@Component标记组件被容器扫描,异步方法加上@Async@Async失效情况注解@Async的方法不是public方法注解@Async的返回值只能为void或者Future注解@Async方法使用static修饰也会失......