首页 > 其他分享 >RL 基础 | 如何复现 PPO,以及一些踩坑经历

RL 基础 | 如何复现 PPO,以及一些踩坑经历

时间:2024-11-21 16:33:31浏览次数:1  
标签:advantage self PPO policy 复现 Delta RL theta pi


最近在复现 PPO 跑 MiniGrid,记录一下…

这里跑的环境是 Empty-5x5 和 8x8,都是简单环境,主要验证 PPO 实现是否正确。

01 Proximal policy Optimization(PPO)

(参考:知乎 | Proximal Policy Optimization (PPO) 算法理解:从策略梯度开始

首先,策略梯度方法 的梯度形式是

\[\nabla_\theta J(\theta)\approx \frac1n \sum_{i=0}^{n-1} R(\tau_i) \sum_{t=0}^{T-1} \nabla_\theta \log \pi_\theta(a_t|s_t) \tag1 \]

然而,传统策略梯度方法容易一步走的太多,以至于越过了中间比较好的点(在参考知乎博客里称为 overshooting)。一个直观的想法是限制策略每次不要更新太多,比如去约束 新策略 旧策略之间的 KL 散度(公式是 plog(p/q)):

\[D_{KL}(\pi_\theta | \pi_{\theta+\Delta \theta}) = \mathbb E_{s,a} \pi_\theta(a|s)\log\frac{\pi_\theta(a|s)}{\pi_{\theta+\Delta \theta}(a|s)} \le \epsilon \tag2 \]

我们把这个约束进行拉格朗日松弛,将它变成一个惩罚项:

\[\Delta\theta^* = \arg\max_{\Delta\theta} J(\theta+\Delta\theta) - \lambda [D_{KL}(\pi_\theta | \pi_{\theta+\Delta \theta})-\epsilon] \tag3 \]

然后再使用一些数学近似技巧,可以得到自然策略梯度(NPG)算法。

NPG 算法貌似还有种种问题,比如 KL 散度的约束太紧,导致每次更新后的策略性能没有提升。我们希望每次策略更新后都带来性能提升,因此计算 新策略 旧策略之间 预期回报的差异。这里采用计算 advantage 的方式:

\[J(\pi_{\theta+\Delta\theta})=J(\pi_{\theta})+\mathbb E_{\tau\sim\pi_{\theta+\Delta\theta}}\sum_{t=0}^\infty \gamma^tA^{\pi_{\theta}}(s_t,a_t) \tag{4} \]

其中优势函数(advantage)的定义是:

\[A^{\pi_{\theta}}(s_t,a_t)=\mathbb E(Q^{\pi_{\theta}}(s_t,a_t)-V^{\pi_{\theta}}(s_t)) \tag{5} \]

在公式 (4) 中,我们计算的 advantage 是在 新策略 的期望下的。但是,在新策略下蒙特卡洛采样(rollout)来算 advantage 期望太麻烦了,因此我们在原策略下 rollout,并进行 importance sampling,假装计算的是新策略下的 advantage。这个 advantage 被称为替代优势(surrogate advantage):

\[\mathcal{L}_{\pi_{\theta}}\left(\pi_{\theta+\Delta\theta}\right) = J\left(\pi_{\theta+\Delta\theta}\right)-J\left(\pi_{\theta}\right)\approx E_{s\sim\rho_{\pi\theta}}\frac{\pi_{\theta+\Delta\theta}(a\mid s)}{\pi_{\theta}(a\mid s)} A^{\pi_{\theta}}(s, a) \tag6 \]

所产生的近似误差,貌似可以用两种策略之间最坏情况的 KL 散度表示:

\[J(\pi_{\theta+\Delta\theta})-J(\pi_{\theta})\geq\mathcal{L}_{\pi\theta}(\pi_{\theta+\Delta\theta})-CD_{KL}^{\max}(\pi_{\theta}||\pi_{\theta+\Delta\theta}) \tag7 \]

其中 C 是一个常数。这貌似就是 TRPO 的单调改进定理,即,如果我们改进下限 RHS,我们也会将目标 LHS 改进至少相同的量。

基于 TRPO 算法,我们可以得到 PPO 算法。PPO Penalty 跟 TRPO 比较相近:

\[\Delta\theta^{*}=\underset{\Delta\theta}{\text{argmax}} \Big[\mathcal{L}_{\theta+\Delta\theta}(\theta+\Delta\theta)-\beta\cdot \mathcal{D}_{KL}(\pi_{\theta}\parallel\pi_{\theta+\Delta\theta})\Big] \tag 8 \]

其中,KL 散度惩罚的 β 是启发式确定的:PPO 会设置一个目标散度 \(\delta\),如果最终更新的散度超过目标散度的 1.5 倍,则下一次迭代我们将加倍 β 来加重惩罚。相反,如果更新太小,我们将 β 减半,从而扩大信任域。

接下来是 PPO Clip,这貌似是目前最常用的 PPO。PPO Penalty 用 β 来惩罚策略变化,而 PPO Clip 与此不同,直接限制策略可以改变的范围。我们重新定义 surrogate advantage:

\[\begin{aligned} \mathcal{L}_{\pi_{\theta}}^{CLIP}(\pi_{\theta_{k}}) = \mathbb E_{\tau\sim\pi_{\theta}}\bigg[\sum_{t=0}^{T} \min\Big( & \rho_{t}(\pi_{\theta}, \pi_{\theta_{k}})A_{t}^{\pi_{\theta_{k}}}, \\ & \text{clip} (\rho_{t}(\pi_{\theta},\pi_{\theta_{k}}), 1-\epsilon, 1+\epsilon) A_{t}^{\pi_{\theta_{k}}} \Big)\bigg] \end{aligned} \tag 9 \]

其中, \(\rho_{t}\) 为重要性采样的 ratio:

\[\rho_{t}(\theta)=\frac{\pi_{\theta}(a_{t}\mid s_{t})}{\pi_{\theta_{k}}(a_{t}\mid s_{t})} \tag{10} \]

公式 (9) 中,min 括号里的第一项是 ratio 和 advantage 相乘,代表新策略下的 advantage;min 括号里的第二项是对 ration 进行的 clip 与 advantage 的相乘。这个 min 貌似可以限制策略变化不要太大。

02 如何复现 PPO(参考 stable baselines3 和 clean RL)

代码主要结构如下,以 stable baselines3 为例:(仅保留主要结构,相当于伪代码,不保证正确性)

import torch
import torch.nn.functional as F
import numpy as np

# 1. collect rollout
self.policy.eval()
rollout_buffer.reset()
while not done:
    actions, values, log_probs = self.policy(self._last_obs)
    new_obs, rewards, dones, infos = env.step(clipped_actions)
    rollout_buffer.add(
        self._last_obs, actions, rewards,
        self._last_episode_starts, values, log_probs,
    )
    self._last_obs = new_obs
    self._last_episode_starts = dones

with torch.no_grad():
    # Compute value for the last timestep
    values = self.policy.predict_values(obs_as_tensor(new_obs, self.device)) 

rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)


# 2. policy optimization
for rollout_data in self.rollout_buffer.get(self.batch_size):
    actions = rollout_data.actions
    values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
    advantages = rollout_data.advantages
    # Normalize advantage
    if self.normalize_advantage and len(advantages) > 1:
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    # ratio between old and new policy, should be one at the first iteration
    ratio = torch.exp(log_prob - rollout_data.old_log_prob)

    # clipped surrogate loss
    policy_loss_1 = advantages * ratio
    policy_loss_2 = advantages * torch.clamp(ratio, 1 - clip_range, 1 + clip_range)
    policy_loss = -torch.min(policy_loss_1, policy_loss_2).mean()

    # Value loss using the TD(gae_lambda) target
    value_loss = F.mse_loss(rollout_data.returns, values_pred)

    # Entropy loss favor exploration
    entropy_loss = -torch.mean(entropy)

    loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss

    # Optimization step
    self.policy.optimizer.zero_grad()
    loss.backward()
    # Clip grad norm
    torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
    self.policy.optimizer.step()

大致流程:收集当前策略的 rollout → 计算 advantage → 策略优化。

计算 advantage 是由 rollout_buffer.compute_returns_and_advantage 函数实现的:

rb = rollout_buffer
last_gae_lam = 0
for step in reversed(range(buffer_size)):
    if step == buffer_size - 1:
        next_non_terminal = 1.0 - dones.astype(np.float32)
        next_values = last_values
    else:
        next_non_terminal = 1.0 - rb.episode_starts[step + 1]
        next_values = rb.values[step + 1]
    delta = rb.rewards[step] + gamma * next_values * next_non_terminal - rb.values[step]  # (1)
    last_gae_lam = delta + gamma * gae_lambda * next_non_terminal * last_gae_lam  # (2)
    rb.advantages[step] = last_gae_lam
rb.returns = rb.advantages + rb.values

其中,

  • (1) 行通过类似于 TD error 的形式(A = r + γV(s') - V(s)),计算当前 t 时刻的 advantage;
  • (2) 行则是把 t+1 时刻的 advantage 乘 gamma 和 gae_lambda 传递过来。

03 记录一些踩坑经历

  1. PPO 在收集 rollout 的时候,要在分布里采样,而非采用 argmax 动作,否则没有 exploration。(PPO 在分布里采样 action,这样来保证探索,而非使用 epsilon greedy 等机制;听说 epsilon greedy 机制是 value-based 方法用的)
  2. 如果 policy 网络里有(比如说)batch norm,rollout 时应该把 policy 开 eval 模式,这样就不会出错。
  3. (但是,不要加 batch norm,加 batch norm 性能就不好了。听说 RL 不能加 batch norm)
  4. minigrid 简单环境,RNN 加不加貌似都可以(?)
  5. 在算 entropy loss 的时候,要用真 entropy,从 Categorical 分布里得到的 entropy;不要用 -logprob 近似的,不然会导致策略分布 熵变得很小 炸掉。


标签:advantage,self,PPO,policy,复现,Delta,RL,theta,pi
From: https://www.cnblogs.com/moonout/p/18561027

相关文章

  • java.lang.IllegalArgumentException: Unsupported class file major version xx解决
    在一次项目打包中遇到了这个问题,这个问题的本质是打包时,你依赖的包或这些依赖的间接依赖中含有高于当前项目构建jdk版本编译出来的类,导致打包失败。1.majorversion和jdk各版本对应关系可以自行搜索,当前主要版本的对应关系是c:55对应java11majorversion:52对应java8maj......
  • Baichuan2 模型详解,附实验代码复现
    简介近年来,大规模语言模型(LLM)领域取得了令人瞩目的进展。语言模型的参数规模从早期的数百万(如ELMo、GPT-1),发展到如今的数十亿甚至上万亿(如GPT-3、PaLM和SwitchTransformers)。随着模型规模的增长,LLM的能力显著提升,展现出更接近人类的语言流畅性,并能执行多样化的自然语......
  • 在浏览器中输入url到页面显示出来的过程发生了什么?
    在浏览器中输入URL到页面显示出来,这中间经历了一系列复杂的过程,可以概括为以下几个主要步骤:URL解析:浏览器首先会解析你输入的URL,检查其语法是否正确,并提取出协议(如HTTP或HTTPS)、域名、端口、路径、查询参数和片段标识符等信息。DNS查询:浏览器会向DNS服务器查......
  • Altenergy电力系统控制软件 status_zigbee SQL注入漏洞复现(CVE-2024-11305)
     0x01阅读须知        技术文章仅供参考,此文所提供的信息只为网络安全人员对自己所负责的网站、服务器等(包括但不限于)进行检测或维护参考,未经授权请勿利用文章中的技术资料对任何计算机系统进行入侵操作。利用此文所提供的信息而造成的直接或间接后果和损失,均由使用......
  • Android CoordinatorLayout使用示例记录
    原文链接:AndroidCoordinatorLayout使用示例记录-Stars-One的杂货小窝简单记录下常用CoordinatorLayout的几个效果代码示例,方便后续有需求的时候参照实现开始之前,注意下项目material版本,下文提到的某些属性是在后续版本才有的implementation("com.google.android.materia......
  • admin.site.urls是什么
    admin.site.urls是Django框架中用来注册管理后台(AdminSite)的URL配置的一个属性。它通常在项目的主URL配置文件(urls.py)中引用,用于将Django的管理后台功能添加到项目的路由中。示例fromdjango.contribimportadminfromdjango.urlsimportpathurlpatterns=[......
  • Spring中使用BeanUtils.copyProperties()导致Hessian/Burlap:ClassNotFoundException
    背景遇到一个问题:dubbo服务客户端发现提示警告异常[NewI/Oworker#4]WARNc.a.c.c.hessian.io.SerializerFactory-Hessian/Burlap:'XX.XX.XBean'isanunknownclassinjava.net.URLClassLoader@988246e:java.lang.ClassNotFoundException:XX.XX.XBean但是根据代码......
  • C++零基础入门:趣味学信息学奥赛从“Hello World”开始
    编程学习的第一步,往往从“HelloWorld”开始。这不仅是程序员的“入门仪式”,更是打开编程世界的一把钥匙。结合树莓派Pico开发板的实际操作,这篇文章将为C++零基础的学生和信息学奥赛爱好者讲解如何通过一个简单的“HelloWorld”项目,学会基础语法、编程思维,以及软硬件结合的实......
  • Ubuntu 24.04上安装JupyterLab并远程访问
    更新你的Ubuntu软件包索引。 sudoaptupdate 现在通过Ubuntu软件源安装Python3和Node.js,方法如下--输入"Y"确认安装: sudoaptinstallpython3-devpython3-pippython3-venvnodejsnpm安装Jupyter 安装完依赖项后,您将在Python虚拟环境中通过Pip安装Jup......
  • PbRL | Christiano 2017 年的开山之作,以及 Preference PPO / PrefPPO
    PrefPPO首次(?)出现在PEBBLE,作为pebble的一个baseline,是用PPO复现Christianoetal.(2017)的PbRL算法。Forevaluation,wecomparetoChristianoetal.(2017),whichisthecurrentstate-of-the-artapproachusingthesametypeoffeedback.Theprimarydif......