Transformer 模型现在很火,内存优化又很重要。上周面试了一个 985 大学的女生,跟她谈到了 Transformer 模型的内存优化问题。
那么这个女生到底给出了哪些关于 Transformer 模型内存优化的独特思路呢?一起来看看。
01
什么是Transformer模型中的KV缓存?
Transformer 中文本是逐个 token 生成的,每次新的预测都会基于之前生成的所有 tokens 的上下文信息。
这种对顺序数据的依赖会减慢生成过程,因为每次预测下一个 token 都需要重新处理序列中所有之前的 tokens。
例如,要预测第 100 个 token,模型必须使用前 99 个 token 的信息,需要对这些 token 进行复杂的矩阵运算。预测第 101 个 token 时,也要对前 99 个 token 做类似计算,以及对第 100 个 token 的新计算。
如何简化呢?
答案是使用 KV 缓存。KV 缓存通过保存这些计算结果,使模型可以在生成后续 tokens 时直接访问这些结果,而不需要重新计算。
换句话说,在生成第 101 个 token 时,模型只需从 KV 缓存中检索前 99 个 token 的已存储数据,并只对第 100 个 token 执行必要的计算。
02
如何估算KV缓存消耗的内存大小?
KV 缓存通常使用 float16 或 bfloat16 数据类型以 16 位的精度存储张量。对于一个 token,KV 缓存会为每一层和每个注意力头存储一对张量(键和值)。
这些张量的大小由注意力头的维度决定,这对张量的总内存消耗(以字节为单位)可以通过以下公式计算:
层数 × KV 注意力头的数量 × 注意力头的维度 × (位宽 / 8) × 2
最后的 “2” 是因为有两组张量,也就是键和值。位宽通常为 16 位,由于 8 位是 1 字节,因此我们将位宽除以 8,这样在 KV 缓存中每 16 位参数占用 2 个字节。
我们以 Llama 3 8B 为例,这个公式就变为:
32 × 8 × 128 × 2 × 2 = 131,072
**注意:**Llama 3 8B 有 32 个注意力头,不过由于 GQA 的存在,只有 8 个注意力头用于键和值。
从上面可以看到,对于一个 token,KV 缓存占用 131,072 字节,差不多 0.1 MB。这看起来好像不大,但对于许多不同类型的应用,大模型需要生成成千上万的 tokens。
举个例子,如果我们想利用 Llama 3 8B 的全部 context 大小(8192),KV 缓存将为 8191 个 token 存储键值张量,差不多占用 1.1 G 内存。换句话说,对于一块 24G 显存的消费级 GPU,KV 缓存将占用其总内存的 4.5%。
而对于更大的模型,KV 缓存增长得更快。比如对于 Llama 3 70B,它有 80 层,公式变为:
80 × 8 × 128 × 2 × 2 = 327,680
对于 8191 个 token,Llama 3 70B 的 KV 缓存将占用 2.7 GB。并且注意,这只是单个序列的内存消耗,如果我们进行批量解码,还需要将这个值乘以 batch size。
比如 batch size=32 的 Llama 3 8B 模型,将需要 35.2 GB 的 GPU 显存,一块消费级 GPU 显然搞不定了。
因此虽然在推理阶段用 KV 缓存可以提高处理速度,并且已经是业界标准做法,但是 KV 缓存在深层模型和长序列场景下,也会占据大量 GPU 内存。
而实际开发中,我们可以通过 KV 缓存量化,来降低推理阶段的 LLM 内存需求。后面我们将通过实际的例子(Llama 3 8B 模型),来看看如何对 KV 缓存进行量化的
如何学习AI大模型 ?
“最先掌握AI的人,将会比较晚掌握AI的人有竞争优势”。
这句话,放在计算机、互联网、移动互联网的开局时期,都是一样的道理。
我在一线互联网企业工作十余年里,指导过不少同行后辈。帮助很多人得到了学习和成长。
我意识到有很多经验和知识值得分享给大家,故此将并将重要的AI大模型资料包括AI大模型入门学习思维导图、精品AI大模型学习书籍手册、视频教程、实战学习等录播视频免费分享出来。【保证100%免费】
标签:Transformer,缓存,AI,模型,985,学习,token,KV,内存 From: https://blog.csdn.net/2401_86435672/article/details/142651844