首页 > 其他分享 >关于attention中对padding的处理:mask

关于attention中对padding的处理:mask

时间:2024-05-21 17:58:05浏览次数:21  
标签:dim attention False torch mask padding attn

先问了下chatgpt:
我正在学习torch.nn.multiheadattention,请告诉我att_mask和key_padding_mask这两个参数有什么不同,分别用于处理什么问题,以及输出有什么不同,并给出代码示例

chatgpt的回答:
torch.nn.MultiheadAttention 中的 attn_maskkey_padding_mask 是两个非常重要的参数,用于处理不同的注意力机制中的问题。

attn_mask (Attention Mask)

attn_mask 用于在计算注意力权重时屏蔽某些位置。这在例如自回归生成模型中非常有用,可以防止模型在生成当前词时看到未来的词。它可以是一个 2D 的布尔型矩阵,用于屏蔽特定位置,或是一个上三角矩阵,用于屏蔽未来的位置。

key_padding_mask

key_padding_mask 用于在计算注意力权重时忽略某些填充位置(padding)。这是在处理不同长度的序列时常用的技巧,确保模型不会将填充值(通常为0)作为有效的信息来处理。

不同之处

  • attn_mask 是用来遮蔽未来的位置或指定位置,以防止信息泄露或特定位置的信息影响计算。
  • key_padding_mask 是用来忽略填充值的位置,确保这些填充值不会影响模型的注意力计算。

代码示例

import torch
import torch.nn as nn

# 创建一些示例数据
batch_size = 4
seq_length = 6
embed_dim = 8
num_heads = 2

# 模拟输入 (batch_size, seq_length, embed_dim)
x = torch.rand(batch_size, seq_length, embed_dim)

# 模拟填充后的序列
padded_x = nn.utils.rnn.pad_sequence(
    [torch.rand(3, embed_dim), torch.rand(4, embed_dim), torch.rand(5, embed_dim), torch.rand(6, embed_dim)], 
    batch_first=True
)

# 创建 SelfAttention 实例
self_attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)

# Attention mask: 遮蔽未来的位置(上三角矩阵)
attn_mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).bool()

# Key padding mask: 遮蔽填充值(padding)
key_padding_mask = torch.tensor([
    [False, False, False, True, True, True],  # 只有前三个位置有效
    [False, False, False, False, True, True], # 只有前四个位置有效
    [False, False, False, False, False, True],# 只有前五个位置有效
    [False, False, False, False, False, False]# 所有位置都有效
])

# 计算注意力
attn_output, attn_weights = self_attention(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask)

print("Attention Output:", attn_output)
print("Attention Weights:", attn_weights)

解释代码

  1. 模拟输入数据:创建一个随机的输入张量 x,形状为 (batch_size, seq_length, embed_dim)
  2. 模拟填充后的序列:使用 nn.utils.rnn.pad_sequence 创建一个填充后的序列 padded_x
  3. 创建 SelfAttention 实例:实例化 nn.MultiheadAttention,设置 embed_dimnum_heads
  4. 创建 Attention Maskattn_mask 是一个上三角矩阵,用于屏蔽未来的位置。
  5. 创建 Key Padding Maskkey_padding_mask 是一个布尔型张量,用于屏蔽填充值的位置。
  6. 计算注意力:调用 self_attention,传入输入张量、attn_maskkey_padding_mask

输出

  • Attention Output:经过注意力机制后的输出张量。
  • Attention Weights:注意力权重,展示了每个位置的权重值。

这样,你可以通过 attn_maskkey_padding_mask 灵活地控制注意力机制的行为,以适应不同的任务需求。

标签:dim,attention,False,torch,mask,padding,attn
From: https://www.cnblogs.com/waterrr404/p/18204669

相关文章

  • 基于GWO灰狼优化的CNN-LSTM-Attention的时间序列回归预测matlab仿真
    1.算法运行效果图预览优化前    优化后     2.算法运行软件版本matlab2022a  3.算法理论概述       时间序列回归预测是数据分析的重要领域,旨在根据历史数据预测未来时刻的数值。近年来,深度学习模型如卷积神经网络(ConvolutionalNeuralN......
  • 探索大语言模型:理解Self Attention
    一、背景知识在ChatGPT引发全球关注之后,学习和运用大型语言模型迅速成为了热门趋势。作为程序员,我们不仅要理解其表象,更要探究其背后的原理。究竟是什么使得ChatGPT能够实现如此卓越的问答性能?自注意力机制的巧妙融入无疑是关键因素之一。那么,自注意力机制究竟是什么,它是如何创造......
  • 经典译文:Transformer--Attention Is All You Need
    经典译文:Transformer--AttentionIsAllYouNeed来源  https://zhuanlan.zhihu.com/p/689083488 本文为Transformer经典论文《AttentionIsAllYouNeed》的中文翻译:https://arxiv.org/pdf/1706.03762.pdf注意力满足一切[email protected]......
  • 基于WOA优化的CNN-LSTM-Attention的时间序列回归预测matlab仿真
    1.算法运行效果图预览优化前:    优化后:   2.算法运行软件版本matlab2022a 3.算法理论概述       时间序列回归预测是数据分析的重要领域,旨在根据历史数据预测未来时刻的数值。近年来,深度学习模型如卷积神经网络(ConvolutionalNeuralNetwork,C......
  • CSS mask 与 切图艺术
    一、“切图”的局限性传统的“切图”简单暴力,但往往缺少适应性。适应性一般有两种,一是尺寸自适应,二是颜色可以自定义。举个例子,有这样一个优惠券样式关于这类样式实现技巧,之前在这篇文章中有详细介绍:CSS实现优惠券的技巧不过这里略微不一样的地方是,两个凹陷处都是平滑处理......
  • Apache Shiro 721反序列化漏洞Padding Oracle Attack
    目录漏洞原理复现修复方式漏洞原理Shiro的RememberMeCookie使用的是AES-128-CBC模式加密。其中128表示密钥长度为128位,CBC代表CipherBlockChaining,这种AES算法模式的主要特点是将明文分成固定长度的块,然后利用前一个块的密文对当前块的明文进行加密处理。这种模式的加......
  • 基于WOA优化的CNN-GRU-Attention的时间序列回归预测matlab仿真
    1.算法运行效果图预览woa优化前      woa优化后    2.算法运行软件版本matlab2022a 3.算法理论概述      时间序列回归预测是数据分析的重要领域,旨在根据历史数据预测未来时刻的数值。近年来,深度学习模型如卷积神经网络(ConvolutionalNeur......
  • 根据bounding box坐标框绘制mask
    根据boundingbox坐标框绘制maskimportosfromPILimportImage,ImageDraw#定义图像和标注文件夹路径image_folder_path=r'F:\Liang\Datasets\Text_dataset\Tampered-IC13\train_img'annotation_folder_path=r'F:\Liang\Datasets\Text_dataset\Tampered-IC1......
  • 基于PSO优化的CNN-LSTM-Attention的时间序列回归预测matlab仿真
    1.算法运行效果图预览PSO优化前:      PSO优化后:   2.算法运行软件版本MATLAB2022A  3.算法理论概述       时间序列回归预测是数据分析的重要领域,旨在根据历史数据预测未来时刻的数值。近年来,深度学习模型如卷积神经网络(ConvolutionalNe......
  • EditorGUI.MaskField实现多选枚举
    效果枚举publicenumMyFontStyleMask{Bold=1,Italic=1<<1,Outline=1<<2,}标签类usingUnityEngine;publicclassMyEnumMaskAttribute:PropertyAttribute{}PropertyDrawer#ifUNITY_EDITORusingSystem;usingUnityEd......