首页 > 编程语言 >【Python】DQN处理CartPole-v1

【Python】DQN处理CartPole-v1

时间:2024-06-08 16:11:04浏览次数:31  
标签:CartPole space Python torch state action DQN reward size

DQN是强化学习中的一种方法,是对Q-Learning的扩展。

通过引入深度神经网络、经验回放和目标网络等技术,使得Q-Learning算法能够在高维、连续的状态空间中应用,解决了传统Q-Learning方法在这些场景下的局限性。

Q-Learning可以见之前的文章

算法的几个关键点:

1. 深度学习估计状态动作价值函数:DQN利用Q-Learning算法思想,估计一个Q函数Q(s,a),表示在状态s下采取a动作得到的期望回报,估计该函数时利用深度学习的方法。

2. 经验回放:为了打破数据的相关性和提高样本效率,DQN引入了经验回放池。智能体在与环境交互时,会将每一个时间步的经验(s,a,r,s')存入回放池,每次更新网络时,随机从回放池中抽取一个小批量经验进行训练。

3. 目标网络:DQN算法使用两个神经网络:一个在线网络(用于选择动作)和一个目标网络(用于计算目标Q值)。目标网络的参数每隔一段时间才会从在线网络复制,以稳定训练过程。目标Q值的计算公式为:y = r+g*max(Q(s',a')),其中r为奖励,g为折扣因子,Q为目标网络。

代码如下:

import gym
import random
import warnings

import torch
import torch.nn as nn
import torch.optim as optim
warnings.filterwarnings("ignore")

class Net(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, output_size)
    def forward(self, x):
        x = torch.relu(self.linear1(x))
        x = torch.relu(self.linear2(x))
        x = self.linear3(x)
        return x     

if __name__ == '__main__':

    negative_reward = -10.0
    positive_reward = 10.0
    x_bound = 1.0
    gamma = 0.9
    batch_size = 64
    capacity = 1000
    buffer=[]
    env = gym.make('CartPole-v1')
    
    state_space_num = env.observation_space.shape[0]
    action_space_dim = env.action_space.n  

    q_net = Net(state_space_num, 256, action_space_dim)
    target_q_net = Net(state_space_num, 256, action_space_dim)
    
    optimizer = optim.Adam(q_net.parameters(), lr=5e-4)

    for i in range(3000):
        state = env.reset()
                
        step = 0
        while True:
           # env.render()
            step +=1
            epsi = 1.0 / (i + 1)
            if random.random() < epsi:
                action = random.randrange(action_space_dim)
            else:
                state_tensor =  torch.tensor(state, dtype=torch.float).view(1,-1)
                action = torch.argmax(q_net(state_tensor)).item()
            
            next_state, reward, done, _ = env.step(action)
            x, x_dot, theta, theta_dot = state
            if (abs(x) > x_bound):
                r1 = 0.5 * negative_reward
            else:
                r1 = negative_reward * abs(x) / x_bound + 0.5 * (-negative_reward)
            if (abs(theta) > env.theta_threshold_radians):
                r2 = 0.5 * negative_reward
            else:
                r2 = negative_reward * abs(theta) / env.theta_threshold_radians + 0.5 * (-negative_reward)
            reward = r1 + r2
            if (done) and (step < 499):
                reward += negative_reward
                   
            if len(buffer)==capacity:
                buffer.pop(0)
            buffer.append((state, action, reward, next_state))
            
            state = next_state

            if len(buffer) < batch_size:
                continue
            
            samples = random.sample(buffer,batch_size)
            s0, a0, r1, s1 = zip(*samples)

            s0 = torch.tensor( s0, dtype=torch.float)
            a0 = torch.tensor( a0, dtype=torch.long).view(batch_size, 1)
            r1 = torch.tensor( r1, dtype=torch.float).view(batch_size, 1)
            s1 = torch.tensor( s1, dtype=torch.float)
            
            q_value = q_net(s0).gather(1, a0)
            q_target = r1 + gamma * torch.max(target_q_net(s1).detach(), dim=1)[0].view(batch_size, -1)

            loss_fn = nn.MSELoss()
            loss = loss_fn(q_value, q_target)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % 10==0:
                target_q_net.load_state_dict(q_net.state_dict())

            if done:
                print(i,step)
                break

    env.close()

基本在迭代100多次之后都能稳定到500步。

标签:CartPole,space,Python,torch,state,action,DQN,reward,size
From: https://www.cnblogs.com/tiandsp/p/18133282

相关文章

  • 使用Python进行容器编排【Docker Compose与Kubernetes的比较】
    ......
  • 【Python基础】集合(3997字)
    文章目录@[toc]什么是集合集合的特点元素不重复性示例无序性示例集合操作增加元素add()方法删除元素clear()方法pop()方法remove()方法交集intersection()方法&符号isdisjoint()方法并集union()方法|符号差集difference()方法-符号对称差集symmetric_difference(......
  • python -- series和 DataFrame增删改数据
    学习目标知道df添加新列的操作知道insert函数插入列数据知道drop函数删除df的行或列数据知道drop_duplicates函数对df或series进行数据去重知道unique函数对series进行数据去重知道apply函数的使用方法1DataFrame添加列注意:本文用到的数据集在文章顶部1.1......
  • 【Python】文件处理的魔法之旅
    目录 引言文件处理的重要性基本概念主体部分读取文件写入文件修改文件处理不同类型的文件文本文件CSV文件JSON文件示例代码代码解释案例研究结论参考文献引言你是否曾经面对一堆杂乱无章的文件,感到束手无策?是否曾梦想过拥有一种能力,能够轻松地读取、修改......
  • 2024华为OD机试真题-字符串分割(二)-(C++/Python)-C卷D卷-100分
    2024华为OD机试题库-(C卷+D卷)-(JAVA、Python、C++) 题目描述给定一个非空字符串S,其被N个‘-’分隔成N+1的子串,给定正整数K,要求除第一个子串外,其余的子串每K个字符组成新的子串,并用‘-’分隔。对于新组成的每一个子串,如果它含有的小写字母比大写字母多,则将这个子串的所有......
  • 2024华为OD机试真题-测试用例执行计划-(C++/Python)-C卷D卷-100分
     2024华为OD机试题库-(C卷+D卷)-(JAVA、Python、C++) 题目描述某个产品当前迭代周期内有N个特性(F1,F2,......FN)需要进行覆盖测试,每个特性都被评估了对应的优先级,特性使用其ID作为下标进行标识。设计了M个测试用例(T1,T2,......,TM),每个测试用例对应一个覆盖特性的集......
  • python系列:FASTAPI系列 01 环境准备 & FASTAPI系列 02-简单入门
    FASTAPI系列01环境准备&FASTAPI系列02-简单入门一、FASTAPI系列01环境准备前言一、FASTAPI简介二、环境准备1.快速安装fastapi以及相关依赖2.创建项目总结二、FASTAPI系列02-简单入门实现一个简单的例子一、FASTAPI系列01环境准备前言FastAPI是一......
  • 浔川贪吃蛇(完整版)——浔川python社
    废话不多说,直接上代码!#-*-coding:utf-8-*-importtkinterastkimporttkinter.messageboximportpickleimportrandom#窗口window=tk.Tk()window.title('欢迎进入python')window.geometry('450x200')#画布放置图片#canvas=tk.Canvas(window,height=300,......
  • 一篇文章学完Python基础知识
    一、数据类型和变量Python使用缩进来组织代码块,一般使用4个空格的缩进.使用#来注释一行,其他每一行都是一个语句,当语句以冒号:结尾时,缩进的语句视为代码块.Python对大小写敏感.1.1整数Python可以处理任意大小的整数,包括负整数,写法与数学上写法一致,例如:-100.如果用......
  • python-自幂数判断
    [题目描述]:自幂数是指,一个N位数,满足各位数字N次方之和是本身。例如,153153是33位数,其每位数的33次方之和,13+53+33=15313+53+33=153,因此153153是自幂数;16341634是44位数,其每位数的44次方之和,14+64+34+44=163414+64+34+44=1634,因此16341634是自幂数。现在,输入若......