首页 > 其他分享 >『NLP学习笔记』如何理解attention中的Q,K,V

『NLP学习笔记』如何理解attention中的Q,K,V

时间:2022-11-14 15:32:45浏览次数:75  
标签:NLP 768 nn self attention 笔记 hidden size


如何理解attention中的Q,K,V?

文章目录

  • ​​一. 如何理解attention中的Q,K,V?​​
  • ​​1.1. 定义三个线性变换矩阵​​
  • ​​1.2. 定义QKV​​
  • ​​1.3. 自注意力计算​​
  • ​​1.3.1. Q和K矩阵乘​​
  • ​​1.3.2. 除以根号dim​​
  • ​​1.3.3. 注意力权重和V矩阵乘​​
  • ​​1.4. 为什么叫自注意力网络​​
  • ​​1.5. 为什么注意力机制是没有位置信息​​
  • ​​二. 参考文章​​
  • 可以先看下之前的文章:​​『NLP学习笔记』Transformer技术详细介绍​​

一. 如何理解attention中的Q,K,V?

1.1. 定义三个线性变换矩阵

  • 1. 首先定义三个线性变换矩阵,query, key, value:
class BertSelfAttention(nn.Module):
self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
  • 注意,这里的 query, key, value 只是一种操作(线性变换)的名称,实际的 『NLP学习笔记』如何理解attention中的Q,K,V_学习

1.2. 定义QKV

  • 假设三种操作的输入都是同一个矩阵(暂且先别管为什么输入是同一个矩阵),这里暂且定为长度为L的句子,每个token的特征维度是768,那么输入就是(L, 768),每一行就是一个字,像这样:



『NLP学习笔记』如何理解attention中的Q,K,V_线性变换_02

  • 乘以上面三种操作就得到了 『NLP学习笔记』如何理解attention中的Q,K,V_学习『NLP学习笔记』如何理解attention中的Q,K,V_自然语言处理_04,维度其实没变,即此刻的 『NLP学习笔记』如何理解attention中的Q,K,V_学习



『NLP学习笔记』如何理解attention中的Q,K,V_线性变换_06

  • 代码为:
class BertSelfAttention(nn.Module):
def __init__(self, config):
self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768

def forward(self,hidden_states): # hidden_states 维度是(L, 768)
Q = self.query(hidden_states)
K = self.key(hidden_states)
V = self.value(hidden_states)

1.3. 自注意力计算

  • 然后来实现这个操作:
    『NLP学习笔记』如何理解attention中的Q,K,V_学习_07

1.3.1. Q和K矩阵乘

  • 首先是 『NLP学习笔记』如何理解attention中的Q,K,V_人工智能_08『NLP学习笔记』如何理解attention中的Q,K,V_学习_09,(L, 768)*(L, 768)的转置=(L,L),看图:

『NLP学习笔记』如何理解attention中的Q,K,V_人工智能_10

  • 首先用 『NLP学习笔记』如何理解attention中的Q,K,V_权重_11 的第一行,即 “我”字的768特征和K中“我”字的768为特征点乘求和,得到输出(0,0)位置的数值,这个数值就代表了“我想吃酸菜鱼”中“我”字对“我”字的注意力权重,然后 显而易见输出的第一行就是“我”字对“我想吃酸菜鱼”里面每个字的注意力权重;整个结果自然就是 “我想吃酸菜鱼”里面每个字对其它字(包括自己)的注意力权重(就是一个数值)了

1.3.2. 除以根号dim

1.3.3. 注意力权重和V矩阵乘

  • 然后就是刚才的注意力权重和 『NLP学习笔记』如何理解attention中的Q,K,V_权重_11

『NLP学习笔记』如何理解attention中的Q,K,V_学习_13

  • 注意力权重 x VALUE矩阵 = 最终结果,首先是“我”这个字对“我想吃酸菜鱼”这句话里面每个字的注意力权重,和V中“我想吃酸菜鱼”里面每个字的第一维特征进行 相乘再求和,这个过程其实就 相当于用每个字的权重对每个字的特征进行加权求和,然后再用“我”这个字对对“我想吃酸菜鱼”这句话里面每个字的注意力权重和V中“我想吃酸菜鱼”里面每个字的第二维特征进行相乘再求和,依次类推,最终也就得到了(L,768)的结果矩阵,和输入保持一致~
class BertSelfAttention(nn.Module):
def __init__(self, config):
self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768

def forward(self,hidden_states): # hidden_states 维度是(L, 768)
Q = self.query(hidden_states)
K = self.key(hidden_states)
V = self.value(hidden_states)

attention_scores = torch.matmul(Q, K.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = nn.Softmax(dim=-1)(attention_scores)

out = torch.matmul(attention_probs, V)
return out
  • 这里对上面的一些值进行假定,给出结果
import math
import torch
from torch import nn


class BertSelfAttention(nn.Module):
def __init__(self, hidden_size=768, all_head_size=768):
super().__init__()
self.query = nn.Linear(hidden_size, all_head_size) # 输入768, 输出768
self.key = nn.Linear(hidden_size, all_head_size) # 输入768, 输出768
self.value = nn.Linear(hidden_size, all_head_size) # 输入768, 输出768

def forward(self, inputs, attention_head_size=768): # inputs 维度是(L, 768)
Q = self.query(inputs)
K = self.key(inputs)
V = self.value(inputs)
attention_scores = torch.matmul(Q, K.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(attention_head_size)
attention_probs = nn.Softmax(dim=-1)(attention_scores)
out = torch.matmul(attention_probs, V)
return out


if __name__ == '__main__':
tensor = torch.normal(0, 1, (25, 768)) # 随意模拟的一个
attention = BertSelfAttention()
out = attention(tensor)
print(out)
tensor([[ 0.0341,  0.1600,  0.0292,  ...,  0.0963, -0.0547,  0.0571],
[ 0.0587, 0.1236, 0.0760, ..., 0.0394, -0.0674, 0.1228],
[ 0.0631, 0.2530, 0.0133, ..., 0.0899, -0.0734, 0.1542],
...,
[ 0.0467, 0.1886, -0.0014, ..., 0.0197, -0.0556, 0.1075],
[ 0.0739, 0.1167, 0.0180, ..., 0.0425, -0.0303, 0.1381],
[ 0.0867, 0.2769, -0.0908, ..., 0.0613, -0.1291, 0.1641]],
grad_fn=<MmBackward0>)

Process finished with exit code 0

1.4. 为什么叫自注意力网络

  • 因为可以看到 『NLP学习笔记』如何理解attention中的Q,K,V_学习 都是通过同一句话的输入算出来的,按照上面的流程也就是一句话内每个字对其它字(包括自己)的权重分配;那如果不是自注意力呢?简单来说,来自于句A,来自于句B即可~

1.5. 为什么注意力机制是没有位置信息

  • 注意,『NLP学习笔记』如何理解attention中的Q,K,V_人工智能_15 中,如果同时替换任意两个字的位置,对最终的结果是不会有影响的,至于为什么,可以自己在草稿纸上画一画矩阵乘;也就是说注意力机制是没有位置信息的,不像CNN/RNN/LSTM;这也是为什么要引入位置embeding的原因

『NLP学习笔记』如何理解attention中的Q,K,V_人工智能_16

  • 从上图可以明显看出每个token的输出和其所在的顺序是没有关系的。

二. 参考文章


标签:NLP,768,nn,self,attention,笔记,hidden,size
From: https://blog.51cto.com/u_15866474/5849254

相关文章

  • java基础笔记
    java的数据类型分为两大类  进制前缀二进制:0b八进制:0十六进制:0xJava会直接将它们转换为十进制输出 float、double并不能准确表示每一位小数,对于有的小数只能无......
  • Navicat使用笔记08---利用Navicat进行数据迁移
    1.使用背景需要将一台服务器上mysql数据迁移到另一台服务器的mysql中2.单库迁移2.1在目标服务器中创建一个和源服务器数据库名称一样的数据库2.2创建任务开始迁移......
  • C基础学习笔记——01-C基础第13天(文件下)
    在学习C基础总结了笔记,并分享出来。01-C基础第13天(文件下)目录:(1)按照块读写文件fread、fwrite1)写文件2)读文件3)强化训练:大文件拷贝(2)文件的随机读写(3)Windows和Linux文本文件区别......
  • C基础学习笔记——01-C基础第07天(字符串处理函数和函数)
    在学习C基础总结了笔记,并分享出来。01-C基础第07天(字符串处理函数和函数)目录:一、字符串处理函数(1)gets()(2)fgets()(3)puts()(4)fputs()(5)strlen()(6)strcpy()(7)strncpy()(8)strcat()(9)str......
  • C基础学习笔记——01-C基础第02天(用户权限、VI操作、Linux服务器搭建)
    在学习C基础总结了笔记,并分享出来。01-C基础第02天(用户权限、VI操作、Linux服务器搭建) 打开终端:ctrl+alt+t清屏:ctrl+l或clear在终端中退出锁定:ctrl+c 目录3常用命令4......
  • 【做题笔记】CF1528B Kavi on Pairing Duty
    ProblemCF1528BKavionPairingDuty题目大意:在数轴上有\(2n\)个点,相邻两个点的距离为\(1\)。现在要将这些点两两匹配成\(n\)个圆弧,要求任意两个圆弧要么等长,要么......
  • 概率期望 DP 学习笔记
    期望这东西学了一次忘了,再学一次过了两天又不会了。我是鱼。故写此博客以便加深记忆及日后复习。经典问题1某事件发生概率为\(p\),则该事件首次发生的期望次数为\(\fr......
  • k8s工作原理(chrono《kubernetes入门实战课》笔记整理)
     【架构理解】k8s可以编排容器,也可以对服务器进行监管。在k8s,不会区分dev(开发人员)和ops(运维人员),而是devops(提倡开发时就要考虑运维,运维也要尽早开始考虑如何对应用进行运......
  • 【网络】安装Nginx笔记
    目录前言安装前先更新下安装依赖库下载NginxNginx编译配置编译&安装&验证nginxNginx服务配置配置SSL参考前言up安装nginx主要是为了在服务器上做反向代理。有兴趣的同学......
  • pexpect常用API笔记
    pexpect常用API笔记spawn()spawn用来执行一个程序,它返回这个程序的操作句柄,以后可以通过操作这个句柄来对这个程序进行操作。参数以及默认值如下:classpexpect.spawn(......