目录
1.引言
在强化学习领域,PPO(Proximal Policy Optimization,近端策略优化)是一种广泛使用且表现优异的算法。它由OpenAI提出,旨在解决策略优化中不稳定和样本效率低的问题。与传统策略梯度方法相比,PPO稳定性更强,且在诸多任务上表现优异。
2.PPO算法的背景
强化学习中的策略优化方法大体可以分为两类:基于值的算法(如DQN)和基于策略的算法(如策略梯度方法)。策略梯度方法直接优化策略函数,使智能体能够在复杂、高维的环境中获得良好的决策能力。然而,直接优化策略可能会导致策略更新过大,导致学习过程不稳定或样本效率低下。
为了解决这个问题,出现了TRPO(Trust Region Policy Optimization)算法,它通过限制策略更新的范围,避免过度更新。然而,TRPO的优化过程复杂且计算开销较大。PPO在此基础上进行改进,通过引入“剪切”(Clipping)等技术简化了优化过程,大幅度提升了算法的稳定性和样本效率。
3.PPO算法的核心思想
PPO的核心思想是限制策略更新的范围,使其不会偏离旧策略太远。PPO主要通过两种方法来实现策略的限制更新:剪切法(Clipping)和KL散度惩罚法(KL Penalty)。其中,剪切法是PPO最常用的实现方式。
具体来说,PPO的优化目标函数为:
这里的符号解释如下:
- :策略更新比率,表示新策略和旧策略之间的差异。
- :优势函数,用于衡量当前动作在当前状态下的好坏。
- :控制策略更新的幅度,一般为一个小值,如0.2。
目标函数的工作原理是:限制策略更新的范围,如果策略的更新比率超过了预设的范围(即大于1+ϵ或小于1−ϵ),则该更新将被裁剪,以防止策略发生剧烈变化。
4.PPO算法的实现步骤
采样数据:使用当前策略与环境交互,采集若干个轨迹,得到状态、动作、奖励和优势函数。
计算优势函数:通常使用时序差分(Temporal Difference)方法或广义优势估计(GAE)来计算优势函数。
计算更新比率:根据旧策略和当前策略,计算比率。
更新策略参数:最小化剪切目标函数中的期望值,使策略尽可能接近“最佳策略”,并确保策略更新不会超出限定范围。
重复采样和更新:不断重复采样和策略更新,直到收敛或达到设定的迭代次数。
4.1 PPO代码实现
这里是PPO的简单实现,包括策略更新和优势估计部分。
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
# Hyperparameters
learning_rate = 3e-4
gamma = 0.99 # Discount factor
lmbda = 0.95 # GAE lambda
eps_clip = 0.2 # PPO clip parameter
K_epoch = 3 # PPO update epochs
T_horizon = 20 # Rollout length
# Policy Network
class ActorCritic(nn.Module):
def __init__(self, state_dim, action_dim):
super(ActorCritic, self).__init__()
self.fc1 = nn.Linear(state_dim, 256)
self.fc_pi = nn.Linear(256, action_dim) # Actor output
self.fc_v = nn.Linear(256, 1) # Critic output
def forward(self, x):
x = torch.relu(self.fc1(x))
pi = torch.softmax(self.fc_pi(x), dim=0)
v = self.fc_v(x)
return pi, v
def act(self, state):
pi, _ = self.forward(state)
action = torch.multinomial(pi, 1).item()
return action
# PPO Algorithm
class PPO:
def __init__(self, state_dim, action_dim):
self.model = ActorCritic(state_dim, action_dim)
self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)
def compute_advantage(self, rewards, values):
deltas = [r + gamma * v_next - v for r, v_next, v in zip(rewards, values[1:], values[:-1])]
advantages = []
advantage = 0.0
for delta in reversed(deltas):
advantage = delta + gamma * lmbda * advantage
advantages.insert(0, advantage)
return advantages
def update(self, rollout):
states, actions, rewards, old_log_probs, values = rollout
advantages = self.compute_advantage(rewards, values)
for _ in range(K_epoch):
pi, v = self.model(states)
log_probs = torch.log(pi.gather(1, actions))
ratios = torch.exp(log_probs - old_log_probs)
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1 - eps_clip, 1 + eps_clip) * advantages
actor_loss = -torch.min(surr1, surr2).mean()
critic_loss = nn.functional.mse_loss(v, rewards)
loss = actor_loss + 0.5 * critic_loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
4.2 代码说明
-
策略网络(Policy Network):
ActorCritic
类包含策略网络(fc_pi
)和价值网络(fc_v
),可以同时输出动作概率和状态值。 -
PPO更新过程:
- 通过
compute_advantage
函数计算广义优势估计(GAE)。 update
函数使用剪切目标函数进行策略更新,其中surr1
和surr2
表示未剪切和剪切后的损失值,取其最小值来控制策略更新幅度。
- 通过
-
运行与优化:在
K_epoch
次循环中重复更新,以使策略能够最大化累积奖励。
5.为什么PPO效果如此出色?
更新限制:PPO通过限制策略的更新幅度,避免了过度更新带来的不稳定性问题。这种限制让PPO的训练更加平滑,学习过程更加稳定。
简单高效:相比TRPO,PPO不需要进行复杂的约束优化,而是通过简单的剪切操作实现约束,从而降低了计算复杂度和资源消耗。
广泛适用:PPO适用于离散和连续动作空间,并在不同类型的任务上取得了良好效果,如机器人控制、视频游戏等。
5.1 PPO的优势函数与GAE
PPO通常使用广义优势估计(Generalized Advantage Estimation, GAE)来计算优势函数。GAE是一种平衡偏差与方差的估计方法,通过衰减参数来控制估计的偏差和方差。GAE的优势在于可以更稳定地估计动作的优势值,使得策略更新的效果更好。
5.2 PPO的变体:PPO-Clip和PPO-KL
-
PPO-Clip:即经典的剪切法,通过将更新比率限制在的范围内,确保策略更新不超过预设范围。
-
PPO-KL:通过在损失函数中加入KL散度惩罚项来控制更新幅度。在这种方法中,如果新旧策略之间的KL散度过大,则增加惩罚项,使得更新更加保守。尽管PPO-KL在一些应用中表现良好,但大多数场景下PPO-Clip更常用。
6.PPO算法的应用场景
PPO算法已成功应用于多个实际场景,包括但不限于以下几个领域:
游戏AI:PPO在复杂的游戏环境中表现出色,如《Dota 2》和《Atari》游戏。其稳定性和高效性使其成为游戏AI训练中的重要选择。
机器人控制:在机器人操作中,PPO被广泛用于控制机器人的手臂、腿等部位。它的高样本效率使机器人能够在模拟环境中快速学习,减少了真实环境的训练成本。
自动驾驶:PPO被用于训练自动驾驶中的决策模块。通过学习不同的驾驶场景,PPO可以帮助自动驾驶车辆更好地应对复杂路况。
7.总结
PPO是一种简单且有效的策略优化算法,通过限制策略更新的范围,实现了稳定和高效的策略优化。它不仅在计算上更简单,还在多个复杂任务中取得了优异的表现。随着强化学习的不断发展,PPO已成为解决复杂决策问题的一项强大工具,未来可能会被应用到更多实际场景中。
标签:策略,self,torch,PPO,更新,算法,原理 From: https://blog.csdn.net/qq_56683019/article/details/143571904