在分享之前先提出三个问题:
1. 为什么KVCache不保存Q
2. KVCache如何减少计算量
3. 为什么模型回答的长度不会影响回答速度?
本文将带着这3个问题来详解KVcache
KVcache是什么
kv cache是指一种用于提升大模型推理性能的技术,通过缓存注意力机制中的键值(Key-Value)对来减少冗余计算,从而提高模型推理的速度。
不懂Self Attention的可以先去看这篇文章:
原因
首先要知道大模型进行推理任务时,是一个token一个token进行输出的。
例:给GPT一个任务 “对这个句子进行扩充:我爱“
GPT的输出为:
我爱
我爱中
我爱中国
我爱中国美
我爱中国美食
我爱中国美食,
我爱中国美食,因
我爱中国美食,因为
我爱中国美食,因为它
我爱中国美食,因为它好
我爱中国美食,因为它好吃
我爱中国美食,因为它好吃。
通过这个例子可以看出它生成句子是按token输出的(为了方便理解,假设一个字为一个token)。输出的token会与输入的tokens 拼接在一起,然后作为下一次推理的输入,这样不断反复直到遇到终止符后结束。自回归任务中,token只能和之前的文本做attention计算。
KVcache图解原理
将这个prompt通过embedding生成QKV三个向量。
“我”只能对自己做attention。得到
“爱”的Q向量对“我”和“爱”的K向量进行计算后再对V进行加权求和算得新向量后输出
输入到模型后得到新的token“中国”
重复上述过程
可以发现 在此过程中,新token只与之前token的KV有关系,和之前的Q没关系,因此可以将之前的KV进行保存,就不用再次计算。这就是KVcache。
问题回答
问题1:为什么不保存Q
因为每次运算只有当前token的Q向量,之前token的Q根本不需要计算,所以缓存Q没意义。
问题2:KVCache如何减少计算量
减少的就是不用重复计算之前token的KV向量,但是每个新词的Attention还得计算。
问题3:每次推理过程的输入tokens都变长了,为什么推理FLOPs不随之增大而是保持恒定呢?
因为使用了KVcache导致第i+1 轮输入数据只比第i轮输入数据新增了一个token,其他全部相同!因此第i+1轮推理时必然包含了第 i 轮的部分计算。
代码实现
这是自己实现的一个简单的多头注意力机制+KVcache
具体如何实现多头注意力机制可以看这篇文章:
import torch
import torch.nn as nn
import math
class MyMultiheadAttentionKV(nn.Module):
def __init__(self, hidden_dim: int = 1024, heads: int = 8, dropout: float = 0.1):
super().__init__()
self.hidden_dim = hidden_dim
self.heads_num = heads
self.dropout = nn.Dropout(dropout)
self.head_dim = hidden_dim // self.heads_num
self.Wq = nn.Linear(self.hidden_dim, self.hidden_dim) # (hidden_dim, heads_num * head_dim)
self.Wk = nn.Linear(self.hidden_dim, self.hidden_dim) # (hidden_dim, heads_num * head_dim)
self.Wv = nn.Linear(self.hidden_dim, self.hidden_dim) # (hidden_dim, heads_num * head_dim)
self.outputlayer = nn.Linear(self.hidden_dim, self.hidden_dim)
def forward(self, x, mask=None, key_cache=None, value_cache=None):
# x = (batch_size, seq_len, hidden_dim)
query = self.Wq(x)
key = self.Wk(x)
value = self.Wv(x)
bs, seq_len, _ = x.size()
# Reshape to (batch_size, heads_num, seq_len, head_dim)
query = query.view(bs, seq_len, self.heads_num, self.head_dim).transpose(1, 2)
key = key.view(bs, seq_len, self.heads_num, self.head_dim).transpose(1, 2)
value = value.view(bs, seq_len, self.heads_num, self.head_dim).transpose(1, 2)
# Cache key and value if provided
if key_cache is not None and value_cache is not None:
key = torch.cat([key_cache, key], dim=2) # Append along sequence dimension
value = torch.cat([value_cache, value], dim=2)
# Update caches
key_cache = key
value_cache = value
# Calculate attention scores
score = query @ key.transpose(-1, -2) / math.sqrt(self.head_dim) # (batch_size, heads_num, seq_len, seq_len)
if mask is not None:
# Mask size should match the updated sequence length after cache concatenation
mask = mask[:, :, :key.size(3), :key.size(3)] # Crop the mask to the new size
score = torch.softmax(score, dim=-1)
score = self.dropout(score)
output = score @ value # (batch_size, heads_num, seq_len, head_dim)
# Reshape back to (batch_size, seq_len, hidden_dim)
output = output.transpose(1, 2).contiguous().view(bs, seq_len, -1)
output = self.outputlayer(output)
return output, key_cache, value_cache # Return output and updated caches
测试代码
def test_kvcache():
torch.manual_seed(42)
batch_size, seq_len, hidden_dim, heads_num = 3000, 100, 128, 8
x = torch.rand(batch_size, seq_len, hidden_dim) # Random input data
attention_mask = torch.randint(0, 2, (batch_size, 1, seq_len, seq_len)) # Attention mask
net = MyMultiheadAttentionKV(hidden_dim, heads_num)
output, key_cache, value_cache = net(x, attention_mask)
new_x = torch.rand(batch_size, seq_len, hidden_dim)
output, key_cache, value_cache = net(new_x, attention_mask, key_cache, value_cache)
third_x = torch.rand(batch_size, seq_len, hidden_dim)
output, key_cache, value_cache = net(third_x, attention_mask, key_cache, value_cache)
print(f"Output shape: {output.shape}")
print(f"Key cache shape: {key_cache.shape}")
print(f"Value cache shape: {value_cache.shape}")
# Run the test
if __name__ == "__main__":
test_kvcache()
使用KVcache后:
其实,KV Cache 配置开启后,推理过程可以分为2个阶段:
- 预填充阶段:发生在计算第一个输出token过程中,这时Cache是空的,计算时需要为每个 transformer layer 计算并保存key cache和value cache,在输出token时Cache完成填充;FLOPs同KV Cache关闭一致,存在大量gemm操作,推理速度慢。
- 使用KV Cache阶段:发生在计算第二个输出token至最后一个token过程中,这时Cache是有值的,每轮推理只需读取Cache,同时将当前轮计算出的新的Key、Value追加写入至Cache;FLOPs降低,gemm变为gemv操作,推理速度相对第一阶段变快,这时属于Memory-bound类型计算。
总结
KV Cache是Transformer推理性能优化的一项重要工程化技术,各大推理框架都已实现并将其进行了封装(例如 transformers库 generate 函数已经将其封装,用户不需要手动传入past_key_values)并默认开启(config.json文件中use_cache=True)。
参考:https://zhuanlan.zhihu.com/p/630832593
标签:dim,self,cache,通俗易懂,KVcache,value,key,hidden,图解 From: https://blog.csdn.net/wlxsp/article/details/143575031