首页 > 其他分享 >DQN玩cartpole游戏

DQN玩cartpole游戏

时间:2024-05-13 13:54:15浏览次数:9  
标签:__ cartpole 游戏 self torch next state DQN tensor

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import random
import pygame
import sys
from collections import deque

# 定义DQN模型
class DQN(nn.Module):
    def __init__(self):
        super(DQN, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(4, 128),
            nn.ReLU(),
            nn.Linear(128, 2)  # 2个动作
        )

    def forward(self, x):
        return self.network(x)

# 经验回放
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)
        return state, action, reward, next_state, done

    def __len__(self):
        return len(self.buffer)

# 训练函数
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    states, actions, rewards, next_states, dones = memory.sample(BATCH_SIZE)
    
    states = torch.tensor(states, dtype=torch.float)
    next_states = torch.tensor(next_states, dtype=torch.float)
    actions = torch.tensor(actions, dtype=torch.long)
    rewards = torch.tensor(rewards, dtype=torch.float)
    dones = torch.tensor(dones, dtype=torch.float)

    current_q_values = model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
    next_q_values = model(next_states).max(1)[0].detach()
    expected_q_values = rewards + 0.99 * next_q_values * (1 - dones)

    loss = criterion(current_q_values, expected_q_values)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# 设置环境和模型
env = gym.make('CartPole-v1')
model = DQN()
memory = ReplayBuffer(10000)
optimizer = optim.Adam(model.parameters())
criterion = nn.MSELoss()
BATCH_SIZE = 128
EPSILON = 0.2

pygame.init()
screen = pygame.display.set_mode((600, 400))
clock = pygame.time.Clock()

# 开始训练
num_episodes = 500
for episode in range(num_episodes):
    state = env.reset()
    total_reward = 0
    done = False
    state = state[0]
    while not done:
        if random.random() < EPSILON:
            action = env.action_space.sample()
        else:
            state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0)
            action = model(state_tensor).max(1)[1].item()
        
        next_state, reward, done, _,_ = env.step(action)
        memory.push(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward
        optimize_model()
        
        # Pygame visualization
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                sys.exit()

        screen.fill((255, 255, 255))
        cart_x = int(state[0] * 100 + 300)
        pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30))
        pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 5)
        pygame.display.flip()
        clock.tick(60)

    EPSILON *= 0.995  # 减少探索率
    print(f'Episode {episode}: Total Reward = {total_reward}')

if __name__ == '__main__':
    main()

 

标签:__,cartpole,游戏,self,torch,next,state,DQN,tensor
From: https://www.cnblogs.com/LiuXinyu12378/p/18189046

相关文章

  • actor critic 玩carpole游戏
     importgymimporttorchimporttorch.nnasnnimporttorch.optimasoptimimportpygameimportsys#定义Actor网络classActor(nn.Module):def__init__(self):super(Actor,self).__init__()self.fc=nn.Sequential(nn.Linea......
  • C120 树剖+李超树 P4069 [SDOI2016] 游戏
    视频链接:C120树剖+李超树P4069[SDOI2016]游戏_哔哩哔哩_bilibili    D12LuoguP3384【模板】轻重链剖分/树链剖分-董晓-博客园(cnblogs.com) LuoguP4069[SDOI2016]游戏//树剖+李超树O(nlognlognlogn)#include<iostream>#include<cstring>#in......
  • 45_jump Game II 跳跃游戏II
    45_jumpGameII跳跃游戏II问题描述链接:https://leetcode.com/problems/jump-game-ii/description/Youaregivena0-indexedarrayofintegersnumsoflengthn.Youareinitiallypositionedatnums[0].Eachelementnums[i]representsthemaximumlengthofafo......
  • DirectX 12 Ultimate 是微软在 DirectX 12 API 的基础上推出的一个新版本,它旨在为游戏
    DirectX12Ultimate是微软在DirectX12API的基础上推出的一个新版本,它旨在为游戏开发者提供更多的功能和支持,同时也为玩家带来更出色的游戏体验。下面我将简要介绍一下DirectX12Ultimate的特点和重要性:支持最新硬件特性:DirectX12Ultimate支持最新的硬件特性,包......
  • 策略梯度玩 cartpole 游戏,强化学习代替PID算法控制平衡杆
     cartpole游戏,车上顶着一个自由摆动的杆子,实现杆子的平衡,杆子每次倒向一端车就开始移动让杆子保持动态直立的状态,策略函数使用一个两层的简单神经网络,输入状态有4个,车位置,车速度,杆角度,杆速度,输出action为左移动或右移动,输入状态发现至少要给3个才能稳定一会儿,给2个完全学不明白,......
  • Python游戏制作大师,Pygame库的深度探索与实践
    写在前言hello,大家好,我是一点,专注于Python编程,如果你也对感Python感兴趣,欢迎关注交流。希望可以持续更新一些有意思的文章,如果觉得还不错,欢迎点赞关注,有啥想说的,可以留言或者私信交流。如果你想看什么主题的文章,欢迎留言交流,关注公众号【一点sir】,领取编程资料。如果你还不了......
  • 55-jump Game 跳跃游戏
    问题描述Youaregivenanintegerarraynums.Youareinitiallypositionedatthearray'sfirstindex,andeachelementinthearrayrepresentsyourmaximumjumplengthatthatposition.Returntrueifyoucanreachthelastindex,orfalseotherwise解释......
  • [附源码+文档]Java Swing小游戏源码合集(14款)_毕业设计必选项目
    (小众游戏塔防迷宫动作剧情类等)16款游戏源码Javaswing五子棋联网版源代码Javaswing贪吃蛇游戏开发教程+源码Javaswing超级玛丽游戏Javaswing俄罗斯方块项目源码Javaswing飞机大战游戏源码Javaswing雷电游戏源码Javaswing连连看游戏源码Javaswing模拟写字板源码......
  • AI已来,我与AI一起用Python编写了一个消消乐小游戏
    在数字化与智能化的浪潮中,目前AI(人工智能)几乎在各行各业中发挥了不可忽略的价值,今天让我们也来体验一下AI的威力:我通过命令,一步一步的教AI利用Python编程语言打造了一款富有创意和趣味性的消消乐小游戏……本文Python消消乐游戏源代码:https://gitee.com/obullxl/Pytho......
  • 爆爽,英语小白怒刷 50 课!像玩游戏一样学习英语~
    ###重点!!!(先看这)1.清楚自己学英语的`目的`,先搞清楚目标,再行动2.自身现在最需要的东西:`词汇量`?`口语`?还是`阅读能力`?3.找对应的书籍,学习资料4.往`兴趣靠拢`:网上有大量的推荐美剧学习、小说学习,不要被他们迷了眼,适合他们的不一定适合你,找到适合的你才能长期坚持下......