目录
强化学习基本框架
OnPolicyAlgorithm类
OnPolicyAlgorithm
类,是稳定基线3 (Stable Baselines3) 中所有策略梯度 (On-Policy) 强化学习算法的基类,例如 A2C 和 PPO。
主要功能:
-
初始化和设置:
- 初始化算法的参数,包括策略网络、环境、学习率、折扣因子、GAE 参数、熵系数等。
- 设置学习率计划、随机种子。
- 根据观测空间类型选择合适的 rollout buffer (经验回放缓冲区)。
- 创建 rollout buffer 和策略网络,并将策略网络移动到指定的设备 (CPU 或 GPU)。
-
收集经验 (collect_rollouts):
- 使用当前策略与环境交互,收集经验数据并存储到 rollout buffer 中。
- 在收集经验的过程中,会调用回调函数进行一些操作,例如记录日志、保存模型等。
- 处理超时 (timeout) 情况,使用值函数进行 bootstrapping。
- 计算回报 (returns) 和优势 (advantage) 函数。
-
训练 (train):
- 这个方法是一个抽象方法,具体的训练逻辑由子类实现,例如 A2C 和 PPO。
- 子类需要根据 rollout buffer 中的经验数据更新策略网络的参数。
-
记录日志 (_dump_logs):
- 记录训练过程中的信息,例如迭代次数、奖励、episode 长度、成功率等。
-
学习 (learn):
- 这是算法的主循环,它会不断地收集经验和训练策略网络,直到达到指定的训练步数。
- 在训练过程中,会定期调用回调函数和记录日志。
关键概念:
- 策略梯度 (On-Policy): 这类算法直接学习策略,并根据当前策略收集的经验数据进行更新。
- Rollout Buffer: 存储经验数据的缓冲区,包括状态、动作、奖励、是否结束等信息。
- GAE (Generalized Advantage Estimation): 一种计算优势函数的方法,可以平衡偏差和方差。
- 熵系数 (Entropy Coefficient): 鼓励策略探索环境,避免过早收敛到次优解。
- 值函数系数 (Value Function Coefficient): 控制值函数损失的权重。
代码结构:
OnPolicyAlgorithm
继承自BaseAlgorithm
,BaseAlgorithm
提供了一些通用的功能,例如设置环境、策略、学习率等。OnPolicyAlgorithm
定义了一些抽象方法,例如train
,由子类实现具体的训练逻辑。OnPolicyAlgorithm
提供了一些辅助方法,例如collect_rollouts
、_dump_logs
、learn
等。
总结:
OnPolicyAlgorithm
是 Stable Baselines3 中策略梯度算法的基类,它定义了策略梯度算法的基本框架,包括初始化、收集经验、训练、记录日志等功能。具体的训练逻辑由子类实现,例如 A2C 和 PPO。 通过继承 OnPolicyAlgorithm
,可以方便地开发新的策略梯度算法。
PPO
PPO
类,是稳定基线3 (Stable Baselines3) 中实现的近端策略优化 (Proximal Policy Optimization) 算法。PPO 是一种常用的策略梯度强化学习算法,它在 A2C 的基础上进行了改进,通过引入 clipped surrogate objective 和 KL 惩罚项来提高训练的稳定性和效率。
主要功能:
-
初始化和设置:
- 初始化算法的参数,包括策略网络、环境、学习率、折扣因子、GAE 参数、clip range 等。
- 检查参数的有效性,例如
batch_size
是否大于 1,rollout buffer 的大小是否合适等。 - 设置学习率计划、clip range 计划。
-
训练 (train):
- 将策略网络设置为训练模式。
- 更新优化器的学习率。
- 计算当前的 clip range。
- 进行
n_epochs
轮训练,每轮训练都遍历一遍 rollout buffer 中的数据。 - 对于每个 minibatch 的数据,计算价值函数、动作的对数概率、熵、优势函数、比率等。
- 计算 clipped surrogate loss、价值函数损失、熵损失。
- 计算近似 KL 散度,用于提前停止训练。
- 进行梯度下降更新策略网络的参数。
- 记录训练过程中的信息,例如损失函数、KL 散度、clip fraction 等。
关键概念:
- Clipped Surrogate Objective: PPO 使用 clipped surrogate objective 来限制策略更新的幅度,防止策略更新过度偏离之前的策略。
- KL 惩罚项: PPO 可以选择使用 KL 惩罚项来限制策略更新的幅度,防止策略更新过度偏离之前的策略。
- Advantage Normalization: PPO 通常会对优势函数进行归一化,以提高训练的稳定性。
- Entropy Bonus: PPO 可以选择使用熵奖励来鼓励策略探索环境,避免过早收敛到次优解。
代码结构:
PPO
继承自OnPolicyAlgorithm
,OnPolicyAlgorithm
是策略梯度算法的基类。PPO
实现了train
方法,定义了 PPO 算法的训练逻辑。PPO
使用了RolloutBuffer
来存储经验数据。PPO
使用了ActorCriticPolicy
作为策略网络。
总结:
PPO
是 Stable Baselines3 中实现的近端策略优化算法,它是一种常用的策略梯度强化学习算法。PPO 通过引入 clipped surrogate objective 和 KL 惩罚项来提高训练的稳定性和效率。PPO
类实现了 PPO 算法的训练逻辑,并提供了一些辅助方法,例如 learn
等。