如何理解attention中的Q,K,V? |
文章目录
- 一. 如何理解attention中的Q,K,V?
- 1.1. 定义三个线性变换矩阵
- 1.2. 定义QKV
- 1.3. 自注意力计算
- 1.3.1. Q和K矩阵乘
- 1.3.2. 除以根号dim
- 1.3.3. 注意力权重和V矩阵乘
- 1.4. 为什么叫自注意力网络
- 1.5. 为什么注意力机制是没有位置信息
- 二. 参考文章
- 可以先看下之前的文章:『NLP学习笔记』Transformer技术详细介绍
一. 如何理解attention中的Q,K,V?
1.1. 定义三个线性变换矩阵
- 1. 首先定义三个线性变换矩阵,query, key, value:
class BertSelfAttention(nn.Module):
self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
- 注意,这里的 query, key, value 只是一种操作(线性变换)的名称,实际的
1.2. 定义QKV
- 假设三种操作的输入都是同一个矩阵(暂且先别管为什么输入是同一个矩阵),这里暂且定为长度为L的句子,每个token的特征维度是768,那么输入就是(L, 768),每一行就是一个字,像这样:
- 乘以上面三种操作就得到了 ,,维度其实没变,即此刻的
- 代码为:
class BertSelfAttention(nn.Module):
def __init__(self, config):
self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
def forward(self,hidden_states): # hidden_states 维度是(L, 768)
Q = self.query(hidden_states)
K = self.key(hidden_states)
V = self.value(hidden_states)
1.3. 自注意力计算
- 然后来实现这个操作:
1.3.1. Q和K矩阵乘
- 首先是 和 ,(L, 768)*(L, 768)的转置=(L,L),看图:
- 首先用 的第一行,即 “我”字的768特征和K中“我”字的768为特征点乘求和,得到输出(0,0)位置的数值,这个数值就代表了“我想吃酸菜鱼”中“我”字对“我”字的注意力权重,然后 显而易见输出的第一行就是“我”字对“我想吃酸菜鱼”里面每个字的注意力权重;整个结果自然就是 “我想吃酸菜鱼”里面每个字对其它字(包括自己)的注意力权重(就是一个数值)了
1.3.2. 除以根号dim
- 除以根号dim,这个dim就是768,至于为什么要除以这个数值?主要是为了缩小点积范围,确保softmax梯度稳定性,具体推导可以看这里:Self-attention中dot-product操作为什么要被缩放,然后就是为什么要softmax?,一种解释是为了保证注意力权重的非负性,同时增加非线性,还有一些工作对去掉softmax进行了实验,如 苏剑林大佬:线性Attention的探索:Attention必须有个Softmax吗?
1.3.3. 注意力权重和V矩阵乘
- 然后就是刚才的注意力权重和
- 注意力权重 x VALUE矩阵 = 最终结果,首先是“我”这个字对“我想吃酸菜鱼”这句话里面每个字的注意力权重,和V中“我想吃酸菜鱼”里面每个字的第一维特征进行 相乘再求和,这个过程其实就 相当于用每个字的权重对每个字的特征进行加权求和,然后再用“我”这个字对对“我想吃酸菜鱼”这句话里面每个字的注意力权重和V中“我想吃酸菜鱼”里面每个字的第二维特征进行相乘再求和,依次类推,最终也就得到了(L,768)的结果矩阵,和输入保持一致~
class BertSelfAttention(nn.Module):
def __init__(self, config):
self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
def forward(self,hidden_states): # hidden_states 维度是(L, 768)
Q = self.query(hidden_states)
K = self.key(hidden_states)
V = self.value(hidden_states)
attention_scores = torch.matmul(Q, K.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_probs = nn.Softmax(dim=-1)(attention_scores)
out = torch.matmul(attention_probs, V)
return out
- 这里对上面的一些值进行假定,给出结果
import math
import torch
from torch import nn
class BertSelfAttention(nn.Module):
def __init__(self, hidden_size=768, all_head_size=768):
super().__init__()
self.query = nn.Linear(hidden_size, all_head_size) # 输入768, 输出768
self.key = nn.Linear(hidden_size, all_head_size) # 输入768, 输出768
self.value = nn.Linear(hidden_size, all_head_size) # 输入768, 输出768
def forward(self, inputs, attention_head_size=768): # inputs 维度是(L, 768)
Q = self.query(inputs)
K = self.key(inputs)
V = self.value(inputs)
attention_scores = torch.matmul(Q, K.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(attention_head_size)
attention_probs = nn.Softmax(dim=-1)(attention_scores)
out = torch.matmul(attention_probs, V)
return out
if __name__ == '__main__':
tensor = torch.normal(0, 1, (25, 768)) # 随意模拟的一个
attention = BertSelfAttention()
out = attention(tensor)
print(out)
tensor([[ 0.0341, 0.1600, 0.0292, ..., 0.0963, -0.0547, 0.0571],
[ 0.0587, 0.1236, 0.0760, ..., 0.0394, -0.0674, 0.1228],
[ 0.0631, 0.2530, 0.0133, ..., 0.0899, -0.0734, 0.1542],
...,
[ 0.0467, 0.1886, -0.0014, ..., 0.0197, -0.0556, 0.1075],
[ 0.0739, 0.1167, 0.0180, ..., 0.0425, -0.0303, 0.1381],
[ 0.0867, 0.2769, -0.0908, ..., 0.0613, -0.1291, 0.1641]],
grad_fn=<MmBackward0>)
Process finished with exit code 0
1.4. 为什么叫自注意力网络
- 因为可以看到 都是通过同一句话的输入算出来的,按照上面的流程也就是一句话内每个字对其它字(包括自己)的权重分配;那如果不是自注意力呢?简单来说,来自于句A,来自于句B即可~
1.5. 为什么注意力机制是没有位置信息
- 注意, 中,如果同时替换任意两个字的位置,对最终的结果是不会有影响的,至于为什么,可以自己在草稿纸上画一画矩阵乘;也就是说注意力机制是没有位置信息的,不像CNN/RNN/LSTM;这也是为什么要引入位置embeding的原因。
- 从上图可以明显看出每个token的输出和其所在的顺序是没有关系的。
二. 参考文章
- 主要参考知乎大神的文章,如何理解attention中的Q,K,V:https://www.zhihu.com/question/298810062/answer/2274132657