首页 > 其他分享 >PPO近端策略优化玩cartpole游戏

PPO近端策略优化玩cartpole游戏

时间:2024-05-14 21:57:54浏览次数:21  
标签:cartpole log self torch PPO value state policy 近端

 

这个难度有些大,有两个policy,一个负责更新策略,另一个负责提供数据,实际这两个policy是一个东西,用policy1跑出一组数据给新的policy2训练,然后policy2跑数据给新的policy3训练,,,,直到policy(N-1)跑数据给新的policyN训练,过程感觉和DQN比较像,但是模型是actor critic 架构,on-policy转换成off-policy,使用剪切策略来限制策略的更新幅度,off-policy的好处是策略更新快,PPO的优化目标是最大化策略的期望回报,同时避免策略更新过大

 

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pygame
import sys
from collections import deque

# 定义策略网络
class PolicyNetwork(nn.Module):
    def __init__(self):
        super(PolicyNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(4, 2),
            nn.Tanh(),
            nn.Linear(2, 2),  # CartPole的动作空间为2
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        return self.fc(x)

# 定义值网络
class ValueNetwork(nn.Module):
    def __init__(self):
        super(ValueNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(4, 2),
            nn.Tanh(),
            nn.Linear(2, 1)
        )

    def forward(self, x):
        return self.fc(x)

# 经验回放缓冲区
class RolloutBuffer:
    def __init__(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.log_probs = []
    
    def store(self, state, action, reward, done, log_prob):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.dones.append(done)
        self.log_probs.append(log_prob)
    
    def clear(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.dones = []
        self.log_probs = []

    def get_batch(self):
        return (
            torch.tensor(self.states, dtype=torch.float),
            torch.tensor(self.actions, dtype=torch.long),
            torch.tensor(self.rewards, dtype=torch.float),
            torch.tensor(self.dones, dtype=torch.bool),
            torch.tensor(self.log_probs, dtype=torch.float)
        )

# PPO更新函数
def ppo_update(policy_net, value_net, optimizer_policy, optimizer_value, buffer, epochs=10, gamma=0.99, clip_param=0.2):
    states, actions, rewards, dones, old_log_probs = buffer.get_batch()
    returns = []
    advantages = []
    G = 0
    adv = 0
    dones = dones.to(torch.int)
    # print(dones)
    for reward, done, value in zip(reversed(rewards), reversed(dones), reversed(value_net(states))):
        if done:
            G = 0
            adv = 0
        G = reward + gamma * G  #蒙特卡洛回溯G值
        delta = reward + gamma * value.item() * (1 - done) - value.item()  #TD差分
        # adv = delta + gamma * 0.95 * adv * (1 - done)  #
        adv = delta + adv*(1-done)
        returns.insert(0, G)
        advantages.insert(0, adv)

    returns = torch.tensor(returns, dtype=torch.float)  #价值
    advantages = torch.tensor(advantages, dtype=torch.float)
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)  #add baseline

    for _ in range(epochs):
        action_probs = policy_net(states)
        dist = torch.distributions.Categorical(action_probs)
        new_log_probs = dist.log_prob(actions)
        ratio = (new_log_probs - old_log_probs).exp()
        surr1 = ratio * advantages
        surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantages
        actor_loss = -torch.min(surr1, surr2).mean()

        optimizer_policy.zero_grad()
        actor_loss.backward()
        optimizer_policy.step()

        value_loss = (returns - value_net(states)).pow(2).mean()

        optimizer_value.zero_grad()
        value_loss.backward()
        optimizer_value.step()

# 初始化环境和模型
env = gym.make('CartPole-v1')
policy_net = PolicyNetwork()
value_net = ValueNetwork()
optimizer_policy = optim.Adam(policy_net.parameters(), lr=3e-4)
optimizer_value = optim.Adam(value_net.parameters(), lr=1e-3)
buffer = RolloutBuffer()

# Pygame初始化
pygame.init()
screen = pygame.display.set_mode((600, 400))
clock = pygame.time.Clock()

draw_on = False
# 训练循环
state = env.reset()
for episode in range(10000):  # 训练轮次
    done = False
    state = state[0]
    step= 0
    while not done:
        step+=1
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        action_probs = policy_net(state_tensor)
        dist = torch.distributions.Categorical(action_probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        
        next_state, reward, done, _ ,_ = env.step(action.item())
        buffer.store(state, action.item(), reward, done, log_prob)
        
        state = next_state

        # 实时显示
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                sys.exit()

        if draw_on:
            # 清屏并重新绘制
            screen.fill((0, 0, 0))
            cart_x = int(state[0] * 100 + 300)  # 位置转换为屏幕坐标
            pygame.draw.rect(screen, (0, 128, 255), (cart_x, 300, 50, 30))
            pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * np.sin(state[2])), 300 - int(50 * np.cos(state[2]))), 5)
            pygame.display.flip()
            clock.tick(600)

    if step >10000:
        draw_on = True
    ppo_update(policy_net, value_net, optimizer_policy, optimizer_value, buffer)
    buffer.clear()
    state = env.reset()
    print(f'Episode {episode} completed {step}.')

# 结束训练
env.close()
pygame.quit()

 

运行效果

 

标签:cartpole,log,self,torch,PPO,value,state,policy,近端
From: https://www.cnblogs.com/LiuXinyu12378/p/18192344

相关文章

  • mysql.connector.errors.NotSupportedError: Authentication plugin 'caching_sha2_pa
    今天将程序部署到服务器,遇到mysql.connector.errors.NotSupportedError:Authenticationplugin'caching_sha2_password'isnotsupported问题产生的原因:从MySQL8.0开始,默认的用户认证插件从mysql_native_password变成了caching_sha2_password查看现有的用户mysql>se......
  • root用户远程登录云服务器失败 No supported authentication methods available (serv
     1、平台:亚马逊AWS云、腾讯云服务器、MobaXterm2、问题:云服务器实例远程登录失败,显示:“Nosupportedauthenticationmethodsavailable(serversent:publickey)”翻译:不支持可用的身份验证方法(服务器发送:publickey)3、解决过程:初步判断:服务器远程登录配置文件问题尝试1:a.......
  • MAC make编译报错 “clang: error: unsupported option '-fopenmp'”
    编译报错➜php_mt_seed-mainmakegcc-Wall-march=native-mtune=generic-O2-fomit-frame-pointer-funroll-loops-fopenmpphp_mt_seed.c-ophp_mt_seedclang:error:unsupportedoption'-fopenmp'clang:error:unsupportedoption'-fopenmp'......
  • DQN玩cartpole游戏
    importgymimporttorchimporttorch.nnasnnimporttorch.optimasoptimimportrandomimportpygameimportsysfromcollectionsimportdeque#定义DQN模型classDQN(nn.Module):def__init__(self):super(DQN,self).__init__()self.netwo......
  • 策略梯度玩 cartpole 游戏,强化学习代替PID算法控制平衡杆
     cartpole游戏,车上顶着一个自由摆动的杆子,实现杆子的平衡,杆子每次倒向一端车就开始移动让杆子保持动态直立的状态,策略函数使用一个两层的简单神经网络,输入状态有4个,车位置,车速度,杆角度,杆速度,输出action为左移动或右移动,输入状态发现至少要给3个才能稳定一会儿,给2个完全学不明白,......
  • unsupported operand type(s) for +: 'function' and 'str'
    unsupportedoperandtype(s)for+:'function'and'str'报错解释:这个错误表明你尝试将一个函数和一个字符串进行加法操作,在Python中,加法不支持对函数和字符串进行。解决方法:确认你的代码中是否有误,检查是否不小心将函数名直接与字符串用+相连。如果你的意图是调用函数并与字符......
  • TypeError 'tuple' object does not support item assignment
    左手编程,右手年华。大家好,我是一点,关注我,带你走入编程的世界。公众号:一点sir,关注领取编程资料TypeError:'tuple'objectdoesnotsupportitemassignment是一个在Python编程语言中常见的错误,意味着你试图修改一个不可变的元组(tuple)对象中的元素。在Python中,元组是一种不......
  • 【Python】Q-Learning处理CartPole-v1
    上一篇配置成功gym环境后,就可以利用该环境做强化学习仿真了。这里首先用之前学习过的qlearning来处理CartPole-v1模型。CartPole-v1是一个倒立摆模型,目标是通过左右移动滑块保证倒立杆能够尽可能长时间倒立,最长步骤为500步。模型控制量是左0、右1两个。模型状态量为下面四个:......
  • MindSpore强化学习:使用PPO配合环境HalfCheetah-v2进行训练
    本文分享自华为云社区《MindSpore强化学习:使用PPO配合环境HalfCheetah-v2进行训练》,作者:irrational。半猎豹(HalfCheetah)是一个基于MuJoCo的强化学习环境,由P.Wawrzyński在“ACat-LikeRobotReal-TimeLearningtoRun”中提出。这个环境中的半猎豹是一个由9个链接和8个关节......
  • 决策支持系统(Decision Support System,DSS)
    决策支持系统(DecisionSupportSystem,DSS)一。定义决策支持系统(DecisionSupportSystem,DSS)是辅助决策者通过数据、模型和知识,以人机交互方式进行半结构化或非结构化决策的计算机应用系统。它是管理信息系统(MIS)向更高一级发展而产生的先进信息管理系统,为决策者提供分析问题、建立......