首页 > 编程语言 >强化学习代码实战-06 Dueling DQN 算法

强化学习代码实战-06 Dueling DQN 算法

时间:2022-11-14 18:24:25浏览次数:59  
标签:06 torch next state done action DQN reward Dueling

引入优势函数A,优势函数A = 状态动作价值函数Q - 状态价值函数V。

在同一状态下,所有动作的优势值为零。因为,所有的动作的状态动作价值的期望就是状态价值。

实现代码:

import random
import gym
import torch
import numpy as np
from matplotlib import pyplot as plt
from IPython import display

env = gym.make("Pendulum-v0")
# 智能体状态
state = env.reset()
# 动作空间
actions = env.action_space
print(state, actions)
# 打印游戏
# plt.imshow(env.render(mode='rgb_array'))
# plt.show()


"""重新定义策略价值网络Q, 比DQN性能更优"""
class VAnet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.fc = torch.nn.Sequential(torch.nn.Linear(3, 128),
                                     torch.nn.ReLU())
        self.fc_A = torch.nn.Linear(128, 11)
        self.fc_V = torch.nn.Linear(128, 1)
        
    def forward(self, x):
        A = self.fc_A(self.fc(x))
        V = self.fc_V(self.fc(x))
        A_mean = A.mean(dim=1).reshape(-1, 1)
        A = A -  A_mean
        # Q值由A和V求和得到
        Q = A + V
        
        return Q
    
    
# 定义动作模型(策略网络)
model = VAnet()

# 经验网络,评估一个动作的分数(目标网络)
next_model = VAnet()
# model的参数赋予next_model
next_model.load_state_dict(model.state_dict())

# 得到一个动作
def get_action(state):
    """state: agent所处的状态。由于是连续动作,做离散化操作"""
    # 走神经网络NN,得到分值最大的那个动作。转为tensor数据
    state = torch.FloatTensor(state).reshape(1, 3)
    action = model(state).argmax().item()
    if random.random() < 0.01:
        action = random.choice(range(11))
    # 离散动作连续化
    action_continuous = action
    action_continuous /= 10
    action_continuous *= 4
    action_continuous -= 2
    
    return action, action_continuous


# 数据池
datas = []
def update_data():
    """加入新的N条数据,删除最老的M条数据"""
    count = len(datas)
    while len(datas) - count < 200:
        # 一直追加数据,尽可能多的获取环境状态
        state = env.reset()
        done = False
        while not done:
            # 由初始状态开始得到一个动作
            action, action_continuous = get_action(state)
            next_state, reward, done, _ = env.step([action_continuous])
            datas.append((state, action, reward, next_state, done))
            # 更新状态
            state = next_state
    # 此时新数据集中比原来多了大约200条样本,如果超过了最大容量,删除最开始数据
    update_count = len(datas) - count
    while len(datas) > 5000:
        datas.pop(0)
    return update_count

# 从数据池中采样
def get_sample():
    # batch size = 64, 数据类型转换为Tensor
    samples = random.sample(datas, 64)
    state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)
    action = torch.LongTensor([i[1] for i in samples]).reshape(-1, 1)
    reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)
    next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)
    done = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)
    
    return state, action, reward, next_state, done

# 获取动作价值
def get_value(state, action):
    """根据网络输出找到对应动作的得分,使用策略网络"""
    action_value = model(state)
    action_value = action_value.gather(dim=1, index=action)
    
    return action_value

# 获取学习目标值
def get_target(next_state, reward, done):
    """使用next_state和reward计算真实得分。对价值的估计,使用目标网络"""
    with torch.no_grad():
        target = next_model(next_state)
        
    target = target.max(dim=1)[0].reshape(-1, 1)
    target *= (1 - done)        # 游戏结束的状态,没有奖励
    
    target = reward + target * 0.98
    
    return target

# 一局游戏得分测试
def test():
    reward_sum = 0
    
    state = env.reset()
    done = False
    
    while not done:
        _, action_continuous = get_action(state)
        next_state, reward, done, _ = env.step([action_continuous])
        reward_sum += reward
        state = next_state
        
    return reward_sum

def train():
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)
    loss_fn = torch.nn.MSELoss()
    
    for epoch in range(600):
        # 更新一批数据
        update_counter = update_data()
        
        # 更新过数据后,学习N次
        for i in range(200):
            state, action, reward, next_state, done = get_sample()
            # 计算value和target
            value = get_value(state, action)
            target = get_target(next_state, reward, done)
            
            # 参数更新
            loss = loss_fn(value, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            """周期性更新目标网络"""
            if (i + 1) % 10 == 0:
                next_model.load_state_dict(model.state_dict())
            
        if epoch % 50 == 0:
            test_score = sum([test() for i in range(50)]) / 50
            print(epoch, len(datas), update_counter, test_score)
        

 

标签:06,torch,next,state,done,action,DQN,reward,Dueling
From: https://www.cnblogs.com/demo-deng/p/16889897.html

相关文章

  • 06python序列
    数据结构是Python中一个很重要的概念,是以某种方式(如通过编号)组合起来的数据元素(如数字、字符乃至其他数据结构)的集合。在Python中,最基本的数据结构是序列(sequence)。......
  • 06基础元器件-三极管
    一、三极管的定义与分类1、定义导体三极管又称为双极结型晶体管(BJT),是一种具有三个电极的装置。实质上就一块半导体基片上的两个PN结将其隔成基区、发射区和集电区,从而......
  • 小爬爬4:12306自动登录&&pyppeteer基本使用
    超级鹰(更简单的操作验证)-超级鹰-注册:普通用户-登陆:-创建一个软件(id)-下载示例代码  1.12306自动登录#Author:studybrothersunfromsele......
  • 223201062520-软件工程基础Y- 实验一 朱旭个人项目报告
    沈阳航空航天大学软件工程基础实验报告实验名称:实验一实验题目:个人项目专业软件工程学号223201062520姓名朱旭指导教师孟桂英成绩完成......
  • day17-Servlet06
    Servlet0615.HttpServletResponse15.1HttpServletResponse介绍每次HTTP请求,Tomcat都会创建一个HttpServletResponse对象传递给Servlet程序使用HttpServletRequest表示......
  • 强化学习代码实战-06 Double DQN算法
    解决DQN的高估问题。即利用一套神经网络的输出选取价值最大的动作,但在使用该动作的价值时,用另一套神经网络计算该动作的价值。importrandomimportgymimporttorchim......
  • 力扣206 反转链表
    题目:给你单链表的头节点head,请你反转链表,并返回反转后的链表。示例:输入:head=[1,2,3,4,5]输出:[5,4,3,2,1] 双指针法:两个指针,cur指向当前节点,用来遍历,pre......
  • 223201062522黄宇轩 223201062523李凌桦-软件工程基础Y- 实验二 结对项目报告
    沈阳航空航天大学  软  件  工 程 基 础实验报告 实验名称:实验二实验题目:结对项目   专   业软件工程学   号223......
  • 223201062506 王靖榕 223201062507 王静怡-软件工程基础Y-实验二结对项目
    沈阳航空航天大学  软 件 工 程 基 础实验报告 实验名称:实验二实验题目:结对项目   专   业软件工程学   号22320......
  • 223201062521黄宇轩 223201062523李凌桦 实验二结对项目
    沈阳航空航天大学  软  件  工 程 基 础实验报告 实验名称:实验二实验题目:结对项目   专   业软件工程学   号223......