首页 > 其他分享 >pytorch-多头注意力

pytorch-多头注意力

时间:2023-09-14 11:44:51浏览次数:36  
标签:head 键和值 汇聚 pytorch 多头 查询 注意力

多头注意力

在实践中,当给定相同的查询、键和值的集合时,我们希望模型可以基于相同的注意力机制学习到不同的行为,然后将不同的行为作为知识组合起来,捕获序列内各种范围的依赖关系(例如,短距离依赖和长距离依
赖关系)
。因此,允许注意力机制组合使用查询、键和值的不同子空间表示(representation subspaces)可能是有益的。为此,与其只使用单独一个注意力汇聚,我们可以用独立学习得到的h组不同的线性投影(linear projections)来变换查询、键和值。然后,这h组变换后的查询、键和值将并行地送到注意力汇聚中。最后,将这h个注意力汇聚的输出拼接在一起,并且通过另一个可以学习的线性投影进行变换,以产生最终输出。这种设计被称为多头注意力(multihead attention)(Vaswani et al., 2017)。对于h个注意力汇聚输出,每一个注意力汇聚都被称作一个头(head)。下图展示了使用全连接层来实现可学习的线性变换的多头注意力。
image

模型

在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。给定查询\(q \in R^{d_q}\)、键\(k \in R^{d_k}\)和值\(v \in R^{d_v}\),每个注意力头\(h_i(i = 1, . . . , h)\)的计算方法为:

\[h_i = f(W^{(q)}_iq, W^{(k)}_ik,W^{(v)}_iv) \in R^{p_v} \]

其中,可学习的参数包括\(W^{(q)}_i \in R^{p_q×d_q}\)、\(W^{(k)}_i \in R^{p_k×d_k}\)和\(W^{(v)}_i \in R^{p_v×d_v}\),以及代表注意力汇聚的函数f。f可以的加性注意力和缩放点积注意力。多头注意力的输出需要经过另一个线性转换,它对应着h个头连结后的结果,因此其可学习参数是 \(W_o \in R^{p_o×hp_v}\):
image

总结

多头注意力机制现在的使用是非常广泛的。为什么需要比较多的head呢?可以想成相关这件事情在做Self-attention的时候,就是用q去找相关的k,但是相关这件事情有很多种不同的形式,有很多种不同的定义,所以我们不能只有一个q,应该要有多个q,不同的q负责不同种类的相关性。
我们应在怎么做呢?首先对于这个\(q_i\)我们分别乘两个矩阵变成\(q^{i,1}\)和\(q^{i,2}\)。这个可以理解为两种不同的相关性。之后q,k,v都要有两个:
image
用第一个head:
image
用第二个head:
image
将这两个接起来,然后通过一个trannsform,也就是乘上一个矩阵,得到\(b_i\)传到下一层去。
image

标签:head,键和值,汇聚,pytorch,多头,查询,注意力
From: https://www.cnblogs.com/lipu123/p/17701717.html

相关文章

  • 《动手学深度学习 Pytorch版》 5.4 自定义层
    5.4.1不带参数的层importtorchimporttorch.nn.functionalasFfromtorchimportnnclassCenteredLayer(nn.Module):def__init__(self):super().__init__()defforward(self,X):returnX-X.mean()#仅作减去均值的操作layer=......
  • TensorFlow PyTorch Transformer --GPT
    你是机器学习专家,请为我解答如下疑问:tensorflow,pytorch分别是什么?他们跟numpy的区别是什么?什么是深度学习,深度学习框架中的框架指的是什么?什么是机器学习,什么是transformer?transformer和tensorflow,pytorch的关系是什么?ChatGPTTensorFlow和PyTorchTensorFlow:这是一个由......
  • PytorchGPU版本环境配置。Anconda + Pycharm。
    2023年9月12日深度学习课程是基于Pytorch框架有的没的因为之前瞎搞,环境乱七八糟,这次DL课程最好装上GPU版本,那么就通过这次一次性弄个好吧!绝对不是因为,我之前不会弄。(是这样的)课程需要配置好环境。最后经过一段时间的瞎搞乱搞的调整,Pytorch-GPU版本成功安装好了。我是根据B......
  • ubuntu16.04安装cuda8.0+pytorch1.0.0
    1.安装cuda1.1查看ubuntu的英伟达显卡驱动nvidia-smi得到驱动版本是384.130,比较老,所以需要下载旧版本的cuda1.2查看显卡是否支持CUDA计算然后去到这里https://developer.nvidia.com/cuda-gpus查看你的显卡是否在表中,在的话你显卡就是支持CUDA计算的(CUDA-capable)。结果......
  • Pytorch深度学习零基础入门知识
    DL跑代码必须知道的事情损失值损失值的大小用于判断是否收敛,比较重要的是有收敛的趋势,即验证集损失不断下降,如果验证集损失基本上不改变的话,模型基本上就收敛了。损失值的具体大小并没有什么意义,大和小只在于损失的计算方式,并不是接近于0才好。如果想要让损失好看点,可以直接......
  • 10.3 注意力评分函数
    1.torch.bmm()的用法先说一般的矩阵乘法torch.mm()。torch.mm()用于将两个二维张量(矩阵)相乘,求它们的叉乘结果。如: 我们创建一个2*3的矩阵A,3*4的矩阵B,它们的值都初始化为均值为0方差为1的标准正态分布,用torch.mm()求它们的叉乘结果:importtorchfromtorchimportnnfromd......
  • pip 安装pytorch
    一、新建虚拟环境二、激活虚拟环境三、配置清华镜像源四、在Pytorch官网:PyTorch 选择相关配置 ......
  • PyTorch安装记录
    打开PyTorch官网,选择getstartedhttps://pytorch.org/查看系统的cuda版本nvcc-V若系统安装了cuda,则最后一行会显示cuda版本。如果返回None,则说明没有使用cuda3.选择合适的系统,安装工具以及cuda版本这里没有看到我们需要的11.4的cuda版本,选择installpreviousver......
  • 转:pytorch RoIAlign函数的用法
    图解RoIAlign以及在PyTorch中的使用(含代码示例)_虾米小馄饨的博客-CSDN博客如何在你自己的代码中使用ROIPool和ROIAlign(PyTorch1.0)_ronghuaiyang的博客-CSDN博客 ......
  • PyTorch基础知识
    PyTorchTutorialPython3中机器学习框架dataset=MyDataset(file)dataloader=DataLoader(dataset,batch_size=size,shuffle=True)Training:TrueTesting:Falsefromtorch.utils.dataimportDataset,DateLoaderclassMyDataset(Dataset):def__init__(self,......