首页 > 其他分享 >各种注意力评分函数的实现

各种注意力评分函数的实现

时间:2024-08-29 18:23:12浏览次数:20  
标签:函数 点积 评分 self 查询 lens valid 注意力

预备知识

本文基于MXNet进行实现,需要对于注意力机制有一定初步了解。也需要对Python有足够了解。

另外这里稍加说明,在注意力机制中,本质上是“注意”的位置,即加权计算后进行Softmax回归的结果。在Nadaraya-Watson核回归中,首先具有一个键值对(key-value),输入称为一个查询(query),对于每个查询,有对应计算,计算查询与键的关系,根据关系的大小,取键所对应的值,通过带权重的值进行预测,这就是Nadaraya-Watson核回归的基本思想。

注意力评分函数

注意力评分函数本质上是对查询和键之间的关系建模,即\hat{y}=\Sigma_i^n \alpha(x,x_i)y_i

在Nadaraya-Watson核回归中,α为查询与键的距离。将注意力评分函数的输出结果输入到softmax函数中进行运算。 通过上述步骤,将得到与键对应的值的概率分布(即注意力权重)。 最后,注意力汇聚的输出就是基于这些注意力权重的值的加权和。

准备工作

选择不同的注意力评分函数α会导致不同的注意力汇聚操作。 本节将介绍两个流行的评分函数,稍后将用他们来实现更复杂的注意力机制。

引入库

import math
from mxnet import np, npx
from mxnet.gluon import nn
from d2l import mxnet as d2l

npx.set_np()

掩蔽Softmax

为了使注意力机制的实现是有效的,可以采用掩蔽Softmax操作,仅对一定的值纳入注意力汇聚中,而无意义的值则排除掉。

def masked_softmax(X, valid_lens):
    if valid_lens is None:
        return npx.softmax(X)
    else:
        shape = X.shape
        if valid_lens.ndim == 1:
            valid_lens = valid_lens.repeat(shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
        X = npx.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, True,
                              value=-1e6, axis=1)
        return npx.softmax(X).reshape(shape)

加性注意力机制

对于给定查询q和键k,分别乘以对应权重,连结后输入一个多层感知机,具有一个隐藏层,禁用bias项,对于这一步产生的结果再进行tanh激活函数的操作,最后通过一个权重矩阵W_v输出结果。(这里还使用了Dropout。)

大致可以理解为,输入含有若干特征x,对其进行运算,获得num_hiddens个隐藏单元,又有若干个键key,对其进行运算,获得num_hiddens个隐藏单元,进行连结,再经过tanh运算,最后乘以权重矩阵,获得输出为一个神经元的结果,这个结果是对键和查询的关系进行加权运算的结果。

class AdditiveAttention(nn.Block):
    def __init__(self, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.Dense(num_hiddens, use_bias=False, flatten=False)
        self.W_q = nn.Dense(num_hiddens, use_bias=False, flatten=False)
        self.w_v = nn.Dense(1, use_bias=False, flatten=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        features = np.expand_dims(queries, axis=2) + np.expand_dims(
            keys, axis=1)
        features = np.tanh(features)
        scores = np.squeeze(self.w_v(features), axis=-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return npx.batch_dot(self.dropout(self.attention_weights), values)

缩放点积注意力

缩放点击直接将查询和键进行点积操作,之后进行缩放,得到的值进行Softmax回归。显然,直接进行矩阵乘法操作是更加快速的,因此缩放点积注意力的运算效率远远高于加性注意力机制,不过缩放点积注意力对于输入和键的大小是有要求的,要求输入和键具有相同大小,否则不可乘。

我个人认为加性注意力机制更类似于一种一般的深度学习方法,而缩放点积注意力则是一种特殊方法。

实现过程如下,需要注意的是:
    # queries的形状:(batch_size,查询的个数,d)
    # keys的形状:(batch_size,“键-值”对的个数,d)
    # values的形状:(batch_size,“键-值”对的个数,值的维度)
    # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)

class DotProductAttention(nn.Block):
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # 设置transpose_b=True为了交换keys的最后两个维度
        scores = npx.batch_dot(queries, keys, transpose_b=True) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return npx.batch_dot(self.dropout(self.attention_weights), values)

说明

疑问

加性注意力中的直接学习的机制我是可以理解的,但在缩放点积注意力的点积部分我感到不解,对于既定的若干个键值对,为什么查询和键直接进行点积操作可以有效获得类似于“权重”的结果呢?

对应解答

查询和键的点积操作有效地衡量了它们之间的相关性或匹配程度。这个操作可以理解为测量查询和每个键的“相似度”或“匹配度”。点积较大的结果意味着查询和对应的键在特征空间中更接近,因此它们之间的匹配程度更高。这个相似度分数在经过缩放和 Softmax 后转化为权重,反映了查询对各个键值对的关注程度。最终,这些权重用于加权值(Value),从而产生最终的注意力输出。

标签:函数,点积,评分,self,查询,lens,valid,注意力
From: https://blog.csdn.net/2301_79335566/article/details/141651345

相关文章

  • 29:函数查询,添加,修改,删除
    #_*_coding:utf-8_*_importosdeffile_handle(filename,backend_data,record_list=None,type='fetch'):#type:fetchappendchangenew_file=filename+'_new'bak_file=filename+'_bak'iftype=='fetch':......
  • Pytorch 的 损失函数
    1.损失函数损失函数(LossFunction)是用来衡量模型预测结果与真实值之间的差异的函数。它是训练过程中最重要的组成部分之一,用来指导模型的优化过程。 作用损失函数的作用包括:衡量模型性能:通过计算预测结果与真实值的差异,损失函数可以提供一个衡量模型预测准确性的指标......
  • 类的成员静态变量和静态成员函数需要类外定义吗,举例说明
    类的成员静态变量需要在类外定义(非声明),而静态成员函数则不需要在类外额外定义。 静态变量类外定义示例 假设有一个类MyClass,它有一个静态成员变量staticVar: cppclassMyClass{public:  staticintstaticVar;//静态成员变量声明  staticvoidstaticFu......
  • C#学习笔记- 随机函数Random()的用法详解
    原文链接:https://www.jb51.net/article/90933.htmRandom.Next()返回非负随机数;Random.Next(Int)返回一个小于所指定最大值的非负随机数Random.Next(Int,Int)返回一个指定范围内的随机数,例如(-100,0)返回负数1、random(number)函数介绍random(number)返回一个0~number-1之间......
  • Python——集合基本操作以及哈希函数
    Python中的集合(Set)是一个无序的、不包含重复元素的数据结构。集合主要用于数学上的集合操作,如并集、交集、差集和对称差集等。集合使用大括号 {} 来表示,但注意空集合不能使用 {} 表示(这会创建一个空字典),而应该使用 set() 来创建。创建集合1.使用大括号 {}:这是最直接......
  • PHP8面向对象快速入门三 类的继承 类方法属性重写和final关键字 parent调用父类的方法
    在PHP中,类的继承(继承)是一种机制,允许一个类继承另一个类的属性和方法,从而实现代码的重用和扩展。继承可以帮助你创建一个基于现有类的新类,保留原有类的特性并增加或修改其功能。classAnimal{public$name='dongwu';protected$age=1;private......
  • switch&回调函数
    #include<stdio.h>//函数原型声明floatcalc(floata,floaty,constcharop);floatadd(floata,floatb);floatminus(floata,floatb);floatmultiple(floata,floatb);floatdivide(floata,floatb);floatcalc_using_callback(floata,floatb,floa......