首页 > 编程语言 >强化学习算法PPO实现

强化学习算法PPO实现

时间:2024-07-20 21:54:13浏览次数:14  
标签:state PPO critic actor states action 算法 tf 强化

PPO的基本思想

  1. 策略优化:PPO直接优化策略,通过限制更新幅度来保证训练稳定性。
  2. Clip方法:PPO引入了clip方法限制策略更新的幅度,避免策略过大更新导致的不稳定。
  3. 优势估计:使用优势函数来评估当前策略相对于某个基准策略的提升。

详细的训练过程

  1. 初始化:初始化策略网络(Actor)和价值网络(Critic),设置超参数和经验回放池。
  2. 交互环境:在每一回合中,智能体根据当前策略与环境进行交互,选择动作并获得奖励,存储经验。
  3. 计算优势:使用GAE(广义优势估计)方法计算优势函数,估计每个状态-动作对的优势。
  4. 更新策略网络:使用Clip-PPO方法,通过限制策略变化幅度来更新策略网络。
  5. 更新价值网络:通过最小化价值函数预测误差来更新价值网络。
  6. 训练结束:在达到设定的回合数后,结束训练过程。
import gym
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers

# 超参数设置
gamma = 0.99  # 折扣因子
lambda_ = 0.95  # GAE参数
clip_ratio = 0.2  # Clip比率
actor_lr = 0.0003  # Actor学习率
critic_lr = 0.0003  # Critic学习率
batch_size = 64  # 批量大小
epochs = 10  # 训练每个回合的迭代次数
update_steps = 4000  # 更新步骤

# 环境设置
env = gym.make('Pendulum-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_bound = env.action_space.high[0]

# 构建Actor网络
def build_actor():
    model = tf.keras.Sequential([
        layers.Dense(64, activation='relu', input_dim=state_dim),
        layers.Dense(64, activation='relu'),
        layers.Dense(action_dim, activation='tanh')
    ])
    return model

# 构建Critic网络
def build_critic():
    model = tf.keras.Sequential([
        layers.Dense(64, activation='relu', input_dim=state_dim),
        layers.Dense(64, activation='relu'),
        layers.Dense(1)
    ])
    return model

# PPO训练
def train_ppo(episodes):
    actor = build_actor()
    critic = build_critic()

    actor_optimizer = tf.keras.optimizers.Adam(learning_rate=actor_lr)
    critic_optimizer = tf.keras.optimizers.Adam(learning_rate=critic_lr)

    def compute_advantages(rewards, values, next_values, done):
        advantages = np.zeros_like(rewards)
        gae = 0
        for t in reversed(range(len(rewards))):
            delta = rewards[t] + gamma * next_values[t] * (1 - done[t]) - values[t]
            gae = delta + gamma * lambda_ * (1 - done[t]) * gae
            advantages[t] = gae
        return advantages

    def update_actor_and_critic(states, actions, advantages, old_log_probs, returns):
        with tf.GradientTape() as tape:
            logits = actor(states)
            new_log_probs = tf.reduce_sum(tf.math.log(tf.reduce_sum(logits * actions, axis=1)))
            ratio = tf.exp(new_log_probs - old_log_probs)
            clip_loss = tf.reduce_mean(tf.minimum(ratio * advantages, tf.clip_by_value(ratio, 1 - clip_ratio, 1 + clip_ratio) * advantages))
            actor_loss = -clip_loss

        actor_grads = tape.gradient(actor_loss, actor.trainable_variables)
        actor_optimizer.apply_gradients(zip(actor_grads, actor.trainable_variables))

        with tf.GradientTape() as tape:
            values = critic(states)
            critic_loss = tf.reduce_mean((returns - values) ** 2)

        critic_grads = tape.gradient(critic_loss, critic.trainable_variables)
        critic_optimizer.apply_gradients(zip(critic_grads, critic.trainable_variables))

    for episode in range(episodes):
        state = env.reset()
        done = False
        total_reward = 0
        states, actions, rewards, next_states, dones = [], [], [], [], []

        while not done:
            state = np.reshape(state, [1, state_dim])
            action = actor.predict(state)[0]
            noise = np.random.normal(0, action_bound * 0.1, size=action_dim)
            action = np.clip(action + noise, -action_bound, action_bound)

            next_state, reward, done, _ = env.step(action)
            next_state = np.reshape(next_state, [1, state_dim])

            states.append(state)
            actions.append(action)
            rewards.append(reward)
            next_states.append(next_state)
            dones.append(done)

            state = next_state
            total_reward += reward

        states = np.vstack(states)
        actions = np.vstack(actions)
        rewards = np.array(rewards)
        next_states = np.vstack(next_states)
        dones = np.array(dones)

        values = critic.predict(states)
        next_values = critic.predict(next_states)
        advantages = compute_advantages(rewards, values, next_values, dones)
        returns = advantages + values

        old_log_probs = tf.reduce_sum(tf.math.log(tf.reduce_sum(actor.predict(states) * actions, axis=1)))

        for _ in range(epochs):
            update_actor_and_critic(states, actions, advantages, old_log_probs, returns)

    print(f"Episode {episode + 1} - Total Reward: {total_reward}")

train_ppo(1000)

标签:state,PPO,critic,actor,states,action,算法,tf,强化
From: https://blog.csdn.net/PeterClerk/article/details/140578270

相关文章

  • 七大排序算法的Python实现
    七大排序算法的Python实现1.冒泡排序(BubbleSort)算法思想冒泡排序通过重复交换相邻的未按顺序排列的元素来排序数组。每次迭代都将最大的元素“冒泡”到数组的末尾。复杂度分析时间复杂度:O(n^2)空间复杂度:O(1)defbubble_sort(arr):n=len(arr)for......
  • 高精度算法
    加法includeusingnamespacestd;strings1,s2;inta[101],b[101],c[101];voidstrtoint(stringstr,intdes[]){for(inti=0;i<str.size();i++){des[str.size()-i]=str[i]-'0';}}intmain(){cin>>s1>>s2;strtoint(......
  • 算法 图论最短路径
    零、写在前面本文讲述Dijkstra、Bellman-Ford、Floyd-Warshall算法一、分类G(graph):图V(vertex):点E(edge):边一个图可以用数学语言描述为。W(weights):权所以一个图也可以用数学语言描述为。二、作图2.1作图网站(推荐) 在线作图网站:图论作图网站GraphEditor用法:Undirected......
  • 蓝桥杯 算法季度赛2
    T2第一发没判最后一组后没有间隔T3WA了两发,调不出来往后看T5是线段树板子,1A了T4贺了个zfunction板子,WA了两发,调不出来剩下的题都没来得及看丑陋sol3.兽之泪II讨论选不选\(x_n\)比较好些如果讨论的是\(y_n\),在选\(y_i\)的情况下可能会选一些\(>y_i\)......
  • 2024“钉耙编程”中国大学生算法设计超级联赛(1)结题报告1 2 8
    1001循环位移字符串哈希将a展开*2对于每个长度为len_a的序列进行一次hash存储并将其插入set中对于b进行一次哈希对于每个长度为len_a的连续子串进行一次查询点击查看代码#include<bits/stdc++.h>usingnamespacestd;//22222constintN=5e6+10;constintp1......
  • 2024“钉耙编程”中国大学生算法设计超级联赛(1)
    发挥相当差,最好笑的是1h没写出一个三维偏序、30min没写出一个字符串哈希。甚至1h没意识到组合数式子推错了。A我写了点阴间东西。假设模式串为ABC,考虑一个形如ABCABCABC的东西,如果长度是\(x\),会贡献\(x-n+1\)个子串。枚举\(i\),从\(i\)把\(T\)分成两部分,一部分......
  • 字符串算法之一:朴素算法找子串
    publicclassStringAlgorithm{publicstaticvoidmain(String[]args){intresult=plainFindSubStr("12345","1234");System.out.println(result);}/***@paramstr*@parampattern*@retu......
  • 代码随想录算法训练营第33天 | 贪心4:452. 用最少数量的箭引爆气球、435. 无重叠区间
    代码随想录算法训练营第33天|贪心4:452.用最少数量的箭引爆气球、435.无重叠区间、763.划分字母区间452.用最少数量的箭引爆气球https://leetcode.cn/problems/minimum-number-of-arrows-to-burst-balloons/description/代码随想录https://programmercarl.com/0452.用最......
  • 代码随想录算法训练营第31天 | 贪心3:134.加油站、135.分发糖果、860.柠檬水找零、406.
    代码随想录算法训练营第31天|贪心3:134.加油站、135.分发糖果、860.柠檬水找零、406.根据身高重建队列134.加油站https://leetcode.cn/problems/gas-station/description/代码随想录https://programmercarl.com/0134.加油站.html135.分发糖果https://leetcode.cn/problems......
  • 强化学习入门
    原文:https://blog.csdn.net/v_JULY_v/article/details/128965854目录强化学习极简入门:通俗理解MDP、DPMCTC和Q学习、策略梯度、PPO第一部分RL基础:什么是RL与MRP、MDP1.1入门强化学习所需掌握的基本概念1.1.1什么是强化学习:依据策略执行动作-感知状态-得到奖励1.1.2RL与监督......