首页 > 其他分享 >InstructGPT: Training language models to follow instructions with human feedback 原理详解

InstructGPT: Training language models to follow instructions with human feedback 原理详解

时间:2024-08-16 16:27:25浏览次数:17  
标签:10 Training feedback language SFT 训练 模型 InstructGPT log

文章目录


1. InstructGPT目标

InstructGPT 探讨了如何通过人类反馈来训练语言模型以更好地遵循用户的意图。通过对模型进行监督学习和强化学习,使其能够更好地满足用户的需求。实验结果表明,尽管 InstructGPT 仍然会犯一些简单的错误,但它在遵循用户意图、减少有毒输出等方面表现出色。

InstructGPT 原文:《Training language models to follow instructions
with human feedback》

图中是 InstructGPT 训练的三个步骤,蓝色箭头表示数据训练的流向:
在这里插入图片描述

(1)监督微调(SFT)。根据标注人员提供的数据对 GPT-3 模型进行微调;GPT 1-3 详解参考:《详解GPT-1到GPT-3的论文亮点以及实验结论》

(2)奖励模型(RM)训练。给定输入并对模型的多个输出进行打分排序,训练一个奖励模型来预测人类更喜欢的输出;

(3)近端策略优化(PPO)强化学习。用 RM 的输出作为奖励,使用 PPO 算法优化奖励策略。

PPO 论文《Proximal Policy Optimization Algorithms》。PPO(Proximal Policy Optimization,近端策略优化)是一种强化学习算法,由 John Schulman 等人在2017年提出。PPO 属于策略梯度方法,这类方法直接对策略(即模型的行为)进行优化,试图找到使得期望回报最大化的策略。PPO 已被广泛应用于各种强化学习场景,包括游戏、机器人控制以及自然语言处理中的序列决策问题,是目前最流行的强化学习算法之一。

2. 数据集

下表中是用于训练SFT、RM 和 PPO 模型的数据集的大小,labeler 表示来源于雇佣的标注人员,customer 表示来源于 API 接口用户。

在这里插入图片描述

2.1 SFT数据集

SFT 数据集用于训练第一步的 GPT-3 模型。SFT 数据集是通过一组由人类编写并用于指导语言模型行为的提示,由提示-答复对组成。来源包括:

(1)OpenAI 的 PlayGround 的用户;

(2)OpenAI 雇佣的40名标注人员,根据内容编写提示,内容包含以下部分:

  • 简单任务:标签人员编写一个任意任务,确保这些任务具有足够的多样性;
  • Few-shot 任务:标签人员编写多个查询/响应对;
  • 用户相关的任务:标签人员根据 OpenAI 的接口用例编写提示。

其中具体的内容类型:

在这里插入图片描述

例如:

类型内容
generationWrite a short story where a brown bear to the beach, makes friends with a seal, and then return home.
brainstormingList five ideas for how to regain enthusiasm for my career
extractExtract all place names from the article below: {news article}

2.2 RM数据集

RM 数据集用于训练第二步中的奖励模型。InstructGPT 的方法是让模型生成一系列候选文本,然后由标注人员根据内容质量对这些文本进行排序,通过人工给那些有毒有害的生成内容打低分,以避免模型产生不受欢迎的内容。这个数据集用于进一步优化模型的行为,使其更符合人类期望。

例如,下图是 OpenAI 雇佣的标注人员对输出进行排序的界面

在这里插入图片描述

2.3 PPO数据集

PPO数据集是通过强化学习从人类反馈中获得的数据集。它包含了人类评估员给出的奖励信号,用于调整模型参数以改进其性能。

3. 训练细节

(1)所有模型架构都使用了 GPT-3 架构;

(2)对于 RM 模型,原始模型的 decode 层被替换为投影层以输出标量值;

(3)所有模型都使用 fp16 权重和激活值,并且具有 fp32 权重主副本;

(4)所有的语言模型和强化学习策略都有一个上下文长度为 2k 个标记。将超过 1k 个标记的提示过滤掉,并将最大响应长度限制为 1k 个标记。

(5)所有模型都使用 Adam 优化器进行训练,其中 β 1 = 0.9 , β 2 = 0.95 β_1=0.9,β_2=0.95 β1​=0.9,β2​=0.95

3.1 SFT训练

(1) e p o c h = 16 epoch=16 epoch=16, d r o p o u t = 0.2 dropout=0.2 dropout=0.2;

(2)学习率采用使用余弦,学习率衰减至原始学习率的 10%,并且不进行学习率预热;

(3)对于 1.3B 参数和 6B 参数模型,初始学习率为 9.65 × 1 0 − 6 9.65×10^-6 9.65×10−6,batchsize 为 32;

(4)对于 175B 参数模型,初始学习率为 5.03 × 1 0 − 6 5.03×10^-6 5.03×10−6 的学习率,batchsize 为 8;

3.2 RM训练

RM 在初始 6B 的 GPT-3 上训练。该 GPT-3 模型已经在多种公共 NLP 数据集(ARC、BoolQ、CoQA、DROP、MultiNLI、OpenBookQA、QuAC、RACE 和 Winogrande)上进行了微调。

(1)采用 6B 是因为在该参数下 RM 模型在广泛的学习率下都很稳定,训练效率高,且可生成同等能力的 PPO 模型;

(2)学习率采用使用余弦,学习率衰减至原始学习率的 10%;

(3)学习率初始为 9 × 1 0 − 6 9×10^-6 9×10−6,batchsize 为 64;

(4)损失函数的数学表达式:

loss ( θ ) = − 1 ( K 2 ) E ( x , y w , y l ) ∼ D [ log ⁡ ( σ ( r θ ( x , y w ) − r θ ( x , y l ) ) ) ] \text{loss}(\theta) = -\frac{1}{K \choose 2} \mathbb{E}_{(x, y_w, y_l) \sim D} \left[ \log \left( \sigma(r_\theta(x, y_w) - r_\theta(x, y_l)) \right) \right] loss(θ)=−(2K​)1​E(x,yw​,yl​)∼D​[log(σ(rθ​(x,yw​)−rθ​(x,yl​)))]

其中:

  • r θ ( x , y ) r_\theta(x, y) rθ​(x,y) 是奖励模型对于给定的提示 $ x $ 和完成 $ y $ 的标量输出,参数为 $ \theta $;
  • y w y_w yw​ 和 y l y_l yl​ 同一个提示下的不同响应,用于对比输出结果;
  • D D D 是人类标注的比较数据集;
  • K K K 是输出的响应数量;
  • σ \sigma σ 是 sigmoid 函数,用于将输出映射到0和1之间,表示概率。

3.3 RLHF训练

RLHF(Reinforcement Learning from Human Feedback,从人类反馈中进行强化学习)。从预训练好的 GPT-3 模型初始化 RLHF 模型,并在数据集上对它进行两轮有监督微调,在微调过程中还混合了 10% 的预训练数据。

微调阶段:

(1)学习率采用使用余弦,学习率衰减至原始学习率的 10%;

(2)对于 1.3B 参数和 6B 参数模型,初始学习率为 5 × 1 0 − 6 5×10^-6 5×10−6, 1.04 × 1 0 − 6 1.04×10^-6 1.04×10−6,batchsize 为 32;

(3)对于 175B 参数模型,初始学习率为 2.45 × 1 0 − 6 2.45×10^-6 2.45×10−6 的学习率,batchsize 为 8;

RLHF训练:

(1)从上述微调的模型中初始化强化学习策略,计算 KL 奖励;

(2)batchsize 为512,每个 batch 被随机拆分为8个 minibatch;

(3)在前10次迭代中,从峰值学习率的十分之一开始,应用恒定的学习率进行预热。权重采用指数移动平均值,衰减率为0.992;

(4)PPO剪切比设置为0.2,rollout 采样温度为1;

(5)1.3B 和6B 策略的价值函数采用固定学习率为 9 × 1 0 − 6 9×10^-6 9×10−6,而175B策略采用 5 × 1 0 − 6 5×10^-6 5×10−6;

(6)强化学习的目标函数如下:

objective ( ϕ ) = E ( x , y ) ∼ π ϕ R L [ r θ ( x , y ) − β log ⁡ ( π ϕ R L ( y ∣ x ) π SFT ( y ∣ x ) ) ] + γ E x ∼ D pretrain [ log ⁡ ( π ϕ R L ( x ) ) ] \text{objective}(\phi) = \mathbb{E}_{(x, y) \sim \pi_{\phi}^{RL}} \left[ r_\theta(x, y) - \beta \log \left( \frac{\pi_{\phi}^{RL}(y | x)}{\pi^{\text{SFT}}(y | x)} \right) \right] + \gamma \mathbb{E}_{x \sim D_{\text{pretrain}}} \left[ \log(\pi_{\phi}^{RL}(x)) \right] objective(ϕ)=E(x,y)∼πϕRL​​[rθ​(x,y)−βlog(πSFT(y∣x)πϕRL​(y∣x)​)]+γEx∼Dpretrain​​[log(πϕRL​(x))]

其中:

  • 奖励模型输出(Reward Model Output): r θ ( x , y ) r_\theta(x, y) rθ​(x,y) 是奖励模型对输入 x x x 和输出 y y y 的评分,反映了人类标注者对输出的偏好。

  • 策略梯度调整项(Policy Gradient Adjustment): 这一项 − β log ⁡ ( π ϕ R L ( y ∣ x ) π SFT ( y ∣ x ) ) -\beta \log \left( \frac{\pi_{\phi}^{RL}(y | x)}{\pi^{\text{SFT}}(y | x)} \right) −βlog(πSFT(y∣x)πϕRL​(y∣x)​) 使用 KL 散度惩罚来调整 RL 策略 π ϕ R L \pi_{\phi}^{RL} πϕRL​ 与监督学习策略 π SFT \pi^{\text{SFT}} πSFT 之间的差异。 β \beta β 是控制这种调整强度的超参数。

  • 预训练分布一致性项(Pretraining Distribution Consistency): 这一项 γ E x ∼ D pretrain [ log ⁡ ( π ϕ R L ( x ) ) ] \gamma \mathbb{E}_{x \sim D_{\text{pretrain}}} \left[ \log(\pi_{\phi}^{RL}(x)) \right] γEx∼Dpretrain​​[log(πϕRL​(x))] 通过最大化预训练数据分布 D pretrain D_{\text{pretrain}} Dpretrain​ 下的对数似然,来确保 RL 策略生成的输出与预训练模型的输出保持一致。 γ \gamma γ 是控制这一部分贡献的超参数。

4. 结论

(1)标注人员偏好 InstructGPT 输出:在测试集上,人类标注者显著偏好 InstructGPT 模型的输出而不是 GPT-3。即使是参数数量少100倍的1.3B 参数的 InstructGPT 模型,其输出也比175B 参数的GPT-3更受偏好;

(2)InstructGPT 在真实性方面表现更好:在 TruthfulQA 基准测试中,InstructGPT 生成真实且提供信息的回答的频率是 GPT-3 的两倍;

(3)InstructGPT 在减少毒性输出方面有改进,但在减少偏见方面没有显著改进:使用 RealToxicityPrompts 数据集进行自动和人类评估,发现当被提示要尊重时,InstructGPT 模型生成的有毒输出比 GPT-3少约25%。然而,在 Winogender 和 CrowSPairs 数据集上,InstructGP T并没有比 GPT-3 有显著的改进;

(4)通过修改 RLHF 微调过程,可以最大程度上降低在公共数据集上的性能退化:在 RLHF 微调期间,在某些公共 NLP 数据集上观察到性能退化。通过将 PPO 更新与增加预训练分布的对数似然的更新混合,可以大大减少这些数据集上的性能退化,同时不损害标注者的偏好分数;

(5)InstructGPT 模型输出能引起广泛人员的偏好:通过额外人员实验,发现他们偏好 InstructGPT 输出的程度与训练标注者大致相同;

(6)InstructGPT 仍然会犯简单的错误:例如,InstructGPT 可能无法遵循指令、编造事实、对简单问题给出长篇的犹豫不决的回答,或者无法检测到包含虚假前提的指令。


欢迎关注本人,我是喜欢搞事的程序猿; 一起进步,一起学习;

欢迎关注知乎/CSDN:SmallerFL

也欢迎关注我的wx公众号(精选高质量文章):一个比特定乾坤

请添加图片描述

标签:10,Training,feedback,language,SFT,训练,模型,InstructGPT,log
From: https://blog.csdn.net/qq_36803941/article/details/141262533

相关文章

  • M3KE: A Massive Multi-Level Multi-Subject Knowledge Evaluation Benchmark for Chi
    文章目录题目摘要简介相关工作M3KE实验结论题目M3KE:面向中文大型语言模型的海量多层次多学科知识评估基准论文地址:https://arxiv.org/abs/2305.10263项目地址:https://github.com/tjunlp-lab/M3KE摘要    大型语言模型最近在跨任务泛化、指令跟随等多个......
  • 新手常见错误:Language level is invalid or missing in pom.xml. Current project JDK
    目录Blue留声机:分析报错 Blue留声机:今天开一个maven的时候遇到这样一个报错,这个报错对于我来言是一个并不陌生的报错,早期学习spring框架的时候,遇到过这个问题,当时怎么也弄不出来(现在想想那个时候的我真菜),现在却对这种问题的解决游刃有余。好了,不多bb了,看看我一般处理bu......
  • 【学习日记3】DAIL-SQL论文:Text-to-SQL Empowered by Large Language Models: A Bench
    PS:自己回顾用的ABSTRACT        大型语言模型(LLMs)已成为Text-to-SQL任务的新模式。然而,缺乏系统的基准测试限制了有效、高效和经济的基于LLM的Text-to-SQL方案的发展。为了解决这一挑战,本文首先对现有的提示工程方法进行了系统且广泛的比较,包括问题表示、示例......
  • Enhancing Question Answering for Enterprise Knowledge Bases using Large Language
    本文是LLM系列文章,针对《EnhancingQuestionAnsweringforEnterpriseKnowledgeBasesusingLargeLanguageModels》的翻译。使用大型语言模型增强企业知识库的问答能力摘要1引言2相关工作3前言4方法5实验6结论摘要高效的知识管理在提高企业和组......
  • Large Language Models meet Collaborative Filtering
    本文是LLM系列文章,针对《LargeLanguageModelsmeetCollaborativeFiltering:AnEfficientAll大型语言模型与协同过滤:一个高效的基于LLM的全方位推荐系统摘要1引言2相关工作3问题定义4提出的方法5实验6结论摘要协同过滤推荐系统(CFRecSys)在增强社......
  • [paper阅读笔记][2023]CorpusLM: Towards a Unified Language Model on Corpusfor Kno
    文章链接:https://arxiv.org/pdf/2402.01176v2Paper的任务处理各种知识密集型任务任务的科学问题本文任务虽然是:提出一个统一的语言模型来处理各种知识密集型任务,但其实其本质科学问题是:如何提高LLMs在知识密集型任务中的检索效率。原因是:LLMs在生成文本时容易出现错误信......
  • Pixel Aligned Language Models论文阅读笔记
    Motivation&Abs近年来,大语言模型在视觉方面取得了极大的进步,但其如何完成定位任务(如wordgrounding等)仍然不清楚。本文旨在设计一种模型能够将一系列点/边界框作为输入或者输出。当模型接受定位信息作为输入时,可以进行以定位为condition的captioning。当生成位置作为输出时,模型......
  • 易优CMS模板标签language语言列表罗列所有语言列表
    【基础用法】标签:languagename值:web_language_switch描述:语言列表标签,获取多语言列表内容。用法:{eyou:languagetype='default'}<ahref="{$field.url}"><imgsrc="{$field.logo}"alt="{$field.title}">{$field.title}</a>{/eyou:......
  • Codeforces Round 929 (Div. 3)---->E. Turtle vs. Rabbit Race: Optimal Trainings
    https://codeforces.com/contest/1933/problem/E#include<bits/stdc++.h>#definexfirst#defineysecondusingnamespacestd;typedeflonglongll;typedef__int128i128;typedefpair<int,int>pii;constintN=2e5+10,M=110;intn,q;inta[N];ll......
  • USACO Training辅导课    刷题记录
    Chapter1入门Section1.1介绍Section1.2提交解决方案,任务类型,特殊问题AcWing1339.你的旅途由此开始753人打卡AcWing1340.贪婪的送礼者581人打卡AcWing1341.十三号星期五521人打卡AcWing1342.断开的项链447人打卡Section1.3完全搜索AcWing1343.挤牛奶472......