池化层
池化操作
池化操作是CNN中非常常见的一种操作,池化层是模仿人的视觉系统对数据进行降维,池化操作通常也叫做子采样(Subsampling)或降采样(Downsampling),在构建卷积神经网络时,往往会用在卷积层之后,通过池化来降低卷积层输出的特征维度,有效减少网络参数的同时还可以防止过拟合现象。
池化操作的作用:将一个尺寸较大的图像通过池化操作转换成尺寸较小的图像,但是这个操作过程中尽量保留原始图像的特征。详细了解池化操作
池化操作最常见的是最大池化和平均池化,此外还有随即池化和中值池化,nn.MaxPool1d
、 nn.MaxPool2d
、nn.MaxPool3d
最大池化也被称为下采样;nn.MaxUnPool1d
、nn.MaxUnPool2d
、nn.MaxUnPool3d
为上采样,最常用的是nn.MaxPool2d
。
import torch
input= torch.tensor([[1, 2, 0, 3, 1],
[0, 1, 2, 3, 1],
[1, 2, 1, 0, 0],
[5, 2, 3, 1, 1],
[2, 1, 0, 1, 1]])
input = torch.reshape(input, (-1, 1, 5, 5))
print(input.shape)
池化层在神经网络中的使用
以 nn.MaxPool2d
为例
torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)
ceil_mode
默认为False,当为True时使用ceil模式,表示当池化核覆盖的输入图像超出边缘时保留池化操作;False时使用floor模式,表示不保留。(ceil表示向上取整,floor表示向下取整)
输入和输出的尺寸计算
代码实现
# 自定义输入
import torch
from torch import nn
from torch.nn import MaxPool2d
# 最大池化对long数据型不能实现,input中以为都是整数,所以需要进行类型转换,转换为浮点数,1->1.0
input= torch.tensor([[1, 2, 0, 3, 1],
[0, 1, 2, 3, 1],
[1, 2, 1, 0, 0],
[5, 2, 3, 1, 1],
[2, 1, 0, 1, 1]], dtype=torch.float32)
input = torch.reshape(input, (-1, 1, 5, 5))
print(input.shape)
class Basempool(nn.Module):
def __init__(self):
super(Basempool, self).__init__()
self.maxpool1 = MaxPool2d(kernel_size=3, ceil_mode=False)
def forward(self, input):
output = self.maxpool1(input)
return output
basempool = Basempool()
output = basempool(input)
print(output)
# 使用CIFAR 10数据集作为输入
import torch
import torchvision.datasets
from torch import nn
from torch.nn import MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
dataset = torchvision.datasets.CIFAR10("./dataset2", train=False, download=True, transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64)
class Basempool(nn.Module):
def __init__(self):
super(Basempool, self).__init__()
self.maxpool1 = MaxPool2d(kernel_size=3, ceil_mode=False)
def forward(self, input):
output = self.maxpool1(input)
return output
basempool = Basempool()
writer = SummaryWriter("logs")
step = 0
for data in dataloader:
imgs, targets = data
writer.add_images("maxpool_input", imgs, step)
output = basempool(imgs)
writer.add_images("maxpool_output", output, step)
step = step + 1
writer.close()
标签:池化层,nn,self,torch,池化,import,input
From: https://www.cnblogs.com/yq-ydky/p/17632393.html