首页 > 其他分享 >PyTorch实战深度学习——用CNN进行手写数字识别

PyTorch实战深度学习——用CNN进行手写数字识别

时间:2024-11-11 09:43:37浏览次数:3  
标签:torchvision 模型 卷积 self torch PyTorch CNN 手写 model

用CNN进行手写数字识别---计算机专业研究生的代码第一课,相当于”Hello World“,不管以后选择什么研究方向,都值得一看,欢迎大家留言交流学习!

下面手把手教大家一步一步实现该任务:

1. 环境准备

首先呢,您需要确保安装了PyTorch库。如果还没有安装,可以使用以下命令进行安装,这里默认您已经有Anaconda并创建好虚拟环境啦,如果还没有安装,可以参考其他更完整的安装pytorch的教程:

pip install torch torchvision

怎样判断是否安装成功呢,给大家几个方法,嘿嘿~

方法一:使用 pip list 检查安装列表

在终端或命令行中输入以下命令,查看安装的包列表:

pip list | grep torch

这会列出包含 torchtorchvision 的包及其版本号。如果看到了 torchtorchvision,则说明它们已安装。

方法二:使用 Python 代码检查

在 Python 环境中,尝试导入 torchtorchvision 包。如果没有报错,说明安装成功。

import torch
import torchvision

print("Torch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)

2. 导入库和加载数据

安装成功后就可以开始下一步啦~

PyTorch中提供了大量经典数据集,其中就包括MNIST,如下图所示。我们可以直接通过torchvision库加载MNIST数据集,并进行预处理。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # 对图像进行标准化处理
])

# 加载训练集和测试集
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

代码通过PyTorch加载并预处理MNIST手写数字数据集,为模型训练和评估做好准备。首先,它导入必要的库并定义数据预处理操作,将图像转换为张量并标准化到[-1, 1]的范围。接着,它使用torchvision.datasets.MNIST加载数据集,并将训练集和测试集分别存储为train_settest_set,同时设定存储路径并下载数据集(如有需要)。最后,DataLoader以批量的形式加载训练和测试数据,指定每批次包含64张图像,训练数据会被打乱以提高模型泛化性,而测试数据则按顺序加载。这些操作将数据集组织成适合神经网络的输入格式,为接下来的模型训练和评估提供便利。 


3. 定义CNN模型

我们来定义一个简单的卷积神经网络,包含卷积层、池化层和全连接层。这个模型将用于手写数字的分类任务。

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)  # 展平操作
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = CNN()

此代码定义了一个卷积神经网络(CNN)类,用于图像分类任务,主要包含卷积层、池化层和全连接层。初始化方法__init__中定义了网络的结构:首先是两个卷积层conv1conv2,分别将输入图像的通道数从1增加到32、再到64,并各自跟随一个MaxPool2d池化层,以减少特征图的尺寸。卷积层使用kernel_size=3的3x3卷积核,padding=1确保卷积后图像大小不变。接着是两层全连接层fc1fc2,其中fc1将展平的特征图(尺寸为64*7*7)转为128维,fc2将128维特征向量映射到10个输出类别(用于10分类任务)。在forward方法中,输入图像依次经过卷积、池化和非线性激活函数ReLU,并在池化层后将特征展平成向量,再通过全连接层得到最终输出。 


4. 定义损失函数和优化器

为了训练模型,我们需要定义一个损失函数和一个优化器。这里我们使用交叉熵损失函数和Adam优化器。

criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam优化器

5. 训练模型

接下来,我们编写训练代码。在每一轮(epoch=5)中,模型会对训练集进行正向传播、计算损失、反向传播和参数更新。

num_epochs = 5
for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in train_loader:
        # 清除上一次梯度
        optimizer.zero_grad()
        
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # 反向传播与优化
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
print("训练完成")

6. 模型评估

我们来在测试集上评估模型的性能,计算测试集的准确率。

correct = 0
total = 0
with torch.no_grad():  # 评估模式,不需要计算梯度
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'测试集准确率: {accuracy:.2f}%')

7. 保存模型

为了方便后续的部署或复用,可以将训练好的模型保存下来:

torch.save(model.state_dict(), "mnist_cnn_model.pth")
print("模型已保存")

8. 加载模型

在需要使用模型时,可以加载保存的权重,并重新实例化模型:

model = CNN()
model.load_state_dict(torch.load("mnist_cnn_model.pth"))
model.eval()  # 切换到评估模式

 


总结

本文中,我们使用PyTorch搭建了一个简单的卷积神经网络,并在MNIST手写数字识别任务上进行了训练。通过这种方式,你可以快速上手PyTorch,掌握深度学习模型的基本构建与训练流程。

 

标签:torchvision,模型,卷积,self,torch,PyTorch,CNN,手写,model
From: https://blog.csdn.net/xyaixy/article/details/143660074

相关文章

  • 深度学习(三)2.利用pytorch实现线性回归
    一、基础概念1.线性层线性层(LinearLayer)是神经网络中的一种基本层,也称为全连接层(FullyConnectedLayer)。它的工作方式类似于简单的线性方程:y=Wx+b,其中W是权重矩阵,x是输入,b是偏置项,y是输出。线性层的主要任务是将输入的数据通过权重和偏置进行线性变换,从而生成输出......
  • call(),bind(),apply(),的区别和手写
    1.call(),bind(),apply()的区别call(),bind(),和apply()是JavaScript中用于改变函数执行上下文(即this的指向)的方法,它们之间有一些区别:call():call()方法允许你调用一个具有指定this值的函数,并且允许你传递一个参数列表。它的语法是function.call(thisArg,ar......
  • Bayes-CNN-BiGRU-Att贝叶斯算法-卷机网络-双向门控循环单元-注意力机制多特分类预测 M
    %*****************************************************************************************************************************************************************************************************************%%清空环境变量warningoff%关闭报警......
  • 基于深度学习+pytorch+PyQt6+MySQL的口罩佩戴识别系统
    前言本系统是一个完整的基于深度学习+pytorch+PyQt6+MySQL的口罩佩戴识别系统。包括LeNet、AlexNet、VGG、GoogLeNet、ResNet、MobileNetV2网络模型。可以直接训练、测试、使用。也就是说,它不仅仅是一个口罩佩戴识别系统。它可以是任意识别系统!!系统演示视频登陆注册系......
  • 基于深度学习+pytorch+PyQt6+MySQL的农作物识别系统
    前言本系统是一个完整的基于深度学习+pytorch+PyQt6+MySQL的农作物识别系统。包括LeNet、AlexNet、VGG、GoogLeNet、ResNet、MobileNetV2网络模型。可以直接训练、测试、使用。也就是说,它不仅仅是一个农作物识别系统。它可以是任意识别系统!!系统演示视频登陆注册系统使......
  • 基于YOLOv8模型的安全背心目标检测系统(PyTorch+Pyside6+YOLOv8模型)
    摘要:基于YOLOv8模型的安全背心目标检测系统可用于日常生活中检测与定位安全背心目标,利用深度学习算法可实现图片、视频、摄像头等方式的目标检测,另外本系统还支持图片、视频等格式的结果可视化与结果导出。本系统采用YOLOv8目标检测算法训练数据集,使用Pysdie6库来搭建前端页面......
  • 基于YOLOv8模型和PCB电子线路板缺陷目标检测系统(PyTorch+Pyside6+YOLOv8模型)
    摘要:基于YOLOv8模型PCB电子线路板缺陷目标检测系统可用于日常生活中检测与定位PCB线路板瑕疵,利用深度学习算法可实现图片、视频、摄像头等方式的目标检测,另外本系统还支持图片、视频等格式的结果可视化与结果导出。本系统采用YOLOv8目标检测算法训练数据集,使用Pysdie6库来搭建......
  • 【造轮子】qiankun详解和手写
    说到微前端,现在最火的方案就是qiankun。qiankun的特点是易用性和完备性很高。说白了就是能很方便、快速的接入,同时bug少,功能强大。介绍微前端已经火了一段时间了,就不介绍了,直接贴图得了。话不多少,本次主要做两件事情:拆解和解析qiankun源码尝试qiankun造轮子分析qi......
  • 【CNN-GRU-Attention】基于卷积神经网络和门控循环单元网络结合注意力机制的多变量回
    ......
  • 深度学习工程实践:PyTorch Lightning与Ignite框架的技术特性对比分析
    在深度学习框架的选择上,PyTorchLightning和Ignite代表了两种不同的技术路线。本文将从技术实现的角度,深入分析这两个框架在实际应用中的差异,为开发者提供客观的技术参考。核心技术差异PyTorchLightning和Ignite在架构设计上采用了不同的方法论。Lightning通过提供高层次的抽象......