Inferllm源码解析
文件结构
- application: 放置几个不同模型的参数配置和后处理
- include: 包含抽象model类的头文件
- src/core: 核心组件,包括tensor、算法等基础算子的抽象和KV文件系统的实现
- src/graph: 包含了几种LLM模型的具体实现
- src/kern: 包含了不同硬件下的算子实现
- src: 剩余一些其他公共函数实现
继承与组合关系
chat/alpaca/chatglm是外置的application程序,通过该入口设置模型随机种子、线程数、token等inferllm::ModelConfig模型参数
利用这些模型参数构建inferllm::Model的shared_ptr,经过load、init、decode_iter等操作进行编解码,设置fix_word函数对结果后处理
关键inferllm::Model类中包含了实际执行类inferllm::ModelImp,而inferllm::ModelImp则包含了inferllms::device的unique_ptr和inferllms::Graph的shared_ptr以及Vocab的shared_ptr,在实际计算时,则通过inferllm::ModelImp中的graph进行execute,得到token,将token进行解码并返回。
alpaca
chat/baichuan/llama
chatglm
-- inferllm::mdoel
-- inferllm::ModelImp
-- inferllms::Tensor
-- inferllms::device(CPUDevice, GPUdevice)
-- inferllm::ThreadPool
-- inferllm::Kernel
-- inferllm::KernelID
-- inferllm::LlmParams
-- inferllm::Graph(llamaGraph, chatglmGraph, baichuanGraph)
-- inferllm::OprModuleBase()
-- inferllm::OpBase(LayerNorm, Embedding, SoftMax, Elemwise, MatMul, MatMulLast, LlamaAttention, GlmAttention, DiagMask)
-- inferllm::Kernel
-- inferllm::UserConfig
-- inferllm::WorkSpace
-- inferllm::Tensor
-- inferllm:: Vocab
-- inferllm::ModelConfig
-- compt_type
-- nr_thread
-- nr_ctx
-- device_id
核心类属性分析
model_imp类
所有model的抽象接口,也是所有model的基类,当其他模型运行时,使用该基类的shared_ptr对象,借助多态的方法实现模型的参数加载load函数,模型填充prefill函数以及token的encode和decode操作,
load: 加载模型文件,将权重加载到graph中,并将最终输出logist重置为对应的vocab_len
prefill: 类似于warmup函数,将token填入到模型中
decode: 第一次运行的token,放于网络中运行并解码
decode_iter:非第一次的模型运行tokens,取top_k个返回
sample_and_update:加入惩罚因子对输出进行惩罚,并选出top_p个作为下一个token
tokenize: 将文本编码为对应的input_ids
decode_summary: 将编解码效率和速度总结输出
OprModuleBase类
是所有opr操作的基础类,有添加opr函数,get_all_weights函数,获取inputs函数,获取output函数,name名字,device设备等函数。
execute:有pre_execute、execute、end_execute对op的输入输出进行前处理和后处理
get_workspace_in_byte:得到当前Module中的所有op中最大占用空间
LlamaFFNModule类
包含了2次matmul乘积silu激活函数并进行残差操作,然后与w3进行matmul得到最终输出结果。
GlmFFNModule类
包含了一次matmul乘积gelu激活函数,然后与w2相乘得到最后输出。
HeadModule类
包含了一次layer_norm,对输入进行norm和matmul得到最终输出。
EmbdModule
Embedding模块,通过input_ids找到对应的embedding编码。
OpBase类
所有Op基础类,包含了执行预处理操作,执行操作,后执行操作、set_name、set_outputs、set_weights等操作。
LayerNorm类、Embedding类、SoftMax类、Elemwise类、MatMul类、MatMulLast类、LlamaAttention类、GlmAttention类、DiagMask类
调用两种不同的kernel,RmsNormFloat和NormFloat,对输入数据进行归一化处理。
调用EmbeddingGetFloatFloat kernel进行使用不同的device计算。
调用softmax类在不同的device上进行操作
Kernel实现
llm_elemwise_broadcast_dim0_src1_compute_float_add_gpu、llm_elemwise_compute_float_scale_gpu
实现elemwise的乘法和加法,每个线程计算一次乘法或者加法,并且当第二个数的维度小于第一个的维度时,采用dim0上广播expand。
例如:
[[1, 2, 3,], [4, 5, 6]] * [[7, 8, 9]]
这时,如果乘完计算123之后,会循环计算,第二个矩阵就在dim0上进行广播扩充。
llm_elemwise_broadcast_dim0_src1_compute_float
计算当前机器的blocks,设置不同的blocks,找到对应不同的乘加算法。
ApplyFunction
函数模板,根据不同的函数function去实现不同的计算逻辑
LaunchKernel
带安全检查和block参数配置的应用函数
llm_softmax_compute_float_gpu
先找到最大值,val = exp(v-max_v), sum += val_all, softmax(val_i / sum)
并行策略:每次计算一行上的softmax,多线程并行计算多行,线程数为行数
llm_norm_compute_float_gpu
得到每一行中的方差scale,v[i] *= scale
并行策略:每次计算一个seq_len上的norm,并行线程数为seq_len
llm_embedding_get_float_float
将embedding的头指针拷贝到cuda中,并将每一行对应的数据拷贝到cuda中
dequantize_row_q4_0_reference_gpu
__restrict用法
反量化操作,加了unroll操作,每32个数计算一次,其中反量化算子下高4位为第一个数,低4位为第二个数,并在计算时加入了scale因子。
并行策略:在embedding量化中并行线程数为seq_len
SiluFunctor、GeluFunctor、AddFunctor、MulFunctor
cpu计算函数,一些常见的函数
llm_rms_norm_compute_float_gpu
得到每一行中的方差scale,v[i] *= scale
并行策略:每次计算一个seq_len上的norm,并行线程数为seq_len
llm_rope_compute_float_gpu
计算rotate_position_embedding,每次计算rotate_scale, 将position分为两半,前后两部分分别计算x0 * cos_scale - x1 * sin_scale
并行策略:每次计算每个seq_len下每个head每个rotate下的旋转位置编码,并行线程数seqlen * head * (n_rot / 2)次
llm_matmul_compute_float_float_gpu
矩阵乘积运算,每次只进行一行一列的计算
并行策略:并行线程数M×N
llm_matmul_compute_int4_float_step1_gpu
分两步进行计算,第一步计算两个张量的scale尺度,保存在d中,第二步将对应位置上的x,y取出来,int4乘加后并保存在float中,其中val = x_int4_sum * y_int4_sum * d1 * d2
llm_scale_diag_mask_inf_float_gpu、llm_diag_mask_inf_float_gpu
将矩阵对角线以上的mask设置为无限大,否则乘以scale
并行策略:并行线程数为head * seqlen * (n_past + seqlen)
llm_matmul_compute_with_head_stride_float、llm_head_batched_matmul_compute_float_gpu
用于计算多头、多batch矩阵乘积,并行数量上比普通矩阵乘法多了一个head num
标签:scale,Inferllm,--,float,源码,llm,gpu,解析,inferllm From: https://www.cnblogs.com/wildkid1024/p/17609154.html