首页 > 编程语言 >大模型推理框架 vLLM 源码解析(一)

大模型推理框架 vLLM 源码解析(一)

时间:2024-02-04 18:22:57浏览次数:34  
标签:解析 prompt seq self request token 源码 vLLM id

1. Quick Start

创建如下代码,命名为 run.py

from vllm import LLM, SamplingParams

prompts = [
	"Have you followed marsggbo in Zhihu?",
	"你一键三连了吗?"
] # 输入prompts
sampling_params = SamplingParams(temperature=0.8, top_k=50) # 采样策略
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2) # 初始化 LLM
outputs = llm.generate(prompts, sampling_params) # 完成推理
for output in outputs:
	prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

执行命令:python run.py。该脚本会自动将模型以张量并行的方式在两个 GPU 上进行推理计算。

整个推理过程大大致流程如下图所示,即
1 给定一定数量的 prompts(字符串数组)
2. vllm 会使用 Scheduler 模块自动对需要推理句子进行调度
3. 根据调度的结果,使用 tokenizer 将字符串转换成 prompt id,然后喂给 model 进行计算得到 logits 预测结果
4. 根据 logits 预测结果和提前设置好的采样策略对结果进行采样得到新的 token id
5. 将采样结果保存到 output

inferencce pipeline

2. 整体核心模块

vllm 核心模块结构
上图给出了 vLLM 核心模块之间的结构关系。接下来我们从简单的模块(即输入、采样和输出)开始介绍,最后详细介绍 LLM 模块。

3. Sequence

句子模块
如上图我们可以看到 vLLM 为输入的句子设计了很多子模块,这些模块的用处各不相同,但是有彼此之间有关系,下面分别详细介绍一下。

3.1 SequenceStatus

首先看到 SequenceStatus,其源代码如下:

class SequenceStatus(enum.Enum):
    """Status of a sequence."""
    WAITING = enum.auto() # 等待中,句子还没开始推理,或者推理还未结束
    RUNNING = enum.auto() # 运行中
    SWAPPED = enum.auto() # 已交换
    FINISHED_STOPPED = enum.auto() # 已停止
    FINISHED_LENGTH_CAPPED = enum.auto() # 已长度限制
    FINISHED_ABORTED = enum.auto() # 已中止
    FINISHED_IGNORED = enum.auto() # 已忽略

    @staticmethod
    def is_finished(status: "SequenceStatus") -> bool:
        # 判断状态是否为已停止、已长度限制、已中止或已忽略
        return status in [
            SequenceStatus.FINISHED_STOPPED,
            SequenceStatus.FINISHED_LENGTH_CAPPED,
            SequenceStatus.FINISHED_ABORTED,
            SequenceStatus.FINISHED_IGNORED,
        ]

3.2 SequenceData

SequenceData 用于存储与序列相关的数据。这个类有三个属性:prompt_token_ids(提示词的标记ID)、output_token_ids(生成文本的标记ID)和cumulative_logprob(累计对数概率)。

class SequenceData:
    def __init__(
        self,
        prompt_token_ids: List[int],
    ) -> None:
        self.prompt_token_ids = prompt_token_ids
        self.output_token_ids: List[int] = []
        self.cumulative_logprob = 0.0

3.3 Sequence

Sequence 用于存储序列的数据、状态和块信息,且每个序列有唯一标识,即seq_id。注意看下面的代码:

  • 数据其实是通过上面的 SequenceData 保存的
  • 默认初始化状态,所有句子序列的状态都是 SequenceStatus.WAITING
  • 所谓块信息,其实就是 vLLM 会在初始化阶段预留出一定数量的CPU 和 GPU 内存,一般是以 token 为单位的,例如在初始化的时候会使用值全为 0,大小为 (256, 128)的 prompt_ids做 warm up。每个序列会按照实际大小申请 block 来记录内存使用情况,即序列 token 数越多,属性logical_token_blocks包含的 block 个数也就越多。
class Sequence:
    def __init__(
        self,
        seq_id: int,
        prompt: str,
        prompt_token_ids: List[int],
        block_size: int,
    ) -> None:
        self.seq_id = seq_id
        self.prompt = prompt
        self.block_size = block_size

        self.data = SequenceData(prompt_token_ids) # 数据

        self.logical_token_blocks: List[LogicalTokenBlock] = []
        # Initialize the logical token blocks with the prompt token ids.
        self._append_tokens_to_blocks(prompt_token_ids) # 块信息
        self.status = SequenceStatus.WAITING # 状态
		...

3.3 SequenceGroup

Sequence只是单个序列的表示方式,seq_id是它的唯一标识。SequenceGroup则是为了表示多个序列,request_id是它的唯一标识,表示是第几个请求。

具体而言,可以看到__init__函数有个参数是 seqs: List[Sequence],它表示由一个或多个 Sequence 组成的列表,然后会通过self.seqs_dict = {seq.seq_id: seq for seq in seqs}转化成字典方便管理,这个字典的 key 是每个 Sequence 的唯一标识seq_id

class SequenceGroup:
    def __init__(
        self,
        request_id: str,
        seqs: List[Sequence],
        sampling_params: SamplingParams,
        arrival_time: float,
        lora_request: Optional[LoRARequest] = None,
        prefix: Optional[Prefix] = None,
    ) -> None:
        self.request_id = request_id
        self.seqs_dict = {seq.seq_id: seq for seq in seqs}
        self.sampling_params = sampling_params
        self.arrival_time = arrival_time
		...

下面是 vLLm 中 LLMEngine 使用 Sequence 和 SequenceGroup 的场景示例:

class LLMEngine:
    def add_request(
        self,
        request_id: str,
        prompt: Optional[str],
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]] = None,
        arrival_time: Optional[float] = None,
        lora_request: Optional[LoRARequest] = None,
        prefix_pos: Optional[int] = None,
    ) -> None:
        prompt_token_ids = self.encode_request(
            request_id=request_id,
            prompt=prompt,
            prompt_token_ids=prompt_token_ids,
            lora_request=lora_request) # 将字符串序列转换成 id

        # Create the sequences.
        block_size = self.cache_config.block_size
        seq_id = next(self.seq_counter)
        seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
                       lora_request)

        # Create the sequence group.
        seq_group = SequenceGroup(request_id, [seq], sampling_params,
                                  arrival_time)

        # Add the sequence group to the scheduler.
        self.scheduler.add_seq_group(seq_group)

可以看到SequenceGroupseqs参数在最初阶段其实只是单个序列 ,即[seq]。但是我们知道其实一个 prompt 可以有多个输出结果,所以SequenceGroup的目的是管理一个输入 prompt的多个生成序列信息。如果我们设置SamplingParams.n=2(第 4 节会介绍),那么在推理过程中,SequenceGroup会新增一个 Sequence,这个新增的 Sequence 的 seq_id 和原来的那个 Sequence 不一样,具体的代码细节会在下一篇文章中介绍。

3.5 SequenceGroupMetadata

class SequenceGroupMetadata:
    def __init__(
        self,
        request_id: str,
        is_prompt: bool,
        seq_data: Dict[int, SequenceData],
        sampling_params: SamplingParams,
        block_tables: Dict[int, List[int]],
    ) -> None:
        self.request_id = request_id
        self.is_prompt = is_prompt
        self.seq_data = seq_data
        self.sampling_params = sampling_params
        self.block_tables = block_tables
		...

SequenceGroupMetadata 记录了一些元信息,下面代码展示了 Scheduler 模块是如何生成这些信息的:

  • request_id 就是 SequenceGroup的 request_id
  • seq_data 是一个字典,key 是每个 Sequence的 seq_id,value 则是对应的 data (即 SequenceData)
  • block_tables也是一个字典,key 也是每个 Sequence的 seq_id,value 这是对应 Sequence 申请的 block
class Scheduler:
    def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
        scheduler_outputs = self._schedule()

        # Create input data structures.
        seq_group_metadata_list: List[SequenceGroupMetadata] = []
        for seq_group in scheduler_outputs.scheduled_seq_groups:
            seq_data: Dict[int, SequenceData] = {}
            block_tables: Dict[int, List[int]] = {}
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
                seq_id = seq.seq_id
                seq_data[seq_id] = seq.data # 单个 SequenceData
                block_tables[seq_id] = self.block_manager.get_block_table(seq) # 对应Sequence的block信息

            seq_group_metadata = SequenceGroupMetadata(
                request_id=seq_group.request_id,
                is_prompt=scheduler_outputs.prompt_run,
                seq_data=seq_data,
                sampling_params=seq_group.sampling_params,
                block_tables=block_tables,
                lora_request=seq_group.lora_request,
                prefix=seq_group.prefix,
            )
            seq_group_metadata_list.append(seq_group_metadata)
        return seq_group_metadata_list, scheduler_outputs

3.6 SequenceOutput 和 SequenceGroupOutput

SequenceOutput 和 SequenceGroupOutput的关系就类似 Sequence 和 SequenceGroup。SequenceOutput其实就是记录了上一个 输入 token id 以及对应输出的 token id。

class SequenceOutput:
    def __init__(
        self,
        parent_seq_id: int,
        output_token: int,
        logprobs: Dict[int, float],
    ) -> None:
        self.parent_seq_id = parent_seq_id
        self.output_token = output_token
        self.logprobs = logprobs

class SequenceGroupOutput:
    def __init__(
        self,
        samples: List[SequenceOutput],
        prompt_logprobs: Optional[PromptLogprobs],
    ) -> None:
        self.samples = samples
        self.prompt_logprobs = prompt_logprobs

4. SamplingParams

SamplingParams

SamplingParams 包含以下参数:

  • n:要生成的序列的数量,默认为 1。
  • best_of:从多少个序列中选择最佳序列,需要大于 n,默认等于 n。
  • temperature:用于控制生成结果的随机性,较低的温度会使生成结果更确定性,较高的温度会使生成结果更随机。
  • top_p:用于过滤掉生成词汇表中概率低于给定阈值的词汇,控制随机性。
  • top_k:选择前 k 个候选 token,控制多样性。
  • presence_penalty:用于控制生成结果中特定词汇的出现频率。
  • frequency_penalty:用于控制生成结果中词汇的频率分布。
  • repetition_penalty:用于控制生成结果中的词汇重复程度。
  • use_beam_search:是否使用束搜索来生成序列。
  • length_penalty:用于控制生成结果的长度分布。
  • early_stopping:是否在生成过程中提前停止。
  • stop:要停止生成的词汇列表。
  • stop_token_ids:要停止生成的词汇的ID列表。
  • include_stop_str_in_output:是否在输出结果中包含停止字符串。
  • ignore_eos:在生成过程中是否忽略结束符号。
  • max_tokens:生成序列的最大长度。
  • logprobs:用于记录生成过程的概率信息。
  • prompt_logprobs:用于记录生成过程的概率信息,用于特定提示。
  • skip_special_tokens:是否跳过特殊符号。
  • spaces_between_special_tokens:是否在特殊符号之间添加空格。

这些参数的设置通常取决于具体需求和模型性能。以下是一些常见的设置指导方法:

  • temperature:较低的温度(如0.2)会产生更确定性的结果,而较高的温度(如0.8)会产生更随机的结果。您可以根据您的需求进行调整。
  • presence_penalty、frequency_penalty 和 repetition_penalty:这些参数可以用于控制生成结果中的词汇分布和重复程度。您可以根据您的需求进行调整。
  • use_beam_search:束搜索通常用于生成更高质量的结果,但可能会降低生成速度。您可以根据您的需求进行调整。
  • length_penalty:这个参数可以用于控制生成结果的长度。较高的值会产生更长的结果,而较低的值会产生更短的结果。您可以根据您的需求进行调整。
  • early_stopping:如果您不希望生成过长的结果,可以设置此参数为True。
  • stop 和 stop_token_ids:您可以使用这些参数来指定生成结果的结束条件。

5. Output 模块

Output模块

Output 主要用于表示语言模型(LLM)的生成结果,包含如下两个模块:

  • CompletionOutput
  • RequestOutput

通过上面的介绍我们知道一个 request 可能包含多个序列,CompletionOutput 用来表示一个 request 中某个序列的完整输出的数据,其中下面的index就表示该序列在 request 中的索引位置

class CompletionOutput:
    def __init__(
        self,
        index: int, # 输出结果在请求中的索引
        text: str, # 生成的文本
        token_ids: List[int], # 生成的文本对应的 token ID 列表
        cumulative_logprob: float,
        logprobs: Optional[SampleLogprobs],
        finish_reason: Optional[str] = None, # 序列完成的原因(SequenceStatus)
        lora_request: Optional[LoRARequest] = None,
    ) -> None:
        self.index = index
        self.text = text
        self.token_ids = token_ids
        self.finish_reason = finish_reason
		...

RequestOutput则表示 request 所有序列的输出结果,有它的初始化函数可以看到它记录了对应的 request_id

class RequestOutput:
    def __init__(
        self,
        request_id: str,
        prompt: str,
        prompt_token_ids: List[int],
        prompt_logprobs: Optional[PromptLogprobs],
        outputs: List[CompletionOutput],
        finished: bool,
        lora_request: Optional[LoRARequest] = None,
    ) -> None:
        self.request_id = request_id
        self.prompt = prompt
        self.prompt_token_ids = prompt_token_ids
        self.outputs = outputs
        self.finished = finished
		...

我们看看RequestOutput的from_seq_group就能很好理解CompletionOutputRequestOutput是如何使用的了。为方便理解,代码有删减,但是不影响最终结果:

class RequestOutput:
    @classmethod
    def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
        # 1. Get the top-n sequences.
        n = seq_group.sampling_params.n # 每个序列返回的生成序列数量
        seqs = seq_group.get_seqs()
		# 根据累积 logprob 值来选择出前 n 个生成序列
		sorting_key = lambda seq: seq.get_cumulative_logprob()
        sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
        top_n_seqs = sorted_seqs[:n]

        # 2. Create the outputs.
        outputs: List[CompletionOutput] = []
        for seq in top_n_seqs:
            logprobs = seq.output_logprobs
            finshed_reason = SequenceStatus.get_finished_reason(seq.status)
            output = CompletionOutput(seqs.index(seq), seq.output_text,
                                      seq.get_output_token_ids(),
                                      seq.get_cumulative_logprob(), logprobs,
                                      finshed_reason)
            outputs.append(output)

        # Every sequence in the sequence group should have the same prompt.
        prompt = seq_group.prompt
        prompt_token_ids = seq_group.prompt_token_ids
        prompt_logprobs = seq_group.prompt_logprobs
        finished = seq_group.is_finished()
        return cls(seq_group.request_id,
                   prompt,
                   prompt_token_ids,
                   prompt_logprobs,
                   outputs,
                   finished,
                   lora_request=seq_group.lora_request)

RequestOutput是通过对传入的seq_group: SequenceGroup进行解析后得到的。解析过程主要有两个阶段:

  1. Get the top-n sequences:这一阶段就是对生成序列按照 cumulative_logprob 进行排序,最后选择出top-n 序列。
  2. Create the outputs:将所有top-n生成序列分别转换成 CompletionOutput列表,并作为RequestOutput的初始化参数。

微信公众号:AutoML机器学习
MARSGGBO♥原创
如有意合作或学术讨论欢迎私戳联系~
邮箱:marsggbo@foxmail.com

标签:解析,prompt,seq,self,request,token,源码,vLLM,id
From: https://www.cnblogs.com/marsggbo/p/18006753

相关文章

  • 鱼和熊掌如何兼得?一文解析 RDS 数据库存储架构升级
    在2023年云栖大会上,阿里云数据库产品事业部负责人李飞飞在主题演讲中提到,瑶池数据库推出“DB+存储”一体化能力,结合人工智能、机器学习、存储等方法和创新能力,实现BufferPoolExtension能力和智能冷温热数据分层能力。在大会的《云数据库RDS年度发布与最佳实践》演讲中,阿里......
  • Java原子性、可见性、有序性的解析
    一、原子性原子性指操作在CPU执行的过程中,不可中断,也不可在中途切换,要么执行完成要么不执行。packagelearn;/***@authorqx*@date2024/1/5*@des*/publicclassThreadAtomicityTest{privateintstep;publicintgetStep(){returnstep;......
  • 深入解析 Flink CDC 增量快照读取机制
    深入解析FlinkCDC增量快照读取机制一、Flink-CDC1.x痛点FlinkCDC1.x使用Debezium引擎集成来实现数据采集,支持全量加增量模式,确保数据的一致性。然而,这种集成存在一些痛点需要注意:一致性通过加锁保证:在保证数据一致性时,Debezium需要对读取的库或表加锁。全局锁可能导致数......
  • 卸载centos源码安装的mysql
    关闭mysql查看安装的mysqlrpm-qa|grep-imysql卸载安装的mysqlrpm-evmysql-community-client-8.0.11-1.el7.x86_64--nodeps--nodeps强制卸载删除mysql相关目录whereismysqlfind/-namemysqlrm-rf目录rm命令文件一旦通过rm命令删除,则无法恢复;删除文件......
  • 学习解析几何的启示——去掉直接联系,采用中心化标准
    目录引入案例1:找出三角形的外心案例2:证明两条线段垂直案例3:确定与一组点等距离的点的位置案例4:研究二次曲线的性质思想引入同样的几何体,不同阶段所使用的解题技巧:在初中,熟悉几何定理,需要添加辅助线在高中,需要建立坐标系,采用向量的方法,套对应的公式解析几何之所以强大,在于......
  • oracle 发生硬解析的常见原因
     经常有SQL的执行计划变差了,导致出现了性能问题,为什么会解析出了新的执行计划?首先可以肯定突然出现了新的执行计划,表明SQL进行了重新硬解析(注意重新硬解析不一定产生新的执行计划),那么为什么SQL需要重新硬解析呢?今天就列举几种常见的原因: 1.自动收集统计信息 为了保证sql......
  • CTFer——文件操作与upload-labs解析
    一、什么是webshellwebshell就是以aspx、php、jsp等网页文件形式存在的一种命令环境也可以将其称为一种网页后门 黑客在入侵网站后,通常会将jsp、aspx或php后门文件与网站服务器web目录下正常的网页文件混在一起然后就可以使用浏览器来访问后门文件,得到一个命令执行环境,以达......
  • BeautifulSoup爬虫库应用——Python 页面解析
    爬虫技术作为信息搜集的重要手段,在大数据时代发挥着至关重要的作用。通过网络爬虫,可以高效地从各种在线源头获取大规模、多样化的数据,为大数据分析和应用提供了必要的原始材料。首先,爬虫使得大数据的采集更为全面和及时。网络上存在着庞大的信息资源,包括社交媒体、新闻网站、电子......
  • 解锁教育系统源码的定制奥秘:企业培训平台开发详解
    今天,小编将为大家讲解教育系统源码的奥秘,详细解释企业培训定制开发的关键步骤和技术要点。 一、需求分析与设计阶段设计阶段则包括系统的整体架构设计、数据库设计以及用户界面设计等方面。二、技术选型与开发环境搭建通过使用版本控制系统、集成开发环境(IDE)以及一系列的测试工具,......
  • synchronized【如何保证原子性、可见性、有序性】【如何实现原子性 原理解析】【什么
    @TOC转自极客时间如何解决可见性问题?同步原理剖析什么是Monitor?什么是锁优化?......