首页 > 编程语言 >算法面试准备 - 手撕系列第六期 - 多头注意力机制(包括Self_atten和Cross_atten)

算法面试准备 - 手撕系列第六期 - 多头注意力机制(包括Self_atten和Cross_atten)

时间:2025-01-17 10:33:04浏览次数:3  
标签:head seq atten Self Cross len 注意力 size

算法面试准备 - 手撕系列第六期 - 多头注意力机制(包括Self_atten和Cross_atten)

目录

多头注意力机制原理

多头注意力机制原理图像

在这里插入图片描述

多头注意力机制原理图

背景介绍

多头注意力机制(Multi-Head Attention)是 Transformer 架构的核心模块之一,用于捕获输入序列中不同位置的复杂依赖关系。通过多个注意力头,它能够从不同的表示子空间中提取信息,从而提高模型的表达能力。


原理解析

多头注意力机制的核心思想是并行计算多个注意力机制(头),然后将它们的输出连接起来,进一步线性变换得到最终结果。

1. 输入与嵌入

如果为自注意力机制则输入为一个qkv统一源的矩阵 (   X q k v ∈ R l e n ( q k v ) × d \ X_{qkv} \in \mathbb{R}^{len(qkv) \times d}  Xqkv​∈Rlen(qkv)×d),交叉注意力机制需要输入两个,kv的源矩阵和q的源矩阵(   X q ∈ R l e n q × d \ X_q \in \mathbb{R}^{lenq \times d}  Xq​∈Rlenq×d,   X k v ∈ R l e n ( k v ) × d \ X_{kv} \in \mathbb{R}^{len(kv) \times d}  Xkv​∈Rlen(kv)×d),其中:

  • 其中len (n) 是输入对应序列的长度(例如,句子中的词数量)。
  • 其中(d) 是输入向量的维度。

输入被映射到查询(Query)、键(Key)和值(Value)矩阵。


2. 多头注意力的计算流程

(1) 线性变换

对于每个注意力头,使用独立的线性变换得到查询、键和值,为生成查询(Query)、键(Key)和值(Value),通过可学习的权重矩阵进行线性变换,如果为自注意力机制则KQV的计算公式为:
Q h = X q k v W Q h , K h = X q k v W K h , V h = X q k v W V h Q_h = X_{qkv}W_Q^h, \quad K_h = X_{qkv}W_K^h, \quad V_h = X_{qkv}W_V^h Qh​=Xqkv​WQh​,Kh​=Xqkv​WKh​,Vh​=Xqkv​WVh​
如果是交叉注意力机制则计算公式为:
Q h = X q W Q h , K h = X k v W K h , V h = X k v W V h Q_h = X_{q}W_Q^h, \quad K_h = X_{kv}W_K^h, \quad V_h = X_{kv}W_V^h Qh​=Xq​WQh​,Kh​=Xkv​WKh​,Vh​=Xkv​WVh​
其中:

  • W Q h , W K h , W V h ∈ R d × d k W_Q^h, W_K^h, W_V^h \in \mathbb{R}^{d \times d_k} WQh​,WKh​,WVh​∈Rd×dk​ 是第 h h h 个头的可学习权重矩阵。
  • d k = d h d_k = \frac{d}{h} dk​=hd​ 是每个头的向量维度, h h h 是头的数量。
(2) 注意力计算

每个头独立计算自注意力(Scaled Dot-Product Attention):
Attention h ( Q h , K h , V h ) = softmax ( Q h K h ⊤ d k ) V h \text{Attention}_h(Q_h, K_h, V_h) = \text{softmax}\left(\frac{Q_h K_h^\top}{\sqrt{d_k}}\right) V_h Attentionh​(Qh​,Kh​,Vh​)=softmax(dk​ ​Qh​Kh⊤​​)Vh​

(3) 拼接与线性变换

将所有头的输出拼接在一起,并通过一个线性层进行变换:
MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , … , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \dots, \text{head}_h) W_O MultiHead(Q,K,V)=Concat(head1​,head2​,…,headh​)WO​
其中:

  • head i = Attention i ( Q i , K i , V i ) \text{head}_i = \text{Attention}_i(Q_i, K_i, V_i) headi​=Attentioni​(Qi​,Ki​,Vi​) 是第 i i i 个头的输出。
  • W O ∈ R d × d W_O \in \mathbb{R}^{d \times d} WO​∈Rd×d 是输出层的权重矩阵。

总结公式

多头注意力机制的完整公式为:
MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , … , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \dots, \text{head}_h) W_O MultiHead(Q,K,V)=Concat(head1​,head2​,…,headh​)WO​
其中:
head i = softmax ( Q i K i ⊤ d k ) V i \text{head}_i = \text{softmax}\left(\frac{Q_i K_i^\top}{\sqrt{d_k}}\right) V_i headi​=softmax(dk​ ​Qi​Ki⊤​​)Vi​


优缺点分析

优点

  • 多样性:多个注意力头可以从不同的子空间学习特征。
  • 捕获长距离依赖:能够建模序列中任意位置之间的关系。
  • 提升表示能力:比单头注意力机制具有更高的表达能力。

缺点

  • 计算开销高:计算多个头的注意力会增加计算量和显存开销。
  • 实现复杂性:需要对多个头进行并行计算和拼接。

多头注意力机制代码

以下是基于 PyTorch 实现多头注意力机制的代码:

第一步,引入相关的库函数

# 该模块实现的是多头注意力机制,和单头不一样的点
# 1. 需要把头提取出来,2. 需要对mask进行expand

'''
# Part1 引入相关的库函数
'''
import torch
from torch import nn
import math

第二步,初始化Multi_atten作为一个类

'''
# Part 2 设计一个多头注意力的类
'''
class Multi_atten(nn.Module):
    def __init__(self,emd_size,q_k_size,v_size,head):
        super(Multi_atten,self).__init__()
        # 输入的x为(batch_size,seq_len,emd_size)
        # 第一步初始化三个全连接矩阵和头的数量
        self.head=head
        # 初始化是head的倍数,便于提取
        self.Wk=nn.Linear(emd_size,q_k_size*head)
        self.Wq=nn.Linear(emd_size,q_k_size*head)
        self.Wv=nn.Linear(emd_size,v_size*head)

        # 初始化Softmax函数
        self.softmax=nn.Softmax(dim=-1)

        # 剩下的等会看看
    def forward(self,x_q,x_k_v,mask):

        # 首先得到kvq
        q=self.Wq(x_q) # (batch_size,q_seq_len,q_size*head)
        k=self.Wk(x_k_v)
        v=self.Wv(x_k_v)

        # 其次是把头分出来得到多头的kvq
        q=q.reshape(q.size()[0],q.size()[1],self.head,-1).transpose(1,2) # (batch_size,head,q_seq_len,q_size)
        k = k.reshape(k.size()[0], k.size()[1], self.head, -1).transpose(1,2)
        v = v.reshape(q.size()[0], v.size()[1], self.head, -1).transpose(1,2)


        # 把k进行转置
        k=k.transpose(2,3) # (batch_size,head,k_seq_len,q_size)

        q_k=self.softmax(torch.matmul(q,k)/math.sqrt(k.size()[2]))


        # 进行mask(batch,seq_q,seq_k)
        if mask is not None:
            mask.unsqueeze(1).expand(-1,self.head,-1,-1)
        q_k.masked_fill(mask,1e-9)

        # 和v相乘
        atten=torch.matmul(q_k,v) # (batch_size,head,k_seq_len,k_v_size)

        # 将其进行返回原来的尺寸
        atten.transpose(1,2) # (batch_size,k_seq_len,head,k_v_size)
        atten=atten.reshape(atten.size()[0],atten.size()[1],-1) # (batch_size, k_seq_len, head*k_v_size)

        return atten

第三步 测试代码 -分为自注意力和交叉注意力两种注意力

if __name__ == '__main__':
    # 类别1 单头的自注意力机制
    # 初始化输入x(batch_size,seq_len,emding)
    batch_size = 1  # 批量也就是句子的数量
    emd_size = 128  # 一个token嵌入的维度
    seq_len = 5  # kqv源的token长度
    q_k_size = emd_size//8  # q和k的嵌入维度
    v_size = emd_size//8  # v的嵌入维度

    x = torch.rand(size=(batch_size, seq_len, emd_size), dtype=torch.float)

    self_atten = Multi_atten(emd_size=emd_size, q_k_size=q_k_size, v_size=v_size,head=8)
    # 初始化mask(batch,len_k,len_q)
    mask = torch.randn(size=(batch_size, seq_len, seq_len))
    mask = mask.bool()

    print('单头的自注意力结果', self_atten(x, x, mask).size())

    # 类别2 单头的交叉注意力机制
    # 初始化输入x(batch_size,seq_len,emding)
    batch_size = 1  # 批量也就是句子的数量
    emd_size = 128  # 一个token嵌入的维度
    q_seq_len = 5  # q源的token长度
    q_k_size = emd_size//8  # q和k的嵌入维度/head
    k_v_seq_len = 7  # k_v源的token长度
    v_size = emd_size//8  # v的嵌入维度/head
    head=8 # 头的数量

    x_q = torch.rand(size=(batch_size, q_seq_len, emd_size), dtype=torch.float)
    x_k_v = torch.rand(size=(batch_size, k_v_seq_len, emd_size), dtype=torch.float)

    cross_atten = Multi_atten(emd_size=emd_size, q_k_size=q_k_size, v_size=v_size,head=head)
    # 初始化mask(batch,len_k,len_q)
    mask = torch.randn(size=(batch_size, q_seq_len, k_v_seq_len))
    mask = mask.bool()

    print('单头的交叉注意力结果', cross_atten(x_q, x_k_v, mask).size())

参考

自己(值得纪念+1,终于自己会从头开始写多头注意力机制喽,哈(*≧▽≦)):小菜鸟博士-CSDN博客手撕Transformer – Day3 – MultiHead Attention-CSDN博客

标签:head,seq,atten,Self,Cross,len,注意力,size
From: https://blog.csdn.net/m0_62030579/article/details/145165026

相关文章

  • About Myself ver.2025
    前言Jabber如果有人有这个闲工夫的话,上一个关于是写于2023.9月的,那时的我刚刚步入大学。一眨眼,现在就已经大二上学期结束了,时间过的真快啊。在过去的一年半里,这博客基本处于半报废的状态大一上开学写过一点点时间,大一下开学写过一点点时间,但都没有坚持下来本人呢在客观上,没......
  • 算法面试准备 - 手撕系列第二期 - 交叉熵损失(Cross Entropy Loss)
    算法面试准备-手撕系列第二期-交叉熵损失(CrossEntropyLoss)目录算法面试准备-手撕系列第二期-交叉熵损失(CrossEntropyLoss)交叉熵原理图交叉熵损失实现代码-不同y_pre版本参考交叉熵原理图Softmax原理图交叉熵损失实现代码-不同y_pre版本......
  • 解决生成图像质量和美学问题!《VMix: Improving Text-to-Image Diffusion Model with C
    为了解决扩散模型在文生图的质量和美学问题,字节跳动&中科大研究团队提出VMix美学条件注入方法,通过将抽象的图像美感拆分成不同维度的美学向量引入扩散模型,从而实现细粒度美学图像生成。论文基于提出的方法训练了一个即插即用的模块,无需再训练即可应用于不同的开源模型,提升模型......
  • FlashAttention的原理及其优势
    在深度学习领域,尤其是自然语言处理(NLP)和计算机视觉(CV)任务中,注意力机制(AttentionMechanism)已经成为许多模型的核心组件。然而,随着模型规模的不断扩大,注意力机制的计算复杂度和内存消耗也急剧增加,成为训练和推理的瓶颈。为了解决这一问题,研究人员提出了FlashAttention,一种高......
  • 深入探索 DeepSeek-V3 的算法创新:Multi-head Latent Attention 的实现与细节
    引言在当今的大规模语言模型(LLM)领域,随着模型参数规模的指数级增长,如何在保证性能的同时优化计算效率和内存使用成为了一个核心挑战。DeepSeek-V3模型以其创新的架构和训练策略脱颖而出,其中Multi-headLatentAttention(MLA)是其关键技术之一。MLA的引入不仅解决了传统......
  • 基于Informer网络实现电力负荷时序预测——cross validation交叉验证与Hyperopt超参数
    前言系列专栏:【深度学习:算法项目实战】✨︎涉及医疗健康、财经金融、商业零售、食品饮料、运动健身、交通运输、环境科学、社交媒体以及文本和图像处理等诸多领域,讨论了各种复杂的深度神经网络思想,如卷积神经网络、循环神经网络、生成对抗网络、门控循环单元、长短期记忆......
  • (即插即用模块-Attention部分) 四十一、(2023) MLCA 混合局部通道注意力
    文章目录1、MixedLocalChannelAttention2、代码实现paper:MixedlocalchannelattentionforobjectdetectionCode:https://github.com/wandahangFY/MLCA1、MixedLocalChannelAttention现有通道注意力机制的局限性:大多数通道注意力机制只关注通道特征信......
  • 【论文阅读】Integrating single-cell multi-omics data through self-supervised clu
    论文地址:Integratingsingle-cellmulti-omicsdatathroughself-supervisedclustering-ScienceDirect代码地址:https://github.com/biomed-AI/scFPN摘要单细胞测序技术的进步使得个体细胞能够同时在多种组学层面进行测序,例如转录组学、表观基因组学和蛋白质组学。整合......
  • YOLOv11改进,YOLOv11添加HAttention注意机制用于图像修复的混合注意力转换器,CVPR2023,超
    摘要基于Transformer的方法在低层视觉任务中表现出色,例如图像超分辨率。然而,作者通过归因分析发现,这些网络只能利用有限的空间范围的输入信息。这意味着现有网络尚未充分发挥Transformer的潜力。为了激活更多的输入像素以获得更好的重建效果,作者提出了一种新型的混合注......
  • YOLOv11改进,YOLOv11自研检测头融合HyCTAS的Self_Attention自注意力机制(2024),并添加小目
    摘要论文提出了一种新的搜索框架,名为HyCTAS,用于在给定任务中自动搜索高效的神经网络架构。HyCTAS框架结合了高分辨率表示和自注意力机制,通过多目标优化搜索,找到了一种在性能和计算效率之间的平衡。#理论介绍自注意力(Self-Attention)机制是HyCTAS框架中的一个重要组......