首页 > 编程语言 >强化学习—PPO代码实现及个人详解1(python)

强化学习—PPO代码实现及个人详解1(python)

时间:2024-04-07 09:30:06浏览次数:33  
标签:__ dim 函数 nn python self PPO 详解 hidden

上一篇文章我们已经搞定了如何搭建一个可以运行强化学习的python环境,现在我们就跑一下代码,这里我对代码加上一些个人理解,方便基础差一些的朋友进行理解和学习。

我在这段时间对强化学习进行了学习,所以知识和代码基本来自这本:磨菇书

一、定义模型

import torch.nn as nn
import torch.nn.functional as F


class ActorSoftmax(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=256):
        super(ActorSoftmax, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        probs = F.softmax(self.fc3(x), dim=1)
        return probs


class Critic(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=256):
        super(Critic, self).__init__()
        assert output_dim == 1  # critic must output a single value
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        value = self.fc3(x)
        return value

先简单解释下神经网络模型,一般是分为三层,输入层、隐藏层以及输出层,如果系统很复杂,也可以添加隐藏层,但是一般情况下,2、3层就够用了。如果想深入了解神经网络,可以参考这篇文章:神经网络——最易懂最清晰的一篇文章

这一段代码,是定义了两个神经网络模型,一个是演员,另一个是评论家。在 ActorSoftmax 类中,通过输入状态 x 经过神经网络计算得到各个动作的概率分布,而在 Critic 类中,通过输入状态 x 经过神经网络计算得到该状态的价值(即给定状态下的预期累积奖励)。

先说演员,首先我们要初始化,但是调用初始化是需要super的,所以就有了那段函数  super(ActorSoftmax, self).__init__()。之后进入初始化,使用linear,创造线性层,可进行全连接操作(即对输入特征进行加权求和并加上偏置,得到隐藏层的输出结果。这个隐藏层的输出结果会成为下一层的输入,进而参与后续的非线性变换、特征提取等操作。)可以看出,一共三层,第一层是输入层到隐藏层,第二层是隐藏层到隐藏层,第三层是隐藏层到输出层。之后用forword定义前向传播过程。为了让网络学习更加复杂的特征和模式,需要用激活函数relu(也就是在神经网络模型的隐藏层中引入非线性,提高表达能力,一般作用在隐藏层),想要用这个函数,就需要torch.nn.functional,代码里已经命名为F了,而Softmax激活函数应用于输出层,它能够将神经网络输出的原始分数转换为概率分布(其实就是归一化处理)。

(关于激活函数,可以看这篇文章进行学习了解:常用的激活函数合集(详细版)

简单来说,就是演员这个大类里面有两个函数,init初始化了全连接层,而forward定义了前向传播,最终通过 softmax 函数得到输出,表示各动作的概率分布。

接下来说评论家,初始化和前向传播是一样的,就两点不同,一是初始化时输出维度只能是1,因为最终返还的是状态的价值(一个单一的数值)。

注意:

(1)隐藏层维度是一个需要指定的超参数,而输入维度和输出维度是根据具体问题来确定的,因此没有在模型内部直接赋值,而是作为参数在初始化模型实例时进行传入。(也就是说输入维度和输出维度在后面参数定义时会有)

(2)在归一化函数softmax中,如果没有dim=1,则默认对最后一个维度进行归一化处理,有的话,就对第二个维度进行归一化处理(每一行为一个样本,每一列为一个类别)。

(3)self 是一个标识符,指向类的实例对象本身。

(4)ActorSoftmax 类定义了两个函数:__init__ 函数用于初始化模型的结构,forward 函数用于定义前向传播过程。在类被实例化之前,这些函数是不会自动运行的。

(5)assert函数是当作断言的,即当assert后面的条件为False时,会触发异常,从而中断程序的执行。适当在代码中加入些断言,可以检查程序的逻辑性和正确性。

除了定义模型,后续还会分析定义经验回放、定义智能体、定义训练、定义环境、设置参数、开始训练这几部分,如果大家想先跑通代码看看结果,可以直接去我最上面的蘑菇书链接里找,之后我还会接着对代码进行分析。(以上分析均是个人拙见,可能有问题,欢迎讨论修改)

标签:__,dim,函数,nn,python,self,PPO,详解,hidden
From: https://blog.csdn.net/weixin_70267340/article/details/137332848

相关文章

  • 使用Python的turtle模块绘制美丽的樱花树
    引言Python的turtle模块是一个直观的图形化编程工具,让用户通过控制海龟在屏幕上的移动来绘制各种形状和图案。turtle模块的独特之处在于其简洁易懂的操作方式以及与用户的互动性。用户可以轻松地通过使用诸如前进、后退、左转、右转等基本命令,来编写程序控制海龟的行动路径,从而创......
  • Oracle之DBMS_LOCK包用法详解
    概述与背景某些并发程序,在高并发的情况下,必须控制好并发请求的运行时间和次序,来保证处理数据的正确性和完整性。对于并发请求的并发控制,EBS系统可以通过ConcurrentProgram定义界面的Incompatibilities功能配置实现。但是Incompatibilities功能存在其局限性,它只能把整个并发请求......
  • python 浅拷贝与深拷贝
    copy Python的赋值语句不复制对象,而是创建目标和对象的绑定关系。对于自身可变,或包含可变项的集合,有时要生成副本用于改变操作,而不必改变原始对象。浅拷贝(ShallowCopy)和深拷贝(DeepCopy)是在Python中用于复制数据结构(如列表)时经常用到的概念。浅拷贝(ShallowCopy)浅复制创建......
  • python_列表推导式_矩阵运算
    带条件的列表推导式even_number=[iforiinrange(10)ifi%2==0]even_number#output[0,2,4,6,8][0,2,4,6,8]列表推导式的嵌套matrix=[[i*jforiinrange(1,4)]forjinrange(1,4)]matrix#output=[[1,2,3],[2,4,6],[3,6,9]][[1,......
  • HART报文详解
    1.简介HART(HighwayAddressableRemoteTransducer可寻址远程传感器高速通道)协议,主要用于工业自动化领域的通信协议,专为发送和接收数字信息而设计,同时也支持模拟信号(如4-20mA信号)的传输。这种设计使得HART设备能够同时传输模拟信号和数字数据,从而提供了更加灵活和强大的通信能力......
  • dd if=devzero of=的含义是什么?Linux 下的dd命令使用详解
    ddif=/dev/zeroof=的含义是什么?Linux下的dd命令使用详解一、dd命令的解释dd:用指定大小的块拷贝一个文件,并在拷贝的同时进行指定的转换。注意:指定数字的地方若以下列字符结尾,则乘以相应的数字:b=512;c=1;k=1024;w=2参数注释:1.if=文件名:输入文件名,缺省为标准输入。即指定源文......
  • Python SciPy库
    SciPy库为Python提供了科学计算的基本算法基本操作求解非线性方程(组)scipy.optimize模块的fsolve和root可求非线性方程(组)的解fsolve或root求解非线性方程组时,先把非线性方程组写成F(x)=0这样的形式,其中,x为向量,F(x)为向量函数scipy.optimize.fsolve(func, x0, args......
  • Docker学习笔记(三)Dockerfile指令详解
    文章目录FROM指定基础镜像RUN执行命令COPY复制文件ADD高级文件复制CMD容器启动命令ENTRYPOINT入口点ENV设置环境变量ARG构建参数VOLUME定义匿名卷EXPOSE声明端口WORKDIR指定工作目录USER指定当前用户HEALTHCHECK健康检查ONBUILD构建触发器LABEL添加元数据......
  • 环境配置——已解决ModuleNotFoundError: No module named ‘cv2’(python)
    一、报错代码在网上搜到不少用Python处理图形的代码,于是复制别人的代码直接运行却报错,得到的结果却是:已解决ModuleNotFoundError:Nomodulenamed‘cv2’。(当时心里瞬间凉了一大截,最后顺利解决了,顺便记录一下希望可以帮助到更多遇到这个bug不会解决的小伙伴),代码如下:impor......
  • 环境配置——python代码打包超详细教程
    在Python开发的过程中我们经常会需要将自己的代码打包成一个可执行文件,方便将代码分享给其他人使用,下面这篇文章主要给大家介绍了关于python代码打包的相关资料,需要的朋友可以参考下一、前言网上的文章对小白都不太友好呀,讲得都比较高大上,本文章就用最简单的方式来教会......