首页 > 编程语言 >强化学习代码实战-06 DQN算法(单模型)

强化学习代码实战-06 DQN算法(单模型)

时间:2022-11-13 11:11:29浏览次数:42  
标签:06 torch state next 算法 done action DQN reward

现在我们想在类似车杆的环境中得到动作价值函数,由于状态每一维度的值都是连续的,无法使用表格记录,因此一个常见的解决方法便是使用函数拟合(function approximation)的思想。由于神经网络具有强大的表达能力,因此我们可以用一个神经网络来表示函数。

import random
import gym
import torch
import numpy as np
from matplotlib import pyplot as plt
from IPython import display

env = gym.make("CartPole-v0")
# 智能体状态
state = env.reset()
# 动作空间
actions = env.action_space.n
print(state, actions)
# 打印游戏
# plt.imshow(env.render(mode='rgb_array'))
# plt.show()

# 定义动作模型
model = torch.nn.Sequential(torch.nn.Linear(4, 128),
                           torch.nn.ReLU(),
                           torch.nn.Linear(128, 2))
# 得到一个动作
def get_action(state):
    """state: agent所处的状态"""
    if random.random() < .1:
        return random.choice(range(2))
    # 走神经网络NN,得到分值最大的那个动作。转为tensor数据
    state = torch.FloatTensor(state).reshape(1, 4)
    
    return model(state).argmax().item()

# 数据池
datas = []
def update_data():
    """加入新的N条数据,删除最老的M条数据"""
    count = len(datas)
    while len(datas) - count < 200:
        # 一直追加数据,尽可能多的获取环境状态
        state = env.reset()
        done = False
        while not done:
            # 由初始状态开始得到一个动作
            action = get_action(state)
            next_state, reward, done, _ = env.step(action)
            datas.append((state, action, reward, next_state, done))
            # 更新状态
            state = next_state
    # 此时新数据集中比原来多了大约200条样本,如果超过了最大容量,删除最开始数据
    update_count = len(datas) - count
    while len(datas) > 10000:
        datas.pop(0)
    return update_count

# 从数据池中采样
def get_sample():
    # batch size = 64, 数据类型转换为Tensor
    samples = random.sample(datas, 64)
    state = torch.FloatTensor([i[0] for i in samples])
    action = torch.LongTensor([i[1] for i in samples])
    reward = torch.FloatTensor([i[2] for i in samples])
    next_state = torch.FloatTensor([i[3] for i in samples])
    done = torch.LongTensor([i[4] for i in samples])
    
    return state, action, reward, next_state, done

# 获取动作价值
def get_value(state, action):
    """根据网络输出找到对应动作的得分"""
    value = model(state)
    value = value[range(64), action]
    
    return value

# 获取学习目标值
def get_target(next_state, reward, done):
    """使用next_state和reward计算真实得分。对价值的估计"""
    with torch.no_grad():
        next_value = model(next_state)
    # 贪心选取最大价值
    target = next_value.max(dim=1)[0]
    # 如果next_state已经游戏结束,则其target得分为0
    for i in range(64):
        if done[i]:
            target[i] = 0
    target = reward + target * 0.98
    
    return target

# 一局游戏得分测试
def test():
    reward_sum = 0
    
    state = env.reset()
    done = False
    
    while not done:
        action = get_action(state)
        next_state, reward, done, _ = env.step(action)
        reward_sum += reward
        state = next_state
        
    return reward_sum

def train():
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)
    loss_fn = torch.nn.MSELoss()
    
    for epoch in range(600):
        # 更新一批数据
        update_counter = update_data()
        
        # 更新过数据后,学习N词
        for i in range(200):
            state, action, reward, next_state, done = get_sample()
            # 计算value和target
            value = get_value(state, action)
            target = get_target(next_state, reward, done)
            
            # 参数更新
            loss = loss_fn(value, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        if epoch % 50 == 0:
            test_score = sum([test() for i in range(50)]) / 50
            print(epoch, len(datas), update_counter, test_score)
        

最佳平均奖励可以达到200

 

标签:06,torch,state,next,算法,done,action,DQN,reward
From: https://www.cnblogs.com/demo-deng/p/16885612.html

相关文章

  • 实验三:朴素贝叶斯算法实验
    ##【实验目的】理解朴素贝叶斯算法原理,掌握朴素贝叶斯算法框架。 ##【实验内容】针对下表中的数据,编写python程序实现朴素贝叶斯算法(不使用sklearn包),对输入数据进行预......
  • 分组加密算法的CPA安全性证明
    分组加密算法的CPA安全性证明CPA安全模型构造分组加密算法令\(F\)是伪随机函数,定义一个消息长度为\(n\)的对称加密方案如下:Gen:输入:安全参数\(n\),输出\(k\),\(k\)是一......
  • [回溯算法]leetcode40. 组合总和 II(c实现)
    题目给定一个候选人编号的集合 candidates 和一个目标数 target ,找出 candidates 中所有可以使数字和为 target 的组合。candidates 中的每个数字在每个组合中......
  • 每日算法题之买卖股票的最好时机(一)
    买卖股票的最好时机(一)描述假设你有一个数组prices,长度为n,其中prices[i]是股票在第i天的价格,请根据这个价格数组,返回买卖股票能获得的最大收益1.你可以买入一次股票和......
  • HMM算法python实现
    基础介绍,后5项为基础5元素Q=['q0','q1','q2','q3']#状态集合States,共N种状态V=['v0','v1']#观测集合Observatio......
  • 实验三:朴素贝叶斯算法实验
    实验三:朴素贝叶斯算法实验|博客班级|https://edu.cnblogs.com/campus/czu/classof2020BigDataClass3-MachineLearning||----|----|----||作业要求|https://edu.cnblogs.......
  • C++ 面经:项目常见问题 ----- nagle算法,keepalive,Linger 选项
    nagle算法应用场景:1.对于实时性要求很高的交互上,我们不能使用nagle算法,比如FPS射击类PVP对抗类游戏,或者MMO类的对实时要求很高的游戏开发来说是显而易见需要禁掉的,因为假......
  • 前端学习-CSS-06-元素显示模式
    学习时间:2022.11.12元素显示模式块级元素特点:独占一行宽度默认等于其父母标签,高度由内容撑开宽高可以设置代表元素:div,p,h系列,ul,li,dl,dt,dd,form,header,n......
  • 实验三:朴素贝叶斯算法实验
    姓名:冯莹学号:201613305【实验目的】理解朴素贝叶斯算法原理,掌握朴素贝叶斯算法框架。【实验内容】针对下表中的数据,编写python程序实现朴素贝叶斯算法(不使用sklearn包......
  • 排序函数的算法(day12)
    今天尝试了三种数字的排序法。目的为1)熟悉数组的操作2)熟悉循环笔者是做嵌入式的,不想再算法上做过多探究,自身水平和专业也不允许深入太多。现在直接给出三种排序函数。1.插值......