首页 > 其他分享 >【防忘笔记】一个例子理解Pytorch中一维卷积nn.Conv1d

【防忘笔记】一个例子理解Pytorch中一维卷积nn.Conv1d

时间:2022-08-30 16:26:55浏览次数:88  
标签:... nn Conv1d torch 防忘 channels out

一维卷积层的各项参数如下

torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)

nn.Conv1d输入

输入形状一般应为:(N, Cin, Lin) 或 (Cin, Lin), (N, Cin, Lin)

N = 批量大小,例如 32 或 64;
Cin = 表示通道数;
Lin = 它是信号序列的长度;

nn.Conv1d输出

torch.nn.Conv1d() 的输出形状为:(N, Cout, Lout) 或 (Cout, Lout)

其中,Cout由给Conv1d的参数out_channels决定,即Cout == out_channels

Lout则是使用Lin与padding、stride等参数计算后得到的结果,计算公式如下:

Lout

例子:

import torch

N = 40
C_in = 40
L_in = 100

inputs = torch.rand([N, C_in, L_in])

padding = 3
kernel_size = 3
stride = 2
C_out = 10

x = torch.nn.Conv1d(C_in, C_out, kernel_size, stride=stride, padding=padding)
y = x(inputs)
print(y)
print(y.shape)

运行上述示例后会得到以下结果

tensor([[[-0.0850,  0.3896,  0.7539,  ...,  0.4054,  0.3753,  0.2802],
         [ 0.0181, -0.0184, -0.0605,  ...,  0.0114, -0.0016, -0.0268],
         [-0.0570, -0.4591, -0.3195,  ..., -0.2958, -0.1871,  0.0635],
         ...,
         [ 0.0554,  0.1234, -0.0150,  ...,  0.0763, -0.3085, -0.2996],
         [-0.0516,  0.2781,  0.3457,  ...,  0.2195,  0.1143, -0.0742],
         [ 0.0281, -0.0804, -0.3606,  ..., -0.3509, -0.2694, -0.0084]]],
       grad_fn=<SqueezeBackward1>)
torch.Size([40, 10, 52])

y 是输出,它的形状是: 40* 10* 52

40是batchsize;10是用户设定的Cout(即out_channels),52是经过一维卷积层计算后目前序列的长度(即Lout,也可以理解为某个一维矩阵的形状)

注意:

对于一维卷积,

通道数被视为“输入向量的数量”(in_channels)和“输出特征向量的数量”(out_channels);

Lout是输出特征向量的大小不是数量);

参考:
1、https://stackoverflow.com/questions/60671530/how-can-i-have-a-pytorch-conv1d-work-over-a-vector
2、https://www.tutorialexample.com/understand-torch-nn-conv1d-with-examples-pytorch-tutorial/

标签:...,nn,Conv1d,torch,防忘,channels,out
From: https://www.cnblogs.com/DAYceng/p/16639803.html

相关文章

  • AtCoder Beginner Contest 266 一句话题解
    AandBsbt,不讲。C垃圾计算几何,问是不是一个凸包,搞份板子交就可以了。D简单dp,令\(f(i,j)\)表示第\(i\)个时间在第\(j\)个位置的最大价值,从上一个时间转移,可以......
  • channel与range、select
    channel与range、selectpackagemainimport"fmt"funcmain(){c:=make(chanint)gofunc(){fori:=0;i<5;i++{c<-i......
  • channel
    channel有缓冲与无缓冲同步问题packagemainimport("fmt""time")funcmain(){c:=make(chanint,3)//带有缓冲的channelfmt.Println("len(c......
  • channel定义与使用
    channel定义与使用packagemainimport"fmt"funcmain(){//定义一个channelc:=make(chanint)gofunc(){deferfmt.Println("goroutine结......
  • InnoDB关键特性之double write (转)
    一、脏页刷盘风险原文地址:https://www.cnblogs.com/geaozhang/p/7241744.html关于IO的最小单位:1、数据库IO的最小单位是16K(MySQL默认,oracle是8K)2、文件系统......
  • AtCoder Beginner Contest 266
    比赛链接:https://atcoder.jp/contests/abc266C-ConvexQuadrilateral题意:平面图上有一个四边形,按照逆时针顺序给定四个点的坐标,判断四边形是不是凸的。思路:求两条......
  • AtCoder Beginner Contest 179
    https://atcoder.jp/contests/abc179我的AC代码https://atcoder.jp/contests/abc179/submissions/me?f.Task=&f.LanguageName=&f.Status=AC&f.User=HinanawiTenshi这......
  • AtCoder Beginner Contest 265(D-E)
    D-IrohaandHaiku(NewABCEdition)题意:找一个最少含有三个点的区间,将区间分成三块,三块的和分别为p,q,r,问是否存在这样的区间题解:先预处理一遍前缀和,和每一个前缀......
  • vue报错Error in render: “TypeError: Cannot read property ‘length‘ of undefine
    最近弄安卓开发,uniapp,开发,微信小程序无任何报错,但安卓端,报错,而且,一个错误会再报很多不相干的错误;并不会显示代码具体报错的行数。排查费劲!!![Vuewarn]:Errorinrender:......
  • Ubuntu18.04 开机卡“A start job is running for wait for network to be Configured
    Ubuntu开机卡在这里迟迟无法开机,要等倒计时完以后才会顺利开机。原因可能是系统开机初始化网络配置出错,加上系统默认配置有等待时间,导致系统会一直进行一些无用的尝试,直到......