首页 > 其他分享 >LLM采样后处理总结:LLM的后处理的cpp实现

LLM采样后处理总结:LLM的后处理的cpp实现

时间:2023-10-11 18:23:11浏览次数:30  
标签:last top float 后处理 score LLM cpp sum first

LLM采样后处理总结:LLM的后处理的cpp实现

在经过LLM的lm_head之后,会得到[batch, vocab_size]大小的矩阵向量,此时需要对输出的逻辑张量进行采样,除了beam_search的贪心策略,还有repetition_penalty、temperature、top_k、top_p等几种控制采样的方法。

repetition_penalty

repetition_penalty的主要作用是控制重复,这里first和last分别为vocab中的第一个元素和最后一个元素的位置,input_ids为之前输出的文本id。
也即是把之前输出过的内容全部变小,那么就可以防止文本出现不断重复的情况,penalty越小,惩罚力度越大,penalty越大,惩罚力度越小,重复概率就会增加。

void sampling_repetition_penalty(float *first, float *last, const std::vector<int> &input_ids,
                                                       float penalty) {
    std::unordered_set<int> unique_input_ids(input_ids.begin(), input_ids.end());
    for (int id : unique_input_ids) {
        if (first[id] > 0) {
            first[id] /= penalty;
        } else {
            first[id] *= penalty;
        }
    }
}

temperature

temperature是控制softmax下的平滑参数,相当于在softmax前每个逻辑值都进行了放缩。
当temp越大的时候,此时softmax值之间的差距会减小,分布就越均匀,此时采样出的结果就越随机,反之就会使得原本高概率的的变得更高低的更低减少了随机性。

void sampling_temperature(float *first, float *last, float temp) {
    float inv_temp = 1.f / temp;
    for (float *it = first; it != last; it++) {
        *it *= inv_temp;
    }
}

top_k

top_k是取前k个,直接排序拿到概率最大的前k个。

void sampling_top_k(TokenIdScore *first, TokenIdScore *kth, TokenIdScore *last) {
    std::nth_element(first, kth, last, std::greater<TokenIdScore>());
}

top_p

top_p是先对所有的值进行softmax,然后找到满足sum_p <= top_p的最小集合,然后对这个集合内的数再进行softmax和采样。
一种简单的做法是将所有值进行排序,然后贪心找到满足条件的前k个。
示例代码中使用了一种类似于快速排序的方法,每次找mid点,将大于mid和小于mid的分为两堆,要么在大的一堆要么在小的一堆。
当在大的一堆中时就mid往前移动,在小的一堆时则更新top_p = top_p-sum_p,直至找到对应的位置。
时间复杂度上会稍微比先排序快一些。

void sampling_softmax_inplace(TokenIdScore *first, TokenIdScore *last) {
    float max_score = std::max_element(first, last)->score;
    float sum = 0.f;
    for (TokenIdScore *p = first; p != last; p++) {
        float s = std::exp(p->score - max_score);
        p->score = s;
        sum += s;
    }
    float inv_sum = 1.f / sum;
    for (TokenIdScore *p = first; p != last; p++) {
        p->score *= inv_sum;
    }
}
TokenIdScore *sampling_top_p(TokenIdScore *first, TokenIdScore *last, float top_p) {
    // fast top_p in expected O(n) time complexity
    sampling_softmax_inplace(first, last);

    while (first + 1 < last) {
        float pivot_score = (last - 1)->score; // use mid score?
        TokenIdScore *mid =
            std::partition(first, last - 1, [pivot_score](const TokenIdScore &x) { return x.score > pivot_score; });
        std::swap(*mid, *(last - 1));

        float prefix_sum =
            std::accumulate(first, mid, 0.f, [](float sum, const TokenIdScore &x) { return sum + x.score; });
        if (prefix_sum >= top_p) {
            last = mid;
        } else if (prefix_sum + mid->score < top_p) {
            first = mid + 1;
            top_p -= prefix_sum + mid->score;
        } else {
            return mid + 1;
        }
    }
    return last;
}

标签:last,top,float,后处理,score,LLM,cpp,sum,first
From: https://www.cnblogs.com/wildkid1024/p/17757877.html

相关文章

  • src/param.cpp:30:26: fatal error: gsl/gsl_blas.h: No such file or directory
     001、问题:安装gemma软件报错src/param.cpp:30:26:fatalerror:gsl/gsl_blas.h:Nosuchfileordirectory 002、解决方法,安装glsa、官网下载http://mirrors.ustc.edu.cn/gnu/gsl/ b、wgethttp://mirrors.ustc.edu.cn/gnu/gsl/gsl-2.7.tar.gztar-xzfgsl-2.7......
  • Graph RAG: 知识图谱结合 LLM 的检索增强
    本文为大家揭示NebulaGraph率先提出的GraphRAG方法,这种结合知识图谱、图数据库作为大模型结合私有知识系统的最新技术栈,是LLM+系列的第三篇,加上之前的图上下文学习、Text2Cypher这两篇文章,目前NebulaGraph+LLM相关的文章一共有3篇。GraphRAG在第一篇关于上下文......
  • 【Cpp】RTTI 机制原理解析
    ReferencesBaiduWikiC++中的RTTI机制详解RTTI推荐阅读:RTTI原理推荐阅读:C++中的RTTI机制什么是RTTI机制?RTTI是“RuntimeTypeInformation”的缩写,意思是:运行时类型信息。它提供了运行时确定对象类型的方法。RTTI通过运行时类型信息程序能够使用基类的指针或引用......
  • LLM实践-在Colab上使用免费T4 GPU进行Chinese-Llama-2-7b-4bit推理
    一、配置环境1、打开colab,创建一个空白notebook,在[修改运行时环境]中选择15GB显存的T4GPU.2、pip安装依赖python包!pipinstall--upgradeaccelerate!pipinstallbitsandbytestransformers_stream_generator!pipinstalltransformers!pipinstallsentencepiece!pip......
  • Windows桌面应用程序源文件.cpp注释
     这个是visualstudio2022上利用Windows桌面应用程序模板创建的源文件注释一个Windows图形界面(GUI)应用程序通常由主窗体,对话框,控件组成。当应用程序创建一个窗体,需要调用CreateWindowEx函数,必须提供的参数1.窗体类窗体类是一个结构体。是一系列属性的集合,用来描述窗体的行为......
  • 论文阅读:iterator zero-shot llm prompting for knowledge graph construction
    Abstract知识图谱,一种相互连接和可解释的结构。生成需要更多的人力、领域知识、并需要适用于不同的应用领域。本论文提出借助LLM,通过0-shot和外部知识不可知的情况下生成知识图谱。主要贡献:迭代的prompting提取最终图的相关部分0-shot,不需要examples一个可扩展的解决方案,......
  • MaSuRCA 软件安装 swig/perl5/swig_wrap.cpp:342:20: fatal error: string.h: No such
     001、问题MaSuRCA软件安装swig/perl5/swig_wrap.cpp:342:20:fatalerror:string.h:Nosuchfileordirectory  002、原因,当前环境处于conda的base环境,可能是函数库调用混乱。  003、解决方法,推出conda基础环境安装(base)[b20223040323@admin1MaSuRCA-4......
  • C和CPP程序是如何运行起来的?
    C和CPP程序是如何运行起来的?个人见解,谨慎阅读。如有错误,欢迎指正!代码均在Linux下编译运行。1.C语言程序从源码到可执行文件的过程C语言程序从源码到可执行文件的过程主要分为以下几个步骤:预处理、编译、汇编、链接。flowchartLRA1[代码]--"预处理"-->B1[预处理文......
  • 解密Prompt系列16. LLM对齐经验之数据越少越好?LTD & LIMA & AlpaGasus
    LLMAgent中间插个队,总结下指令微调、对齐数据相关的方案,已经凑够7篇论文可以召唤神龙啦!论文都是以优化指令样本为核心,Data-Centric的观点比较一致:指令微调也就是对齐阶段的数据质量>>数量,少量+多样+高质量的对齐数据,就能让你快速拥有效果杠杠的模型。注意以上三者是充分必要关系,......
  • [CPP] CPP的编译链接过程
    手写的源代码本质上只是一串文本,但是在编译器里点一下编译就可以直接看到程序的输出,从文本到执行输出之间发生了什么 源代码到可执行程序大致经历以下几个过程         1、预编译(Preprocessing)预编译阶段主要做四件事:头文件展开,宏替换,执行预编译......