首页 > 编程语言 >强化学习——DQN算法

强化学习——DQN算法

时间:2023-07-31 10:11:06浏览次数:26  
标签:target torch state next 算法 action DQN reward 强化

1、DQN算法介绍

DQN算与sarsa算法和Q-learning算法类似,对于sarsa和Q-learning,我们使用一个Q矩阵,记录所有的state(状态)和action(动作)的价值,不断学习更新,最后使得机器选择在某种状态下,价值最高的action进行行动。但是当state和action的数量特别大的时候,甚至有限情况下不可数时,这时候再用Q矩阵去记录对应价值就会有很大的局限性,而DQN就是在这一点上进行了改进。在DQN中,对于每种state下采取的action所对应价值,我们将会使用神经网络来进行计算。

2、平衡车游戏的实例

平衡车游戏在gym库中,这里需要下载gym库,这个游戏中,状态由四个数字来进行表示(我也不知道这四个数字代表什么,但是无伤大雅),接着只会有两种行动,并且reward并不需要我们进行设置,这个游戏进行过程中会自己返回reward。现在先搭建环境

env = gym.make('CartPole-v1')
env.reset()
#打印游戏
def show():
plt.imshow(env.render(mode='rgb_array'))
plt.axis('off')
plt.show()
#show()
搭建环境后再创建两个神经网络,并且要让两个神经网络的参数一致,在后续的过程中,一个神经网络会延迟更新。这两个神经网络会以四个状态参数作为输入,然后以两个动作的评分作为输出

#创建神经网络

#计算动作模型,也就是真正需要使用的模型
model = torch.nn.Sequential(
torch.nn.Linear(4,128),
torch.nn.ReLU(),
torch.nn.Linear(128,2),
)

#经验网络,用于评估状态分数
next_model = torch.nn.Sequential(
torch.nn.Linear(4,128),
torch.nn.ReLU(),
torch.nn.Linear(128,2)
)

#把两个神经网络的参数统一一下
next_model.load_state_dict(model.state_dict())

#print(model,next_model)

接下来我们需要创建一个样本池,神经网络会在这个样本池中进行学习,随后需要不断更新我们的样本池,当我们有一个新的行动时,我们应该添加新的样本,删除旧的样本,保持样本池最大数量不变

#想样本池中添加一些数据,删除一些古老的数据
def update_date():
old_count = len(datas)

while len(datas) - old_count<200:
#初始化
state = env.reset()

over = False
while not over:
#获取当前状态得到一个动作
action = get_action(state)

#执行动作,得到反馈
next_state,reward,over,_ = env.step(action)

#记录样本
datas.append((state,action,reward,next_state,over))

#更新状态
state = next_state

update_count = len(datas) - old_count
drop_count = max(len(datas)-10000,0)

while len(datas)>10000:
datas.pop(0)
return update_count,drop_count

接下来需要进行采样,并将样本格式转换为所需要的格式

#获取一批数据样本
def get_sample():
# 从样本池中采样
samples = random.sample(datas, 64)

state = np.array([i[0] for i in samples])
action = np.array([i[1] for i in samples])
reward = np.array([i[2] for i in samples])
next_state = np.array([i[3] for i in samples])
over = np.array([i[4] for i in samples])

state = torch.FloatTensor(state).reshape(-1, 4)
action = torch.LongTensor(action).reshape(-1, 1)
reward = torch.FloatTensor(reward).reshape(-1, 1)
next_state = torch.FloatTensor(next_state).reshape(-1, 4)
over = torch.LongTensor(over).reshape(-1, 1)

return state, action, reward, next_state, over

接着是价值函数(直接交给神经网络就可以了)

def get_value(state,action):
value = model(state)

value = value.gather(dim=1,index=action)
return value
接下来是获取target函数,这个函数的意义在于,我们是不知道游戏的全貌的,这样的话在一个状态下所采取的行动,不仅仅至于它本身有关,更和接下来所到达的状态和接下来应该采取的行动有关,价值value应该要想target靠近

def get_target(reward,next_state,over):
with torch.no_grad():
target = next_model(next_state)

target = target.max(dim = 1)[0]
target = target.reshape(-1,1)

target *= 0.98

target *= (1-over)#游戏结束了就不用玩了

target += reward

return target
然后就是测试函数并开始训练了

def train():
model.train()
optimizer = torch.optim.Adam(model.parameters(),lr=2e-3)
loss_fn = torch.nn.MSELoss()


for epoch in range(500):
update_count,drop_count = update_date()

for i in range(200):
state,action,reward,next_state,over = get_sample()

value = get_value(state,action)
target = get_target(reward,next_state,over)

loss = loss_fn(value,target)
optimizer.zero_grad()
loss.backward()
optimizer.step()

if (i+1)%10==0:
next_model.load_state_dict(model.state_dict())

if epoch%50==0:
test_result = sum([tes(play=False) for _ in range(20)])/20
print(f"Epoch: {epoch}, Data Size: {len(datas)}, Update: {update_count}, Drop: {drop_count}, Test Reward: {test_result}")

接下来是完整代码

#这里会使用神经网络
import gym
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
from IPython import display
from matplotlib.animation import FuncAnimation

env = gym.make('CartPole-v1')
env.reset()
#打印游戏
def show():
plt.imshow(env.render(mode='rgb_array'))
plt.axis('off')
plt.show()
#show()

#创建神经网络

#计算动作模型,也就是真正需要使用的模型
model = torch.nn.Sequential(
torch.nn.Linear(4,128),
torch.nn.ReLU(),
torch.nn.Linear(128,2),
)

#经验网络,用于评估状态分数
next_model = torch.nn.Sequential(
torch.nn.Linear(4,128),
torch.nn.ReLU(),
torch.nn.Linear(128,2)
)

#把两个神经网络的参数统一一下
next_model.load_state_dict(model.state_dict())

#print(model,next_model)


#定义动作函数
def get_action(state):
if random.random()<0.01:
return random.choice([0,1])

state = torch.FloatTensor(state).reshape(1,4)

return model(state).argmax().item()

#样本池
datas = []

#想样本池中添加一些数据,删除一些古老的数据
def update_date():
old_count = len(datas)

while len(datas) - old_count<200:
#初始化
state = env.reset()

over = False
while not over:
#获取当前状态得到一个动作
action = get_action(state)

#执行动作,得到反馈
next_state,reward,over,_ = env.step(action)

#记录样本
datas.append((state,action,reward,next_state,over))

#更新状态
state = next_state

update_count = len(datas) - old_count
drop_count = max(len(datas)-10000,0)

while len(datas)>10000:
datas.pop(0)
return update_count,drop_count

#获取一批数据样本
def get_sample():
# 从样本池中采样
samples = random.sample(datas, 64)

state = np.array([i[0] for i in samples])
action = np.array([i[1] for i in samples])
reward = np.array([i[2] for i in samples])
next_state = np.array([i[3] for i in samples])
over = np.array([i[4] for i in samples])

state = torch.FloatTensor(state).reshape(-1, 4)
action = torch.LongTensor(action).reshape(-1, 1)
reward = torch.FloatTensor(reward).reshape(-1, 1)
next_state = torch.FloatTensor(next_state).reshape(-1, 4)
over = torch.LongTensor(over).reshape(-1, 1)

return state, action, reward, next_state, over


def get_value(state,action):
value = model(state)

value = value.gather(dim=1,index=action)
return value

def get_target(reward,next_state,over):
with torch.no_grad():
target = next_model(next_state)

target = target.max(dim = 1)[0]
target = target.reshape(-1,1)

target *= 0.98

target *= (1-over)

target += reward

return target

#测试函数
def tes(play):
#初始化
state = env.reset()

#记录reward之和,越大越好
reward_sum = 0

over = False
while not over:
#获取动作
action = get_action(state)

#执行动作
state,reward,over,_ = env.step(action)
reward_sum +=reward

#打印动画
if play and random.random()<0.2:#跳帧
display.clear_output(wait=True)
show()
return reward_sum

def train():
model.train()
optimizer = torch.optim.Adam(model.parameters(),lr=2e-3)
loss_fn = torch.nn.MSELoss()


for epoch in range(500):
update_count,drop_count = update_date()

for i in range(200):
state,action,reward,next_state,over = get_sample()

value = get_value(state,action)
target = get_target(reward,next_state,over)

loss = loss_fn(value,target)
optimizer.zero_grad()
loss.backward()
optimizer.step()

if (i+1)%10==0:
next_model.load_state_dict(model.state_dict())

if epoch%50==0:
test_result = sum([tes(play=False) for _ in range(20)])/20
print(f"Epoch: {epoch}, Data Size: {len(datas)}, Update: {update_count}, Drop: {drop_count}, Test Reward: {test_result}")


train()
为什么需要两个神经网络?

这个两个神经网络其实是一致的,只是其中一个会延迟更新参数。

这个原因在于假如在一个游戏中,我们的目标状态并不是固定,可能是一直变换的,就如这个游戏中,平衡的状态是多种多样的,那么我们一直跟踪这个目标就会变得困难,这时我们不妨固定住某一个曾经是目标的状态,让机器先尝试去达到这种状态,再去跟踪下一个固定目标状态,这样的方式会使得机器更容易找到目标状态。这也是为什么需要一个一样的,但是延迟更新的神经网络。

标签:target,torch,state,next,算法,action,DQN,reward,强化
From: https://www.cnblogs.com/humanplug/p/17592707.html

相关文章

  • 基于内容的个性化推荐算法-电影推荐系统
    之前在博客中介绍了协同过滤算法在电影推荐系统中的应用。今天我将向大家分享另一种常见的推荐算法——基于内容的推荐算法,并使用它来实现一个个性化的电影推荐系统。基于内容的推荐算法原理:基于内容的推荐算法是一种常用的推荐方法,它通过分析电影本身的特征来进行推荐。在电影推荐......
  • 基于标签的个性化推荐算法-电影推荐系统
    之前在博客中介绍了协同过滤算法和基于内容的推荐算法在电影推荐系统中的应用。今天我将向大家介绍另一种常见的推荐算法——基于标签的推荐算法,并使用它来实现一个更加个性化的电影推荐系统。基于标签的推荐算法原理:基于标签的推荐算法是一种利用用户标记信息进行推荐的算法。在电......
  • [算法学习笔记] 强连通分量
    DFS生成树在介绍强连通分量前,我们先来了解一下DFS生成树。一棵DFS生成树分为树边,前向边,返祖边(一说反向边),横叉边。我们来画图解释一下:在这棵DFS生成树中,黑色为树边,它是在DFS遍历时获得的,红色为返祖边,顾名思义,从儿子指向父亲或祖先。蓝色为横叉边,它是在搜索的时候遇到子树中的节......
  • C#冒泡排序算法
    冒泡排序实现原理冒泡排序是一种简单的排序算法,其原理如下:从待排序的数组的第一个元素开始,依次比较相邻的两个元素。如果前面的元素大于后面的元素(升序排序),则交换这两个元素的位置,使较大的元素“冒泡”到右侧。继续比较下一对相邻元素,重复步骤2,直到遍历到数组的倒数第二......
  • 代码随想录算法训练营第四天| LeetCode 24. 两两交换链表中的节点 19.删除链表的倒
    24.两两交换链表中的节点     卡哥建议:用虚拟头结点,这样会方便很多。 本题链表操作就比较复杂了,建议大家先看视频,视频里我讲解了注意事项,为什么需要temp保存临时节点。   题目链接/文章讲解/视频讲解:https://programmercarl.com/0024.%E4%B8%A4%E4%B8%A4%E4%BA%......
  • 文心一言 VS 讯飞星火 VS chatgpt (68)-- 算法导论6.5 7题
    文心一言VS讯飞星火VSchatgpt(68)--算法导论6.57题七、试说明如何使用优先队列来实现一个先进先出队列,以及如何使用优先队列来实现栈(队列和栈的定义见10.1节。)文心一言:优先队列是一种数据结构,其中每个元素都有一个关联的优先级或权值。在优先队列中,元素按照它们的优先级被......
  • JavaScript学习 -- 对称加密算法DES
    在现代的互联网时代,数据安全性备受关注。为了保护敏感数据的机密性,对称加密算法是一种常用的方法。在JavaScript中,DES(DataEncryptionStandard)是一种常用的对称加密算法。本篇博客将为您展示如何在JavaScript中使用DES算法进行加密和解密,并提供一个实际的例子。首先,我们需要使用Cr......
  • 文心一言 VS 讯飞星火 VS chatgpt (68)-- 算法导论6.5 7题
    文心一言VS讯飞星火VSchatgpt(68)--算法导论6.57题七、试说明如何使用优先队列来实现一个先进先出队列,以及如何使用优先队列来实现栈(队列和栈的定义见10.1节。)文心一言:优先队列是一种数据结构,其中每个元素都有一个关联的优先级或权值。在优先队列中,元素按照它们的优先......
  • 408-数据结构算法题笔记
    常用基本操作1.定义整数无穷大#defineINT_MAX=0x7f7f7f7f;2.绝对值函数intabs_(intx){ if(x<0)return-x; returnx;}3.最大最小值函数(一般可以直接写吧)intmin(inta,intb){ if(a<b)returna; returnb;}说明时空间复杂度可以先设neg:代码规范1.函......
  • 基于Alexnet深度学习神经网络的人脸识别算法matlab仿真
    1.算法理论概述       人脸识别是计算机视觉领域中一个重要的研究方向,其目的是识别不同人的面部特征以实现自动身份识别。随着深度学习神经网络的发展,基于深度学习神经网络的人脸识别算法已经成为了当前最先进的人脸识别技术之一。本文将详细介绍基于AlexNet深度学习神经......