首页 > 其他分享 >LLM 加速技巧:Muti Query Attention

LLM 加速技巧:Muti Query Attention

时间:2024-03-08 16:25:34浏览次数:45  
标签:MQA Attention key self Muti LLM tf model 512

前言 MQA 是 19 年提出的一种新的 Attention 机制,其能够在保证模型效果的同时加快 decoder 生成 token 的速度。在大语言模型时代被广泛使用,很多LLM都采用了MQA,如Falcon、PaLM、StarCoder等。

本文转载自Deephub Imba

作者:Florian June

仅用于学术分享,若侵权请联系删除

欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。

CV各大方向专栏与各个部署框架最全教程整理

【CV技术指南】CV全栈指导班、基础入门班、论文指导班 全面上线!!

在介绍MQA 之前,我们先回顾一下传统的多头注意力

Multi-Head Attention(MHA)

多头注意力是transformer 模型的默认注意力机制,如下图所示:

 

 

在文本生成方面,基于transformer 的自回归语言模型存在一个问题。在训练过程中可以获得真实的目标序列,并且可以有效地实现并行化。

但是在推理过程中,每个位置的查询都要处理在该位置或之前生成的所有键值对。也就是说自注意力层在特定位置的输出影响下一个令牌的生成,所以无法并行化,这使得推理变得非常的慢。

下图是基于transformer 解码器的自回归语言模型中自注意层的解码过程:

 def MHAForDecoder(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
     q = tf.einsum("bd, hdk−>bhk", x, P_q)
     new_K = tf.concat([prev_K, tf.expand_dims(tf.einsum ("bd, hdk−>bhk", x, P_k), axis = 2)], axis = 2)
     new_V = tf.concat([prev_V, tf.expand_dims(tf.einsum("bd, hdv−>bhv", x, P_v), axis = 2)], axis = 2)
     logits = tf.einsum("bhk, bhmk−>bhm", q, new_K)
     weights = tf.softmax(logits)
     O = tf.einsum("bhm, bhmv−>bhv", weights, new_V)
     Y = tf.einsum("bhv, hdv−>bd", O, P_o)
     return Y, new_K, new_V

其中:

X:当前的输入张量,m为当前步,m+1为阶跃,形状为[b, d]

P_q, P_k:查询和键投影张量,形状为[h, d, k]

P_v:值投影张量,形状为[h, d, v]

P_o:学习到的线性投影,形状为[h, d, v]

Prev_K:上一步的关键张量,形状为[b, h, m, k]

Prev_V:前一步的Value张量,形状为[b, h, m, v]

new_K:加上当前步的键张量,形状为[b, h, m+1, k]

new_V:加了当前步长的Value张量,形状为[b, h, m+1, v]

维度表示如下:

M:先前执行的步骤数

B:批量大小

D:输入和输出的尺寸

H:注意力头数

k:Q,K张量的另一个维度

v: v张量的另一个维度

Multi-Query Attention(MQA)

MQA是多头注意的一种变体。

MQA的方法是保持Q的初始头数,但K和V只有一个头,这意味着所有Q个头共享相同的K和V,因此称为Multi-Query,如下图所示:

 

 

从论文的解释中可以看到,MQA 让所有的头之间 共享 同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量。

MQA解码过程的代码本质上与MHA的代码相同,只是从中删除了表示头部尺寸的字母“h”。K, V, P_k, P_v的和方程:

 def MQAForDecoder(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
     q = tf.einsum("bd, hdk−>bhk", x, P_q)
     new_K = tf.concat([prev_K, tf.expand_dims(tf.einsum ("bd, dk−>bk", x, P_k), axis = 2)], axis = 2)
     new_V = tf.concat([prev_V, tf.expand_dims(tf.einsum("bd, dv−>bv", x, P_v), axis = 2)], axis = 2)
     logits = tf.einsum("bhk, bmk−>bhm", q, new_K)
     weights = tf.softmax(logits)
     O = tf.einsum("bhm, bmv−>bhv", weights, new_V)
     Y = tf.einsum("bhv, hdv−>bd", O, P_o)
     return Y, new_K, new_V

上面都是tf的代码,如果阅读有问题,我从 llm-foundry项目中找到了pytorch的代码实现,这里只做个摘抄,有兴趣的请看原项目

 class MultiheadAttention(nn.Module):
 
     def __init__(
             self,
             d_model: int,
             n_heads: int,
             device: str
        ):
         """
        Multi Head init func.
 
        Args:
            d_model (int): hidden state size, e.g. 768
            n_heads (int): 设定的注意力头数, e.g. 8
            device (str): _description_
        """
         super().__init__()
 
         self.d_model = d_model
         self.n_heads = n_heads
     
         self.Wqkv = nn.Linear(                       # Multi-Head Attention 的创建方法
             self.d_model,
             3 * self.d_model,                        # 有 query, key, value 3 个矩阵, 所以是 3 * d_model
             device=device
        )                                            # (d_model, 3 * d_model)
         self.attn_fn = scaled_multihead_dot_product_attention
         self.out_proj = nn.Linear(
             self.d_model,
             self.d_model,
             device=device
        )
 
     def forward(
         self,
         x
    ):
         """
        forward func.
 
        Args:
            x (tensor): (batch, hidden_state, d_model) e.g. -> (1, 768, 512)
 
        Returns:
            _type_: _description_
        """
         qkv = self.Wqkv(x)                            # (1, 768, 3 * 768)
 
         query, key, value = qkv.chunk(                # 每个 tensor 都是 (1, 512, 768)
             3,
             dim=2
        )    
 
         context, attn_weights, past_key_value = self.attn_fn(
             query,
             key,
             value,
             self.n_heads
        )                                             # (1, 512, 768)
 
         return self.out_proj(context), attn_weights, past_key_value
 
 
 class MultiQueryAttention(nn.Module):
     """Multi-Query self attention.
 
    Using torch or triton attention implemetation enables user to also use
    additive bias.
    """
 
     def __init__(
         self,
         d_model: int,
         n_heads: int,
         device: Optional[str] = None,
    ):
         super().__init__()
 
         self.d_model = d_model
         self.n_heads = n_heads
         self.head_dim = d_model // n_heads
 
         self.Wqkv = nn.Linear(                           # Multi-Query Attention 的创建方法
             d_model,
             d_model + 2 * self.head_dim,                 # 只创建 query 的 head 向量,所以只有 1 个 d_model
             device=device,                               # 而 key 和 value 则只共享各自的一个 head_dim 的向量
        )
 
         self.attn_fn = scaled_multihead_dot_product_attention
         self.out_proj = nn.Linear(
             self.d_model,
             self.d_model,
             device=device
        )
         self.out_proj._is_residual = True  # type: ignore
 
     def forward(
         self,
         x,
    ):
         qkv = self.Wqkv(x)                                           # (1, 512, 960)
 
         query, key, value = qkv.split(                               # query -> (1, 512, 768)
            [self.d_model, self.head_dim, self.head_dim],            # key   -> (1, 512, 96)
             dim=2                                                    # value -> (1, 512, 96)
        )
 
         context, attn_weights, past_key_value = self.attn_fn(
             query,
             key,
             value,
             self.n_heads,
             multiquery=True,
        )
 
         return self.out_proj(context), attn_weights, past_key_value

从代码中可以看到所有 头之间共享一份 key 和 value 的参数,但是如何将这 1 份参数同时让 8 个头都使用呢?

代码里使用矩阵乘法 matmul 来广播,使得每个头都乘以这同一个 tensor,以此来实现参数共享,主要是这个函数:scaled_multihead_dot_product_attention

 def scaled_multihead_dot_product_attention(
         query,
         key,
         value,
         n_heads,
         past_key_value=None,
         softmax_scale=None,
         attn_bias=None,
         key_padding_mask=None,
         is_causal=False,
         dropout_p=0.0,
         training=False,
         needs_weights=False,
         multiquery=False,
    ):
     q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)         # (1, 512, 768) -> (1, 8, 512, 96)
     kv_n_heads = 1 if multiquery else n_heads
     k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)        # (1, 512, 768) -> (1, 8, 96, 512) if not multiquery
                                                                     # (1, 512, 96) -> (1, 1, 96, 512) if multiquery
     v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)      # (1, 512, 768) -> (1, 8, 512, 96) if not multiquery
                                                                     # (1, 512, 96) -> (1, 1, 512, 96) if multiquery
     
     attn_weight = q.matmul(k) * softmax_scale                       # (1, 8, 512, 512)
     attn_weight = torch.softmax(attn_weight, dim=-1)                # (1, 8, 512, 512)
 
     out = attn_weight.matmul(v)                                     # (1, 8, 512, 512) * (1, 1, 512, 96) = (1, 8, 512, 96)
     out = rearrange(out, 'b h s d -> b s (h d)')                    # (1, 512, 768)
 
     return out, attn_weight, past_key_value

MQA指标测试

MQA能在多大程度上提高速度?让我们看看原文中提供的结果图表:

 

 

从上表可以看出,MQA在编码器上的速度提升不是很显著,但在解码器上的速度提升是相当显著的。

论文中也有关于质量的实验,结果表明MQA的性能与基线相比只是稍微低一些。降低应该是肯定的因为毕竟共享了参数,但是只要再可接受范围内并且能够大量提升速度这个降低就是可以接受的,对吧。

为什么MQA可以实现推理加速?

在MQA中,键张量和值张量的大小分别为b * k和b * v,而在MHA中,键张量和值张量的大小分别为b * h * k和b * h * v,其中h表示头的个数。

MQA通过以下方法实现推理加速:

1、KV缓存大小减少了h(头数量),这意味着需要存储在GPU内存中的张量也减少了。节省的空间可以用来增加批大小,从而提高效率。

2、减少了从内存中读取的数据量,从而减少了计算单元的等待时间,提高了计算利用率。

3、MQA有一个相对较小的KV数量,可以放入缓存(SRAM)中。MHA则需要较大的KV数量,不能完全存储在缓存中,需要从GPU内存(DRAM)读取,这很耗时。

总结

MQA是在2019年提出的,当时的应用还没有那么广泛。这是因为以前的模型不需要关心这些方面,例如,LSTM只需要维护一个状态,而不需要保留任何缓存。

当transformer最初被提出时,它主要用于Seq2Seq任务,特别是在Encoder-Decoder模型中。由于模型的规模不是很大,也并且没有太多的实际需求,所以MQA并没有引起太多的关注。

直到近年来(尤其是2023年开始)基于transformer的大型语言模型(如GPT)得到广泛应用后,推理的瓶颈才被人们重视。所以MQA才被发现非常有用,这主要是由于对大规模gpt式生成模型的实际需求。

最后我们再回顾以下这个论文:

https://arxiv.org/abs/1911.02150

最后如果你对LLM的完整构建流程感兴趣,可以看看这个项目,文章的pytorch代码就是从这里找到的

https://github.com/mosaicml/llm-foundry

 

欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。

计算机视觉入门1v3辅导班

【技术文档】《从零搭建pytorch模型教程》122页PDF下载

QQ交流群:470899183。群内有大佬负责解答大家的日常学习、科研、代码问题。

其它文章

分享一个CV知识库,上千篇文章、专栏,CV所有资料都在这了

明年毕业,还不知道怎么做毕设的请抓紧机会了

LSKA注意力 | 重新思考和设计大卷积核注意力,性能优于ConvNeXt、SWin、RepLKNet以及VAN

CVPR 2023 | TinyMIM:微软亚洲研究院用知识蒸馏改进小型ViT

ICCV2023|涨点神器!目标检测蒸馏学习新方法,浙大、海康威视等提出

ICCV 2023 Oral | 突破性图像融合与分割研究:全时多模态基准与多交互特征学习

听我说,Transformer它就是个支持向量机

HDRUNet | 深圳先进院董超团队提出带降噪与反量化功能的单帧HDR重建算法

南科大提出ORCTrack | 解决DeepSORT等跟踪方法的遮挡问题,即插即用真的很香

1800亿参数,世界顶级开源大模型Falcon官宣!碾压LLaMA 2,性能直逼GPT-4

SAM-Med2D:打破自然图像与医学图像的领域鸿沟,医疗版 SAM 开源了!

GhostSR|针对图像超分的特征冗余,华为诺亚&北大联合提出GhostSR

Meta推出像素级动作追踪模型,简易版在线可玩 | GitHub 1.4K星

CSUNet | 完美缝合Transformer和CNN,性能达到UNet家族的巅峰!

AI最全资料汇总 | 基础入门、技术前沿、工业应用、部署框架、实战教程学习

计算机视觉入门1v3辅导班

计算机视觉交流群

标签:MQA,Attention,key,self,Muti,LLM,tf,model,512
From: https://www.cnblogs.com/wxkang/p/18061244

相关文章

  • StarCoder 2:GitHub Copilot本地开源LLM替代方案
    GitHubCoPilot拥有超过130万付费用户,部署在5万多个组织中,是世界上部署最广泛的人工智能开发工具。使用LLM进行编程辅助工作不仅提高了生产力,而且正在永久性地改变数字原住民开发软件的方式,我也是它的付费用户之一。低代码/无代码平台将使应用程序创建、工作流自动化和数据分析更......
  • PokéLLMon 源码解析(二)
    .\PokeLLMon\poke_env\environment\double_battle.py#从logging模块中导入Logger类fromloggingimportLogger#从typing模块中导入Any,Dict,List,Optional,Union类型fromtypingimportAny,Dict,List,Optional,Union#从poke_env.environment.abstract......
  • PokéLLMon 源码解析(三)
    .\PokeLLMon\poke_env\environment\move_category.py#导入需要的模块fromenumimportEnum,auto,unique#定义一个枚举类MoveCategory,表示一个移动类别@uniqueclassMoveCategory(Enum):"""Enumeration,representamovecategory."""#定......
  • PokéLLMon 源码解析(四)
    .\PokeLLMon\poke_env\exceptions.py"""Thismodulecontainsexceptions."""#定义一个自定义异常类ShowdownException,继承自内置异常类ExceptionclassShowdownException(Exception):"""Thisexceptionisraisedwhena......
  • PokéLLMon 源码解析(五)
    .\PokeLLMon\poke_env\player\player.py"""Thismoduledefinesabaseclassforplayers."""importasyncioimportrandomfromabcimportABC,abstractmethodfromasyncioimportCondition,Event,Queue,Semaphorefromlogg......
  • PokéLLMon 源码解析(六)
    .\PokeLLMon\poke_env\teambuilder\teambuilder_pokemon.py"""ThismoduledefinestheTeambuilderPokemonclass,whichisusedasanintermediateformattospecifypokemonbuildsinteambuilderscustomclasses."""#导入必要的模块......
  • PokéLLMon 源码解析(一)
    .\PokeLLMon\poke_env\concurrency.py#导入必要的模块importasyncioimportatexitimportsysfromloggingimportCRITICAL,disablefromthreadingimportThreadfromtypingimportAny,List#在新线程中运行事件循环def__run_loop(loop:asyncio.AbstractEventLo......
  • LLM 加速技巧:Muti Query Attention
    MQA是19年提出的一种新的Attention机制,其能够在保证模型效果的同时加快decoder生成token的速度。在大语言模型时代被广泛使用,很多LLM都采用了MQA,如Falcon、PaLM、StarCoder等。在介绍MQA之前,我们先回顾一下传统的多头注意力Multi-HeadAttention(MHA)多头注意力是tra......
  • 论文阅读1--llm对于因果方向的影响
    论文1:CausalReasoningandLargeLanguageModels:  OpeningaNewFrontierforCausality文章根据LLM在因果发现、反事实推理和实际因果关系方面的评估结果表明,它们可以取得很好的效果,这是通过大量的文本训练模仿该领域知识模式的结果。虽然它也会产生错误,但是错误的次数极少......
  • RNN又行了!DeepMind新发布的Griffin可以与同级别的LLM性能相当
    Hawk和Griffin是DeepMind推出的新型循环神经网络(RNNs),2月刚刚发布在arxiv上。Hawk通过使用门控线性递归(gatedlinearrecurrences)超越了Mamba的性能,而Griffin则是一种混合型模型,结合了门控线性递归和局部注意力(localattention),与Llama-2的性能相当,但使用的训练数据明显较少。Griffi......