首页 > 编程语言 >LLM面面观之RLHF平替算法DPO

LLM面面观之RLHF平替算法DPO

时间:2024-01-31 12:44:40浏览次数:34  
标签:logps 平替 reference chosen 面面观 rejected RLHF policy DPO

1. 背景

最近本qiang~老看到一些关于大语言模型的DPO、RLHF算法,但都有些云里雾里,因此静下心来收集资料、研读论文,并执行了下开源代码,以便加深印象。

此文是本qiang~针对大语言模型的DPO算法的整理,包括原理、流程及部分源码。

2. DPO vs RLHF

 

 

上图左边是RLHF算法,右边为DPO算法,两图的差异对比即可体现出DPO的改进之处。

1. RLHF算法包含奖励模型(reward model)和策略模型(policy model,也称为演员模型,actor model),基于偏好数据以及强化学习不断迭代优化策略模型的过程。

2. DPO算法不包含奖励模型和强化学习过程,直接通过偏好数据进行微调,将强化学习过程直接转换为SFT过程,因此整个训练过程简单、高效,主要的改进之处体现在于损失函数。

PS:

1. 偏好数据,可以表示为三元组(提示语prompt, 良好回答chosen, 一般回答rejected)。论文中的chosen表示为下标w(即win),rejected表示为下标l(即lose)

2. RLHF常使用PPO作为基础算法,整体流程包含了4个模型,且通常训练过程中需要针对训练的actor model进行采样,因此训练起来,稳定性、效率、效果不易控制。

1) actor model/policy model: 待训练的模型,通常是SFT训练后的模型作为初始化

2) reference model: 参考模型,也是经SFT训练后的模型进行初始化,且通常与actor model是同一个模型,且模型冻结,不参与训练,其作用是在强化学习过程中,保障actor model与reference model的分布差异不宜过大。

3) reward model: 奖励模型,用于提供每个状态或状态动作对的即时奖励信号。

4) Critic model: 作用是估计状态或状态动作对的长期价值,也称为状态值函数或动作值函数。

3. DPO算法仅包含RLHF中的两个模型,即演员模型(actor model)以及参考(reference model),且训练过程中不需要进行数据采样。

4. RLHF可以参考附件中的引文

3. DPO的损失函数

 

 

如何将RLHF的Reward model过程简化为上式,作者花了大量篇幅进行了推导,感兴趣的读者可以参考附件DPO的论文。

DPO算法的目的是最大化奖励模型(此处的奖励模型即为训练的策略),使得奖励模型对chosen和rejected数据的差值最大,进而学到人类偏好。

上式的后半部分通过对数函数运算规则,可以进行如下转化。

 

 

转化后的公式和源代码中的计算函数中的公式是一致的。

其中左半部分是训练的policy模型选择chosen优先于rejected,右半部分是冻结的reference模型选择chosen优先于rejected,二者的差值可类似于KL散度,保障actor模型的分布与reference模型的分布不会有较大的差异。

4. 微调流程

 

 

上图展示了DPO微调的大致流程,其中Trained LM即为策略模型,Frozen LM即为参考模型,二者均是先进行SFT微调得到的模型进行初始化,其中Trained LM需要进行训练,Frozen LM不参与训练。

两个模型分别针对chosen和rejected进行预测获取对应的得分,再通过DPO的损失函数进行损失计算,进而不断的迭代优化。

5. 源码

源码参考代码:https://github.com/eric-mitchell/direct-preference-optimization

5.1 DPO损失函数

 1 def preference_loss(policy_chosen_logps: torch.FloatTensor,
 2                     policy_rejected_logps: torch.FloatTensor,
 3                     reference_chosen_logps: torch.FloatTensor,
 4                     reference_rejected_logps: torch.FloatTensor,
 5                     beta: float,
 6                     label_smoothing: float = 0.0,
 7                     ipo: bool = False,
 8                     reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
 9     # policy_chosen_logps: 训练模型对于chosen经过log后logits
10     # policy_rejected_logps: 训练模型对于rejected经过log后logits
11     # reference_chosen_logps: 训练模型对于chosen经过log后logits
12     # reference_rejected_logps: 训练模型对于rejected经过log后logits
13     # beta: policy和reference的差异性控制参数
14     
15     # actor模型选择chosen优先于rejected
16     pi_logratios = policy_chosen_logps - policy_rejected_logps
17     # reference模型选择chosen优先于rejected
18     ref_logratios = reference_chosen_logps - reference_rejected_logps
19 
20     if reference_free:
21         ref_logratios = 0
22     
23     # 差值可类似于KL散度,保障actor模型的分布与reference模型的分布不会有较大的差异
24     logits = pi_logratios - ref_logratios  # also known as h_{\pi_\theta}^{y_w,y_l}
25 
26     if ipo:
27         losses = (logits - 1/(2 * beta)) ** 2  # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
28     else:
29         # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
30         # label_smoothing为0,对应的DPO论文的算法
31         losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(-beta * logits) * label_smoothing
32     
33     # chosen和rejected的奖励
34     chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
35     rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()
36 
37     return losses, chosen_rewards, rejected_rewards

 

5.2 批次训练过程

 1 def get_batch_metrics(self, batch: Dict[str, Union[List, torch.LongTensor]], loss_config: DictConfig, train=True):
 2     """Compute the SFT or DPO loss and other metrics for the given batch of inputs."""
 3 
 4     if loss_config.name in {'dpo', 'ipo'}:
 5         # policy模型针对chosen和rejected进行预测
 6         policy_chosen_logps, policy_rejected_logps = self.concatenated_forward(self.policy, batch)
 7         with torch.no_grad():
 8             # reference模型针对chosen和rejected进行预测
 9             reference_chosen_logps, reference_rejected_logps = self.concatenated_forward(self.reference_model, batch)
10 
11         if loss_config.name == 'dpo':
12             loss_kwargs = {'beta': loss_config.beta, 'reference_free': loss_config.reference_free, 'label_smoothing': loss_config.label_smoothing, 'ipo': False}
13         elif loss_config.name == 'ipo':
14             loss_kwargs = {'beta': loss_config.beta, 'ipo': True}
15         else:
16             raise ValueError(f'unknown loss {loss_config.name}')
17         # 损失计算
18         losses, chosen_rewards, rejected_rewards = preference_loss(
19             policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, **loss_kwargs)
20 
21         reward_accuracies = (chosen_rewards > rejected_rewards).float()
22 
23     elif loss_config.name == 'sft':
24         policy_chosen_logits = self.policy(batch['chosen_input_ids'], attention_mask=batch['chosen_attention_mask']).logits.to(torch.float32)
25         policy_chosen_logps = _get_batch_logps(policy_chosen_logits, batch['chosen_labels'], average_log_prob=False)
26 
27         losses = -policy_chosen_logps
28 
29     return losses.mean()

 

5.3 LM的交叉熵计算

 1 def _get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False) -> torch.FloatTensor:
 2     # 经模型后的logits进行批量计算logps
 3     
 4     assert logits.shape[:-1] == labels.shape
 5     
 6     # 基于先前的token预测下一个token
 7     labels = labels[:, 1:].clone()
 8     logits = logits[:, :-1, :]
 9     loss_mask = (labels != -100)
10 
11     # dummy token; we'll ignore the losses on these tokens later
12     labels[labels == -100] = 0
13     
14     # 交叉熵函数
15     per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
16 
17     if average_log_prob:
18         return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
19     else:
20         return (per_token_logps * loss_mask).sum(-1)

 

5.4 其他注意

1. hugging face设置代理

源码会从hugging face中下载英文语料和模型,由于网络限制,因此设置代理映射,将HF_ENDPOINT设置为https://hf-mirror.com,即设置: os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

2. 如果仅想要熟悉DPO整体流程,可以下载较小的生成式模型,如BLOOM 560M,GPT2等

6. 总结

一句话足矣~

本文主要针对大语言模型的DPO算法的整理,包括原理、流程及部分源码。

此外,建议大家可以针对源码进行运行,源码的欢迎大家一块交流。

7. 参考

(1) RLHF:https://blog.csdn.net/v_JULY_v/article/details/128579457

(2) DPO论文: https://arxiv.org/pdf/2305.18290v2.pdf

(3) DPO代码: https://github.com/eric-mitchell/direct-preference-optimization

(4) DPO理解1:https://medium.com/@joaolages/direct-preference-optimization-dpo-622fc1f18707

(5) DPO理解2: https://zhuanlan.zhihu.com/p/669825918

 

 

标签:logps,平替,reference,chosen,面面观,rejected,RLHF,policy,DPO
From: https://www.cnblogs.com/mengrennwpu/p/17999027

相关文章

  • LLM成功不可或缺的RLHF基于人类反馈的强化学习是如何运作的?OJAC近屿智能带你揭秘
    基于人类反馈的强化学习(RLHF,ReinforcementLearningfromHumanFeedback)是人工智能(AI)领域的一个新兴研究领域,它将强化学习技术与人类反馈相结合,以训练能够学习复杂任务的个体。该方法在提高人工智能系统的性能方面显示出前景,使其在各种应用中更具有适应性和效率。 强化学习......
  • LLM面面观之LLM上下文扩展方案
    1.背景本qiang~这段时间调研了LLM上下文扩展的问题,并且实打实的运行了几个开源的项目,所谓实践与理论相结合嘛!此文是本qiang~针对上下文扩展问题的总结,包括解决方案的整理概括,文中参考了多篇有意义的文章,他山之石可以攻玉。大语言模型的扩展有诸多意义,如进行更长的会话、总结更......
  • RLHF · PbRL | 速通 ICLR 2024 RLHF
    检索关键词:ICLR2024、reinforcementlearning、preference、humanfeedback。https://openreview.net/search?term=ICLR+2024+reinforcement+learning+preference+human+feedback&group=all&content=all&source=forumContrastivePreferenceLearning:LearningfromH......
  • Navicat平替工具,一款免费开源的通用数据库工具
    前言前段时间有小伙伴在群里提问说:因为公司不允许使用破解版的Navicat,有好用的Navicat平替工具推荐吗?今天分享一款免费开源的通用数据库工具:DBeaver。工具介绍DBeaver是一款免费的跨平台数据库工具,适用于开发人员、数据库管理员、分析师和所有数据处理人员。它支持所有流行的S......
  • Code Review、InLineChat、RAG能力全部独家提供,这波上新CodeGeeX平替Github Copilot稳
    智谱AI2024年度的技术开放日上,CodeGeeX重磅发布第三代模型。针对CodeGeeX插件产品的系列新功能,也同时上线发布,提供给用户免费使用。一、第三代模型性能全面提升CodeGeeX第三代模型正式发布,基础能力全面提升。针对Python、Java、JavaScript、C++、Golang五种主流编程语言,代......
  • SCA面面观 | 如何生成一份软件物料清单SBOM?
    由于网络安全挑战和不断变化的威胁环境,使得软件供应链安全成为了一个重要议题。特别是近年来,软件供应链的复杂性和全球化程度的提升,第三方软件的安全性和可追溯性变得越来越重要。为了应对这一挑战,从美国政府开始,各个国家开始积极推动软件供应链安全的相关政策和措施。美国陆续发布......
  • SCA面面观 | 企业该如何选择组件检测工具?
    一般来说,一个软件应用程序可以被分解成若干部分,为软件程序解耦,以减少整个应用程序的复杂性,这些部分就是软件组件。以一种标准化的方式相互作用,使得组件可以像机器的“零部件”一样被换入或换出,因组件具有独立性、可重用行、高内聚、低耦合等优势,可以帮助企业提高开发效率和质量,减少......
  • 使用Pipenv进行Python虚拟环境管理--conda平替
    Pipenv使用教程Anaconda是一个开箱即用的Python开发环境,同时也包含虚拟环境管理工具conda。但是Anaconda的缺点包括:大型安装包:Anaconda的安装包相对较大,需要消耗较多的磁盘空间。依赖冲突:在使用Anaconda时,若安装包过多可能会出现依赖冲突的情况,需要手动解决。此时则......
  • 大模型 RLHF 实战!【OpenAI独家绝技RLHF!RLHF的替代算法DPO!Claude 暗黑科技 RAIHF!】
    大模型RLHF实战大模型RLHF实战RLHF:OpenAI独家绝技RLHF的问题DPO直接偏好优化算法:RLHF的替代算法公式1-4:KL散度下奖励的最大化目标使用DPO微调Llama2RAIHF 大模型RLHF实战RLHF(基于人类反馈的强化学习)分为3个阶段:预训练:为了生成内容,需要一个生成式的预训练语言模......
  • 找到了!GitHub Copilot的最佳免费平替
    在如今这个人工智能高速发展的时代,每个行业都在被AI技术影响而改变。层出不穷的AI辅助工具,让我们看到了机器正在取代一部分基础的日常工作。对于我们开发者而言,当前最炙手可热的就是GitHubCopilot,市面上最好的开发者辅助工具。GitHubCopilot所提供的代码补全、建议、解释等能力......