首页 > 其他分享 >深度学习(六)——神经网络的基本骨架:nn.Module的使用

深度学习(六)——神经网络的基本骨架:nn.Module的使用

时间:2023-07-14 23:00:18浏览次数:75  
标签:__ nn self torch init Module 神经网络

一、torch.nn简介

官网地址:

torch.nn — PyTorch 2.0 documentation

1. torch.nn中的函数简介

  • Containers:神经网络的骨架

  • Convolution Layers:卷积层

  • Pooling layers:池化层

  • Padding Layers:Padding

  • Non-linear Activations:非线性激活

  • Normalization Layers:正则化层

还有其他函数,详情可以看官方文档。以上这些函数构成了神经网络的基本操作。

2. torch.nn中Containers函数的介绍

Containers一共有六个模块:

  • Module:对于所有神经网络提供一个基本的骨架,一般定义一个神经网络用如下代码。其中,Model代表模型的名称,nn.Module就是继承了这个类的模板。然后我们先用__init__初始化,其中super(Model,self).__init__()指的是对父类进行初始化,后面的部分是根据自己构建的神经网络个性化定制的。之后我们使用forword函数对输入数据进行计算,也可以这么理解:对于一个神经网络,首先输入数据-->使用forword函数计算数据-->输出数据,这个过程也叫前向传播
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))
  • Sequential

  • ModuleList

  • ModuleDict

  • ParameterList

  • ParameterDict

二、实操nn.Module

1. 构建一个简单的神经网络

  • 一些小技巧:在写__init__super函数时,pycharm点击下面这个按钮就可以自动补全:

  • 下面构建一个很简单的神经网络,具体作用就是把输入数据+1然后返回,之后调用这个神经网络:

from torch import nn
import torch

#构建一个叫Demo的神经网络
class Demo(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self,input):
        output=input+1  #对输入神经网络的数据+1,然后返回
        return output

#调用神经网络
demo=Demo()
x=torch.tensor(1.0)  #输入神经网络的数据
output=demo(x)
print(output)  #输出神经网络的数据

[Run] tensor(2.)

2. 神经网络运行过程

为了更好地说明上面代码的运行过程,把debug打到第14行的demo=Demo()代码上,并点击Step into My Code

之后一直点击Step into My Code,就可以看到代码的运行过程如下:

  • 在调用demo=Demo()后,首先使用super().__init__()对\(nn.Module\)进行初始化

  • 然后设定输入值x,并使用demo(x)将该值传入到forword函数中

  • forword函数将该值进行加一,并返回output

  • 最后将返回的output输出

标签:__,nn,self,torch,init,Module,神经网络
From: https://www.cnblogs.com/zoubilin/p/17555199.html

相关文章

  • vue3项目 运行 报错 Cannot assign to "b" because it is a constant
    环境依赖node18.16.0vite4.4.4vue 3.2.47背景当前错误与环境依赖关系不大,是由于我在打包的文件写的代码错误导致的,一般情况不会有这个错报错信息X[ERROR]Cannotassignto"b"becauseitisaconstantThesymbol"b"wasdeclaredaconstanthere:原因将r......
  • AtCoder Beginner Contest 294
    A-Filter#include<bits/stdc++.h>usingnamespacestd;#defineintlonglongint32_tmain(){ios::sync_with_stdio(false),cin.tie(nullptr),cout.tie(nullptr);intn;cin>>n;for(intx;n;n--){cin>&......
  • pytorch+CRNN实现
    最近接触了一个仪表盘识别的项目,简单调研以后发现可以用CRNN来做。但是手边缺少仪表盘数据集,就先用ICDAR2013试了一下。 结果遇到了一系列坑。为了不使读者和自己在以后的日子继续遭罪。我把正确的代码发到下面了。超参数请不要调整!!!!CRNN前期训练极其慢,需要良好的调参,loss才会......
  • compattelrunner.exe 进程会定期运行,扫描系统以收集应用程序、硬件和设备的兼容性数据
    compattelrunner.exe是Windows容错报告工具(WindowsCompatibilityTelemetry)Windows容错报告工具是Microsoft开发的一项功能,旨在帮助改进Windows的稳定性和兼容性。而compattelrunner.exe是容错报告工具的一个组成部分,它负责收集系统的兼容性数据以及硬件和驱动程序信......
  • Qt信号槽信号函数重载问题 error: C2664: “QMetaObject::Connection const”
    //connect(spinFontSize,&QSpinBox::valueChanged,this,&MainWindow::spinFontSize_valueChanged);//由于信号函数存在重载,发送者找不到正确信号函数。//改用A.Qt4带形参方式//connect(spinFontSize,SIGNAL(valueChanged(int)),this,SLOT(spinFontSize_valueChang......
  • vim E447: cannot find file iostream in path
    查看c/c++文件中的头文件,可以使用gf跳转,但是有时会出现Error447:notfoundinpath1,命名模式中输入,临时修改:setpath=.,/usr/include,,/usr/include/c++/*/2,修改vimrc增加setpath+=.,/usr/include,,/usr/include/c++/*/......
  • AtCoder Beginner Contest 309 - D(最短路)
    目录D-AddOneEdge法一:dijkstra法二:BFS+队列题目传送门:abc309前面的简单题就不放了D-AddOneEdge题意:给你一个无向图图,分为两个连通块,一个顶点数为n1(1~n1),一个顶点数为n2(n1+1~n1+n2),图中共有m条边。如果现在在两个连通块之间连接一条边,那么顶点1与顶点n1+n2......
  • is greater than this module's compileSdkVersion (android-32). Dependency: an
    实现"isgreaterthanthismodule'scompileSdkVersion(android-32)"的步骤为了解决这个问题,我们需要按照以下步骤进行操作:步骤操作1确认项目的compileSdkVersion2更新项目的compileSdkVersion3更新相关依赖库的版本下面是每一步具体需要做的操作:步骤1......
  • 【深入浅出】你必须知道的 InnoDB 锁(二)
    ......
  • YOLOX目标检测实战:LabVIEW+YOLOX ONNX模型实现推理检测(含源码)
    (文章目录)前言好长一段时间没更博了,没更新博客的这段时间博主都有在努力产出,前段时间好多朋友私信问我说自己的yolov5模型是比较老的版本,使用LabVIEW推理的时候会报错。为各位朋友新老版本都能兼容,博主这段时间做了一个LabVIEWYOLOv5的插件,里面包含了大部分的新旧版本,老版本的......