import gym import torch import torch.nn as nn import torch.optim as optim import random import pygame import sys from collections import deque # 定义DQN模型 class DQN(nn.Module): def __init__(self): super(DQN, self).__init__() self.network = nn.Sequential( nn.Linear(4, 128), nn.ReLU(), nn.Linear(128, 2) # 2个动作 ) def forward(self, x): return self.network(x) # 经验回放 class ReplayBuffer: def __init__(self, capacity): self.buffer = deque(maxlen=capacity) def push(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): batch = random.sample(self.buffer, batch_size) state, action, reward, next_state, done = zip(*batch) return state, action, reward, next_state, done def __len__(self): return len(self.buffer) # 训练函数 def optimize_model(): if len(memory) < BATCH_SIZE: return states, actions, rewards, next_states, dones = memory.sample(BATCH_SIZE) states = torch.tensor(states, dtype=torch.float) next_states = torch.tensor(next_states, dtype=torch.float) actions = torch.tensor(actions, dtype=torch.long) rewards = torch.tensor(rewards, dtype=torch.float) dones = torch.tensor(dones, dtype=torch.float) current_q_values = model(states).gather(1, actions.unsqueeze(1)).squeeze(1) next_q_values = model(next_states).max(1)[0].detach() expected_q_values = rewards + 0.99 * next_q_values * (1 - dones) loss = criterion(current_q_values, expected_q_values) optimizer.zero_grad() loss.backward() optimizer.step() # 设置环境和模型 env = gym.make('CartPole-v1') model = DQN() memory = ReplayBuffer(10000) optimizer = optim.Adam(model.parameters()) criterion = nn.MSELoss() BATCH_SIZE = 128 EPSILON = 0.2 pygame.init() screen = pygame.display.set_mode((600, 400)) clock = pygame.time.Clock() # 开始训练 num_episodes = 500 for episode in range(num_episodes): state = env.reset() total_reward = 0 done = False state = state[0] while not done: if random.random() < EPSILON: action = env.action_space.sample() else: state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0) action = model(state_tensor).max(1)[1].item() next_state, reward, done, _,_ = env.step(action) memory.push(state, action, reward, next_state, done) state = next_state total_reward += reward optimize_model() # Pygame visualization for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() sys.exit() screen.fill((255, 255, 255)) cart_x = int(state[0] * 100 + 300) pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30)) pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 5) pygame.display.flip() clock.tick(60) EPSILON *= 0.995 # 减少探索率 print(f'Episode {episode}: Total Reward = {total_reward}') if __name__ == '__main__': main()
标签:__,cartpole,游戏,self,torch,next,state,DQN,tensor From: https://www.cnblogs.com/LiuXinyu12378/p/18189046