首页 > 其他分享 >Pytorch入门(一):MNIST-手写数字识别-搭建网络模型

Pytorch入门(一):MNIST-手写数字识别-搭建网络模型

时间:2024-06-17 21:31:43浏览次数:39  
标签:nn self 28 Pytorch padding pytorch 手写 MNIST 搭建

前言

作为刚入门深度学习的一位初学者来说,各种各样的学习资料、视频让我看得头昏眼花。明明本来是想了解Pytorch使用方法,莫名其妙看了一个多小时的算法推理,原理逻辑,让人很是苦恼。于是在自己学习了一段时间后,打算做出这个pytorch的系列教程,就是想让大家基于项目进行实战,更多地了解pytorch这个深度学习框架。废话不多说,我们开始第一期的实战之旅。

搭建环境

作为一个基础教程,我们从搭建环境开始做起,本篇环境搭建适用于windows用户。

python环境安装

这个不多讲了,直接官网下载一个版本的python安装即可。

深度学习环境安装

本篇文章着重讲的部分就是安装pytorch,在此之前,你需要确保电脑上已经进行了python的安装,并且配置了系统的环境变量,否则在这一步将会出现各种问题。所有由于python没有安装导致的问题,这篇文章不进行解答。
首先,进入pytorch,根据自己的需求选择相应的安装方式。pytorch官方网站
在这里插入图片描述
一般我们选择的时候都是选stable版本,也就是稳定版,这样用的过程中不会出现一些奇奇怪怪的问题。
这里注意一下,电脑操作系统的版本、python包管理器的类型、编程语言、是否下载GPU版本的pytorch,这些选项都会影响最后红框中生成的命令。
我这里选择GPU版本的pytorch,因为自己的显卡是N卡3060,如果大家没有显卡可以下载CPU类型的pytorch,作为初学深度学习者,CPU版本也就够用了。

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

由于pytorch的包托管在境外,下载的速度会比较慢,这里建议选择国内镜像源进行下载。
下载完成之后可以进行验证,看是否能够正确识别到电脑显卡。

import torch
print(torch.cuda.is_available())

如果输出是True说明GPU版本的pytorch已经安装完成,现在就可以来进行搭建网络的操作了。

模型搭建

在搭建模型前,我们可以自己进行网络结构的设计,或者也可以参考他人的网络的结构,进行设计。这里我们从下面这张网络结构图进行设计。
在这里插入图片描述在这里插入图片描述在设计网络之前,我们首先需要了解我们的数据集是什么大小,以及我们期望的输出是什么样的。由于MNIST数据集已经提前帮我们确定好了这些,所以我们只需要记住就行。MNIST数据集是每张图片都是由1个单通道28x28像素构成的。我们最终需要网络给我们预测10中数字的概率,所以网络的最终的输出应该是10。知道了这些,我们就可以开始着手设计网络。
pytorch中提供了很方便的方式,我们只需要通过不到100行代码便可以设计出一个优雅的网络出来。

from torch import nn

class MyNet(nn.Module):
    def __init__(self):
        """
        初始化神经网络
        :param args:
        :param kwargs:
        """
        super(MyNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 28, 5, padding=2)
        self.maxpool1 = nn.MaxPool2d(2)
        self.drop1 = nn.Dropout(0.5)
        self.conv2 = nn.Conv2d(28, 28, 5, padding=2)
        self.maxpool2 = nn.MaxPool2d(2)
        self.drop2 = nn.Dropout(0.5)
        self.conv3 = nn.Conv2d(28, 64, 5, padding=2)
        self.maxpool3 = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(576, 64)
        self.linear2 = nn.Linear(64, 10)

    def forward(self, x):
        """
        前向传播函数
        :param x: input
        :return:
        """
        x = self.conv1(x)
        print(x.shape)
        x = self.maxpool1(x)
        x = self.drop1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.drop2(x)
        x = self.conv3(x)
        x = self.maxpool3(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.linear2(x)
        return x

如此,我们的网络就搭建好了,我们试着实例化该网络,然后输出看下它的结构。

model = MyNet()
print(model)
//输出的网络结构如下
MyNet(
  (conv1): Conv2d(1, 28, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (drop1): Dropout(p=0.5, inplace=False)
  (conv2): Conv2d(28, 28, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (drop2): Dropout(p=0.5, inplace=False)
  (conv3): Conv2d(28, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (maxpool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear1): Linear(in_features=576, out_features=64, bias=True)
  (linear2): Linear(in_features=64, out_features=10, bias=True)
)

本次的文章到这里就结束了,后续系列将继续对MNIST手写数据网络搭建及训练+预测进行教学,我们下期再见!

标签:nn,self,28,Pytorch,padding,pytorch,手写,MNIST,搭建
From: https://blog.csdn.net/weixin_42746181/article/details/139727661

相关文章

  • PyTorch与TensorFlow模型互转指南
    在深度学习的领域中,PyTorch和TensorFlow是两大广泛使用的框架。每个框架都有其独特的优势和特性,因此在不同的项目中选择使用哪一个框架可能会有所不同。然而,有时我们可能需要在这两个框架之间进行模型的转换,以便于在不同的环境中部署或利用两者的优势。本文将详细介绍如何......
  • pytorch使用交叉熵训练模型学习笔记
    python代码:importtorchimporttorch.nnasnnimporttorch.optimasoptim#定义一个简单的神经网络模型classSimpleModel(nn.Module):def__init__(self):super(SimpleModel,self).__init__()self.fc=nn.Linear(3,2)#输入3维,输出2类......
  • 手把手教NLP小白如何用PyTorch构建和训练一个简单的情感分类神经网络
        在当今的深度学习领域,神经网络已经成为解决各种复杂问题的强大工具。本文将通过一个实际案例——对Yelp餐厅评论进行情感分类,来介绍如何使用PyTorch构建和训练一个简单的神经网络模型。我们将逐步讲解神经网络的基础概念,如激活函数、损失函数和优化器,并最终实现一......
  • 『手写Mybatis』创建简单的映射器代理工厂
    前言在阅读本文之前,我相信你已经是一个MybatisORM框架工具使用的熟练工了,那你是否清楚这个ORM框架是怎么屏蔽我们对数据库操作的细节的?比如我们使用JDBC的时候,需要手动建立数据库链接、编码SQL语句、执行数据库操作、自己封装返回结果等。但在使用ORM框架后,只需要......
  • 『手写Mybatis』实现映射器的注册和使用
    前言如何面对复杂系统的设计?我们可以把Spring、MyBatis、Dubbo这样的大型框架或者一些公司内部的较核心的项目,都可以称为复杂的系统。这样的工程也不在是初学编程手里的玩具项目,没有所谓的CRUD,更多时候要面对的都是对系统分层的结构设计和聚合逻辑功能的实现,再通过层层转换......
  • GPU版PyTorch安装、GPU版TensorFlow安装(详细教程)
    目录一、介绍PyTorch、TensorFlow 1. PyTorch2.TensorFlow二、GPU版PyTorch安装1.确定CUDA版本2.确定python版本3.安装PyTorch3.1使用官网命令安装(速度慢)3.2本地安装(速度快)4.检验是否安装成功三、GPU版TensorFlow安装1.确定CUDA版本2.确定TensorFlow版本3.安......
  • PyTorch 动态量化模型
    PyTorch动态量化模型简介PyTorch动态量化是一种模型优化技术,可以将模型参数和激活从浮点数转换为定点数,从而显著降低模型大小和提高推理速度。与静态量化不同,动态量化是在推理时进行量化,无需预先收集校准数据。动态量化工作原理动态量化主要包含以下步骤:观察:在模型推理过......
  • pytorch动态量化函数
    PyTorch动态量化APIPyTorch提供了丰富的动态量化API,可以帮助开发者轻松地将模型转换为动态量化模型。主要API包括:torch.quantization.quantize_dynamic:将模型转换为动态量化模型。torch.quantization.QuantStub:观察模型层的输入和输出分布。torch.quantization.Observer......
  • Caffe、PyTorch、Scikit-learn、Spark MLlib 和 TensorFlowOnSpark 概述
    在AI框架方面,有几种工具可用于图像分类、视觉和语音等任务。有些很受欢迎,如PyTorch和Caffe,而另一些则更受限制。以下是四种流行的AI工具的亮点。CaffeeCaffee是贾扬青在加州大学伯克利分校(UCBerkeley)时开发的深度学习框架。该工具可用于图像分类、语音和视觉。但......
  • PyTorch学习9:卷积神经网络
    文章目录前言一、说明二、具体实例1.程序说明2.代码示例总结前言介绍卷积神经网络的基本概念及具体实例一、说明1.如果一个网络由线性形式串联起来,那么就是一个全连接的网络。2.全连接会丧失图像的一些空间信息,因为是按照一维结构保存。CNN是按照图像原始结构进......