算法面试准备 - 手撕系列第六期 - 多头注意力机制(包括Self_atten和Cross_atten)
目录
多头注意力机制原理
多头注意力机制原理图像
背景介绍
多头注意力机制(Multi-Head Attention)是 Transformer 架构的核心模块之一,用于捕获输入序列中不同位置的复杂依赖关系。通过多个注意力头,它能够从不同的表示子空间中提取信息,从而提高模型的表达能力。
原理解析
多头注意力机制的核心思想是并行计算多个注意力机制(头),然后将它们的输出连接起来,进一步线性变换得到最终结果。
1. 输入与嵌入
如果为自注意力机制则输入为一个qkv统一源的矩阵 ( X q k v ∈ R l e n ( q k v ) × d \ X_{qkv} \in \mathbb{R}^{len(qkv) \times d} Xqkv∈Rlen(qkv)×d),交叉注意力机制需要输入两个,kv的源矩阵和q的源矩阵( X q ∈ R l e n q × d \ X_q \in \mathbb{R}^{lenq \times d} Xq∈Rlenq×d, X k v ∈ R l e n ( k v ) × d \ X_{kv} \in \mathbb{R}^{len(kv) \times d} Xkv∈Rlen(kv)×d),其中:
- 其中len (n) 是输入对应序列的长度(例如,句子中的词数量)。
- 其中(d) 是输入向量的维度。
输入被映射到查询(Query)、键(Key)和值(Value)矩阵。
2. 多头注意力的计算流程
(1) 线性变换
对于每个注意力头,使用独立的线性变换得到查询、键和值,为生成查询(Query)、键(Key)和值(Value),通过可学习的权重矩阵进行线性变换,如果为自注意力机制则KQV的计算公式为:
Q
h
=
X
q
k
v
W
Q
h
,
K
h
=
X
q
k
v
W
K
h
,
V
h
=
X
q
k
v
W
V
h
Q_h = X_{qkv}W_Q^h, \quad K_h = X_{qkv}W_K^h, \quad V_h = X_{qkv}W_V^h
Qh=XqkvWQh,Kh=XqkvWKh,Vh=XqkvWVh
如果是交叉注意力机制则计算公式为:
Q
h
=
X
q
W
Q
h
,
K
h
=
X
k
v
W
K
h
,
V
h
=
X
k
v
W
V
h
Q_h = X_{q}W_Q^h, \quad K_h = X_{kv}W_K^h, \quad V_h = X_{kv}W_V^h
Qh=XqWQh,Kh=XkvWKh,Vh=XkvWVh
其中:
- W Q h , W K h , W V h ∈ R d × d k W_Q^h, W_K^h, W_V^h \in \mathbb{R}^{d \times d_k} WQh,WKh,WVh∈Rd×dk 是第 h h h 个头的可学习权重矩阵。
- d k = d h d_k = \frac{d}{h} dk=hd 是每个头的向量维度, h h h 是头的数量。
(2) 注意力计算
每个头独立计算自注意力(Scaled Dot-Product Attention):
Attention
h
(
Q
h
,
K
h
,
V
h
)
=
softmax
(
Q
h
K
h
⊤
d
k
)
V
h
\text{Attention}_h(Q_h, K_h, V_h) = \text{softmax}\left(\frac{Q_h K_h^\top}{\sqrt{d_k}}\right) V_h
Attentionh(Qh,Kh,Vh)=softmax(dk
QhKh⊤)Vh
(3) 拼接与线性变换
将所有头的输出拼接在一起,并通过一个线性层进行变换:
MultiHead
(
Q
,
K
,
V
)
=
Concat
(
head
1
,
head
2
,
…
,
head
h
)
W
O
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \dots, \text{head}_h) W_O
MultiHead(Q,K,V)=Concat(head1,head2,…,headh)WO
其中:
- head i = Attention i ( Q i , K i , V i ) \text{head}_i = \text{Attention}_i(Q_i, K_i, V_i) headi=Attentioni(Qi,Ki,Vi) 是第 i i i 个头的输出。
- W O ∈ R d × d W_O \in \mathbb{R}^{d \times d} WO∈Rd×d 是输出层的权重矩阵。
总结公式
多头注意力机制的完整公式为:
MultiHead
(
Q
,
K
,
V
)
=
Concat
(
head
1
,
head
2
,
…
,
head
h
)
W
O
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \dots, \text{head}_h) W_O
MultiHead(Q,K,V)=Concat(head1,head2,…,headh)WO
其中:
head
i
=
softmax
(
Q
i
K
i
⊤
d
k
)
V
i
\text{head}_i = \text{softmax}\left(\frac{Q_i K_i^\top}{\sqrt{d_k}}\right) V_i
headi=softmax(dk
QiKi⊤)Vi
优缺点分析
优点
- 多样性:多个注意力头可以从不同的子空间学习特征。
- 捕获长距离依赖:能够建模序列中任意位置之间的关系。
- 提升表示能力:比单头注意力机制具有更高的表达能力。
缺点
- 计算开销高:计算多个头的注意力会增加计算量和显存开销。
- 实现复杂性:需要对多个头进行并行计算和拼接。
多头注意力机制代码
以下是基于 PyTorch 实现多头注意力机制的代码:
第一步,引入相关的库函数
# 该模块实现的是多头注意力机制,和单头不一样的点
# 1. 需要把头提取出来,2. 需要对mask进行expand
'''
# Part1 引入相关的库函数
'''
import torch
from torch import nn
import math
第二步,初始化Multi_atten作为一个类
'''
# Part 2 设计一个多头注意力的类
'''
class Multi_atten(nn.Module):
def __init__(self,emd_size,q_k_size,v_size,head):
super(Multi_atten,self).__init__()
# 输入的x为(batch_size,seq_len,emd_size)
# 第一步初始化三个全连接矩阵和头的数量
self.head=head
# 初始化是head的倍数,便于提取
self.Wk=nn.Linear(emd_size,q_k_size*head)
self.Wq=nn.Linear(emd_size,q_k_size*head)
self.Wv=nn.Linear(emd_size,v_size*head)
# 初始化Softmax函数
self.softmax=nn.Softmax(dim=-1)
# 剩下的等会看看
def forward(self,x_q,x_k_v,mask):
# 首先得到kvq
q=self.Wq(x_q) # (batch_size,q_seq_len,q_size*head)
k=self.Wk(x_k_v)
v=self.Wv(x_k_v)
# 其次是把头分出来得到多头的kvq
q=q.reshape(q.size()[0],q.size()[1],self.head,-1).transpose(1,2) # (batch_size,head,q_seq_len,q_size)
k = k.reshape(k.size()[0], k.size()[1], self.head, -1).transpose(1,2)
v = v.reshape(q.size()[0], v.size()[1], self.head, -1).transpose(1,2)
# 把k进行转置
k=k.transpose(2,3) # (batch_size,head,k_seq_len,q_size)
q_k=self.softmax(torch.matmul(q,k)/math.sqrt(k.size()[2]))
# 进行mask(batch,seq_q,seq_k)
if mask is not None:
mask.unsqueeze(1).expand(-1,self.head,-1,-1)
q_k.masked_fill(mask,1e-9)
# 和v相乘
atten=torch.matmul(q_k,v) # (batch_size,head,k_seq_len,k_v_size)
# 将其进行返回原来的尺寸
atten.transpose(1,2) # (batch_size,k_seq_len,head,k_v_size)
atten=atten.reshape(atten.size()[0],atten.size()[1],-1) # (batch_size, k_seq_len, head*k_v_size)
return atten
第三步 测试代码 -分为自注意力和交叉注意力两种注意力
if __name__ == '__main__':
# 类别1 单头的自注意力机制
# 初始化输入x(batch_size,seq_len,emding)
batch_size = 1 # 批量也就是句子的数量
emd_size = 128 # 一个token嵌入的维度
seq_len = 5 # kqv源的token长度
q_k_size = emd_size//8 # q和k的嵌入维度
v_size = emd_size//8 # v的嵌入维度
x = torch.rand(size=(batch_size, seq_len, emd_size), dtype=torch.float)
self_atten = Multi_atten(emd_size=emd_size, q_k_size=q_k_size, v_size=v_size,head=8)
# 初始化mask(batch,len_k,len_q)
mask = torch.randn(size=(batch_size, seq_len, seq_len))
mask = mask.bool()
print('单头的自注意力结果', self_atten(x, x, mask).size())
# 类别2 单头的交叉注意力机制
# 初始化输入x(batch_size,seq_len,emding)
batch_size = 1 # 批量也就是句子的数量
emd_size = 128 # 一个token嵌入的维度
q_seq_len = 5 # q源的token长度
q_k_size = emd_size//8 # q和k的嵌入维度/head
k_v_seq_len = 7 # k_v源的token长度
v_size = emd_size//8 # v的嵌入维度/head
head=8 # 头的数量
x_q = torch.rand(size=(batch_size, q_seq_len, emd_size), dtype=torch.float)
x_k_v = torch.rand(size=(batch_size, k_v_seq_len, emd_size), dtype=torch.float)
cross_atten = Multi_atten(emd_size=emd_size, q_k_size=q_k_size, v_size=v_size,head=head)
# 初始化mask(batch,len_k,len_q)
mask = torch.randn(size=(batch_size, q_seq_len, k_v_seq_len))
mask = mask.bool()
print('单头的交叉注意力结果', cross_atten(x_q, x_k_v, mask).size())
参考
自己(值得纪念+1,终于自己会从头开始写多头注意力机制喽,哈(*≧▽≦)):小菜鸟博士-CSDN博客,手撕Transformer – Day3 – MultiHead Attention-CSDN博客
标签:head,seq,atten,Self,Cross,len,注意力,size From: https://blog.csdn.net/m0_62030579/article/details/145165026