首页 > 其他分享 >PPO_trainer

PPO_trainer

时间:2025-01-21 18:21:05浏览次数:1  
标签:trainer log mask self PPO actor critic model

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
import torch
import torch.nn.functional as F
import time
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.accelerator import get_accelerator

from deepspeed_chat.dschat.utils.utils import print_rank_0


def print_all_ranks(tag, value, rank):
    world_size = torch.distributed.get_world_size()
    all_tensor = torch.zeros(world_size, dtype=torch.float32).to(
        get_accelerator().current_device_name())
    all_tensor[rank] = value
    torch.distributed.all_reduce(all_tensor, op=torch.distributed.ReduceOp.SUM)
    print_rank_0(f'{tag} {all_tensor}', rank)


def get_model_norm(model):
    with torch.no_grad():
        total = 0.0
        for param in model.parameters():
            should_gather = hasattr(
                param,
                'ds_id') and param.ds_status == ZeroParamStatus.NOT_AVAILABLE
            with deepspeed.zero.GatheredParameters(param,
                                                   enabled=should_gather):
                total += float(param.float().norm())

    return total


def gather_log_probs(logits, labels):#lohit:[2, 511, 50272]. label[2, 511]
    log_probs = F.log_softmax(logits, dim=-1)
    log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) #log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) 用于从 log_probs 中提取每个 token 对应 labels 的 log 概率。这里 labels 的形状是 [2, 511],但为了匹配 gather 的维度要求,我们使用 unsqueeze(-1) 将 labels 的形状扩展为 [2, 511, 1]。
    return log_probs_labels.squeeze(-1)


class DeepSpeedPPOTrainer():

    def __init__(self, rlhf_engine, args):
        self.rlhf_engine = rlhf_engine
        self.actor_model = self.rlhf_engine.actor
        self.critic_model = self.rlhf_engine.critic
        self.ref_model = self.rlhf_engine.ref
        self.reward_model = self.rlhf_engine.reward
        self.tokenizer = self.rlhf_engine.tokenizer
        self.args = args
        self.max_answer_seq_len = args.max_answer_seq_len
        self.end_of_conversation_token_id = self.tokenizer(
            args.end_of_conversation_token)['input_ids'][-1]
        self.z3_enabled = args.actor_zero_stage == 3
        self.compute_fp32_loss = self.args.compute_fp32_loss

        self.last_generated_experience = None

        self.kl_ctl = 0.1
        self.clip_reward_value = 5
        self.cliprange = 0.2
        self.cliprange_value = 0.2
        self.gamma = 1.0
        self.lam = 0.95
        self.generate_time = 0.0

    def _generate_sequence(self, prompts, mask, step):
        max_min_length = self.max_answer_seq_len + prompts.shape[1] #最大回复长度256+prompt长度256
        # 由于在启用 do_sample 后发生了概率/nan 错误,已添加此项修复:
        # https://huggingface.co/meta-llama/Llama-2-7b-hf/commit/6fdf2e60f86ff2481f2241aaee459f85b5b0bbb9
        if self.actor_model.module.config.model_type == "llama":
            kwargs = dict(do_sample=False)
        else:
            kwargs = dict()
        # 演员生成序列
        with torch.no_grad():
            seq = self.actor_model.module.generate(
                prompts,
                attention_mask=mask,
                max_length=max_min_length,
                pad_token_id=self.tokenizer.pad_token_id,
                synced_gpus=self.z3_enabled,
                **kwargs)

        # 过滤掉没有答案(或非常短)的序列。这种情况发生在用户直接使用预训练模型检查点而没有进行有监督微调时。
        # 注意:这会导致每个 GPU 拥有不同数量的样本。
        batch_size = seq.shape[0] #2
        prompt_length = prompts.shape[1] # 256
        self.prompt_length = prompt_length # 256
        ans = seq[:, prompt_length:] # 从256到512部分是生成的序列,它直接在原始数据续上了
        valid_ans_len = (ans != self.tokenizer.pad_token_id).sum(dim=-1)

        if self.args.print_answers and (step % self.args.print_answers_interval== 0):
            print(f"--- prompt --> step={step}, rank={torch.distributed.get_rank()}, {self.tokenizer.batch_decode(prompts, skip_special_tokens=True)}")
            print(f"--- ans    --> step={step}, rank={torch.distributed.get_rank()}, {self.tokenizer.batch_decode(ans, skip_special_tokens=True)}")

        out_seq = []
        for i in range(batch_size):
            if valid_ans_len[i] <= 1:  # if the answer is shorter than 1 token, drop it
                print(
                    f'Dropping too short generated answer: {step=}: \n'
                    f'prompts: {self.tokenizer.batch_decode(prompts, skip_special_tokens=False)}\n'
                    f'answers: {self.tokenizer.batch_decode(ans, skip_special_tokens=False)}'
                )
                continue
            else:
                out_seq.append(seq[i:i + 1])# seq[b, 512],取seq[0:1], seq[1:2]。

        if not out_seq:
            print(
                f'All generated results are too short for rank={self.args.local_rank} step={step}\n'
                f'-> prompts: {self.tokenizer.batch_decode(prompts, skip_special_tokens=False)}\n'
                f'-> answers: {self.tokenizer.batch_decode(ans, skip_special_tokens=False)}'
            )
            return None
        out_seq = torch.cat(out_seq, dim=0)  # out_seq是list,每个元素[1,512], cat之后变成了[b, 512]
        return out_seq

    def generate_experience(self, prompts, mask, step):
        self.eval() #生成经验过程全部参数固定不动
        generate_start = time.time()
        #演员模型推理了一下变成标签了
        seq = self._generate_sequence(prompts, mask, step) #由batch个prompt获得batch个回答,注意seq是问题和回答的拼接,
        generate_end = time.time()
        if seq is None:
            assert self.last_generated_experience is not None, f'Invalid generated experience at {step=}'
            prompts = self.last_generated_experience['prompts']
            seq = self.last_generated_experience['seq']
        else:
            self.last_generated_experience = {'prompts': prompts, 'seq': seq}
        self.train() #转训练了?没有,为了获得分数

        pad_token_id = self.tokenizer.pad_token_id
        attention_mask = seq.not_equal(pad_token_id).long()#[2, 512],不等于pad的mask=1
        with torch.no_grad(): #上一个generate,这个是forword
            output = self.actor_model(seq, attention_mask=attention_mask)
            output_ref = self.ref_model(seq, attention_mask=attention_mask)
            reward_score = self.reward_model.forward_value(seq, attention_mask,prompt_length=self.prompt_length)['chosen_end_scores'].detach()
            values = self.critic_model.forward_value(seq, attention_mask, return_value_only=True).detach()[:, :-1]

        logits = output.logits
        logits_ref = output_ref.logits
        if self.compute_fp32_loss:
            logits = logits.to(torch.float)
            logits_ref = logits_ref.to(torch.float)
        self.generate_time = generate_end - generate_start

        return {
            'prompts': prompts,
            'logprobs': gather_log_probs(logits[:, :-1, :], seq[:, 1:]),#logits[:, :-1, :]指的是除了最后一个字,输出是[b, 512, 30000], #[b, 512]除了第一个字
            'ref_logprobs': gather_log_probs(logits_ref[:, :-1, :], seq[:,1:]), #用于从 log_probs 中提取每个 token 对应 labels 的 log 概率
            'value': values,
            'rewards': reward_score,
            'input_ids': seq,
            "attention_mask": attention_mask
        }

    def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score,
                        action_mask):
        """
        reward_function:计算最终的reward分数
        复习一下几个相关参数的默认值:
        self.kl_ctl = 0.1
        self.clip_reward_value = 5
        对于batch中的某个prompt来说,它最终的reward分数为:
        (1) 先计算actor和ref_model的logit相似度: -self.kl_ctl * (log_probs - ref_log_probs)
            其实写成self.kl_ctl * (ref_log_probs - log_probs)更好理解些
            这个值越大,说明ref_model对actor生成的结果的认可度越高(即表明rlhf没有训歪),
            没有训歪的情况下我们也应该给模型一些奖励,这个奖励就是self.kl_ctl * (ref_log_probs - log_probs)
        (2)由于我们只取最后一个token对应位置的分数作为reward_score,因此我们只需要:
            self.kl_ctl * (ref_log_probs - log_probs)的最后一位 + reward_score
         (3) 同时我们对reward_score也做了大小限制,最大不超过self.clip_reward_value(超过统一给成self.clip_reward_value),
             最小不低于-self.clip_reward_value(低于统一给成-self.clip_reward_value)
         (4) 最后返回的rewards大小为:(batch_size, 各条数据的长度),对batch中的每条数据来说:
             - response的最后一位:self.kl_ctl * (ref_log_probs - log_probs)的最后一位 + reward_score
             - response的其余位置:self.kl_ctl * (ref_log_probs - log_probs)
        """
        kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs) #[2, 511]
        rewards = kl_divergence_estimate #[2, 511]
        start = prompts.shape[1] - 1 #回复的开始位置:255
        ends = start + action_mask[:, start:].sum(1) + 1#回复的结束位置:[512, 512],注意计算有效的回复
        reward_clip = torch.clamp(reward_score, -self.clip_reward_value,
                                  self.clip_reward_value)#将两个分数值裁剪,[0.14, 0.59]
        batch_size = log_probs.shape[0]
        for j in range(batch_size):
            rewards[j, start:ends[j]][-1] += reward_clip[j] #这部分表示将裁剪后的奖励值(reward_clip[j])加到提取到的最后一个奖励值上。因此,奖励就是参考模型与演员的KL再加奖励

        return rewards

    def train_rlhf(self, inputs):
        # # 在这里训练 RLHF 模型
        ### process the old outputs
        prompts = inputs['prompts']#原始数据
        log_probs = inputs['logprobs']#计算 token 在其实际标签位置上的对数概率:
        ref_log_probs = inputs['ref_logprobs']# 推理一次输出与seq标签在试剂标签的对数概率
        reward_score = inputs['rewards']#[2],batch中每句话的分数
        values = inputs['value']#[2,511],seq去掉最后一个字,511个字的价值
        attention_mask = inputs['attention_mask']##seq中[2, 512],不等于pad的mask=1
        seq = inputs['input_ids'] #seq是问题和回答的拼接,标签数据

        start = prompts.size()[-1] - 1# 256-1=255
        action_mask = attention_mask[:, 1:]#去掉第一个字的mask,[2,511]

        old_values = values
        with torch.no_grad():
            old_rewards = self.compute_rewards(prompts, log_probs, ref_log_probs, reward_score, action_mask)#【2, 511】,计算奖励是为了训练演员,要求kl散度+奖励模型最后一个字的奖励分数
            ends = start + action_mask[:, start:].sum(1) + 1 #有效回复的最终位置
            # 我们需要在对话结束后将奖励和价值清零
            # 否则优势/回报将会出错
            for i in range(old_rewards.shape[0]):
                old_rewards[i, ends[i]:] = 0 # old_rewards[0, 512:]
                old_values[i, ends[i]:] = 0
            advantages, returns = self.get_advantages_and_returns(old_values, old_rewards, start)

        ### process the new outputs
        batch = {'input_ids': seq, "attention_mask": attention_mask}
        actor_prob = self.actor_model(**batch, use_cache=False).logits
        actor_log_prob = gather_log_probs(actor_prob[:, :-1, :], seq[:, 1:]) #输入序列是除了最后一个字,标签是除了第一个字
        actor_loss = self.actor_loss_fn(actor_log_prob[:, start:],
                                        log_probs[:, start:], advantages,
                                        action_mask[:, start:])
        self.actor_model.backward(actor_loss)
        # T5走到这一步应该就够了
        if not self.args.align_overflow:
            self.actor_model.step()
        # 批评家模型也需要训练,[2, 511]
        value = self.critic_model.forward_value(**batch,
                                                return_value_only=True,
                                                use_cache=False)[:, :-1]
        critic_loss = self.critic_loss_fn(value[:, start:], old_values[:,start:],
                                          returns, action_mask[:, start:])
        self.critic_model.backward(critic_loss)

        if self.args.align_overflow:
            actor_overflow = self.actor_model.optimizer.check_overflow(
                external=True)
            critic_overflow = self.critic_model.optimizer.check_overflow(
                external=True)

            rank = torch.distributed.get_rank()
            if actor_overflow and not critic_overflow:
                self.critic_model.optimizer.skip_step = True
                print_rank_0(
                    "OVERFLOW: actor overflow, skipping both actor and critic steps",
                    rank)
            elif not actor_overflow and critic_overflow:
                self.actor_model.optimizer.skip_step = True
                print_rank_0(
                    "OVERFLOW: critic overflow, skipping both actor and critic steps",
                    rank)
            elif actor_overflow and critic_overflow:
                print_rank_0(
                    "OVERFLOW: actor and critic overflow, skipping both actor and critic steps",
                    rank)
            self.actor_model.step()

        self.critic_model.step()

        return actor_loss, critic_loss

    def get_overflow(self):
        # Overflow is not expected when using bf16
        # Therefore, DeepSpeed's BF16_Optimizer does not maintain an overflow indication
        if self.args.dtype == "bf16":
            return False, False

        actor_overflow = self.actor_model.optimizer.overflow
        critic_overflow = self.critic_model.optimizer.overflow

        return actor_overflow, critic_overflow

    def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
        ## policy gradient loss
        """
        logprobs: 实时计算的,response部分的prob(只有这个是随着actor实时更新而改变的)
        old_logprobs:老策略中,response部分的prob (这个是固定的,不随actor实时更新而改变)
        advantages: 老策略中,response部分每个token对应的优势(这个是固定的,不随actor实时更新而改变)
        mask:老策略中,response部分对应的mask情况这个是固定的,不随actor实时更新而改变)
        之所以要引入logprobs计算actor_loss,是因为我们不希望策略每次更新的幅度太大,防止模型训歪
        self.cliprange: 默认值是0.2
        """
        log_ratio = (logprobs - old_logprobs) * mask
        ratio = torch.exp(log_ratio) #e^log(a/b)=a/b , [2, 256]
        pg_loss1 = -advantages * ratio #loss = -Adv * (a/b)
        pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
                                             1.0 + self.cliprange)
        pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
        return pg_loss

    def critic_loss_fn(self, values, old_values, returns, mask):
        # 用旧的value去约束新的value
        ## value loss
        values_clipped = torch.clamp(#【2,256】
            values,
            old_values - self.cliprange_value,
            old_values + self.cliprange_value,
        )
        if self.compute_fp32_loss:
            values = values.float()
            values_clipped = values_clipped.float()
        # critic模型的loss定义为(预估预期收益-实际预期收益)**2
        vf_loss1 = (values - returns)**2#MSE
        vf_loss2 = (values_clipped - returns)**2
        vf_loss = 0.5 * torch.sum(
            torch.max(vf_loss1, vf_loss2) * mask) / mask.sum() # 同样,最后也是把critic loss平均到每个token上
        return vf_loss

    def get_advantages_and_returns(self, values, rewards, start):
        """
        Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134

        没有引入GAE前的t时刻的优势值:
        detal_t = r_t + gamma * V_t+1 - V_t
        其中:
            - r_t表示t时刻的即时收益
            - V_t+1表示未来时刻的预期收益
            - r_t + gamma * V_t+1可理解成t时刻的实际预期收益
            - V_t可理解成t时刻的预估预期收益(是模型,例如critic model自己估算出来的)

        引入GAE后的t时刻的优势值:
        A_t = delta_t + gamma * lambda * A_t+1
        粗暴理解为在t时刻时,不仅考虑当下优势,还考虑了未来的优势
        为了知道A_t, 我们得知道A_t+1,所以在本算法中采取了从后往前做动态规划求解的方法,也即:
        假设T是最后一个时刻,则有A_T+1 = 0, 所以有: A_T = delta_T
        知道了A_T, 就可以依次往前倒推,把A_t-1, A_t-2之类都算出来了

        引入GAE后t时刻的实际预期收益
        returns_t = A_t + V_t
                  = delta_t + gamma * lambda * A_t+1 + V_t
                  = r_t + gamma * V_t+1 - V_t + gamma * lambda * A_t+1 + V_t
                  = r_t + gamma * (V_t+1 + lambda * A_t+1)

        注意,这里不管是advantages还是returns,都只算response的部分
        """
        # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
        lastgaelam = 0
        advantages_reversed = []
        length = rewards.size()[-1]
        for t in reversed(range(start, length)):
            nextvalues = values[:, t + 1] if t < length - 1 else 0.0
            delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
            lastgaelam = delta + self.gamma * self.lam * lastgaelam
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)#优势
        returns = advantages + values[:, start:] #实际收益=优势+价值
        return advantages.detach(), returns

    def _validate_training_mode(self):
        assert self.actor_model.module.training
        assert self.critic_model.module.training

    def _validate_evaluation_mode(self):
        assert not self.actor_model.module.training
        assert not self.critic_model.module.training
        assert not self.ref_model.module.training
        assert not self.reward_model.module.training

    def train(self):
        self.actor_model.train()
        self.critic_model.train()

    def eval(self):
        self.actor_model.eval()
        self.critic_model.eval()
        self.reward_model.eval()
        self.ref_model.eval()

    def dump_model_norms(self, tag):
        actor_model_norm = get_model_norm(self.actor_model)
        ref_model_norm = get_model_norm(self.ref_model)
        critic_model_norm = get_model_norm(self.critic_model)
        reward_model_norm = get_model_norm(self.reward_model)
        print_all_ranks(f'{tag} global_actor_model_norm', actor_model_norm,
                        self.args.local_rank)
        print_all_ranks(f'{tag} global_ref_model_norm', ref_model_norm,
                        self.args.local_rank)
        print_all_ranks(f'{tag} global_critic_model_norm', critic_model_norm,
                        self.args.local_rank)
        print_all_ranks(f'{tag} global_reward_model_norm', reward_model_norm,
                        self.args.local_rank)


class DeepSpeedPPOTrainerUnsupervised(DeepSpeedPPOTrainer):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def train_unsupervised(self, inputs, unsup_coef):
        # Train the unsupervised model here
        self._validate_training_mode()

        outputs = self.actor_model(**inputs, use_cache=False)
        loss = outputs.loss
        self.actor_model.backward(unsup_coef * loss)
        self.actor_model.step()

        return loss

 

标签:trainer,log,mask,self,PPO,actor,critic,model
From: https://www.cnblogs.com/zhangxianrong/p/18684118

相关文章

  • 详解ppo算法
    详解ppo算法GPT-4oPoePPO(ProximalPolicyOptimization,近端策略优化)是深度强化学习中一种高效、稳定的策略优化算法,由OpenAI于2017年提出。PPO在策略梯度方法上进行了改进,结合了策略优化和信任域约束,使得训练更加稳定且易于实现。以下是对PPO算法的详细解读,包括背......
  • 解决1235 - This version of MySQL doesn‘t yet support ‘LIMIT & IN/ALL/ANY/SOME
    文章讲述了在MySQL中尝试使用IN关键字结合LIMIT子句时遇到的1235错误,即不支持LIMIT&IN/ALL/ANY/SOMEsubquery。解决方案是将子查询封装到另一个查询中,避免IN和LIMIT在同一层次。通过创建一个新的子查询来获取TOP3用户ID,然后在外层查询中使用这些ID过滤用户。SELECT *FROM `u......
  • 【git】Qualcomm 代码clone失败出现RProtocol https not supported or disabled in li
    问题描述    在尝试从https://服务器(ChipCode是)克隆任何内容时收到此输出,则表示您正在使用的curl/libcurl实例是在不支持此协议的情况下构建的。如果在构建时运行的configure脚本找不到curl使SSL工作所需的所有库和包含文件,则可能会发生这种情况。如果conf......
  • RepPoints: Point Set Representation for Object Detection—用于目标检测的点集表示
    用于目标检测的点集表示-RepDet全网最全InternationalConferenceonComputerVision(ICCV2019)对这种检测模型生成的点进行基于点的匹配过程完成跟踪但是可否保证随着人的运动或者形状的改变每次选取的关键点是否一致呢?文章目录用于目标检测的点集表示-RepDet全......
  • PPP和PPPoE原理与配置
    广域网络设备角色介绍CE:用户站点边界设备PE:运营商边界设备P :运营商骨干设备PPP协议概述:PPP认证模式:PAP(明文认证),被认证发主动发起认证CHAP(密文认证),认证发主动发起认证    PAP认证模式:1.认证方配置[R1]intSerial1/0/0[R1-Serial1/0/0]pppauthenticati......
  • PPO算法
    PPO(ProximalPolicyOptimization,近端策略优化)是一种策略梯度方法,广泛应用于强化学习任务中,以训练智能体在复杂环境中做出最优决策。PPO算法的核心目标是通过优化策略,使得智能体的行为逐渐朝向最大化奖励的方向发展,同时保持策略更新的稳定性和效率。1.PPO算法的基本组成:1.1状......
  • vue - 解决报错 Error: error:0308010C:digital envelope routines::unsupported(Vue项
    问题说明在vue2、vue3项目开发中,执行rundev运行|runbuild打包时,Vue报错error:0308010C:digitalenveloperoutines::unsupported,很奇怪的错误,无论是打包编译还是正常运行测试,直接报错终止,并且更改node.js版本依旧无效,试了很多办法都不行,提供详细解决教程!其他教程都无......
  • LockSupport底层源码分析(二)
    目录blocker对象分析基本作用内存屏障效果写入过程实际应用其他线程可见性 诊断和监控blocker对象分析publicclassLockSupport{publicstaticvoidpark(Objectblocker){//1.设置blockerThreadt=Thread.currentThread();......
  • oppo R9m线刷包文件夹目录功能说明
    前言全局说明> 一、说明环境:  二、线刷包文件夹目录功能说明注:以下文件或文件夹,并不是每个刷机包中都有,这里是统计后的。文件(夹)类型中文说明备注data 文件夹  data\app 文件夹系统的app,每个app都有两个文件,一个*.apk,一个*.odex data\res......
  • Chapter 10-11-12. Find AI Opportunities - 4 Stages
    WhoseJobisAIIt’scommonformanagementteamstoassumethatdatascientistsinherentlyknowwhichproblemstosolveforthecompany.However,thisbottom-upapproachtoAIrarelyleadstomeaningfulresults.WhiledatascientistsandMLengineerscan......