首页 > 其他分享 >torch.nn.MaxPool2d()

torch.nn.MaxPool2d()

时间:2022-12-29 12:33:48浏览次数:40  
标签:窗口 nn 填充 torch MaxPool2d stride size


torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

\(2D\) 最大池化。


参数:

  • kernel_size:最大池化的窗口大小,可以是单个值,也可以是 \(tuple\) 元组。
  • stride:步长,可以是单个值,也可以是 \(tuple\) 元组。
  • padding:填充,可以是单个值,也可以是 \(tuple\) 元组。
  • dilation:控制窗口中元素步幅。
  • return_indices:布尔类型,返回最大值位置索引。
  • ceil_mode:布尔类型,为 \(True\),用向上取整的方法,计算输出形状,默认向下取整。

参数详解:

kernel_size:

这里的kernel_size跟卷积核不是一个东西。kernel_size可以看做是一个滑动窗口,这个窗口的大小由自己指定,如果输入是单个值,例如 \(3\) ,那么窗口的大小就是 \(3 \times 3\) ,还可以输入元组,例如 \((3, 2)\),那么窗口大小就是 \(3 \times 2\) 。

最大池化的方法就是取这个窗口覆盖元素中的最大值。

stride:

这个参数来确定这个窗口如何进行滑动。如果不指定这个参数,那么默认步长跟最大池化窗口大小一致。如果指定了参数,那么将按照我们指定的参数进行滑动。例如 stride=(2,3) , 那么窗口将每次向右滑 \(3\) 个元素位置,或者向下滑动 \(2\) 个元素位置。

padding:

这参数控制如何进行填充,填充值默认为 \(0\)。如果是单个值,例如 \(1\),那么将在周围填充一圈0。还可以用元组指定如何填充,例如padding=(2, 1),表示在上下两个方向个填充两行 \(0\),在左右两个方向各填充一列 \(0\)。

return_indices:

这是个布尔类型值,表示返回值中是否包含最大值位置的索引。注意这个最大值指的是在所有窗口中产生的最大值,如果窗口产生的最大值总共有 \(5\) 个,就会有 \(5\) 个返回值。

ceil_mode:

这个也是布尔类型值,它决定的是在计算输出结果形状的时候,是使用向上取整还是向下取整。


示例:

import torch
from torch import nn

# 定义输入
# 四个参数分别表示 (batch_size, C_in, H_in, W_in)
# 分别对应,批处理大小,输入通道数,图像高度(像素),图像宽度(像素)
# 为了简化表示,我们只模拟单张图片输入,单通道图片,图片大小是4x4
X = torch.arange(16, dtype=torch.float).view((1, 1, 4, 4))
print(X)
tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [12., 13., 14., 15.]]]])
# 仅定义一个 3x3 的池化层窗口
pool2d = nn.MaxPool2d(3)
print(pool2d(X))
tensor([[[[10.]]]])


# 定义一个 3x3 的池化层窗口;
# 周围填充了一圈 0;
# 步长为 2。
pool2d = nn.MaxPool2d(3, padding=1, stride=2)
print(pool2d(X))
tensor([[[[ 5.,  7.],
          [13., 15.]]]])


# 定义一个 2x4 的池化层窗口;
# 上下方向填充两行 0, 左右方向填充一行 0;
# 窗口将每次向右滑动 3 个元素位置,或者向下滑动 2 个元素位置。
pool2d = nn.MaxPool2d((2, 4), padding=(1, 2), stride=(2, 3))
print(pool2d(X))
tensor([[[[ 1.,  3.],
          [ 9., 11.],
          [13., 15.]]]])



标签:窗口,nn,填充,torch,MaxPool2d,stride,size
From: https://www.cnblogs.com/keye/p/17012215.html

相关文章

  • HZNU Winter Trainning 7 补题 - Zeoy
    CodeForces-1660C题目传送门:https://vjudge.net/contest/535955#problem/C题意:询问一个字符串最少删去几个字符,能够把这个字符串变成aabbccdd这种两两相同的字符串题......
  • pytorch:二分类时的loss选择
    PyTorch二分类时BCELoss,CrossEntropyLoss,Sigmoid等的选择和使用这里就总结一下使用PyTorch做二分类时的几种情况:总体上来讲,有三种实现形式:最后分类层降至一维,使用sigmo......
  • .Net 7.0 AOT /usr/bin/ld: cannot find -lz
    命令:sudodotnetpublish-cRelease报错内容:MSBuildversion17.4.0+18d5aef85for.NETDeterminingprojectstorestore...Allprojectsareup-to-dateforre......
  • vue中 WebSocket connection to 'ws://192.168.10.103:8080/ws' failed 问题的解决
    首先吧 vue中WebSocketconnectionto'ws://192.168.10.103:8080/ws'failed这个报错它不会影响你代码的运行,但是报错一定程度上影响页面的美观度。   下面我们......
  • The CBO and Indexes: An Introduction (Absolute Beginners)
    OneofthemorecommonquestionsIgetaskedandoneofthemostcommonquestionsaskedinthevariousOraclerelatedforumsisthegeneralquestionofwhydoes......
  • torch.cat() 与 torch.stack() 的区别
    目录1.torch.cat()2.torch.stack()1.torch.cat()torch.cat(tensors, dim=0)在给定维度中拼接张量序列。参数:tensors:张量序列。dim:拼接张量序列的维度。impo......
  • flannel v0.20.2版内容
    地址https://github.com/flannel-io/flannel/kubectlapply-fhttps://raw.githubusercontent.com/flannel-io/flannel/v0.20.2/Documentation/kube-flannel.ymlkube......
  • 【Allwinner】---全志GPIO号 计算
    全志的GPIO号在sunxi-gpio.h中定义sunxi-gpio.h1二、GPIO号定义#defineSUNXI_PA_BASE0#defineSUNXI_PB_BASE32#defineSUNXI_PC_BASE64#defineSUNXI_PD_BAS......
  • [题解] CF1761E Make It Connected
    CF1761EMakeItConnected题目大意给一张无向图,每次操作表现为选择一个点\(x\),断开所有原来连上的边,连接所有原来断开的边。求最少需要几步使得图联通,并构造方案。思......
  • WinNTSetup V5.3.0 Bata5 单文件版
    前言WinNTSetup是一款Windows系统硬盘安装器,支持从PE和本地安装系统,支持支持NT内核的系统。WinNTSetup包括XP、Win7、Win8、Win8.1、Win10等这些系统。直接从硬盘安装......