算法面试准备 - 手撕系列第五期 - 单头注意力机制(包括Self_atten和Cross_atten)
目录
单头注意力机制原理
原理图像
背景介绍
单头注意力机制(Single-Head Attention)是深度学习领域中广泛使用的一种注意力机制,用于从输入数据中捕获相关性和依赖性。它是多头注意力机制的一个简化版本,只有一个注意力头,但仍然能够有效地计算输入序列中不同位置之间的相关性。
原理解析
单头注意力机制的核心流程包括以下几个步骤:
1. 输入与嵌入
如果为自注意力机制则输入为一个qkv统一源的矩阵 ( X q k v ∈ R l e n ( q k v ) × d \ X_{qkv} \in \mathbb{R}^{len(qkv) \times d} Xqkv∈Rlen(qkv)×d),交叉注意力机制需要输入两个,kv的源矩阵和q的源矩阵( X q ∈ R l e n q × d \ X_q \in \mathbb{R}^{lenq \times d} Xq∈Rlenq×d, X k v ∈ R l e n ( k v ) × d \ X_{kv} \in \mathbb{R}^{len(kv) \times d} Xkv∈Rlen(kv)×d),其中:
- 其中len (n) 是输入对应序列的长度(例如,句子中的词数量)。
- 其中d 是输入向量的维度。
通过嵌入层或直接提供的特征,得到输入矩阵。
2. 线性变换
为生成查询(Query)、键(Key)和值(Value),通过可学习的权重矩阵进行线性变换,如果为自注意力机制则KQV的计算公式为:
Q
=
X
q
k
v
W
Q
,
K
=
X
q
k
v
W
K
,
V
=
X
q
k
v
W
V
Q = X_{qkv}W_Q, \quad K = X_{qkv}W_K, \quad V = X_{qkv}W_V
Q=XqkvWQ,K=XqkvWK,V=XqkvWV
如果是交叉注意力机制则计算公式为:
Q
=
X
q
W
Q
,
K
=
X
k
v
W
K
,
V
=
X
k
v
W
V
Q = X_{q}W_Q, \quad K = X_{kv}W_K, \quad V = X_{kv}W_V
Q=XqWQ,K=XkvWK,V=XkvWV
其中:
- ( W Q , W K , W V ∈ R d × d k W_Q, W_K, W_V \in \mathbb{R}^{d \times d_k} WQ,WK,WV∈Rd×dk) 是可学习的权重矩阵。
- ( d k d_k dk) 是注意力机制中查询和键的向量维度。
3. 注意力分数计算
计算查询与键之间的点积来衡量它们的相关性,并进行缩放以防止梯度过大:
Attention Scores
=
Q
K
⊤
d
k
\text{Attention Scores} = \frac{QK^\top}{\sqrt{d_k}}
Attention Scores=dk
QK⊤
4. 软max归一化
对注意力分数进行归一化处理,使其表示概率分布:
Attention Weights
=
softmax
(
Q
K
⊤
d
k
)
\text{Attention Weights} = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)
Attention Weights=softmax(dk
QK⊤)
5. 加权求和
将注意力权重与值向量相乘,得到最终的注意力输出:
Output
=
Attention Weights
⋅
V
\text{Output} = \text{Attention Weights} \cdot V
Output=Attention Weights⋅V
总结公式
单头注意力机制的完整公式为:
Attention
(
Q
,
K
,
V
)
=
softmax
(
Q
K
⊤
d
k
)
V
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V
Attention(Q,K,V)=softmax(dk
QK⊤)V
优缺点分析
优点
- 简单高效:计算开销较低,适用于资源有限的环境。
- 捕获相关性:能够学习输入序列中不同位置的依赖关系。
缺点
- 表达能力受限:与多头注意力机制相比,单头注意力机制无法同时学习多个不同的子空间表示。
- 缺乏多样性:对于复杂任务,可能无法充分挖掘特征。
单头注意力机制代码
以下是基于 PyTorch 实现单头注意力机制的代码:
第一步,引入相关的库函数
# 该模块主要实现单头的注意力机制,输入为x,形成qkv,然后得到注意力z
'''
# Part1 引入相关的库函数
'''
import torch
from torch import nn
import math
第二步,初始化Onehead_Atten作为一个类
'''
# Part2 定义单头注意力类
'''
class Onehead_Atten(nn.Module):
def __init__(self, emd_size, q_k_size, v_size):
super(Onehead_Atten, self).__init__()
# Part1 初始化矩阵Wk,Wv,Wq
# 注意x为(batch_size,q_seq_len,emd_size),且要实现(Q*KT)所以,Q(q_len,q_k_size),K为(k_len,q_k_size)
self.Wk = nn.Linear(emd_size, q_k_size)
self.Wq = nn.Linear(emd_size, q_k_size)
self.Wv = nn.Linear(emd_size, v_size)
# Part2 得到矩阵Q(batch_size,q_len,q_k_size),K(batch_size,k_v_len,q_k_size),V(batch_size,k_v_len,q_k_size)
# softmax((Q*KT/sqrt(dk)))*V
self.softmax = nn.Softmax(dim=-1)
def forward(self, x_q, x_k_v, mask=None):
q = self.Wq(x_q) # (batch_size,q_len,q_k_size)
k = self.Wk(x_k_v) # (batch_size,k_v_len,q_k_size)
v = self.Wv(x_k_v) # (batch_size,k_v_len,v_size)
# 为了便于相乘对K转置
k = k.transpose(1, 2)
# 第一步把(Q*Kt)/根号dk
q_k = self.softmax(torch.matmul(q, k) / math.sqrt(q.size()[-1]))
# 判断是够要mask(1,seq_len_q,seq_len_k)
if mask is not None:
q_k = q_k.masked_fill(mask, 1e-9)
# 第二步和v相乘
atten_z = torch.matmul(q_k, v)
return atten_z
第三步 测试代码 -分为自注意力和交叉注意力两种注意力
if __name__ == '__main__':
# 类别1 单头的自注意力机制
# 初始化输入x(batch_size,seq_len,emding)
batch_size = 1 # 批量也就是句子的数量
emd_size = 128 # 一个token嵌入的维度
seq_len = 5 # kqv源的token长度
q_k_size = 128 # q和k的嵌入维度
v_size = 128 # v的嵌入维度
x = torch.rand(size=(batch_size, seq_len, emd_size), dtype=torch.float)
self_atten = Onehead_Atten(emd_size=emd_size, q_k_size=q_k_size, v_size=v_size)
# 初始化mask(batch,len_k,len_q)
mask = torch.randn(size=(batch_size, seq_len, seq_len))
mask = mask.bool()
print('单头的自注意力结果',self_atten(x, x, mask).size())
# 类别2 单头的交叉注意力机制
# 初始化输入x(batch_size,seq_len,emding)
batch_size=1 # 批量也就是句子的数量
emd_size=128 # 一个token嵌入的维度
q_seq_len=5 # q源的token长度
q_k_size=128 # q和k的嵌入维度/head
k_v_seq_len=7 # k_v源的token长度
v_size=128 # v的嵌入维度/head
x_q = torch.rand(size=(batch_size, q_seq_len, emd_size), dtype=torch.float)
x_k_v = torch.rand(size=(batch_size, k_v_seq_len, emd_size), dtype=torch.float)
cross_atten = Onehead_Atten(emd_size=emd_size, q_k_size=q_k_size, v_size=v_size)
# 初始化mask(batch,len_k,len_q)
mask = torch.randn(size=(batch_size, q_seq_len, k_v_seq_len))
mask = mask.bool()
print('单头的交叉注意力结果',cross_atten(x_q, x_k_v, mask).size())
参考
自己(值得纪念,终于自己会从头开始写注意力机制喽,哈(*≧▽≦)):小菜鸟博士-CSDN博客
标签:seq,atten,Self,batch,len,单头,注意力,size From: https://blog.csdn.net/m0_62030579/article/details/145164741