首页 > 其他分享 >通俗易懂的KVcache图解

通俗易懂的KVcache图解

时间:2024-11-06 18:51:35浏览次数:3  
标签:dim self cache 通俗易懂 KVcache value key hidden 图解

在分享之前先提出三个问题:

1. 为什么KVCache不保存Q

2. KVCache如何减少计算量

3. 为什么模型回答的长度不会影响回答速度?

本文将带着这3个问题来详解KVcache

KVcache是什么

kv cache是指一种用于提升大模型推理性能的技术,通过缓存注意力机制中的键值(Key-Value)对来减少冗余计算,从而提高模型推理的速度。

不懂Self Attention的可以先去看这篇文章:

原因

首先要知道大模型进行推理任务时,是一个token一个token进行输出的。

例:给GPT一个任务 “对这个句子进行扩充:我爱“

GPT的输出为:

我爱

我爱中

我爱中国

我爱中国美

我爱中国美食

我爱中国美食,

我爱中国美食,因

我爱中国美食,因为

我爱中国美食,因为它

我爱中国美食,因为它好

我爱中国美食,因为它好吃

我爱中国美食,因为它好吃。

通过这个例子可以看出它生成句子是按token输出的(为了方便理解,假设一个字为一个token)。输出的token会与输入的tokens 拼接在一起,然后作为下一次推理的输入,这样不断反复直到遇到终止符后结束。自回归任务中,token只能和之前的文本做attention计算。

KVcache图解原理

将这个prompt通过embedding生成QKV三个向量。

“我”只能对自己做attention。得到Z_1

“爱”的Q向量对“我”和“爱”的K向量进行计算后再对V进行加权求和算得新向量后输出Z_{\text{2}}

输入到模型后得到新的token“中国”

重复上述过程

可以发现 在此过程中,新token只与之前token的KV有关系,和之前的Q没关系,因此可以将之前的KV进行保存,就不用再次计算。这就是KVcache。

问题回答

问题1:为什么不保存Q

因为每次运算只有当前token的Q向量,之前token的Q根本不需要计算,所以缓存Q没意义。

问题2:KVCache如何减少计算量

减少的就是不用重复计算之前token的KV向量,但是每个新词的Attention还得计算。

问题3:每次推理过程的输入tokens都变长了,为什么推理FLOPs不随之增大而是保持恒定呢?

因为使用了KVcache导致第i+1 轮输入数据只比第i轮输入数据新增了一个token,其他全部相同!因此第i+1轮推理时必然包含了第 i 轮的部分计算。

代码实现

这是自己实现的一个简单的多头注意力机制+KVcache

具体如何实现多头注意力机制可以看这篇文章:

import torch
import torch.nn as nn
import math

class MyMultiheadAttentionKV(nn.Module):
    def __init__(self, hidden_dim: int = 1024, heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.heads_num = heads
        self.dropout = nn.Dropout(dropout)
        self.head_dim = hidden_dim // self.heads_num
        self.Wq = nn.Linear(self.hidden_dim, self.hidden_dim)  # (hidden_dim, heads_num * head_dim)
        self.Wk = nn.Linear(self.hidden_dim, self.hidden_dim)  # (hidden_dim, heads_num * head_dim)
        self.Wv = nn.Linear(self.hidden_dim, self.hidden_dim)  # (hidden_dim, heads_num * head_dim)
        self.outputlayer = nn.Linear(self.hidden_dim, self.hidden_dim)

    def forward(self, x, mask=None, key_cache=None, value_cache=None):
        # x = (batch_size, seq_len, hidden_dim)
        query = self.Wq(x)
        key = self.Wk(x)
        value = self.Wv(x)
        bs, seq_len, _ = x.size()

        # Reshape to (batch_size, heads_num, seq_len, head_dim)
        query = query.view(bs, seq_len, self.heads_num, self.head_dim).transpose(1, 2)
        
        key = key.view(bs, seq_len, self.heads_num, self.head_dim).transpose(1, 2)
        value = value.view(bs, seq_len, self.heads_num, self.head_dim).transpose(1, 2)
        # Cache key and value if provided
        if key_cache is not None and value_cache is not None:
            key = torch.cat([key_cache, key], dim=2)  # Append along sequence dimension
            value = torch.cat([value_cache, value], dim=2)
        

        # Update caches
        key_cache = key
        value_cache = value

        # Calculate attention scores
        score = query @ key.transpose(-1, -2) / math.sqrt(self.head_dim)  # (batch_size, heads_num, seq_len, seq_len)
        if mask is not None:
            # Mask size should match the updated sequence length after cache concatenation
            mask = mask[:, :, :key.size(3), :key.size(3)]  # Crop the mask to the new size

        
        score = torch.softmax(score, dim=-1)
        score = self.dropout(score)
        output = score @ value  # (batch_size, heads_num, seq_len, head_dim)

        # Reshape back to (batch_size, seq_len, hidden_dim)
        output = output.transpose(1, 2).contiguous().view(bs, seq_len, -1)
        output = self.outputlayer(output)

        return output, key_cache, value_cache  # Return output and updated caches

 测试代码

def test_kvcache():
    
    torch.manual_seed(42)

   
    batch_size, seq_len, hidden_dim, heads_num = 3000, 100, 128, 8
    x = torch.rand(batch_size, seq_len, hidden_dim)  # Random input data
    attention_mask = torch.randint(0, 2, (batch_size, 1, seq_len, seq_len))  # Attention mask
    

    net = MyMultiheadAttentionKV(hidden_dim, heads_num)
    

    output, key_cache, value_cache = net(x, attention_mask)
    

    new_x = torch.rand(batch_size, seq_len, hidden_dim)  
    output, key_cache, value_cache = net(new_x, attention_mask, key_cache, value_cache)


    third_x = torch.rand(batch_size, seq_len, hidden_dim) 
    output, key_cache, value_cache = net(third_x, attention_mask, key_cache, value_cache)
    
    print(f"Output shape: {output.shape}")
    print(f"Key cache shape: {key_cache.shape}")
    print(f"Value cache shape: {value_cache.shape}")
    


# Run the test
if __name__ == "__main__":
    test_kvcache()

使用KVcache后:

其实,KV Cache 配置开启后,推理过程可以分为2个阶段:

  1. 预填充阶段:发生在计算第一个输出token过程中,这时Cache是空的,计算时需要为每个 transformer layer 计算并保存key cache和value cache,在输出token时Cache完成填充;FLOPs同KV Cache关闭一致,存在大量gemm操作,推理速度慢。
  2. 使用KV Cache阶段:发生在计算第二个输出token至最后一个token过程中,这时Cache是有值的,每轮推理只需读取Cache,同时将当前轮计算出的新的Key、Value追加写入至Cache;FLOPs降低,gemm变为gemv操作,推理速度相对第一阶段变快,这时属于Memory-bound类型计算。

总结

KV Cache是Transformer推理性能优化的一项重要工程化技术,各大推理框架都已实现并将其进行了封装(例如 transformers库 generate 函数已经将其封装,用户不需要手动传入past_key_values)并默认开启(config.json文件中use_cache=True)。

参考:https://zhuanlan.zhihu.com/p/630832593

标签:dim,self,cache,通俗易懂,KVcache,value,key,hidden,图解
From: https://blog.csdn.net/wlxsp/article/details/143575031

相关文章

  • 《图解设计模式》 读后笔记
    设计模式很早前看过,那时候囫囵吞枣,从来没有系统的总结过,因为对于面试而言问的问题总是答的不精确。这次又借助《图解设计模式》复习了一遍,自己写了一篇总结。23种设计模式看起来很多其实大多数在开发中都见到过。甚至有的设计模式对于一个初学者而言即使不知道设计模式也会应......
  • 科普文:软件架构Linux系列之【图解存储 IO性能优化与瓶颈分析】
    概叙科普文:软件架构Linux系列之【Linux的文件预读readahead】-CSDN博客科普文:软件架构Linux系列之【并发问题的根源:CPU缓存模型详解】-CSDN博客从上面冯诺依曼结构下的cpu、内存、外存之间的延迟就可以看出,磁盘I/O性能的发展远远滞后于CPU和内存,因而成为现代计算机系统的......
  • 《图解设计模式》 第九部分 避免浪费
    第二十章Flyweight模式publicclassBigcharFactory{//这里对使用到的内容进行了缓存privateHashMappool=newHashMap();//有则直接取,无则创建并保存到缓存。publicsynchronizedBigChargetBigChar(charcharname){BigCharbc=(BigChar)pool.get("......
  • 《图解设计模式》 第八部分 管理状态
    第17章Observer模式publicabstractclassNumberGenerator{privateArrayListobserverList=newArrayList();/*部分代码省略*///加入基础类,当需要通知的时候通知publicvoidaddObserver(Observerobserver){observerList.add(observer);}pub......
  • 《图解设计模式》 第七部分 简单化
    Facade模式publicclassMain{publicstaticvoidmain(String[]args){PageMaker.makeWelcomePage("[email protected]","welcom.html");}}publicclassPageMaker{publicstaticvoidmakeWelcompage(Stringmailaddr,Stringfile......
  • 《图解设计模式》 第五部分 访问数据结构
    第十三章Visotor模式publicclassfileextendsentry{/*省略*/puhblicvoidaccept(Visitorv){v.visit(this);}}publicclassMain{publicstaticvoidmain(Stringargs){Directoryrootdir=newDirctory("root");/*省略*/ro......
  • 《图解设计模式》 第五部分 一致性
    第11章Composite模式文中举例文件夹系统,简单说明:这里可以讲File和dirctory看作一种东西Entry。在Entry的基础上分化两者,构成结构。能够使容器与内容具有一致性,创造出递归结构。第12章Decorator模式publicclassMain{publicstaticvoidmain(String[]ar......
  • 《图解设计模式》 第三部分 生成实例
    第五章Singleton模式单例模式应该是开发中用的比较多的模式。这里我只记一个知识点。多线程下安全的单例模式的一个知识点publicclassSingleton{publicstaticInstanceClassinstance=null;publicstaticSingletongetInstance(){if(instance==null){......
  • MySQL8.0安装配置教程【超级详细图解】
    万分感谢.参考文章内容:https://blog.csdn.net/m0_73442728/article/details/131359479万分感谢.参考文章内容:https://blog.csdn.net/qq_40187702/article/details/130618805目录**一、MySQL下载与安装二、MySQL安装三、MySQL连接测试四、配置环境变量一、MySQL下载与安......
  • 【linux命令】史上最全Linux命令,结合用例通俗易懂
    前言:目前关于Linux命令的文章往往存在内容不全的问题,导致初学者和中级用户在使用过程中遇到困难。许多文章仅涵盖基础命令,而缺乏对系统管理、网络配置、包管理和脚本编写等重要主题的详细讲解。此外,实际操作中的常见问题及其解决方案也常常未被提及,使得用户在遇到困难时无法......