首页 > 编程语言 >深入理解PPO算法:从原理到实现

深入理解PPO算法:从原理到实现

时间:2024-11-07 12:46:29浏览次数:5  
标签:策略 self torch PPO 更新 算法 原理

目录

1.引言

2.PPO算法的背景

3.PPO算法的核心思想

4.PPO算法的实现步骤

  4.1 PPO代码实现

  4.2 代码说明

5.为什么PPO效果如此出色?

  5.1 PPO的优势函数与GAE

  5.2 PPO的变体:PPO-Clip和PPO-KL

6.PPO算法的应用场景

7.总结


1.引言

        在强化学习领域,PPO(Proximal Policy Optimization,近端策略优化)是一种广泛使用且表现优异的算法。它由OpenAI提出,旨在解决策略优化中不稳定和样本效率低的问题。与传统策略梯度方法相比,PPO稳定性更强,且在诸多任务上表现优异。

2.PPO算法的背景

        强化学习中的策略优化方法大体可以分为两类:基于值的算法(如DQN)和基于策略的算法(如策略梯度方法)。策略梯度方法直接优化策略函数,使智能体能够在复杂、高维的环境中获得良好的决策能力。然而,直接优化策略可能会导致策略更新过大,导致学习过程不稳定或样本效率低下。

        为了解决这个问题,出现了TRPO(Trust Region Policy Optimization)算法,它通过限制策略更新的范围,避免过度更新。然而,TRPO的优化过程复杂且计算开销较大。PPO在此基础上进行改进,通过引入“剪切”(Clipping)等技术简化了优化过程,大幅度提升了算法的稳定性和样本效率。

3.PPO算法的核心思想

        PPO的核心思想是限制策略更新的范围,使其不会偏离旧策略太远。PPO主要通过两种方法来实现策略的限制更新:剪切法(Clipping)KL散度惩罚法(KL Penalty)。其中,剪切法是PPO最常用的实现方式。

        具体来说,PPO的优化目标函数为:

L^{\text{PPO}}(\theta) = \mathbb{E}_t \left[ \min(r_t(\theta) \cdot A_t, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \cdot A_t) \right]

这里的符号解释如下:

  • r_t(\theta) = \frac{\pi_\theta(a_t | s_t)}{\pi_{\theta_\text{old}}(a_t | s_t)}:策略更新比率,表示新策略和旧策略之间的差异。
  • A_t​:优势函数,用于衡量当前动作在当前状态下的好坏。
  • \epsilon:控制策略更新的幅度,一般为一个小值,如0.2。

        目标函数的工作原理是:限制策略更新的范围,如果策略的更新比率超过了预设的范围(即大于1+ϵ或小于1−ϵ),则该更新将被裁剪,以防止策略发生剧烈变化。

4.PPO算法的实现步骤

图1 PPO算法的基本架构图
  1. 采样数据:使用当前策略\pi_\theta​与环境交互,采集若干个轨迹,得到状态、动作、奖励和优势函数。

  2. 计算优势函数:通常使用时序差分(Temporal Difference)方法或广义优势估计(GAE)来计算优势函数A_t

  3. 计算更新比率:根据旧策略和当前策略,计算比率r_t(\theta)

  4. 更新策略参数:最小化剪切目标函数中的期望值,使策略尽可能接近“最佳策略”,并确保策略更新不会超出限定范围。

  5. 重复采样和更新:不断重复采样和策略更新,直到收敛或达到设定的迭代次数。

  4.1 PPO代码实现

        这里是PPO的简单实现,包括策略更新和优势估计部分。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# Hyperparameters
learning_rate = 3e-4
gamma = 0.99          # Discount factor
lmbda = 0.95          # GAE lambda
eps_clip = 0.2        # PPO clip parameter
K_epoch = 3           # PPO update epochs
T_horizon = 20        # Rollout length

# Policy Network
class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc_pi = nn.Linear(256, action_dim)    # Actor output
        self.fc_v = nn.Linear(256, 1)              # Critic output

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        pi = torch.softmax(self.fc_pi(x), dim=0)
        v = self.fc_v(x)
        return pi, v

    def act(self, state):
        pi, _ = self.forward(state)
        action = torch.multinomial(pi, 1).item()
        return action

# PPO Algorithm
class PPO:
    def __init__(self, state_dim, action_dim):
        self.model = ActorCritic(state_dim, action_dim)
        self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)

    def compute_advantage(self, rewards, values):
        deltas = [r + gamma * v_next - v for r, v_next, v in zip(rewards, values[1:], values[:-1])]
        advantages = []
        advantage = 0.0
        for delta in reversed(deltas):
            advantage = delta + gamma * lmbda * advantage
            advantages.insert(0, advantage)
        return advantages

    def update(self, rollout):
        states, actions, rewards, old_log_probs, values = rollout
        advantages = self.compute_advantage(rewards, values)

        for _ in range(K_epoch):
            pi, v = self.model(states)
            log_probs = torch.log(pi.gather(1, actions))
            ratios = torch.exp(log_probs - old_log_probs)

            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1 - eps_clip, 1 + eps_clip) * advantages
            actor_loss = -torch.min(surr1, surr2).mean()
            critic_loss = nn.functional.mse_loss(v, rewards)

            loss = actor_loss + 0.5 * critic_loss
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

  4.2 代码说明

  1. 策略网络(Policy Network)ActorCritic 类包含策略网络(fc_pi)和价值网络(fc_v),可以同时输出动作概率和状态值。

  2. PPO更新过程

    • 通过 compute_advantage 函数计算广义优势估计(GAE)。
    • update 函数使用剪切目标函数进行策略更新,其中 surr1surr2 表示未剪切和剪切后的损失值,取其最小值来控制策略更新幅度。
  3. 运行与优化:在 K_epoch 次循环中重复更新,以使策略能够最大化累积奖励。

5.为什么PPO效果如此出色?

  1. 更新限制:PPO通过限制策略的更新幅度,避免了过度更新带来的不稳定性问题。这种限制让PPO的训练更加平滑,学习过程更加稳定。

  2. 简单高效:相比TRPO,PPO不需要进行复杂的约束优化,而是通过简单的剪切操作实现约束,从而降低了计算复杂度和资源消耗。

  3. 广泛适用:PPO适用于离散和连续动作空间,并在不同类型的任务上取得了良好效果,如机器人控制、视频游戏等。

  5.1 PPO的优势函数与GAE

        PPO通常使用广义优势估计(Generalized Advantage Estimation, GAE)来计算优势函数。GAE是一种平衡偏差与方差的估计方法,通过衰减参数\lambda来控制估计的偏差和方差。GAE的优势在于可以更稳定地估计动作的优势值,使得策略更新的效果更好。

  5.2 PPO的变体:PPO-Clip和PPO-KL

  1. PPO-Clip:即经典的剪切法,通过将更新比率限制在[1 - \epsilon, 1 + \epsilon]的范围内,确保策略更新不超过预设范围。

  2. PPO-KL:通过在损失函数中加入KL散度惩罚项来控制更新幅度。在这种方法中,如果新旧策略之间的KL散度过大,则增加惩罚项,使得更新更加保守。尽管PPO-KL在一些应用中表现良好,但大多数场景下PPO-Clip更常用。

6.PPO算法的应用场景

        PPO算法已成功应用于多个实际场景,包括但不限于以下几个领域:

  • 游戏AI:PPO在复杂的游戏环境中表现出色,如《Dota 2》和《Atari》游戏。其稳定性和高效性使其成为游戏AI训练中的重要选择。

  • 机器人控制:在机器人操作中,PPO被广泛用于控制机器人的手臂、腿等部位。它的高样本效率使机器人能够在模拟环境中快速学习,减少了真实环境的训练成本。

  • 自动驾驶:PPO被用于训练自动驾驶中的决策模块。通过学习不同的驾驶场景,PPO可以帮助自动驾驶车辆更好地应对复杂路况。

7.总结

        PPO是一种简单且有效的策略优化算法,通过限制策略更新的范围,实现了稳定和高效的策略优化。它不仅在计算上更简单,还在多个复杂任务中取得了优异的表现。随着强化学习的不断发展,PPO已成为解决复杂决策问题的一项强大工具,未来可能会被应用到更多实际场景中。

标签:策略,self,torch,PPO,更新,算法,原理
From: https://blog.csdn.net/qq_56683019/article/details/143571904

相关文章

  • 【KMP算法】
    目录BF算法KMP算法BF算法F算法,即暴力(BruteForce)算法,是普通的模式匹配算法,BF算法的思想就是将目标串S的第一个字符与模式串T的第一个字符进行匹配,若相等,则继续比较S的第二个字符和T的第二个字符;若不相等,则比较S的第二个字符和T的第一个字符,依次比较下去,直到得出最后......
  • α-shape算法曲面重建
    目录1原理介绍α-shape的基础概念数学公式推导2.1外接圆半径2.2根据α参数筛选三角形2.3构建α-shape2.4参数调整与优化3α-shape的构建步骤4示例代码        取点云的凹边界是计算几何中的一个经典问题。凹边界与凸边界不同,它能捕捉到数据的细......
  • Java深度优先搜索(DFS)算法实现
    标题:Java深度优先搜索(DFS)算法实现引言:深度优先搜索(Depth-FirstSearch,DFS)是一种常用的图遍历算法,它通过递归地遍历图中的每个顶点,来寻找特定的路径或解决某些问题。本篇博客将介绍如何用Java语言实现深度优先搜索算法。算法思想:深度优先搜索算法的基本思想是先访问一个......
  • c++ Kruskal 最小生成树 (MST) 算法(Kruskal’s Minimum Spanning Tree (MST) Algorith
            对于加权、连通、无向图,最小生成树(MST)或最小权重生成树是权重小于或等于其他所有生成树权重的生成树。Kruskal算法简介:        在这里,我们将讨论Kruskal算法来查找给定加权图的MST。         在Kruskal算法中,按升序对给定图的所有......
  • 常考的排序算法
    冒泡排序#include<iostream>#include<string>usingnamespacestd;//voidShellsort(intA[],intn)//{//   intd,i,j;//   for(d=n/2;d>=1;d=d/2)//   {//      for(i=d+1;i<=n;i++)//      {//     ......
  • 数据库原理 第五章 事务与并发控制
    目录1.事务的基本概念1.1为什么需要事务?什么是事务?1.2数据库事务的四大特性(ACID)1.3事务涉及的基本概念1.3.1 Transaction(事务)1.3.2. Rollback(回滚)1.33. Commit(提交)1.3.4. Savepoint(保存点)1.3.5关系总结1.4MySQL事务管理的完整示例示例:转账操作2.故障2.1故障的种类2.1.1. 数......
  • JavaScript Kruskal 最小生成树 (MST) 算法(Kruskal’s Minimum Spanning Tree (MST) A
             对于加权、连通、无向图,最小生成树(MST)或最小权重生成树是权重小于或等于其他所有生成树权重的生成树。Kruskal算法简介:        在这里,我们将讨论Kruskal算法来查找给定加权图的MST。         在Kruskal算法中,按升序对给定图的所......
  • 简述大前端技术栈的渲染原理
    作者:京东物流卢旭大前端包括哪些技术栈大前端指的是涵盖所有与前端开发相关的技术和平台,应用于各类设备和操作系统上。大前端不仅包括Web开发,还包括移动端开发和跨平台应用开发,具体包括:•原生应用开发:Android、iOS、鸿蒙(HarmonyOS)等;•Web前端框架:Vue、React、Angular等;•......
  • 基于springboot框架在线生鲜商城推荐系统 java实现个性化生鲜/农产品购物商城推荐网站
    基于springboot框架在线生鲜商城推荐系统java实现个性化生鲜/农产品购物商城推荐网站爬虫、数据分析、排行榜基于协同过滤算法推荐、基于流行度热点推荐、平均加权混合推荐机器学习、大数据、深度学习OnlineShopRecommendEx一、项目简介1、开发工具和使用技术IDEA,jdk......
  • vuex、vue-router实现原理
    Vuex和VueRouter是Vue.js生态系统中非常重要的两个库,分别用于状态管理和路由管理。它们各自的实现原理如下:Vuex实现原理1.状态管理Vuex是一个专为Vue.js应用程序开发的状态管理模式。它使用集中式的存储管理所有组件的状态,并以一种可预测的方式来确保状态以一种可追......