首页 > 其他分享 >pytorch中神经网络的定义方法

pytorch中神经网络的定义方法

时间:2024-12-30 18:55:49浏览次数:3  
标签:__ init Linear nn self torch 神经网络 pytorch 定义方法

1. 继承 torch.nn.Module 类(推荐方法)

最常见和推荐的方式是通过继承 torch.nn.Module 类来创建一个自定义的神经网络模型。在这种方式下,你需要定义 __init__() 方法来初始化网络层,并在 forward() 方法中定义前向传播逻辑。

示例:一个简单的全连接神经网络
import torch
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        # 定义网络层
        self.fc1 = nn.Linear(784, 128)  # 输入层:28x28 图像展平为 784
        self.fc2 = nn.Linear(128, 64)   # 隐藏层
        self.fc3 = nn.Linear(64, 10)    # 输出层:10 类分类

        # 激活函数
        self.relu = nn.ReLU()

    def forward(self, x):
        # 前向传播逻辑
        x = self.relu(self.fc1(x))  # 输入 -> 第一层 -> 激活
        x = self.relu(self.fc2(x))  # 第二层 -> 激活
        x = self.fc3(x)             # 输出层
        return x

# 创建模型实例
model = SimpleNN()
print(model)
解释:
  • __init__():在这个方法中定义了神经网络的层(如 nn.Linear),并且可以定义激活函数(如 nn.ReLU())。
  • forward():定义了数据从输入到输出的传播方式。

这种方式非常灵活,可以用于复杂的网络结构设计。

2. 使用 nn.Sequential(顺序模型)

如果你的网络是一个简单的按顺序排列的层,nn.Sequential 提供了一种更加简洁的方式来定义模型。nn.Sequential 允许你将多个层按顺序进行组合,自动处理前向传播的顺序。

示例:使用 nn.Sequential 定义一个简单的全连接神经网络
import torch
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        # 使用 nn.Sequential 顺序堆叠层
        self.model = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        return self.model(x)

# 创建模型实例
model = SimpleNN()
print(model)
解释:
  • nn.Sequential:这种方式会将层按顺序堆叠在一起,并且自动处理前向传播。
  • 适用于结构简单、每一层都执行相同操作(如全连接层 + 激活函数)的模型。

3. 使用 torch.nn.ModuleListtorch.nn.ModuleDict

如果你的网络包含多个层,但它们的顺序不是简单的顺序堆叠,或者你需要在网络中使用循环和条件语句,nn.ModuleListnn.ModuleDict 提供了更大的灵活性。

  • ModuleList:用于存储层的列表,可以通过索引访问这些层。
  • ModuleDict:用于存储层的字典,可以通过键来访问层。
示例:使用 ModuleList 定义一个多层感知机(MLP)
import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        # 使用 ModuleList 来存储多个全连接层
        self.layers = nn.ModuleList([
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)  # 按顺序执行每一层
        return x

# 创建模型实例
model = MLP()
print(model)
解释:
  • ModuleListModuleList 可以存储多个层,这些层可以通过 for 循环逐一执行。
  • forward() 方法中,我们使用 for 循环按顺序执行每一层。

4. 使用 torch.nn.functional(函数式接口)

torch.nn.functional 包含了很多与神经网络相关的函数,这些函数不需要创建层实例,而是可以在 forward() 方法中直接调用。通过这种方式,你可以避免显式地使用 nn.Module 中的层类,减少代码量。

示例:使用 torch.nn.functional 定义一个简单的网络
import torch
import torch.nn.functional as F
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        # self.relu = nn.ReLU()

    def forward(self, x):
        # 使用 nn.functional 进行激活函数处理而不是在init中定义激活层
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 创建模型实例
model = SimpleNN()
print(model)
解释:
  • torch.nn.functional:在 forward() 中使用 F.relu() 等函数式接口,避免显式地调用 nn.ReLU() 层实例。这种方式适合你只需要用函数对数据进行操作的场景。

5. 自定义层

除了 nn.Modulenn.Sequential,你还可以通过继承 nn.Module 来定义自定义的层。这样你可以封装复杂的操作,形成可复用的模块。

示例:自定义一个激活函数层
import torch
import torch.nn as nn

class MyReLU(nn.Module):
    def __init__(self):
        super(MyReLU, self).__init__()

    def forward(self, x):
        return torch.maximum(x, torch.tensor(0.0))  # 自定义 ReLU 激活

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.relu = MyReLU()  # 使用自定义的激活函数层

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 创建模型实例
model = SimpleNN()
print(model)
解释:
  • 自定义层:你可以继承 nn.Module 来定义自己的层,并在 forward() 方法中定义自定义的前向传播行为。这种方式适用于特殊的操作,如自定义的激活函数、正则化、特殊的损失函数等。

总结

在 PyTorch 中定义神经网络的常见方法有:

  1. 继承 torch.nn.Module:适用于复杂的网络结构,最常用的方式。
  2. 使用 nn.Sequential:适用于结构简单、按顺序堆叠的层。
  3. 使用 ModuleListModuleDict:适用于网络中有循环或更复杂结构的场景。
  4. 使用 torch.nn.functional:在 forward() 方法中直接使用函数式接口来定义前向传播,减少代码量。
  5. 自定义层:封装特定的操作,形成可复用的模块,适用于需要自定义操作的场景。

标签:__,init,Linear,nn,self,torch,神经网络,pytorch,定义方法
From: https://blog.csdn.net/kaiaaaa/article/details/144832237

相关文章

  • pytorch(.pth)模型转化为 torchscript(.pt), 导出为onnx格式
    pytorch(.pth)模型转化为torchscript(.pt),导出为onnx格式1.pth模型转换为.pt模型importtorchimporttorchvisionfrommodelsimportfcnmodel=torchvision.models.vgg16()state_dict=torch.load("./checkpoint-epoch100.pth")#print(state_dict)model.load_state......
  • 上机实验五:BP 神经网络算法实现与测试
    上机实验五:BP神经网络算法实现与测试1、实验目的深入理解BP神经网络的算法原理,能够使用Python语言实现BP神经网络的训练与测试,并且使用五折交叉验证算法进行模型训练与评估。2、实验内容(1)从scikit-learn库中加载iris数据集,使用留出法留出1/3的样本作为测试集(注......
  • 电能质量扰动信号分类,基于Transformer的一维信号分类模型附PyTorch代码
    目录背景研究方法研究内容研究框架代码实现背景在电力系统中,电能质量指的是电压、电流和频率等参数的稳定性和纯净度。然而,由于设备故障、电力负载变化、电力系统故障或其他外部因素,电力系统中可能会出现各种电能质量扰动。这些扰动不仅影响电力系统的稳定运行......
  • 【故障诊断】基于贝叶斯优化卷积神经网络BO-CNN实现故障诊断附matlab代码
    研究背景在智能制造和工业4.0的背景下,设备的可靠性和安全性成为了生产过程中的关键因素。故障诊断作为维护设备正常运行的重要手段,其准确性和效率对于减少停机时间、提高生产效率和保障人员安全具有重要意义。传统的故障诊断方法,如基于规则的方法、统计方法和机器学习算法,......
  • 【故障诊断】【pytorch】基于CNN-LSTM故障分类的轴承故障诊断研究[西储大学数据](Pytho
         ......
  • [论文精读](神经网络加速)Eyerissv2原论文精读(一)整体结构分析与背景介绍
    论文链接:Eyerissv2:AFlexibleAcceleratorforEmergingDeepNeuralNetworksonMobileDevices|IEEEJournals&Magazine|IEEEXplore概述Eyeriss是MIT Yu-HsinChen 团队最早于2016年推出的神经网络加速框架,Eyerissv2是其在2019年推出的改进。相比Eyerissv1,v2......
  • 解锁风电运维新密码:深度学习神经网络助力设备寿命精准预估
    摘要:当下,风电产业蓬勃发展,可恶劣运行环境使设备故障频发,精准预估剩余寿命迫在眉睫。深度学习中的神经网络为此带来曙光,其基础源于对大数据处理需求的回应,借由神经元、层架构自动提取特征。在风电应用里,CNN、LSTM深挖多源异构数据特征,MLP等架构构建预测模型,配合优化算法训......
  • 24-12-28-pytorch深度学习中音频I/O 中遇到的问题汇总
    文章目录pytorch深度学习中音频I/O中遇到的问题汇总问题1:音频文件格式的读取问题问题2:音频文件绘图问题小结pytorch深度学习中音频I/O中遇到的问题汇总问题1:音频文件格式的读取问题参考链接:torchaudio加载wav报错Couldn‘tfindappropriatebackendtohandle......
  • 高级神经网络API——Keras 简介和一般工作流程
    概述Keras是一个高级神经网络API,它用Python语言编写,能够在TensorFlow、Theano或者CNTK等深度学习框架之上运行。它的设计理念是简单、快速地构建和实验深度学习模型。Keras提供了易于使用的接口,使得用户可以专注于模型架构的设计和训练,而不必深入了解底层复杂的计算......
  • 【神经网络训练过程可视化】
    一、直方图可视化数据分布1.知识介绍在PyTorch模型的每一层注册一个forwardhook,从而能够捕获每层的输出简单列表存储形式(只能顺序查看每层输出,下文会有改进版用字典将层名字和层输出值对应)activations=[]defhook_fn(module,input,output):activations.appe......