首页 > 其他分享 >PyTorch--双向长短期记忆网络(BiRNN)在MNIST数据集上的实现与分析

PyTorch--双向长短期记忆网络(BiRNN)在MNIST数据集上的实现与分析

时间:2024-08-17 22:27:56浏览次数:5  
标签:nn 示例 -- torch PyTorch BiRNN num device size

文章目录

前言

本代码实现了一个基于PyTorch的双向长短期记忆网络(BiRNN),用于对MNIST数据集中的手写数字进行分类。MNIST数据集是一个广泛使用的计算机视觉数据集,包含了大量的手写数字图像,适合用来训练和测试深度学习模型。

代码的关键特点包括:

  1. 数据加载与预处理:使用torchvision库加载MNIST数据集,并应用了标准化变换以准备数据输入模型。

  2. BiRNN模型定义:模型使用nn.LSTM模块构建双向LSTM层,能够处理序列数据,并通过nn.Linear层进行最终的分类。

  3. 设备无关性:通过torch.device自动选择GPU或CPU,提高了代码的通用性。

  4. 训练与测试:实现了模型的训练循环和测试循环,包括损失计算、反向传播和参数更新。

  5. 可视化工具:集成了数据可视化和模型架构可视化功能,使用matplotlib库展示数据样本和训练进度。

  6. 模型保存:训练完成后,使用torch.save保存模型参数,方便后续的加载和使用。

  7. 超参数设置:提供了灵活的超参数设置,包括隐藏层大小、层数、批次大小、训练轮数和学习率。

代码结构清晰,易于理解和修改,适合作为深度学习入门和实践的参考。通过本代码,用户可以了解如何使用PyTorch构建和训练一个BiRNN模型,并对MNIST数据集进行分类任务。

说明

  • 确保安装了PyTorch、torchvision和matplotlib。
  • 调整超参数以适应不同的训练需求。
  • 运行代码,观察训练过程和测试结果。
  • 使用可视化工具了解数据和模型架构。

完整代码

import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003

# MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data/',
                                           train=True, 
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='../../data/',
                                          train=False, 
                                          transform=transforms.ToTensor())

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size, 
                                          shuffle=False)

# Bidirectional recurrent neural network (many-to-one)
class BiRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(BiRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size*2, num_classes)  # 2 for bidirection
    
    def forward(self, x):
        # Set initial states
        h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) # 2 for bidirection 
        c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)
        
        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)
        
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out

model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)


# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

# Test the model
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.reshape(-1, sequence_length, input_size).to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) 

# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

代码解析

1.导入库

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

这部分代码导入了编写神经网络所需的PyTorch库及其子模块。

2.设备配置

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

根据是否有可用的GPU,设置计算设备,优先使用GPU以加速训练。

3.超参数设置

sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.003

设置了模型训练所需的超参数,包括时间序列的长度、输入数据的尺寸、隐藏层的尺寸、LSTM层数、类别数、批次大小、训练轮数和学习率。

4.数据集加载

train_dataset = torchvision.datasets.MNIST(..., transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(..., transform=transforms.ToTensor())

加载MNIST数据集的训练集和测试集,并使用transforms.ToTensor()将图像数据转换为张量。

5.数据加载器

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

创建了两个数据加载器,分别用于训练和测试数据的批量加载。

6.定义BiRNN模型

class BiRNN(nn.Module):
    # 定义双向循环神经网络模型

创建了一个双向LSTM的模型,包含初始化方法和前向传播方法。

7.实例化模型并移动到设备

model = BiRNN(input_size, hidden_size, num_layers, num_classes).to(device)

实例化BiRNN模型,并将模型移动到之前设置的计算设备上。

8.损失函数和优化器

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

定义了交叉熵损失函数和Adam优化器。

9.训练模型

for epoch in range(num_epochs):
    # 训练循环

在每个epoch中,遍历训练数据的每个批次,执行前向传播、计算损失、反向传播和参数更新。

10.测试模型

with torch.no_grad():
    # 测试循环

在测试阶段,关闭梯度计算,遍历测试数据的每个批次,计算模型的预测准确率。

11.保存模型

torch.save(model.state_dict(), 'model.ckpt')

保存模型的参数到文件,以便于后续的加载和使用。

这段代码实现了一个完整的训练和测试流程,适合用于分类任务,特别是涉及序列数据的任务。对于MNIST数据集,尽管它不是序列数据,但通过将图像的每一行视为序列的一部分,可以使用RNN进行处理。

常用函数

  1. torch.device

    • 格式:torch.device(device_str)
    • 参数:device_str —— 指定设备类型(如'cuda''cpu')的字符串。
    • 样式:属性访问器。
    • 示例:
      device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      
  2. torchvision.datasets.MNIST

    • 格式:torchvision.datasets.MNIST(root, train, transform, download)
    • 参数:
      • root —— 数据集存放的根目录。
      • train —— 是否加载训练集。
      • transform —— 对图像进行的变换操作。
      • download —— 是否下载数据集。
    • 样式:类方法调用。
    • 示例:
      train_dataset = torchvision.datasets.MNIST(root='../../data/', train=True, transform=transforms.ToTensor(), download=True)
      
  3. torchvision.transforms.Compose

    • 格式:torchvision.transforms.Compose(transforms_list)
    • 参数:transforms_list —— 包含多个变换操作的列表。
    • 样式:类方法调用。
    • 示例:
      transform = transforms.Compose([
          transforms.ToTensor(),
          transforms.Normalize((0.1307,), (0.3081,))
      ])
      
  4. torch.utils.data.DataLoader

    • 格式:torch.utils.data.DataLoader(dataset, batch_size, shuffle)
    • 参数:
      • dataset —— 加载的数据集。
      • batch_size —— 每个批次的样本数。
      • shuffle —— 是否在每个epoch开始时打乱数据。
    • 样式:类方法调用。
    • 示例:
      train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
      
  5. nn.Module

    • 格式:class YourModelClass(nn.Module)
    • 参数:继承自nn.Module的类定义。
    • 样式:类继承。
    • 示例:
      class BiRNN(nn.Module):
          def __init__(self, ...):
              super(BiRNN, self).__init__()
              ...
      
  6. nn.LSTM

    • 格式:nn.LSTM(input_size, hidden_size, num_layers, batch_first, bidirectional)
    • 参数:
      • input_size —— 输入特征的维度。
      • hidden_size —— 隐藏层的维度。
      • num_layers —— LSTM层的数量。
      • batch_first —— 输入和输出张量的第一个维度是否为批次大小。
      • bidirectional —— 是否使用双向LSTM。
    • 样式:类方法调用。
    • 示例:
      self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
      
  7. nn.Linear

    • 格式:nn.Linear(in_features, out_features)
    • 参数:
      • in_features —— 输入特征的数量。
      • out_features —— 输出特征的数量。
    • 样式:类方法调用。
    • 示例:
      self.fc = nn.Linear(hidden_size * 2, num_classes)
      
  8. nn.CrossEntropyLoss

    • 格式:nn.CrossEntropyLoss()
    • 参数:无默认参数。
    • 样式:类方法调用。
    • 示例:
      criterion = nn.CrossEntropyLoss()
      
  9. torch.optim.Adam

    • 格式:torch.optim.Adam(params, lr)
    • 参数:
      • params —— 模型参数。
      • lr —— 学习率。
    • 样式:类方法调用。
    • 示例:
      optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
      
  10. .to(device)

    • 格式:.to(device)
    • 参数:device —— 指定的计算设备。
    • 样式:方法调用。
    • 示例:
      images = images.to(device)
      
  11. .reshape

    • 格式:.reshape(shape)
    • 参数:shape —— 要重塑成的新形状。
    • 样式:方法调用。
    • 示例:
      images = images.reshape(-1, sequence_length, input_size)
      
  12. torch.zeros

    • 格式:torch.zeros(size, device)
    • 参数:
      • size —— 张量的形状。
      • device —— 张量所在的设备。
    • 样式:函数调用。
    • 示例:
      h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(device)
      
  13. torch.max

    • 格式:torch.max(input, dim, keepdim)
    • 参数:
      • input —— 输入张量。
      • dim —— 要计算最大值的维度。
      • keepdim —— 是否保留计算维度。
    • 样式:函数调用。
    • 示例:
      _, predicted = torch.max(outputs.data, 1)
      
  14. torch.no_grad()

    • 格式:torch.no_grad()
    • 参数:无参数。
    • 样式:上下文管理器。
    • 示例:
      with torch.no_grad():
          ...
      
  15. torch.save

    • 格式:torch.save(object, filename)
    • 参数:
      • object —— 要保存的对象。
      • filename —— 文件名。
    • 样式:函数调用。
    • 示例:
      torch.save(model.state_dict(), 'model.ckpt')
      
  16. plt.imshow

    • 格式:plt.imshow(X, cmap)
    • 参数:
      • X —— 要显示的图像数据。
      • cmap —— 颜色映射。
    • 样式:函数调用。
    • 示例:
      plt.imshow(images[j].squeeze().cpu(), cmap='gray')
      
  17. plt.show

    • 格式:plt.show()
    • 参数:无参数。
    • 样式:函数调用。
    • 示例:
      plt.show()
      
  18. plt.figure

    • 格式:plt.figure(figsize)
    • 参数:figsize —— 图形的尺寸。
    • 样式:函数调用。
    • 示例:
      plt.figure(figsize=(20, 4))
      
  19. plt.subplot

    • 格式:plt.subplot(nrows, ncols, index)
    • 参数:
      • nrows —— 子图的行数。
      • ncols —— 子图的列数。
      • index —— 当前子图的索引。
    • 样式:函数调用。
    • 示例:
      plt.subplot(1, num_samples, j+1)
      

这些函数覆盖了从数据预处理、模型构建、训练、测试到结果可视化的整个流程。

标签:nn,示例,--,torch,PyTorch,BiRNN,num,device,size
From: https://blog.csdn.net/wumingzei/article/details/141287787

相关文章

  • rsync备份【基于客户端与服务端】
    一、需求1、客户端客户端提前准备存放到的备份目录,目录规则如下:/backup/nfs_IP+年/月/日客户端在本地打包备份(将etc目录中所有的普通文件打包)拷贝到目标目录/backup/nfs_IP+年/月/日客户端最后将备份的数据进行推送到备份服务器中客户端每天凌晨1点定时执行该脚本客户端服......
  • Redis中Sorted Set数据类型常用命令
    目录1.添加元素2.获取成员3.获取成员的分数4.删除元素5.获取集合的大小6.获取成员的排名7.按分数范围获取成员8.按排名范围获取成员9.增减分数10.删除指定分数范围的成员11.获取分数的范围在Redis中,SortedSet(有序集合)是一种重要的数据类型,它的每......
  • 2024.8 #6
    T1.[AGC060F]SpanningTreesofIntervalGraph我们令\(S=\sumC_{i,j}\)。我们设两个矩阵\(B_{i,j}=[[L_i,R_i]\cap[L_j,R_j]]\)以及\(A_{i,i}=\sumB_{i,j}\)。那么根据矩阵树定理,我们知道生成树的数量就是\(\det(A-B)\)。然而直接高斯消元复杂度是\(O(S^3......
  • Spring DI实现方式
    1.set注入语法:1)set方法      2)set配置:<propertynamevauleref>2、构造注入语法:1)构造方法      2)构造配置:<constructor-argnametypeindexvalueref>3、注解注入(1)@Component用于标识一个类为Spring的组件,这个类会被Spring容器管理。......
  • C#配置文件
    ini文件读取获取执行目录App.config文件读取系统信息ini文件读取ini文件是个啥?.ini文件是InitializationFile的缩写,即初始化文件,是windows的系统配置文件所采用的存储格式,统管windows的各项配置,一般用户就用windows提供的各项图形化管理界面就可实现相......
  • AI+服装电商细分赛道的落地应用:图应AI模特的进化史干货篇
    文章目录AI绘制人物的效果进化史2022年2023年2024年摄影师、设计师、模特三方在AI商拍领域的位置国家统计局的一些服装行业数据遇到的一些问题以及相应的解决方案图应AI这个产品未来可能怎么走统一回答某些投资人的一个问题AI绘制人物的效果进化史2022年还记得我20......
  • ABC 367 题解
    AtCoderBeginnerContest367题解:\(Problem\hspace{2mm}A-Shout\hspace{2mm}Everyday\)题目链接opinion:~~code:#include<bits/stdc++.h>#definelllonglong#definepiipair<int,int>usingnamespacestd;lla,b,c;intmain(){ i......
  • 聚星文社AI工具
    聚星文社AI工具是一款基于人工智能技术的文学创作辅助工具。聚星文社AI工具https://docs.qq.com/doc/DRU1vcUZlanBKR2xy它能够帮助作者生成文字内容、自动校对、提供创作灵感等功能。通过聚星文社AI工具,作者可以更快速地完成文学作品的创作,提高创作效率并且能够得到更好的......
  • 聚星文社AI工具小说推文工具
    聚星文社AI工具小说推文工具是一个帮助作者推广小说作品的工具。它可以自动生成吸引读者注意的推文内容,帮助作者在社交媒体平台上宣传自己的作品。聚星文社AI工具小说推文工具https://docs.qq.com/doc/DRU1vcUZlanBKR2xy使用这个工具,作者可以输入小说的关键信息,例如故事背景......
  • 基于Spring Boot的青年公寓服务平台的设计与实现
    目录一、前言二、技术介绍三、系统实现四、论文参考五、核心代码六、其他案例七、源码获取​​​​​​​作者介绍:✌️大厂全栈码农|毕设实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。✌️作者博客:曾几何时​​​​​​​......