class GomokuNet(nn.Module):
def __init__(self, input_dim, action_space):
super(GomokuNet, self).__init__()
# 定义网络层
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.fc4 = nn.Linear(128 * input_dim[0] // 4 * input_dim[1] // 4, 256)
self.policy_head = nn.Linear(256, action_space) # 策略头
self.value_head = nn.Linear(256, 1) # 价值头
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = x.view(x.size(0), -1) # 展平
x = F.relu(self.fc4(x))
policy_logits = self.policy_head(x) # 策略输出
value = self.value_head(x) # 价值输出
return policy_logits, value
这个网络有两个输出头:一个用于预测动作(策略头),另一个用于估计状态的价值(价值头)。
初始化网络结构:
- self.conv1:
- 输入通道数(in_channels): 1
- 输出通道数(out_channels): 32
- 卷积核大小(kernel_size): 3x3
- 步长(stride): 1
- 填充(padding): 1
- 这一层将输入的单通道图像转换为32个特征图。由于填充为1,输入和输出的空间尺寸(高度和宽度)将保持不变。
- self.conv2:
- 输入通道数(in_channels): 32
- 输出通道数(out_channels): 64
- 卷积核大小(kernel_size): 3x3
- 步长(stride): 2
- 填充(padding): 1
- 这一层将32个特征图转换为64个特征图,并通过步长为2的卷积减小了空间尺寸。具体来说,高度和宽度都会减半。
- self.conv3:
- 输入通道数(in_channels): 64
- 输出通道数(out_channels): 128
- 卷积核大小(kernel_size): 3x3
- 步长(stride): 1
- 填充(padding): 1
- 这一层将64个特征图转换为128个特征图。由于填充为1且步长为1,输出的空间尺寸将保持不变(相对于self.conv2的输出)。
- self.fc4:
- 输入特征数(in_features): 这是一个计算值,等于128(来自上一层的输出通道数)乘以输入图像高度和宽度的四分之一(因为self.conv2层使尺寸减半,而之后的层没有改变尺寸)。
- 输出特征数(out_features): 256
- 这是一个全连接层,它将卷积层提取的特征图展平并转换为256个特征,用于后续的策略和价值估计。
- self.policy_head:
- 输入特征数(in_features): 256
- 输出特征数(out_features): action_space(动作空间的大小)
- 这是一个全连接层,用于根据提取的特征预测动作的概率分布或确定性动作。在强化学习中,这通常被称为策略头或行为头。
- self.value_head:
- 输入特征数(in_features): 256
- 输出特征数(out_features): 1
- 这是一个全连接层,用于估计给定状态的价值。在Actor-Critic方法中,这个价值用于计算优势函数或作为基准来稳定学习。
前向传播方法 forward:
- 输入:
x
,表示输入的游戏状态(通常是一个或多个游戏棋盘的图像)。 - 处理流程:
- 通过三个卷积层提取特征,并在每个卷积层后应用ReLU激活函数。
- 将卷积层的输出展平成一个一维向量。这里的展平操作使用
x.view(x.size(0), -1)
实现,其中x.size(0)
保持批处理大小不变,-1
让程序自动计算展平后的特征数。 - 将展平后的特征输入到全连接层
self.fc4
,并应用ReLU激活函数。 - 分别通过策略头和价值头输出策略logits和价值估计。策略logits可以用于计算动作的概率分布,而价值估计可以用于评估当前状态的价值。
- 输出:返回策略logits和价值估计。这两个输出通常用于强化学习中的策略梯度和优势函数Q值计算。
*注意:
- 实际应用中,可以在网络中添加批归一化层(Batch Normalization)来提高性能和稳定性。此外,还可以考虑使用其他类型的激活函数(如Leaky ReLU)或优化网络结构以获得更好的性能。
- 该网络结构适用于处理具有类似五子棋这样的棋盘游戏的任务,但也可以根据具体任务需求进行调整和扩展。比如,可以增加更多的卷积层或全连接层来提取更复杂的特征或处理更大的输入尺寸,也可以考虑使用其他类型的神经网络结构(如循环神经网络RNN或长短期记忆网络LSTM)来处理具有时序依赖性的任务等。