首页 > 其他分享 >一个用于强化学习的卷积神经网络基础结构示例

一个用于强化学习的卷积神经网络基础结构示例

时间:2024-03-18 18:33:31浏览次数:24  
标签:输出 示例 卷积 self 特征 神经网络 输入 size

class GomokuNet(nn.Module):  
    def __init__(self, input_dim, action_space):  
        super(GomokuNet, self).__init__()  
        # 定义网络层
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)  
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)  
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)  
        self.fc4 = nn.Linear(128 * input_dim[0] // 4 * input_dim[1] // 4, 256)  
        self.policy_head = nn.Linear(256, action_space)  # 策略头  
        self.value_head = nn.Linear(256, 1)  # 价值头  
  
    def forward(self, x):  
        x = F.relu(self.conv1(x))  
        x = F.relu(self.conv2(x))  
        x = F.relu(self.conv3(x))  
        x = x.view(x.size(0), -1)  # 展平
        x = F.relu(self.fc4(x))  
        policy_logits = self.policy_head(x)  # 策略输出  
        value = self.value_head(x)  # 价值输出  
        return policy_logits, value

这个网络有两个输出头:一个用于预测动作(策略头),另一个用于估计状态的价值(价值头)。

初始化网络结构:

  • self.conv1:
    • 输入通道数(in_channels): 1
    • 输出通道数(out_channels): 32
    • 卷积核大小(kernel_size): 3x3
    • 步长(stride): 1
    • 填充(padding): 1
    • 这一层将输入的单通道图像转换为32个特征图。由于填充为1,输入和输出的空间尺寸(高度和宽度)将保持不变。
  • self.conv2:
    • 输入通道数(in_channels): 32
    • 输出通道数(out_channels): 64
    • 卷积核大小(kernel_size): 3x3
    • 步长(stride): 2
    • 填充(padding): 1
    • 这一层将32个特征图转换为64个特征图,并通过步长为2的卷积减小了空间尺寸。具体来说,高度和宽度都会减半。
  • self.conv3:
    • 输入通道数(in_channels): 64
    • 输出通道数(out_channels): 128
    • 卷积核大小(kernel_size): 3x3
    • 步长(stride): 1
    • 填充(padding): 1
    • 这一层将64个特征图转换为128个特征图。由于填充为1且步长为1,输出的空间尺寸将保持不变(相对于self.conv2的输出)。
  • self.fc4:
    • 输入特征数(in_features): 这是一个计算值,等于128(来自上一层的输出通道数)乘以输入图像高度和宽度的四分之一(因为self.conv2层使尺寸减半,而之后的层没有改变尺寸)。
    • 输出特征数(out_features): 256
    • 这是一个全连接层,它将卷积层提取的特征图展平并转换为256个特征,用于后续的策略和价值估计。
  • self.policy_head:
    • 输入特征数(in_features): 256
    • 输出特征数(out_features): action_space(动作空间的大小)
    • 这是一个全连接层,用于根据提取的特征预测动作的概率分布或确定性动作。在强化学习中,这通常被称为策略头或行为头。
  • self.value_head:
    • 输入特征数(in_features): 256
    • 输出特征数(out_features): 1
    • 这是一个全连接层,用于估计给定状态的价值。在Actor-Critic方法中,这个价值用于计算优势函数或作为基准来稳定学习。

前向传播方法 forward:

  • 输入x,表示输入的游戏状态(通常是一个或多个游戏棋盘的图像)。
  • 处理流程
    • 通过三个卷积层提取特征,并在每个卷积层后应用ReLU激活函数。
    • 将卷积层的输出展平成一个一维向量。这里的展平操作使用x.view(x.size(0), -1)实现,其中x.size(0)保持批处理大小不变,-1让程序自动计算展平后的特征数。
    • 将展平后的特征输入到全连接层self.fc4,并应用ReLU激活函数。
    • 分别通过策略头和价值头输出策略logits和价值估计。策略logits可以用于计算动作的概率分布,而价值估计可以用于评估当前状态的价值。
  • 输出:返回策略logits和价值估计。这两个输出通常用于强化学习中的策略梯度和优势函数Q值计算。

*注意:

  • 实际应用中,可以在网络中添加批归一化层(Batch Normalization)来提高性能和稳定性。此外,还可以考虑使用其他类型的激活函数(如Leaky ReLU)或优化网络结构以获得更好的性能。
  • 该网络结构适用于处理具有类似五子棋这样的棋盘游戏的任务,但也可以根据具体任务需求进行调整和扩展。比如,可以增加更多的卷积层或全连接层来提取更复杂的特征或处理更大的输入尺寸,也可以考虑使用其他类型的神经网络结构(如循环神经网络RNN或长短期记忆网络LSTM)来处理具有时序依赖性的任务等。

标签:输出,示例,卷积,self,特征,神经网络,输入,size
From: https://blog.csdn.net/YHKKun/article/details/136811046

相关文章

  • 卷积神经网络的池化层学习
    池化层简而言之就是做压缩的,最大池化进行筛选最大池化选择的都是该区域内最大的值,因为参数越大,代表特征越明显,越重要,基本只要最好的特征,因此使用maxpooling较多。且池化层中不涉及到任何矩阵的运算。只是一个筛选压缩。过滤的一个东西。......
  • python @pytest.fixture示例及用法
    python@pytest.fixture示例及用法@pytest.fixture是pytest测试框架中的一个非常有用的功能,它允许你定义可以在多个测试用例之间共享的设置和清理代码。通过使用fixture,你可以减少重复的代码,并使得测试用例更加清晰和模块化。下面是一个简单的示例,展示了如何使用@pytest.fi......
  • 【即插即用】RefConv-重聚焦卷积模块(附源码)
    论文地址: http://arxiv.org/pdf/2310.10563.pdf源码地址:GitHub-Aiolus-X/RefConv概述:作者提出了一种可重参数化的重新聚焦卷积(RefConv),作为常规卷积层的即插即用替代品,能够在不引入额外推理成本的情况下显著提高基于CNN的模型性能。RefConv利用预训练参数编码的表示作为先......
  • go语言请求http接口示例 并解析json
    本例请求了天气api接口对接流程注册一个账号,对接免费实况天气接口阅读接口文档http://tianqiapi.com/index/doc?version=day请求接口解析json开发流程创建一个json.go文件需要引入的包import( "encoding/json" "fmt" "io/ioutil" "net/http")定义Wea......
  • 递归示例-展开编号(Excel函数集团)
    展开编号=DROP(fx(COUNTA(B:B)-1),1)fx=LAMBDA(x,IF(x>0,VSTACK(fx(x-1),SEQUENCE(INDEX(Sheet4!$B:$B,x+1),,INDEX(Sheet4!$C:$C,x+1)))))使用Lambda定义x当x小于等0时,返回False,以此作为开关;当x为1时,返回False连接SEQUENCE(INDEX(Sheet4!$B:$B,2),,INDEX(Sheet4!$C:......
  • Tensorflow笔记(一):常用函数、张量操作、神经网络模型实现(鸢尾花分类)
    importpandasaspdimporttensorflowastfimportnumpyasnp#-----------------------------tensor张量-----------------------------------#创建张量a=tf.constant([1,5],dtype=tf.int64)print(a)#>tf.Tensor([15],shape=(2,),dtype=int64)#结果......
  • 【Qt】使用Qt实现Web服务器(二):QtWebApp示例源码
    1、最简使用介绍Demo2演示了最简单的用法,输入url后返回“HelloWorld!”;下面详解示例代码,先看主函数1.1主函数#a)QtWebApp库中定义的名字空间stefanfringsusingnamespacestefanfrings;intmain(intargc,char*argv[]){......
  • Java常用修饰符及示例
    Java修饰符是用来改变类、方法、变量、接口等元素的行为和可见性的关键字。Java修饰符主要分为两大类:访问修饰符和非访问修饰符。访问修饰符(AccessModifiers):public:提供最大的访问权限,任何类(无论是同一包内的还是不同包的)都可以访问到public修饰的类、方法和变量。示例......
  • C++示例:学习C++标准库,std::unordered_map无序关联容器的使用
    01std::unordered_map介绍std::unordered_map是C++标准库中的一种无序关联容器模板类,它提供了一种将键映射到值的方法。它的底层基于哈希表实现,内容是无序的,可以在平均情况下在O(1)的时间复杂度内完成插入、查找和删除操作。值得注意的是,哈希表可能存在冲突,即不同的键值......
  • 深度学习入门基于python的理论与实现-第四章神经网络的学习(个人向笔记)
    目录从数据中学习损失函数均方误差(MSE)交叉熵误差mini_batch学习mini_batch版交叉熵误差的实现从数据中学习神经网络的"学习"的学习是指从训练数据自动获取最有权重参数的过程。神经网络的特征就是可以从数据中学习即由数据自动决定权重参数的值。机器学习通常是认为确定一些......