首页 > 其他分享 >Speculative Streaming:无需辅助模型的快速大模型推理

Speculative Streaming:无需辅助模型的快速大模型推理

时间:2024-08-10 11:53:34浏览次数:10  
标签:草稿 投机 模型 Streaming 推测 Speculative

人工智能咨询培训老师叶梓 转载标明出处

在自然语言处理领域,大模型(LLM)在进行推理时,由于其自回归生成的特性,往往需要较高的计算成本和内存占用。为了解决这一问题,苹果公司的研究者们提出了一种名为Speculative Streaming的新方法。这种方法通过改变目标模型的微调目标,从下一个词预测转变为未来n-gram预测,从而将草稿生成过程融合到目标模型中,无需使用辅助草稿模型。这一创新不仅简化了推理系统,还提高了推理速度,同时保持了生成质量。

投机解码(Speculative Decoding)和投机流(Speculative Streaming)的对比
(a) 投机解码需要一个单独的草稿模型来自动地推测候选结果,然后由目标模型进行验证。(b) 投机流通过在单个模型中同时执行推测和验证,显著简化了系统。

方法

在大模型(LLM)的推理过程中,由于其自回归解码的顺序特性,每个词的生成都需要完整的网络前向传播,这限制了推理速度。为了解决这个问题,研究者提出了Speculative Decoding技术,通过预测多个候选未来词元然后并行验证来加速解码。然而,现有的Speculative Decoding方法需要训练和对齐两个模型:一个小型的草稿模型用于候选推测,一个大模型用于验证,这不仅增加了系统的复杂性,还增加了参数和计算成本。

于是研究者提出了Speculative Streaming,这是一种单模型推测解码方法。他们改变了目标模型的训练目标,从传统的下一个词预测转变为预测未来的n-gram。这样的改变使得模型能够考虑到未来的词元,而不仅仅是紧接的下一个词元,从而提高了预测的质量和效率。

投机流的架构设计

图2展示了Speculative Streaming方法与标准推测解码(Speculative Decoding)之间的关键差异和操作流程。在图2(a)中,标准推测解码的流程被清晰地描绘出来:一个小型的草稿模型首先自回归地生成一系列候选词元,然后这些候选词元被传递给大型的目标语言模型(Target LLM)进行逐一验证。这个过程涉及到两个不同的模型,并且是串行进行的,意味着目标模型必须等待草稿模型完成序列生成后才能开始验证工作,这限制了整体的解码速度。

在Speculative Streaming方法中,研究者们提出了一种创新的微调技术,用于优化解码器仅预训练的语言模型,这一技术的核心在于通过低秩适配器来提高预测效率。这些适配器被训练来预测下一个目标词元,这是基于当前的上下文信息(x1...xm)和之前已经生成的目标词元序列(y1...yt)。这种微调方式不仅提升了模型对局部上下文的敏感度,而且通过修改目标模型的训练目标,将其从单一的下一个词元预测转变为n-gram预测,进一步增强了模型对文本生成流程的控制能力。

在图2(b)中,可以看到这一机制的实现。图中展示了Speculative Streaming如何在单一的目标模型内部,通过多流注意力(MSA)机制,实现对未来词元的规划。MSA允许模型在最后几层生成多个推测流,每个推测流负责预测一个未来词元,而主流则继续预测当前词元。这样的设计使得模型能够在每个时间步生成多个候选的未来词元序列,而不是单一序列,从而大大提高了生成的多样性和灵活性。

图中还展示了如何通过初始化推测流来减少计算量。推测流不是从嵌入层开始,而是从模型的中间层(N-Ns层)开始,这样可以减少模型需要处理的层数,从而降低计算复杂度。每个推测流都通过一个线性变换和一个流标识符嵌入来初始化,这些嵌入帮助模型捕捉相对位置信息,并区分推测流与主流的计算。

图2(b)中还展示了Speculative Streaming的另一个重要组件——树形草稿(Tree Drafting)。在这个过程中,推测流生成的候选词元不是简单的线性序列,而是以树形结构组织,每个分支代表了一种可能的词元序列。这种结构允许模型在每个前向传播中验证多个候选路径,并且接受最长的匹配序列,从而提高了候选词元的接受率。

为了有效管理大量的候选词元并减少计算负担,Speculative Streaming采用了并行树草稿修剪技术。这一技术通过基于父子词元之间的转移概率来剪除那些概率较低的路径,从而保留了更有可能被接受的候选词元。修剪操作不仅减少了模型的计算量,而且通过去除不太可能的候选词元,提高了整体的解码质量。

图2强调了Speculative Streaming在训练过程中的优势。由于不需要维护和对齐两个模型,Speculative Streaming的训练过程更为简单和高效。模型通过端到端的方式进行微调,自然地对齐了推测和验证阶段,确保了生成的质量和速度的最优平衡。

在多流注意力框架下,主流自我注意力机制保持不变,而额外引入了γ个自我注意力流来推测未来的词元。每个推测流在时间步t关注先前的主流隐藏状态以及推测流自身的隐藏状态,以此来生成推测词元。主流和推测流的隐藏状态分别通过MHA层和MSA层进行更新,其中MHA层用于主流自我注意力,而MSA层结合了主流和推测流的注意力。

为了提高参数效率和减少延迟开销,推测流的初始化采用了一种特别的设计。它们不是从嵌入层开始,而是从靠近顶层的N-Ns层开始,这有助于减少每前向传播的计算量。此外,推测流被训练为从主流的键/值上下文中学习上下文特征,从而避免了额外的缓存开销,使得模型能够在资源受限的设备上有效运行。

在并行推测和验证方面,Speculative Streaming通过并行化这两个过程来提高效率。在每个前向传播中,上一步生成的草稿会被验证,同时生成新的草稿。与传统的串行方法不同,Speculative Streaming允许在每个步骤中同时进行推测和验证。为了有效地生成和验证候选词元,模型采用了树形结构,其中每个路径都代表一个可能的验证候选项。通过从主流和推测流中采样词元形成树形结构,模型能够在每个前向传播中考虑多个候选路径,并通过创建一个可加的注意力掩码来提高验证的效率。

在生成推测树草稿的过程中,研究者面临一个挑战:如何高效地处理大量的候选词元。为了解决这个问题,他们引入了并行树草稿修剪技术。修剪的目的是减少需要考虑的候选词元数量,从而降低计算复杂度,避免模型在每个前向传播中处理过多的候选项,这可能会导致计算资源的浪费。

修剪操作基于父子词元之间的转移概率。这是一种衡量一个词元后面跟随另一个词元的可能性的指标。高转移概率的路径更有可能被模型接受。

为了在不使用额外代理模型的情况下估计转移概率,研究者采用了早期退出技术。这种技术通常用于模型内部,在训练过程中提前终止某些分支,以估计它们的概率。修剪层会根据转移概率对树草稿中的路径进行评估,并剪除那些概率较低的路径。这样可以减少后续步骤中需要处理的候选数量,提高效率。

训练目标是Speculative Streaming方法中的另一个关键方面,它确保了模型在生成未来n-gram时的效率和准确性。模型被训练为同时最小化下一个词的预测损失和未来γ个词的预测损失。这种联合训练方法使得模型能够同时学习生成高质量的当前词和未来词。

通过对基础模型进行微调,研究者能够使模型适应特定的下游任务,同时保持模型在推测未来词元时的准确性。尽管Speculative Streaming方法加速了解码过程,但它并没有牺牲生成的质量。通过精心设计的训练目标和微调过程,模型在加速的同时保持了生成结果的质量。

通过这些方法,研究者成功地将推测和验证统一到了一个模型中,消除了对单独草稿模型的需求,同时实现了与现有技术相当的或更好的速度提升和质量保证。Speculative Streaming方法在参数效率和部署简化方面表现出色,特别适合资源受限的设备。

实验评估

研究者们选择了结构化查询(Structured Queries)、文本摘要(Text Summarization)和语义表示(Meaning Representation)这三种对设备上的AI助手至关重要的应用场景进行测试。他们使用了Dialogsum、WikiSQL和SPIDER构建的sql-create-context以及e2e-nlg数据集。在模型配置方面,研究者们测试了不同规模的四个开源模型:Phi(1.3B)、Openllama(7B)和两种规模的OPT(1.3B, 6.7B)。

研究者们将Speculative Streaming与标准的草稿-目标推测解码(draft-target speculative decoding)和Medusa单模型推测解码框架进行了比较。对于标准的草稿-目标方法,他们使用了最小的开源OPT模型(OPT-125m)作为草稿模型。对于Medusa方法,研究者们使用预训练的基础模型和LoRA适配器作为基线,并使用具有相同基础模型、流嵌入和LoRA适配器的Speculative Streaming作为目标。

在Nvidia A100-80G GPU上,使用批量大小为1,以float16精度进行贪婪采样,并且温度参数T设为0,研究者们报告了墙钟时间加速和生成质量指标。他们使用了Exact Match(EM)准确度来评估结构化查询任务,并使用了Rouge1/RougeLSum指标来评估对话摘要和语义表示任务。表1展示了Speculative Streaming在加速比、调用减少比和额外参数数量方面的比较。结果显示,Speculative Streaming在各种下游任务中的墙钟时间加速和调用减少比与Medusa相当或更好,同时显著减少了参数开销。

标题

研究者们深入探讨了Medusa和投机流两种方法在没有辅助模型的情况下的性能差异。Medusa方法中的每个头都是独立工作的,它们从最后一层共享的隐藏状态生成每个token。这种方法可能无法很好地捕捉由Medusa头预测的投机token与基础模型在特定时间步预测的下一个token之间的依赖关系,因为Medusa缺乏注意力机制。与此同时,投机流通过让投机流自身关注主流和彼此,有效地捕获了token之间的依赖性,从而实现了比Medusa更高的调用减少比率。

在参数方面,每个Medusa头增加了相当数量的参数,这个数量与隐藏层的大小和词汇表的大小有关。而且,随着投机窗口长度的增加,Medusa头的数量也会线性增加,导致参数开销也随之线性增加。相比之下,投机流使用的投机适配器的参数数量不会随着投机窗口长度的增加而增加。尽管流标识符嵌入的参数会随着窗口长度的增加而线性增加,但在微调设置中,这些投机适配器的参数与基础模型适配器共享,因此我们的参数开销仅仅是窗口长度乘以隐藏层大小。

在有辅助模型的情况下,投机流在延迟方面持续优于标准草稿-目标投机解码方法。Table 2展示了使用OPT-125m作为草稿模型时,不同模型在各种任务上的墙钟延迟和性能指标的比较。尽管投机流对目标模型的调用次数高于基于草稿模型的投机解码,但它避免了自回归草稿生成的开销,因此在OPT-1.3b和OPT-6.7b模型上实现了更低的延迟。

Figure 3进一步说明了投机流如何通过增加内存受限自回归解码步骤的算术强度来加速解码。与Medusa风格的方法和基于草稿模型的投机解码方法相比,投机流显著提高了内核和内存的利用率。

研究者们还分析了在不同目标/草稿延迟比率下,投机流相对于基于草稿的投机解码的理论加速比。这个分析帮助我们理解在何种条件下,投机流相比于标准草稿-目标方法更有优势。Figure 4展示了在不同的延迟比率下,投机流可能提供的加速比。当延迟比率增加时,如果草稿模型足够准确并且足够小,以实现比投机流更多的每次目标模型验证步骤的token推进,那么草稿-目标方法可能提供更多的加速优势。然而,找到或创建这样的模型通常需要大量的工程努力。在下游应用设置中,由于不同应用的验证步骤差异,找到理想的草稿模型变得更加具有挑战性。如果多个应用共享草稿模型并且只训练适配器,草稿模型可能无法保持足够小的规模以满足目标-草稿延迟比率,这使得实现比投机流更多的加速变得更加困难。

Figure 5展示了在创建树草案时,从每个流采样更多候选令牌对墙钟速度提升的影响。随着采样的候选令牌数量的增加,由于候选数量的增加,速度提升会增加。但是,当模型进入计算受限阶段时,这种趋势会逆转。从树草案中剪枝不太可能的路径有助于减少计算,从而减少每次前向传播的延迟并提供更多的速度提升。

消融研究中,研究者们专注于两个关键的实验设置:投机草案的大小和多流注意力(MSA)层的数量。这些研究帮助我们理解这些参数是如何影响投机流方法的性能的。

为了提高树形草案的接受率,研究者们尝试了不同的γ值设置,即投机位置的数量,以及每个投机位置采样的候选令牌数量k。Figure 5展示了当γ固定为3时,随着k的增加,每个前向传播步骤中验证的候选令牌数量β也随之增加,这导致速度提升。然而,当k增加到一定程度时,模型将从内存受限转变为计算受限,这会增加每次前向传播的延迟,从而使得速度提升的效果降低。这是因为简单地形成树草案会导致批次大小随着k指数增长。为了解决这个问题,研究者们引入了一个树剪枝层来去除概率较低的路径,从而减少树草案的大小。剪枝树草案可以减少前向传播的延迟,并且一个校准良好的阈值可以确保只有树中的噪声路径被剪枝。当k继续增加时,树剪枝有助于减少每次前向传播的延迟,从而提供更多的速度提升,如图5所示。

研究者们还探讨了MSA层数的权衡。在决定要合并多少层MSA时,需要考虑下游生成指标、训练时间和浮点运算次数(FLOPs)的增加。随着MSA层数的增加,生成指标得到了改善,并且这一趋势在不同的下游任务中保持一致。通常,在模型的顶层中加入2到8层MSA可以在指标、FLOPs增加和训练时间之间提供良好的权衡。Figure 6展示了在结构化查询和摘要任务中,OPT-1.3b模型的生成性能随着MSA层数的增加而提高。

Speculative Streaming为大型语言模型的快速推理提供了一种有效的解决方案,它通过简化的微调过程和无需辅助模型的设计,实现了与现有技术相当的或更好的速度和质量,是资源受限设备的理想选择。

论文链接:https://arxiv.org/abs/2402.11131

标签:草稿,投机,模型,Streaming,推测,Speculative
From: https://blog.csdn.net/weixin_44292902/article/details/140930220

相关文章

  • 推理延迟:解决PyTorch模型Inference阶段的RuntimeError ⏳⚡
    推理延迟:解决PyTorch模型Inference阶段的RuntimeError⏳⚡推理延迟:解决PyTorch模型Inference阶段的RuntimeError⏳⚡摘要引言正文内容什么是RuntimeError?⏳RuntimeError的常见成因⚠️数据格式不一致内存不足模型参数不匹配解决RuntimeError的方法......
  • 被大模型折腾不行了,奉劝不要轻易入行!!!
    科技的进步,生产力就很容易提升,进而就是不需要过多的人。最近在尝试借助一些工具,提升做事的效率,初步试验感觉很不错。网络上所有的东西,确实都可以利用新平台重做一遍。现在火的东西越来越让人看不懂,一首挖呀挖火遍全网,看完后感触是什么?内容越简单越直白,其实更容易火,越是高深......
  • 下载量10w+!大型语言模型:语言理解和生成
    近年来,人工智能在新语言能力方面取得了显著进展,深度学习技术的快速发展推动了语言AI系统在文本编写和理解方面的表现。免费获取:下载量10w+!大型语言模型:语言理解和生成......
  • 《开源大模型食用指南》发布,7个小时,一杯奶茶速通大模型!
    前言《开源大模型食用指南》是一个围绕开源大模型、针对国内初学者、基于AutoDL平台的中国宝宝专属大模型教程,针对各类开源大模型提供包括环境配置、本地部署、高效微调等技能在内的全流程指导,简化开源大模型的部署、使用和应用流程,**让更多的普通学生、研究者更好地使......
  • 你觉得大模型时代该出现什么?
    大模型的概念都火了两年了,之前各种媒体吹嘘大模型的出现是类似“蒸汽机时代”、“iPhone时刻”等等。那为什么我们期待的结果都没出现呢?咱们先一起回顾下历史。1、蒸汽机时代1.1、蒸汽机历史许多人都在讨论大模型时代好像只是概念在火,但没有拿得出手的实际落地的案例。目前的应......
  • 常用的ViT模型
    常用的ViT模型有许多版本和变种,它们在不同的任务和数据规模上表现出色。以下是一些常见的ViT模型及其变种:1.ViT-B/16,ViT-B/32ViT-B/16和ViT-B/32是VisionTransformer的基本版本,"B"代表Base模型,数字16和32代表图像块的大小(如16x16或32x32)。ViT-B/16通常表现优于ViT-B/32,因......
  • 【人工智能】常用的人工智能框架、模型、使用方法、应用场景以及代码实例的概述
    人工智能(AI)领域涉及众多框架和模型,这些框架和模型为开发人员提供了强大的工具,以构建和训练各种AI应用。以下是一些常用的人工智能框架、模型、使用方法、应用场景以及代码实例的概述。一、常用框架1.TensorFlow简介:TensorFlow是一个由谷歌开发的开源深度学习框架,支持大规模......
  • 【深度学习】基于YOLOV5模型的图像识别-目标检测的性能指标详解与计算方法
    目标检测是计算机视觉中的重要任务,主要目的是在图像中识别并定位特定的物体。YOLO(YouOnlyLookOnce)系列模型作为目标检测领域的代表性方法之一,凭借其高效和准确的特点,广泛应用于实际场景中。本文通过详细介绍目标检测的性能指标及其计算方法,帮助读者更好地理解和评估YOLO......
  • 基于C# winform调用文心一言大模型实现实时聊天功能
    【软件界面】【测试通过环境】vs2019netframework4.7.2【使用步骤】由于调用百度接口需要首先去https://login.bce.baidu.com/去注册或者登录自己的账号,进去后界面如下:然后点击左上角九个点图标然后点击百度智能云千帆大模型平台点击应用接入然后选择创建应用即可......
  • Spark Structured Streaming 概论
    SparkStructuredStreaming概论与以往任何时候都不同,今天的大数据处理,对于延迟性的要求越来越高,因此流处理的基本概念与工作原理,是每一个大数据从业者必备的“技能点”。在这个模块中,按照惯例,我们还是从一个可以迅速上手的实例开始,带你初步认识Spark的流处理框架Stru......