首页 > 编程语言 >使用 PPO 算法进行 RLHF 的 N 步实现细节

使用 PPO 算法进行 RLHF 的 N 步实现细节

时间:2023-11-16 15:12:03浏览次数:45  
标签:py preferences lm PPO token batch RLHF 算法 human

当下,RLHF/ChatGPT 已经变成了一个非常流行的话题。我们正在致力于更多有关 RLHF 的研究,这篇博客尝试复现 OpenAI 在 2019 年开源的原始 RLHF 代码库,其仓库位置位于 openai/lm-human-preferences。尽管它具有 “tensorflow-1.x” 的特性,但 OpenAI 的原始代码库评估和基准测试非常完善,使其成为研究 RLHF 实现工程细节的好地方。

我们的目标是:

  1. 复现 OAI 在风格化任务中的结果,并匹配 openai/lm-human-preferences 的学习曲线。
  2. 提供一个实现细节的清单,类似于 近端优化策略的 37 个实施细节 (The 37 Implementation Details of Proximal Policy Optimization)没有痛苦折磨的调试 RL (Debugging RL, Without the Agonizing Pain) 的风格;
  3. 提供一个易于阅读且简洁的 RLHF 参考实现;

这项工作仅适用于以教育/学习为目的的。对于需要更多功能的高级用户,例如使用 PEFT 运行更大的模型, huggingface/trl 将是一个不错的选择。

  • 匹配学习曲线 中,我们展示了我们的主要贡献: 创建一个代码库,能够在风格化任务中复现 OAI 的结果,并且与 openai/lm-human-preferences 的学习曲线非常接近地匹配。

  • 然后我们深入探讨了与复现 OAI 的工作相关的实现细节。在 总体实现细节 中,我们讨论了基本细节,像如何生成奖励/值和如何生成响应。在 奖励模型实现细节 中,我们讨论了诸如奖励标准化之类的细节。在 策略训练实现细节 中,我们讨论了拒绝采样和奖励“白化”等细节。

  • 接下来,我们检查了在奖励标签由 gpt2-large 生成的情况下,训练不同基础模型 (例如 gpt2-xl, falcon-1b) 的效果。

  • 最后,我们通过讨论一些限制来总结我们的研究工作。

以下是一些重要链接:

相关文章

  • 由数据范围反推算法复杂度以及算法内容
    由数据范围反推算法复杂度以及算法内容一般ACM或者笔试题的时间限制是1秒或2秒。在这种情况下,\(\mathrm{C}++\)代码中的操作次数控制在\(10^{7}\sim10^{8}\)为最佳。下面给出在不同数据范围下,代码的时间复杂度和算法该如何选择:\(n\leq30\),指数级别,\(\mathrm{dfs......
  • 文心一言 VS 讯飞星火 VS chatgpt (136)-- 算法导论11.3 2题
    二、用go语言,假设将一个长度为r的字符串散列到m个槽中,并将其视为一个以128为基数的数,要求应用除法散列法。我们可以很容易地把数m表示为一个32位的机器字,但对长度为r的字符串,由于它被当做以128为基数的数来处理,就要占用若干个机器字。假设应用除法散列法来计算一个字符串......
  • 计算机图形:计算法向量
    目录一元向量值函数及其导数一元向量值函数概念一元值函数的导数空间曲线的切线和法平面曲面的切平面与法线示例:求椭球体表面法向量参考一元向量值函数及其导数一元向量值函数概念已知空间曲线Γ(大写的γ)参数方程:\[\tag{1}\begin{cases}x=\varphi(t),\\y=\psi(t),t\in[\al......
  • python机器学习算法原理实现——MCMC算法之gibbs采样
    【算法原理】Gibbs采样是一种用于估计多元分布的联合概率分布的方法。在MCNC(Markov Chain Monte Carlo)中,Gibbs采样是一种常用的方法。通俗理解Gibbs采样,可以想象你在一个多维空间中,你需要找到这个空间的某个特定区域(这个区域代表了你感兴趣的分布)。但是,你不能直接看到整个空间,只......
  • 机器学习算法原理实现——HMM生成序列和维特比算法
    【HMM基本概念】隐马尔可夫模型(Hidden Markov Model,HMM)是一种统计模型,用于描述一个含有未知参数(隐状态)的马尔可夫过程。在HMM中,我们不能直接观察到状态,但可以观察到每个状态产生的一些相关数据(观测值)。HMM的目标是,给定观测序列,估计出最可能的状态序列。HMM的基本假设有两个(见例子......
  • 机器学习算法原理实现——EM算法
    【EM算法简介】EM算法,全称为期望最大化算法(Expectation-Maximization Algorithm),是一种迭代优化算法,主要用于含有隐变量的概率模型参数的估计。EM算法的基本思想是:如果给定模型的参数,那么可以根据模型计算出隐变量的期望值;反过来,如果给定隐变量的值,那么可以通过最大化似然函数来估......
  • 机器学习算法原理实现——朴素贝叶斯
    【先说条件概率】条件概率是指在某个事件发生的条件下,另一个事件发生的概率。以下是一个实际的例子:假设你有一副扑克牌(不包括大小王,共52张牌),你随机抽一张牌。我们设事件A为"抽到的牌是红色的"(红心和方块为红色,共26张),事件B为"抽到的牌是心"(红心共13张)。1.首先,我们可以计算事件A和事......
  • 机器学习算法原理实现——最大熵模型
    【写在前面】在sklearn库中,没有直接称为"最大熵模型"的类,但是有一个与之非常相似的模型,那就是LogisticRegression。逻辑回归模型可以被视为最大熵模型的一个特例,当问题是二分类问题,且特征函数是输入和输出的线性函数时,最大熵模型就等价于逻辑回归模型。【最大熵模型的原理】最大熵......
  • 算法刷题记录-哈希表
    算法刷题记录-哈希表有效的字母异位词给定两个字符串*s*和*t*,编写一个函数来判断*t*是否是*s*的字母异位词。注意:若*s*和*t*中每个字符出现的次数都相同,则称*s*和*t*互为字母异位词。示例1:输入:s="anagram",t="nagaram"输出:true示例2:输入:s......
  • 若依vue启动报Error: error:0308010C:digital envelope routines::unsupported
    解决:若依vue启动报Error:error:0308010C:digitalenveloperoutines::unsupported1.描述:问题产生原因是因为node.jsV17版本中最近发布的OpenSSL3.0,而OpenSSL3.0对允许算法和密钥大小增加了严格的限制,可能会对生态系统造成一些影响.解决方法:有很多种,我把适合我的写在第一......