首页 > 其他分享 >[机器学习]对transformer使用padding mask

[机器学习]对transformer使用padding mask

时间:2023-08-14 17:24:19浏览次数:38  
标签:dim transformer 填充 self mask padding attn

注:本文是对GPT4的回答的整理校正补充。

在处理序列数据时,由于不同的序列可能具有不同的长度,我们经常需要对较短的序列进行填充(padding)以使它们具有相同的长度。但是,在模型的计算过程中,这些填充值是没有实际意义的,因此我们需要一种方法来确保模型在其计算中忽略这些填充值。这就是padding mask的作用。

比如常用的就是在数据集准备中,想用batch来训练,就得将一个batch的数据的长度全部对齐。

1. 什么是Padding Mask?

Padding mask是一个与输入序列形状相同的二进制矩阵,用于指示哪些位置是真实的数据,哪些位置是填充值。

  • 真实数据位置的mask值为0。
  • 填充位置的mask值为1。

2. 如何使用Padding Mask?

在自注意力机制中,我们计算查询和键的点积来得到注意力分数。在应用softmax函数之前,我们可以使用padding mask来确保填充位置的注意力分数为一个非常大的负数(例如,乘以-1e9)。这样,当应用softmax函数时,这些位置的权重将接近于零,从而确保模型在其计算中忽略这些填充值。

3. 示例

假设我们有一个长度为4的序列:[A, B, C, <pad>],其中<pad>是填充标记。对应的padding mask是:[0, 0, 0, 1]

在计算注意力分数后,我们可以使用以下方法应用padding mask:

attention_scores = attention_scores.masked_fill(mask == 1, -1e9)

这里,masked_fill是一个PyTorch函数,它会将mask中值为1的位置替换为-1e9

看图,这里的attention_scores就是Q×K的矩阵,把尾部多余的部分变成-inf,再过SoftMax,这样就是0了。这样,即使V的后半部分有padding的部分,也会因为乘0而变回0。这样被padding掉的部分就从计算图上被剥离了,由此不会影响模型的训练。

4. 代码

笔者自己写的,不保证靠谱哈。

import torch.nn as nn

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, mask=None):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        
        # Apply the padding mask
        if mask is not None:
            attn = attn.masked_fill(mask.unsqueeze(1).unsqueeze(2) == 1, float('-inf'))
        
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

5. 为什么需要Padding Mask?

  • 忽略无关信息:通过使用padding mask,我们可以确保模型在其计算中忽略填充值,从而避免这些无关的信息对模型的输出产生影响。

  • 稳定性:如果不使用padding mask,填充值可能会对模型的输出产生不稳定的影响,尤其是在使用softmax函数时。

  • 解释性:使用padding mask可以提高模型的解释性,因为我们可以确保模型的输出只与真实的输入数据有关,而不是与填充值有关。

总之,padding mask是处理序列数据时的一个重要工具,它确保模型在其计算中忽略填充值,从而提高模型的性能和稳定性。

标签:dim,transformer,填充,self,mask,padding,attn
From: https://www.cnblogs.com/sherrlock/p/17629223.html

相关文章

  • 解码Transformer:自注意力机制与编解码器机制详述与代码实现
    本文全面探讨了Transformer及其衍生模型,深入分析了自注意力机制、编码器和解码器结构,并列举了其编码实现加深理解,最后列出基于Transformer的各类模型如BERT、GPT等。文章旨在深入解释Transformer的工作原理,并展示其在人工智能领域的广泛影响。作者TechLead,拥有10+年互联网服......
  • AES加密 flutter java后台用的 AES/CBC/PKCS5Padding
     可测试AES是否正确的网址https://www.toolhelper.cn/SymmetricEncryption/AES java后台代码如下publicstaticStringencrypt(StringclearText,Stringkey,Stringiv){byte[]result=null;try{byte[]key_bytes=toByte(MD5Util......
  • Windows11安装python模块transformers报错Long Path处理
    Windows11安装python模块transformers报错,报错信息如下ERROR:CouldnotinstallpackagesduetoanOSError:[Errno2]Nosuchfileordirectory:'C:\\Users\\27467\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\LocalCac......
  • maskrcnn详细注解说明(超详细)
     此代码是我对maskrcnn的一些修改,基本还原所有内容,但更加简洁,使代码更易解读。里面有很多注释,非常详细,可自己慢慢品味。若有一些问题,欢迎指正与交流。      此代码为训练文件.py """MASKRCNNalgrithmforobjectdetectionandinstancesegmentationWrittenandmodifi......
  • 【W的AC企划 - 第五期】位运算 (Bitmasks)
    往期浏览第六期-树上分治位运算讲解常见的位运算为:与、或、异或这三种。运算运算符、数学符号表示解释与&、and同1出1或|、or有1出1异或^、\(\bigoplus\)、xor不同出1这一块的内容比较散乱,以海量刷题为首要学习方向,同时需要收集一些常用结论。......
  • 强到离谱,Transformer为何能闯入CV界秒杀CNN?
    Transformer近年来已成为视觉领域的新晋霸主,这个来自NLP领域的模型架构为何能闯入CV界秒杀CNN?自提出之日起,Transformer模型已经在CV、NLP以及其他更多领域中「大展拳脚」,实力冲击CNN。Transformer为什么这么有实力?因为它在分类、检测等任务上展现了极其强劲的性能。而且骨干网络......
  • linux 中umask的作用(还可以)
    https://blog.csdn.net/sinat_42724379/article/details/124752536  ____________________________________________________________________________________________________________________ 我们知道在linux服务器中文件最大权限为666,而目录最大权限为777但是一般我......
  • 基于 Habana Gaudi 的 Transformers 入门
    几周前,我们很高兴地宣布HabanaLabs和HuggingFace将开展加速transformer模型的训练方面的合作。与最新的基于GPU的AmazonWebServices(AWS)EC2实例相比,HabanaGaudi加速卡在训练机器学习模型方面的性价比提高了40%。我们非常高兴将这种性价比优势引入Transform......
  • 医学图像领域--Transformer入门路线推荐
    本文跟那些长篇大论教你入门的文章大大不同!!你读了这些文章,对于小白来讲,原理既难又枯燥,读了等于没读,一样不会用。这里没有枯燥的理论,没有看不懂的术语,因为这些我也不懂!我能提供的,就是一个入门路线!Tina姐妙招:先实践,尝到甜头再回过头来看理论因此,本文分为两个部分,先给大家一些实践案......
  • 《Decision Transformer: Reinforcement Learning via Sequence Modeling》论文学习
    一、Introduction先前的研究工作表明,Transformer可以对处于高维分布的语义概念进行大规模建模抽象,比较典型地体现如:基于自然语言的零样本泛化(zero-shotgeneralization)分布外图像生成(out-of-distributionimagegeneration)鉴于此类模型在多个领域的成功应用,我们希望研究Tran......