首页 > 其他分享 >10.3 注意力评分函数

10.3 注意力评分函数

时间:2023-09-09 23:34:04浏览次数:52  
标签:dim 10.3 评分 torch 矩阵 lens valid softmax 注意力

1.torch.bmm()的用法

先说一般的矩阵乘法torch.mm()。torch.mm()用于将两个二维张量(矩阵)相乘,求它们的叉乘结果。如:

 我们创建一个2*3的矩阵A,3*4的矩阵B,它们的值都初始化为均值为0方差为1的标准正态分布,用torch.mm()求它们的叉乘结果:

import torch
from torch import nn
from d2l import torch as d2l

A = torch.normal(0,1,(2,3))
B = torch.normal(0,1,(3,4))
AB = torch.mm(A,B)
print(AB)

输出:

torch.mm()是求一个矩阵乘以一个矩阵的结果,它的两个参数都是二维张量。而torch.bmm()是求一个批量的矩阵乘以一个批量的矩阵的结果,它的两个参数都是三维张量,其中第一维表示了这一批矩阵的数量。如:

A = torch.normal(0,1,(3,2,3))
B = torch.normal(0,1,(3,3,4))
AB = torch.bmm(A,B)
print(AB)

输出:

 注意bmm操作要求两个三维张量维度的第一个参数必须相等:如这里是(3,2,3)的张量和(3,3,4)的张量,它们的第一个参数都是3.

2.掩蔽softmax操作

在有些情况下,并非所有的值都是有意义的,

#@save
def masked_softmax(X, valid_lens):
    """通过在最后一个轴上掩蔽元素来执行softmax操作"""
    # X:3D张量,valid_lens:1D或2D张量
  #如果没有有效长度,即所有值均有效,那么就是直接返回softmax操作的结果 if valid_lens is None: return nn.functional.softmax(X, dim=-1) else: shape = X.shape if valid_lens.dim() == 1: valid_lens = torch.repeat_interleave(valid_lens, shape[1]) else: valid_lens = valid_lens.reshape(-1) # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0 X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6) return nn.functional.softmax(X.reshape(shape), dim=-1)

dim=-1的用法:对于三维张量(C,H,W),做softmax就有三种方式:

1.dim=0就是对每一个维度对应位置的数值构成的向量做softmax运算。

2.dim=1就是对某一个维度的一列数值构成的向量做softmax运算。

3.dim=2就是对某一个维度的一行数值构成的向量做softmax运算。

 

 

这里面的dim=-1相当于dim=2. 那为什么不直接写dim=2?应该是为了这个函数能适应不同维度的X输入,若输入X是3维的,则softmax(X,dim=-1)相当于softmax(X,dim=2),若输入X是2维的,则softmax(X,dim=-1)相当于softmax(X,dim=1)。

torch.repeat_interleave()的用法:

a = torch.arange(6).reshape(2,1,3)
res = torch.repeat_interleave(a,3,dim = 1) #张量a在第1维(行)上重复3遍
print(a)
print(res)
print(a.shape)
print(res.shape)

运行结果:

 看一下这个masked_softmax(X,valid_lens)函数的效果。考虑样本为2个2*4的矩阵,即(2,2,4)的矩阵:

 这里面valid_lens为[2,3]时,意思是:第一个矩阵中的每一行数据都是前2个是有效值,第二个矩阵中的每一行数据都是前3个是有效值。

       valid_lens为[ [1,3],[2,4] ]时,意思是:第一个矩阵中第一行前1个是有效值,第一个矩阵中第二行前3个是有效值,第二个矩阵第一行前2个是有效值,第二个矩阵中第二行是前4个是有效值。

 

3.加性注意力

其中可学习的参数是:

 它等价于将q和k拼接到一起,丢到一个隐藏层大小为h,输出大小为1的单隐藏层MLP中去。Wq、Wk拼接起来就是隐藏层的参数。

(这里的代码没怎么理解,后面再补充)

 

4.缩放点积注意力

两个同维度并且模相等的向量,它们的点积越大,就表示它们越相似。当query和key具有相同的长度d时,我们可以使用缩放点积作为评分函数。假设query和key都遵循均值为0,方差为1的标准正态分布,那么q和k的点积就遵循均值为0,方差为d的正态分布。我们为了确保无论向量长度d是多长,q和k点积的方差仍然是1,将点积再除以√d,得到缩放点积注意力的评分函数:

 这是一个查询q的情况。考虑有n个查询query,m个key-value pair,其中query和key的长度为d,value的长度为v,那么,

的缩放点积注意力为:

 这里面QKτ / √d (维度为n*m) 的第 i 行第 j 列的值就表示第 i 个query对第 j 个key的注意力分数。再做softmax得到的就是注意力权重,再乘以V得到的就是注意力汇聚的结果。

 下面实现的缩放点积注意力使用了dropout进行模型正则化:

#缩放点积注意力
#@save
class DotProductAttention(nn.Module):
    """缩放点积注意力"""
    def __init__(self, dropout, **kwargs):
     #这里的参数暂时没理解 super(DotProductAttention, self).__init__(**kwargs) self.dropout = nn.Dropout(dropout) # queries的形状:(batch_size,查询的个数,d) # keys的形状:(batch_size,“键-值”对的个数,d) # values的形状:(batch_size,“键-值”对的个数,值的维度) # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数) def forward(self, queries, keys, values, valid_lens=None): d = queries.shape[-1] # keys.transpose(1,2)的意思是将key这个三维张量在1和2这两个维度上转置() scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d) self.attention_weights = masked_softmax(scores, valid_lens) return torch.bmm(self.dropout(self.attention_weights), values)

下面演示一下这个类。在下面的代码中,我们设计了queries维度为(2,1,2),含义是一个批量有两个矩阵,每个矩阵有1个查询,长度为2,那么,key和value也应该一个批量是2个矩阵。我们设计了key-value pair的数量为10,并且key的长度等同于query的长度,value的长度为4.

queries = torch.normal(0, 1, (2, 1, 2))
keys = torch.ones((2, 10, 2))
# values的小批量,两个值矩阵是相同的
#values的维度被repeat成了(2,10,4)
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(
    2, 1, 1)
valid_lens = torch.tensor([2, 6])
attention = DotProductAttention(dropout=0.5)
attention.eval()
print(attention(queries, keys, values, valid_lens))
d2l.show_heatmaps(attention.attention_weights.reshape((1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')

运行结果:

 

 总结:

注意力分数是query和key的相似度,是没有被normalize过的,而注意力权重是注意力分数softmax的结果。两种常见的注意力分数计算:

1. 将query和key合并起来进入一个单输出单隐藏层的MLP。(加性注意力)

2.直接将query和key做内积。(点积注意力)

 

标签:dim,10.3,评分,torch,矩阵,lens,valid,softmax,注意力
From: https://www.cnblogs.com/pkuqcy/p/17689852.html

相关文章

  • 星级评分功能实现
          我这个是在前一篇文章所介绍的js脚本基础上做的修改。(请先看前面一篇《星级评分效果-js实现》)      由于把前面的脚本引入项目里,发现当鼠标移到星星图片上获取到的OY值并不在1到19之间,所以导致该功能不起作用。后来通过调试测试发现在IE中当鼠标移动到星星图......
  • weblogic-10.3.6-'wls-wsat'-XMLDecoder反序列化漏洞-(CVE-2017-10271)
    目录1.1、漏洞描述1.2、漏洞等级1.3、影响版本1.4、漏洞复现1、基础环境2、漏洞扫描nacsweblogicScanner3、漏洞验证说明内容漏洞编号CVE-2017-10271漏洞名称Weblogic<10.3.6'wls-wsat'XMLDecoder反序列化漏洞(CVE-2017-10271)漏洞评级高危影响范围10.3......
  • 少发火,注意力放在解决问题上
    近来,感觉自己的情绪控制能力比以前进步不少,主要是认识到就算发火,生气,事情也不会得到根本性的解决,甚至发火会推动事情往更坏的方向走去。有时候,发火的人才是有问题的人。所以最好好好沟通。我之前觉得家里很容易就变得很乱,我觉得是因为我妈在家不知道主动收拾,但......
  • HCL AppScan Standard v10.3.0 (Windows) - 应用程序安全测试
    HCLAppScanStandardv10.3.0(Windows)-应用程序安全测试请访问原文链接:https://sysin.org/blog/appscan-10/,查看最新版。原创作品,转载请保留出处。作者主页:sysin.orgESG技术评论:使用HCLAppScan实现持续的应用程序安全“AppScan通过直接集成到软件开发生命周期来支......
  • discuz3.4,关于安装dev8133插件(购买帖子内容),在论坛对用户组开启评分功能后,用户一点击
    漏洞修补方案一:后端措施在source/module/forum/forum_misc.php文件中,$post=C::t('forum_post')->fetch('tid:'.$_G['tid'],$_GET['pid']);//这一步调用大C的静态方法t()从表forum_post中根据tid和pid共同查询出当前要评分的帖子主体内容对此处查询出来的$post数据直接后端进......
  • 通过jsoup抓取谷歌商店评分
    背景在谷歌上面发布包,有时候要看看评分,有时候会因为总总原因被下架,希望后台能够对评分进行预警,和下架预警实现测试地址:https://play.google.com/store/apps/details?id=com.tencent.mm通过jsoup解析页面,然后获取评分;这是获取评分的:而判断包是否下架就直接判断返回......
  • 解码Transformer:自注意力机制与编解码器机制详述与代码实现
    本文全面探讨了Transformer及其衍生模型,深入分析了自注意力机制、编码器和解码器结构,并列举了其编码实现加深理解,最后列出基于Transformer的各类模型如BERT、GPT等。文章旨在深入解释Transformer的工作原理,并展示其在人工智能领域的广泛影响。作者TechLead,拥有10+年互联网服......
  • 图注意力网络论文详解和PyTorch实现
    前言 图神经网络(gnn)是一类功能强大的神经网络,它对图结构数据进行操作。它们通过从节点的局部邻域聚合信息来学习节点表示(嵌入)。这个概念在图表示学习文献中被称为“消息传递”。本文转载自P**nHub兄弟网站作者|EbrahimPichka仅用于学术分享,若侵权请联系删除欢迎关注公......
  • 图注意力网络论文详解和PyTorch实现
    图神经网络(gnn)是一类功能强大的神经网络,它对图结构数据进行操作。它们通过从节点的局部邻域聚合信息来学习节点表示(嵌入)。这个概念在图表示学习文献中被称为“消息传递”。消息(嵌入)通过多个GNN层在图中的节点之间传递。每个节点聚合来自其邻居的消息以更新其表示。这个过......
  • Python采集主播照片,实现人脸识别, 进行颜值评分,制作颜值排行榜
    昨晚一回家,表弟就神神秘秘的跟我说,发现一个高颜值网站,非要拉着我研究一下她们的颜值高低。我心想,这还得要我一个个慢慢看,太麻烦了~于是反手用Python给他写了一个人脸识别代码,把她们的照片全部爬下来,自动检测颜值打分排名。这不比手动快多了?准备工作开发环境Py......