首页 > 其他分享 >Transformer的优化

Transformer的优化

时间:2024-04-06 22:33:19浏览次数:24  
标签:MQA head seq Attention Transformer len Query 优化

本文总结 Transformer 和 attention 机制相关的 trick。留下学习痕迹。

Multi Query Attention (MQA)

早在 2019 年就被提出,但最近才被重视。

相比 Multi Head Attention,MQA 让多头注意力层的各个 head 共享同一份 Key 和 Value 参数(Query 不参与共享,各 head 独立)。如此,以不太多的精度代价,可减少参数量,提升推理速度。

目前 PyTorch 没有 MQA 的原生实现,下文的 GQA 也是。若不使用这些 trick,使用原生的 F.scaled_dot_product_attention 能够享受速度加成。

这个仓库 使用纯 Python 的方式实现了 MQA 和 GQA。据作者实验,当 \(n\) 较小时 GQA 仍然比 PyTorch 原生 MHA 更快。

这个视频 讲述了LLaMA2 模型中用到的技巧,包括 MQA 和 KV Cache。可以参考。

Grouped Query Attention (GQA)

MQA 的做法也许太极端——所有 head 只共享一份 Key 和 Value 势必会影响模型效果。

GQA 作为折中方案,实现方法非常直接。将 head 分为 \(n\) 组,各组各自共享 Key 和 Value。

很明显,当 \(n=1\) 时就是 MQA,当 \(n\) 等于 head 数时就是 Multi Head Attention。

Attention with Linear Bias(ALiBi)

为了让 transformer 模型知道输入序列的顺序关系,需要对输入序列添加 Position Embedding。

而 ALiBi 去掉了 Position Embedding 步骤,转而在计算 Query × Key 值时添加一个偏置常量来注入位置信息。实验证明,这样的修改可以让模型训练长度小于推理长度时也能获得良好推理效果。

具体实现的话,部分代码如下(参考)。

首先要获得 bias。这个 bias 要加到 Query × Key 结果中,充当位置信息。

def get_relative_positions(seq_len: int) -> torch.tensor:
    x = torch.arange(seq_len)[None, :]
    y = torch.arange(seq_len)[:, None]
    return x - y

这个函数会返回类似以下的矩阵:

tensor([[ 0,  1,  2,  3,  4],
        [-1,  0,  1,  2,  3],
        [-2, -1,  0,  1,  2],
        [-3, -2, -1,  0,  1],
        [-4, -3, -2, -1,  0]])

针对不同的 head,有着不同的 bias 权重大小。

def get_alibi_slope(num_heads):
    x = (2 ** 8) ** (1 / num_heads)
    return (
        torch.tensor([1 / x ** (i + 1) for i in range(num_heads)])
        .unsqueeze(-1)
        .unsqueeze(-1)
    )

如此,就能获得 [head, seq_len, seq_len] 大小的 bias 了:

alibi = (
    get_relative_positions(seq_len) * get_relative_positions(seq_len)
).unsqueeze(0)

为了让运算更快,可以用上 PyTorch 原生的 F.scaled_dot_product_attention参考)。

context_layer = F.scaled_dot_product_attention(
         query_layer, key_layer, value_layer, attn_mask=alibi, dropout_p=0.0
     )

参考来源

标签:MQA,head,seq,Attention,Transformer,len,Query,优化
From: https://www.cnblogs.com/chirp/p/18118055

相关文章

  • [蓝桥杯 2022 国 B] 齿轮(优化枚举)
        根据题目描述,如果采用dfs暴力做法枚举所有方案,肯定会超时,因此我们需要优化枚举,我们都知道在同一组共同转动的齿轮中,线速度相等,因此角速度的比值就是半径的反比,因此我们只需要找到对于每个齿轮作为起始齿轮,只需要找到其倍数半径是否存在即可,而倍数上限就是假设存在......
  • MySQL中的sql优化
    一、SQL优化原则1、减少数据量(表中数据太多可以分表,例如超过500万数据 双11一个小时一张订单表)2、减少数据访问量(将全表扫描可以调整为基于索引去查询)3、减少数据计算操作(将数据库中的计算拿到程序内存中计算)二、SQL优化的基本逻辑1、良好的SQL编码习惯(熟悉SQL编码规范......
  • 注意力机制 transformer
    https://jalammar.github.io/illustrated-transformer/X就是输入的向量,第一步就是创建三个输入向量qkv第二步是计算分数:分数决定了对输入句子的其他部分的关注程度。分数是通过查询向量与我们要评分的各个单词的键向量的点积来计算的。因此,如果我们处理位置#1中单词的自注意......
  • FJSP:蜣螂优化算法( Dung beetle optimizer, DBO)求解柔性作业车间调度问题(FJSP),提供MAT
    一、柔性作业车间调度问题柔性作业车间调度问题(FlexibleJobShopSchedulingProblem,FJSP),是一种经典的组合优化问题。在FJSP问题中,有多个作业需要在多个机器上进行加工,每个作业由一系列工序组成,每个工序需要在特定的机器上完成。同时,每个机器一次只能处理一个工序,且每个工......
  • FJSP:霸王龙优化算法(Tyrannosaurus optimization,TROA)求解柔性作业车间调度问题(FJSP),提供
    一、柔性作业车间调度问题柔性作业车间调度问题(FlexibleJobShopSchedulingProblem,FJSP),是一种经典的组合优化问题。在FJSP问题中,有多个作业需要在多个机器上进行加工,每个作业由一系列工序组成,每个工序需要在特定的机器上完成。同时,每个机器一次只能处理一个工序,且每个工......
  • 【MATLAB源码-第171期】基于matlab的布谷鸟优化算法(COA)无人机三维路径规划,输出做短路
    操作环境:MATLAB2022a1、算法描述布谷鸟优化算法(CuckooOptimizationAlgorithm,COA)是一种启发式搜索算法,其设计灵感源自于布谷鸟的独特生活习性,尤其是它们的寄生繁殖行为。该算法通过模拟布谷鸟在自然界中的行为特点,为解决各种复杂的优化问题提供了一种新颖的方法。从算法......
  • TVM Pass优化 -- 公共子表达式消除(Common Subexpr Elimination, CSE)
    定义(What)公共子表达式消除就是如果表达式E的值已经计算的到了,并且自计算的到值后E的值就不再改变了,就说,表达式E在后续计算中是一个公共表达式。简单说,该表达式上面已经执行过了,下面没必要再执行了举个例子:importtvmfromtvmimportrelayfromtvm.relayimporttransform......
  • 百度小程序发帖软件怎么优化SEO排名关键词
    百度小程序发帖软件怎么优化SEO排名关键词手把手教你,如何做SEO关键词优化?!#官网建站#seo#百度排名2024创业好项目:做H5响应式建站代理!排名展示推荐阅读:百度小程序发帖外推软件是什么百度不收录我们的网站怎么办呢?网站SEO优化一定要做好,今天咱们说说网站SEO优化的三......
  • cloud.heytap.com 欢太云 优化存储空间 原图
    经检查屏幕截图是不会优化/压缩的本地的屏幕截图跟云上的是同样大小的(相机拍摄的会压缩)        ......
  • ipad + mac mini 自动随航(优化方案)
    ipad+macmini自动随航(优化方案)背景在之前的文章中,介绍了搭建vscode服务器,通过ipad进行访问的方式,但实际操作下来,发现还是没有办法达到很好的效果。问题包括:网页连接vscode之后,在不同应用中切换经常会导致重连,也很容易超时重连,用来当玩具可以,但如果要用作实际工作,则......