cartpole游戏,车上顶着一个自由摆动的杆子,实现杆子的平衡,杆子每次倒向一端车就开始移动让杆子保持动态直立的状态,策略函数使用一个两层的简单神经网络,输入状态有4个,车位置,车速度,杆角度,杆速度,输出action为左移动或右移动,输入状态发现至少要给3个才能稳定一会儿,给2个完全学不明白,给4个能学到很稳定的policy
策略梯度实现代码,使用torch实现一个简单的神经网络
import gym import torch import torch.nn as nn import torch.optim as optim import pygame import sys from collections import deque import numpy as np # 策略网络定义 class PolicyNetwork(nn.Module): def __init__(self): super(PolicyNetwork, self).__init__() self.fc = nn.Sequential( nn.Linear(4, 10), # 4个状态输入,128个隐藏单元 nn.Tanh(), nn.Linear(10, 2), # 输出2个动作的概率 nn.Softmax(dim=-1) ) def forward(self, x): # print(x) 车位置 车速度 杆角度 杆速度 selected_values = x[:, [0,1,2,3]] #只使用车位置和杆角度 return self.fc(selected_values) # 训练函数 def train(policy_net, optimizer, trajectories): policy_net.zero_grad() loss = 0 print(trajectories[0]) for trajectory in trajectories: # if trajectory["returns"] > 90: # returns = torch.tensor(trajectory["returns"]).float() # else: returns = torch.tensor(trajectory["returns"]).float() - torch.tensor(trajectory["step_mean_reward"]).float() # print(f"获得奖励{returns}") log_probs = trajectory["log_prob"] loss += -(log_probs * returns).sum() # 计算策略梯度损失 loss.backward() optimizer.step() return loss.item() # 主函数 def main(): env = gym.make('CartPole-v1') policy_net = PolicyNetwork() optimizer = optim.Adam(policy_net.parameters(), lr=0.01) print(env.action_space) print(env.observation_space) pygame.init() screen = pygame.display.set_mode((600, 400)) clock = pygame.time.Clock() rewards_one_episode= [] for episode in range(10000): state = env.reset() done = False trajectories = [] state = state[0] step = 0 torch.save(policy_net, 'policy_net_full.pth') while not done: state_tensor = torch.tensor(state).float().unsqueeze(0) probs = policy_net(state_tensor) action = torch.distributions.Categorical(probs).sample().item() log_prob = torch.log(probs.squeeze(0)[action]) next_state, reward, done, _,_ = env.step(action) # print(episode) trajectories.append({"state": state, "action": action, "reward": reward, "log_prob": log_prob}) state = next_state for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() sys.exit() step +=1 # 绘制环境状态 if rewards_one_episode and rewards_one_episode[-1] >99: screen.fill((255, 255, 255)) cart_x = int(state[0] * 100 + 300) pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30)) # print(state) pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 2) pygame.display.flip() clock.tick(200) print(f"第{episode}回合",f"运行{step}步后挂了") # 为策略梯度计算累积回报 returns = 0 for traj in reversed(trajectories): returns = traj["reward"] + 0.99 * returns traj["returns"] = returns if rewards_one_episode: # print(rewards_one_episode[:10]) traj["step_mean_reward"] = np.mean(rewards_one_episode[-10:]) else: traj["step_mean_reward"] = 0 rewards_one_episode.append(returns) # print(rewards_one_episode[:10]) train(policy_net, optimizer, trajectories) def play(): env = gym.make('CartPole-v1') policy_net = PolicyNetwork() pygame.init() screen = pygame.display.set_mode((600, 400)) clock = pygame.time.Clock() state = env.reset() done = False trajectories = deque() state = state[0] step = 0 policy_net = torch.load('policy_net_full.pth') while not done: state_tensor = torch.tensor(state).float().unsqueeze(0) probs = policy_net(state_tensor) action = torch.distributions.Categorical(probs).sample().item() log_prob = torch.log(probs.squeeze(0)[action]) next_state, reward, done, _,_ = env.step(action) # print(episode) trajectories.append({"state": state, "action": action, "reward": reward, "log_prob": log_prob}) state = next_state for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() sys.exit() # 绘制环境状态 screen.fill((255, 255, 255)) cart_x = int(state[0] * 100 + 300) pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30)) # print(state) pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 2) pygame.display.flip() clock.tick(60) step +=1 print(f"运行{step}步后挂了") if __name__ == '__main__': main() #训练 # play() #推理
运行效果,训练过程不是很稳定,有时候学很多轮次也学不明白,有时侯只需要几十次就可以学明白了
标签:cartpole,torch,PID,平衡杆,print,state,pygame,policy,returns From: https://www.cnblogs.com/LiuXinyu12378/p/18187947