首页 > 其他分享 >DQN——深度Q网络

DQN——深度Q网络

时间:2024-10-30 19:46:47浏览次数:8  
标签:self torch 网络 state 深度 action DQN size

目录

DQN 原理

DQN 实现代码

 代码要点


        DQN(Deep Q-Network)是一种深度强化学习算法,结合了 Q-learning 和神经网络,用于解决复杂的决策问题。它在游戏和控制任务中取得了出色的效果。DQN 的关键是利用神经网络来近似 Q 值函数,使得算法在较高维度的状态空间中也能有效工作。以下是 DQN 的原理概述及实现代码示例。


DQN 原理

  1. Q-learning:Q-learning 是一种无模型的强化学习算法。核心思想是学习一个 Q 值函数 Q(s, a),表示在状态 s 采取动作 a 后,能获得的未来总奖励的期望值。Q-learning 的更新规则为:

    Q(s, a) \leftarrow Q(s, a) + \alpha \left( r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right)

    其中 \alpha 是学习率,r 是奖励,\gamma 是折扣因子,s' 是下一个状态。

  2. 深度 Q 网络 (DQN):DQN 的核心在于用深度神经网络来近似 Q 值函数,来估计每个状态-动作对的 Q 值。Q 值函数的近似由网络参数 \theta 表示,即 Q(s, a; \theta)。训练目标是使网络输出的 Q 值与真实的 Q 值更接近。

  3. 经验回放:DQN 引入了经验回放缓冲区,用来存储代理与环境交互的经验 (state, action, reward, next_state),在训练时随机抽取小批量经验,减少样本间的相关性并稳定训练。

  4. 目标网络:DQN 中使用了两个网络:当前网络(current Q-network)和目标网络(target Q-network)。目标网络的参数每隔若干步更新为当前网络的参数,使得目标更稳定。

DQN 实现代码

        以下是基于 PyTorch 的 DQN 实现代码示例:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque

# 定义 Q 网络
class QNetwork(nn.Module):
    def __init__(self, state_size, action_size, hidden_size=64):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, action_size)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# 定义 DQN 代理
class DQNAgent:
    def __init__(self, state_size, action_size, gamma=0.99, lr=1e-3, batch_size=64, buffer_size=10000):
        self.state_size = state_size
        self.action_size = action_size
        self.gamma = gamma
        self.batch_size = batch_size
        
        # 初始化 Q 网络和目标网络
        self.q_network = QNetwork(state_size, action_size)
        self.target_network = QNetwork(state_size, action_size)
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        
        # 初始化经验回放
        self.memory = deque(maxlen=buffer_size)
        
        # 同步目标网络参数
        self.update_target_network()

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

    def choose_action(self, state, epsilon):
        if random.random() < epsilon:
            return random.choice(range(self.action_size))
        else:
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            with torch.no_grad():
                q_values = self.q_network(state_tensor)
            return torch.argmax(q_values).item()

    def store_experience(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def replay_experience(self):
        if len(self.memory) < self.batch_size:
            return
        
        # 随机抽取批量经验
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        # 转换为 Tensor
        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)
        
        # 当前 Q 值
        q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        
        # 目标 Q 值
        next_q_values = self.target_network(next_states).max(1)[0]
        target_q_values = rewards + self.gamma * next_q_values * (1 - dones)
        
        # 损失函数
        loss = nn.MSELoss()(q_values, target_q_values.detach())
        
        # 反向传播
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def train(self, env, episodes=1000, epsilon=1.0, epsilon_min=0.1, epsilon_decay=0.995, target_update_freq=10):
        for episode in range(episodes):
            state = env.reset()
            total_reward = 0
            done = False
            while not done:
                action = self.choose_action(state, epsilon)
                next_state, reward, done, _ = env.step(action)
                self.store_experience(state, action, reward, next_state, done)
                
                # 经验回放
                self.replay_experience()
                
                state = next_state
                total_reward += reward
            
            # 动态调整 epsilon
            epsilon = max(epsilon_min, epsilon * epsilon_decay)
            
            # 更新目标网络
            if episode % target_update_freq == 0:
                self.update_target_network()
            
            print(f"Episode {episode}, Total Reward: {total_reward}, Epsilon: {epsilon}")

# 使用 DQNAgent 进行训练(假设 env 已经定义好)
# state_size = env.observation_space.shape[0]
# action_size = env.action_space.n
# agent = DQNAgent(state_size, action_size)
# agent.train(env)

 代码要点

  • Q 网络QNetwork 是一个简单的三层神经网络,用于近似 Q 值。
  • DQN 代理DQNAgent 类包括了 DQN 的核心逻辑,如选择动作、存储经验、经验回放和目标网络更新。
  • 经验回放:从记忆池中随机采样经验,减少样本间的相关性。
  • 目标网络:定期将当前网络的参数复制给目标网络,以减少目标值的变化,提高训练的稳定性。

        这个代码可以用来在各种强化学习环境中实现基本的 DQN 算法。

标签:self,torch,网络,state,深度,action,DQN,size
From: https://blog.csdn.net/qq_56683019/article/details/143349524

相关文章