首页 > 编程语言 >策略梯度玩 cartpole 游戏,强化学习代替PID算法控制平衡杆

策略梯度玩 cartpole 游戏,强化学习代替PID算法控制平衡杆

时间:2024-05-12 17:10:34浏览次数:21  
标签:cartpole torch PID 平衡杆 print state pygame policy returns

 

cartpole游戏,车上顶着一个自由摆动的杆子,实现杆子的平衡,杆子每次倒向一端车就开始移动让杆子保持动态直立的状态,策略函数使用一个两层的简单神经网络,输入状态有4个,车位置,车速度,杆角度,杆速度,输出action为左移动或右移动,输入状态发现至少要给3个才能稳定一会儿,给2个完全学不明白,给4个能学到很稳定的policy

 

 

策略梯度实现代码,使用torch实现一个简单的神经网络

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import pygame
import sys
from collections import deque
import numpy as np

# 策略网络定义
class PolicyNetwork(nn.Module):
    def __init__(self):
        super(PolicyNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(4, 10),  # 4个状态输入,128个隐藏单元
            nn.Tanh(),
            nn.Linear(10, 2),  # 输出2个动作的概率
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        # print(x)  车位置 车速度 杆角度 杆速度
        selected_values = x[:, [0,1,2,3]]  #只使用车位置和杆角度
        return self.fc(selected_values)

# 训练函数
def train(policy_net, optimizer, trajectories):
    policy_net.zero_grad()
    loss = 0
    print(trajectories[0])
    for trajectory in trajectories:
        
        # if trajectory["returns"] > 90:
        # returns = torch.tensor(trajectory["returns"]).float()
        # else:
        returns = torch.tensor(trajectory["returns"]).float() - torch.tensor(trajectory["step_mean_reward"]).float()
        # print(f"获得奖励{returns}")
        log_probs = trajectory["log_prob"]
        loss += -(log_probs * returns).sum()  # 计算策略梯度损失
    loss.backward()
    optimizer.step()
    return loss.item()

# 主函数
def main():
    env = gym.make('CartPole-v1')
    policy_net = PolicyNetwork()
    optimizer = optim.Adam(policy_net.parameters(), lr=0.01)

    print(env.action_space)
    print(env.observation_space)
    pygame.init()
    screen = pygame.display.set_mode((600, 400))
    clock = pygame.time.Clock()

    rewards_one_episode= []
    for episode in range(10000):
        
        state = env.reset()
        done = False
        trajectories = []
        state = state[0]
        step = 0
        torch.save(policy_net, 'policy_net_full.pth')
        while not done:
            state_tensor = torch.tensor(state).float().unsqueeze(0)
            probs = policy_net(state_tensor)
            action = torch.distributions.Categorical(probs).sample().item()
            log_prob = torch.log(probs.squeeze(0)[action])
            next_state, reward, done, _,_ = env.step(action)

            # print(episode)
            trajectories.append({"state": state, "action": action, "reward": reward, "log_prob": log_prob})
            state = next_state

            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    pygame.quit()
                    sys.exit()
            step +=1
            
            # 绘制环境状态
            if rewards_one_episode and rewards_one_episode[-1] >99:
                screen.fill((255, 255, 255))
                cart_x = int(state[0] * 100 + 300)
                pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30))
                # print(state)
                pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 2)
                pygame.display.flip()
                clock.tick(200)
                

        print(f"第{episode}回合",f"运行{step}步后挂了")
        # 为策略梯度计算累积回报
        returns = 0
        
        
        for traj in reversed(trajectories):
            returns = traj["reward"] + 0.99 * returns
            traj["returns"] = returns
            if rewards_one_episode:
                # print(rewards_one_episode[:10])
                traj["step_mean_reward"] = np.mean(rewards_one_episode[-10:])
            else:
                traj["step_mean_reward"] = 0
        rewards_one_episode.append(returns)
        # print(rewards_one_episode[:10])
        train(policy_net, optimizer, trajectories)

def play():

    env = gym.make('CartPole-v1')
    policy_net = PolicyNetwork()
    pygame.init()
    screen = pygame.display.set_mode((600, 400))
    clock = pygame.time.Clock()

    state = env.reset()
    done = False
    trajectories = deque()
    state = state[0]
    step = 0
    policy_net = torch.load('policy_net_full.pth')
    while not done:
        state_tensor = torch.tensor(state).float().unsqueeze(0)
        probs = policy_net(state_tensor)
        action = torch.distributions.Categorical(probs).sample().item()
        log_prob = torch.log(probs.squeeze(0)[action])
        next_state, reward, done, _,_ = env.step(action)

        # print(episode)
        trajectories.append({"state": state, "action": action, "reward": reward, "log_prob": log_prob})
        state = next_state

        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                sys.exit()

        
        # 绘制环境状态
        screen.fill((255, 255, 255))
        cart_x = int(state[0] * 100 + 300)
        pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30))
        # print(state)
        pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 2)
        pygame.display.flip()
        clock.tick(60)
        step +=1

    print(f"运行{step}步后挂了")



if __name__ == '__main__':
    main() #训练
    # play() #推理

  运行效果,训练过程不是很稳定,有时候学很多轮次也学不明白,有时侯只需要几十次就可以学明白了

 

标签:cartpole,torch,PID,平衡杆,print,state,pygame,policy,returns
From: https://www.cnblogs.com/LiuXinyu12378/p/18187947

相关文章

  • rapidjson
    一、简介RapidJSON是腾讯开源的一个高效的C++JSON解析器及生成器,它是只有头文件的C++库。RapidJSON是跨平台的,支持Windows、Linux、MacOSX及iOS、Android。writer和prettywriter都是将JSON数据打包为字符串的方法。官网:https://rapidjson.org/zh-cn/index.html1.1write和pr......
  • PID 控制详解
    阶跃响应阶跃响应是指将一个阶跃输入(stepfunction)加到系统上时,系统的输出。稳态误差是指系统的响应进入稳态后﹐系统的期望输出与实际输出之差。控制系统的性能可以用稳、准、快三个字来描述。稳是指系统的稳定性(stability),一个系统要能正常工作,首先必须是稳定的,从阶跃响应上看......
  • 475-便携式手提RapidIO协议光纤发包测试仪
    便携式手提RapidIO协议光纤发包测试仪一、平台简介   便携式手提RapidIO协议光纤发包仪,以RapidIO收发卡和X86主板为基础,构建便携式的手提设备。   RapidIO收发卡是以KU060PCIeX4的双路QSFP+光纤收发卡,支持双路RapidIOX4数据的收发设计。   ......
  • AppSpider Pro 7.5.009 for Windows - Web 应用程序安全测试
    AppSpiderPro7.5.009forWindows-Web应用程序安全测试Rapid7DynamicApplicationSecurityTesting(DAST)请访问原文链接:https://sysin.org/blog/appspider/,查看最新版。原创作品,转载请保留出处。作者主页:sysin.orgappspider没有任何应用程序未经测试,没有未知风险......
  • 【Python】Q-Learning处理CartPole-v1
    上一篇配置成功gym环境后,就可以利用该环境做强化学习仿真了。这里首先用之前学习过的qlearning来处理CartPole-v1模型。CartPole-v1是一个倒立摆模型,目标是通过左右移动滑块保证倒立杆能够尽可能长时间倒立,最长步骤为500步。模型控制量是左0、右1两个。模型状态量为下面四个:......
  • 控制自行车前进/后退/平衡等动作,有必要使用在控制方面使用人工智能算法吗,还是传统的PI
    直接说答案,用不到人工智能算法做控制,现在人工智能算法主要的应用领域为感知学习,比较典型的就是图像识别和自然语言对话系统,而在控制算法上人工智能的解决方案依然不是很成熟,目前世界上唯一一个宣布可以落地的是特斯拉的老马搞出的那个FSD的自动驾驶,除此之外就没有第二个使用智能控......
  • 18--Scrapy04--CrawlSpider、源码模板文件
    Scrapy04--CrawlSpider、源码模板文件案例:汽车之家,全站抓取二手车的信息来区分Spider和CrawlSpider注意:汽车之家的访问频率要控制一下,要不然会跳验证settings.py中设置DOWNLOAD_DELAY=3一、常规Spider#spiders/Ershou.pyimportscrapyfromscrapy.linkextra......
  • 解决MySQL安装错误:`The server quit without updating PID file`
    在MySQL安装或启动过程中,你可能会遇到如下错误信息:TheserverquitwithoutupdatingPIDfile(/var/lib/mysql/your_hostname.pid).这个错误通常表明MySQL服务器尝试启动时遇到了问题,导致它异常终止而未能更新PID文件。PID文件用于存储启动的MySQL服务进程的ID。本文旨......
  • NL2SQL基础系列(1):业界顶尖排行榜、权威测评数据集及LLM大模型(Spider vs BIRD)全面对比
    NL2SQL基础系列(1):业界顶尖排行榜、权威测评数据集及LLM大模型(SpidervsBIRD)全面对比优劣分析[Text2SQL、Text2DSL]Text-to-SQL(或者Text2SQL),顾名思义就是把文本转化为SQL语言,更学术一点的定义是:把数据库领域下的自然语言(NaturalLanguage,NL)问题,转化为在关系型数据库中可以执行的......
  • scrapy框架之CrawlSpider全站爬取
    一、什么是全站爬取全站爬取(CrawltheEntireSite)是指通过网络爬虫程序遍历并获取一个网站上的所有页面或资源的过程。这种爬取方式旨在获取网站的全部内容,包括文本、图片、视频、链接等,以建立网站的完整内容索引或进行数据分析。二、全栈爬取介绍1、全站数据爬取的方式(1)通过......