首页 > 其他分享 >像素乒乓球:深度强化学习入门实践

像素乒乓球:深度强化学习入门实践

时间:2024-10-29 13:50:13浏览次数:5  
标签:discounted observation frame epr 乒乓球 np model 像素 入门

深度强化学习(Deep Reinforcement Learning, DRL)是人工智能领域最前沿的研究方向之一,它结合了深度学习和强化学习的优点,能够让智能体在复杂环境中通过试错学习来完成任务。本文将带领读者从零开始,使用Python和NumPy实现一个简单的DRL算法,训练智能体学习玩Atari经典游戏Pong。

强化学习基础

在开始编码之前,我们先简单回顾一下强化学习的基本概念:

  • 智能体(Agent):学习和做决策的主体
  • 环境(Environment):智能体所处的外部世界
  • 状态(State):环境在某一时刻的描述
  • 动作(Action):智能体可以采取的行为
  • 奖励(Reward):环境对智能体行为的反馈
  • 策略(Policy):智能体的行为准则

强化学习的目标是让智能体通过与环境交互,不断调整策略,最大化长期累积奖励。

在这里插入图片描述

环境搭建

我们将使用OpenAI Gym库来创建Pong游戏环境。首先安装必要的依赖:

pip install gym[atari] numpy matplotlib

然后导入所需的模块:

import numpy as np
import gym
from gym import wrappers
from gym.wrappers import Monitor
import matplotlib.pyplot as plt

创建Pong环境:

env = gym.make("Pong-v0")
env = Monitor(env, "./video", force=True)

预处理游戏画面

Pong的原始画面是210x160像素的RGB图像。为了简化问题,我们需要对画面进行预处理:

  1. 裁剪无关区域
  2. 降采样为80x80
  3. 转换为灰度图
  4. 二值化处理

预处理函数如下:

def frame_preprocessing(observation_frame):
    observation_frame = observation_frame[35:195]  # 裁剪
    observation_frame = observation_frame[::2, ::2, 0]  # 降采样
    observation_frame[observation_frame == 144] = 0  # 二值化
    observation_frame[observation_frame == 109] = 0
    observation_frame[observation_frame != 0] = 1
    return observation_frame.astype(float)

构建策略网络

我们的策略网络是一个简单的前馈神经网络,结构如下:

  • 输入层:6400个神经元(80x80像素)
  • 隐藏层:200个神经元
  • 输出层:1个神经元(向上移动的概率)

使用NumPy实现前向传播:

def policy_forward(x, model):
    h = np.dot(model["W1"], x)
    h[h < 0] = 0  # ReLU激活函数
    logit = np.dot(model["W2"], h)
    p = sigmoid(logit)
    return p, h

def sigmoid(x):
    return 1.0 / (1.0 + np.exp(-x))

实现反向传播

为了更新网络参数,我们需要实现反向传播算法:

def policy_backward(eph, epdlogp, model):
    dW2 = np.dot(eph.T, epdlogp).ravel()
    dh = np.outer(epdlogp, model["W2"])
    dh[eph <= 0] = 0
    dW1 = np.dot(dh.T, epx)
    return {"W1": dW1, "W2": dW2}

计算折扣奖励

为了平衡短期和长期收益,我们需要计算折扣奖励:

def discount_rewards(r, gamma):
    discounted_r = np.zeros_like(r)
    running_add = 0
    for t in reversed(range(0, r.size)):
        if r[t] != 0:
            running_add = 0
        running_add = running_add * gamma + r[t]
        discounted_r[t] = running_add
    return discounted_r

训练循环

现在我们可以开始训练智能体了。主要步骤如下:

  1. 初始化环境和策略网络
  2. 循环进行多个episode:
    • 与环境交互,收集状态、动作和奖励
    • 计算梯度
    • 更新网络参数

以下是简化版的训练代码:

max_episodes = 3
batch_size = 3
learning_rate = 1e-4
gamma = 0.99

model = {}
model["W1"] = np.random.randn(H, D) / np.sqrt(D)
model["W2"] = np.random.randn(H) / np.sqrt(H)

grad_buffer = {k: np.zeros_like(v) for k, v in model.items()}
rmsprop_cache = {k: np.zeros_like(v) for k, v in model.items()}

episode_number = 0
while episode_number < max_episodes:
    observation = env.reset()
    prev_x = None
    xs, hs, dlogps, drs = [], [], [], []
    reward_sum = 0
    done = False
    
    while not done:
        cur_x = frame_preprocessing(observation).ravel()
        x = cur_x - prev_x if prev_x is not None else np.zeros(D)
        prev_x = cur_x
        
        aprob, h = policy_forward(x, model)
        action = 2 if np.random.uniform() < aprob else 3
        
        xs.append(x)
        hs.append(h)
        y = 1 if action == 2 else 0
        dlogps.append(y - aprob)
        
        observation, reward, done, info = env.step(action)
        reward_sum += reward
        drs.append(reward)
        
    episode_number += 1
    
    epx = np.vstack(xs)
    eph = np.vstack(hs)
    epdlogp = np.vstack(dlogps)
    epr = np.vstack(drs)
    xs, hs, dlogps, drs = [], [], [], []
    
    discounted_epr = discount_rewards(epr, gamma)
    discounted_epr -= np.mean(discounted_epr)
    discounted_epr /= np.std(discounted_epr)
    
    epdlogp *= discounted_epr
    grad = policy_backward(eph, epdlogp, model)
    for k in model:
        grad_buffer[k] += grad[k]
    
    if episode_number % batch_size == 0:
        for k, v in model.items():
            g = grad_buffer[k]
            rmsprop_cache[k] = decay_rate * rmsprop_cache[k] + (1 - decay_rate) * g**2
            model[k] += learning_rate * g / (np.sqrt(rmsprop_cache[k]) + 1e-5)
            grad_buffer[k] = np.zeros_like(v)
    
    print(f"Episode {episode_number}: Reward = {reward_sum}")

进一步改进

本文实现的是最基础的策略梯度算法,还有很大的改进空间:

  1. 使用更复杂的网络结构,如卷积神经网络
  2. 采用更先进的算法,如PPO(Proximal Policy Optimization)
  3. 引入经验回放机制
  4. 使用深度学习框架(如PyTorch或TensorFlow)加速训练

通过这个简单的例子,我们实现了一个能够从像素输入学习玩Pong的DRL智能体。虽然这个实现还很粗糙,但已经包含了DRL的核心思想:

  1. 将高维输入(游戏画面)映射到低维动作空间
  2. 通过试错和奖励信号来优化策略
  3. 使用深度神经网络作为函数逼近器

深度强化学习是一个快速发展的领域,本文只是抛砖引玉。如果你对此感兴趣,可以进一步学习更高级的算法和应用。记住,实践是最好的学习方式,所以不要犹豫,开始你自己的DRL项目吧!

标签:discounted,observation,frame,epr,乒乓球,np,model,像素,入门
From: https://blog.csdn.net/wuinb/article/details/143318196

相关文章

  • HTML入门教程1:HTML简介
    HTML的基本概念HTML不是一种编程语言,而是一种标记语言。标记语言通过标签来标记和描述内容,而不是像编程语言那样通过指令来控制计算机。HTML文档通常由一系列的标记(标签)组成,每个标签都有特定的含义和功能。HTML的发展历史HTML最初由Web的发明者TimBerners-Lee和同事Dani......
  • 【项目实战】Java中集合Collection 和 Collections入门介绍
    在Java编程语言中,Collection是一个接口,它是集合层次结构中的根接口。Collection接口定义了所有集合类型(如列表、集合和队列)所共有的基本操作方法。而Collections则是一个工具类,它提供了一系列静态方法来操作或返回集合。当你需要存储一组对象并在程序中对其进行操作时,......
  • 【项目实战】网络通信协议Socket和WebSocket入门介绍
    一、Socket1.1文件描述符详解文件描述符是在操作系统层面用来访问文件或I/O资源(如网络套接字)的一个抽象的、非负整数。每个进程在打开一个文件或创建一个套接字时,都会得到一个唯一的文件描述符。在Unix/Linux系统中,标准输入(stdin)、标准输出(stdout)和标准错误(stderr)默认......
  • C++之OpenCV入门到提高002:加载、修改、保存图像
    一、介绍今天是这个系列《C++之Opencv入门到提高》得第二篇文章。今天这个篇文章很简单,只是简单介绍如何使用Opencv加载图像、显示图像、修改图像和保存图像,先给大家一个最直观的感受。但是,不能认为很简单,只是让学习的过程没那么平滑一点,以后的路就好走了。OpenCV具......
  • 如何用3个月零基础入门网络安全?_网络安全零基础怎么学习
    ......
  • 如何用3个月零基础入门网络安全?_网络安全零基础怎么学习
    ......
  • Go入门指南-6.9应用闭包:将函数作为返回值
    在程序function_return.go中我们将会看到函数Add2和Adder均会返回签名为func(bint)int的函数:funcAdd2()(func(bint)int)funcAdder(aint)(func(bint)int)函数Add2不接受任何参数,但函数Adder接受一个int类型的整数作为参数。我们也可以将Adder......
  • Go入门指南- 7.2. 切片
    7.2.1概念切片(slice)是对数组一个连续片段的引用(该数组我们称之为相关数组,通常是匿名的),所以切片是一个引用类型(因此更类似于C/C++中的数组类型,或者Python中的list类型)。这个片段可以是整个数组,或者是由起始和终止索引标识的一些项的子集。需要注意的是,终止索引标识的......
  • Go入门指南-7.6字符串、数组和切片的应用
    7.6.1从字符串生成字节切片假设s是一个字符串(本质上是一个字节数组),那么就可以直接通过c:=[]byte(s)来获取一个字节数组的切片c。另外,您还可以通过copy函数来达到相同的目的:copy(dst[]byte,srcstring)。同样的,还可以使用for-range来获得每个元素(Listing7.1......
  • Go入门指南-8.4map 类型的切片
    假设我们想获取一个map类型的切片,我们必须使用两次make()函数,第一次分配切片,第二次分配切片中每个map元素(参见下面的例子8.4)。示例8.4maps_forrange2.go:packagemainimport"fmt"funcmain(){ //VersionA: items:=make([]map[int]int,5) fori:=ra......