首页 > 其他分享 >使用简化VGGnet对MNIST数据集进行训练

使用简化VGGnet对MNIST数据集进行训练

时间:2024-07-23 09:28:48浏览次数:7  
标签:loss nn VGGnet dataloader 简化 test model MNIST size

 

目录

1.VGGNet特点

2.注意点

3.导入数据集

4.定义简化版的VGG网络结构

5.定义训练和验证函数

6.调用函数

7.多批次训练

8.结果


  VGGNet 是由牛津大学的视觉几何组(Visual Geometry Group)在 2014 年提出的一个深度卷积神经网络。它在 ImageNet 竞赛中取得了很好的成绩。VGGNet 的主要贡献是展示了网络深度对于性能的提升有显著影响。

1.VGGNet特点

  1. 网络深度:VGGNet 有多个版本,其中最常见的是 VGG16 和 VGG19,分别包含 13 层和 16 层卷积层。
  2. 卷积核大小:VGGNet 使用 3x3 的小卷积核,这使得网络更深但参数更少。
  3. 步长和填充:卷积层使用 1x1 的步长和 1 像素的填充,保持特征图的尺寸。
  4. 池化层:每两个卷积层后跟一个 2x2 的最大池化层,步长为 2。
  5. 全连接层:在卷积层之后,VGGNet 使用几个全连接层,最后是一个 softmax 层进行分类。

2.注意点

  VGG16结构复杂,而MNIST数据集图像太小,在经过过多的池化层后,维度会简化到0然后报错,所以本代码将使用两个卷积层和两个池化层的简化版VGGnet。

3.导入数据集

import torch
from torch import nn
#导入神经网络模块
from torch.utils.data import DataLoader #数据包管理工具,打包数据
from torchvision import datasets #封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor

training_data = datasets.MNIST(#跳转到国数的内部源代码,pycharm 按ctrl +鼠标点击
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),#张量,图片是不能直接传入神经网络模型
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),#Tensor是在深度学习中提出并厂泛应用的数据类型,它与深度学习框架(如PyTorch、TensorFlow) 紧密集成
)

train_dataloader = DataLoader(training_data, batch_size=64) #64张图片为一个包
test_dataloader = DataLoader(test_data, batch_size=64)

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

4.定义简化版的VGG网络结构

class VGG_MNIST(nn.Module):
    def __init__(self):
        super(VGG_MNIST, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 可以继续添加层,但考虑到MNIST的简单性,这里就足够了
        )
        self.classifier = nn.Sequential(
            nn.Linear(128 * 7 * 7, 1024),  # 根据特征图大小调整这个值
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(1024, 10),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

model = VGG_MNIST().to(device)  # 把刚刚创建的模型传入到 Gpu

5.定义训练和验证函数

def train(dataloader, model, loss_fn, optimizer):
    model.train()  # #告诉模型要开始训练,模型中ω进行随机化操作,以及更新ω,在训练过程中,ω会被修改
    batch_size_num = 1
    for X, y in dataloader: # 其中batch为每个数据的编号
        X, y = X.to(device), y.to(device) # 把训练数据集和标签传入cpu或GPU
        pred = model.forward(X) # 自动初始化ω权值
        loss = loss_fn(pred, y) # 通过交叉熵损失函数计算损失值loss
        optimizer.zero_grad()  # 梯度值清零
        loss.backward()  # 反向传播计算得到每个参数的梯度值
        optimizer.step()  # 根据梯度更新网络参数

        loss_value = loss.item()  # item从tensor数据中提取数据出来,tensor获取损失值
        print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")
        batch_size_num += 1


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval() # 测试,ω就不能再更新。
    test_loss, correct = 0, 0
    with torch.no_grad():# 一个上下文管理器,关闭梯度计算。当确认不会调用Tensor.backward()的时候。这可以减少计算所用内存消耗。
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)  # 数据传到GPU或者CPU中
            pred = model.forward(X)
            test_loss += loss_fn(pred, y).item() # test_loss 会自动累加每一个批次的损失值
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches  # 能来衡量模型测试的好坏。
    correct /= size  # 平均的正确率
    print(f"Test result: \n Accuracy: {(100*correct)}%,Avg loss: {test_loss}")

6.调用函数

loss_fn = nn.CrossEntropyLoss()  # 创建交叉熵损失函数对象,因为手写字识别中共有10个数字,输出会有10个结果
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train(train_dataloader, model, loss_fn, optimizer)  # 训练1次完整的数据,多轮训练
test(test_dataloader, model, loss_fn)

7.多批次训练

epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n--------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)

8.结果

标签:loss,nn,VGGnet,dataloader,简化,test,model,MNIST,size
From: https://blog.csdn.net/2301_77444219/article/details/140620288

相关文章

  • sympy 的简化是否可用于未知函数?
    以下代码意外地将f(1)=xf(0)简化为f(1)=0。是因为没有进一步假设就不能使用未定义的函数吗?运行此命令不应更改表达式,但它会给出Eq(f(1),0)fromsympyimport*x,y=symbols("xy")f=Function("f")print(Eq(f(1),x*f(0)).simplify(rational=True,doi......
  • 想让字典操作更优雅?自定义Python字典类型,简化你的代码库!
    目录1、继承dict类......
  • 基于mnist数据集的手写数字识别模型的训练可视化预测
    使用 tensorflow库创建训练模型数据集使用公开的mnist 一、构建模型fromtensorflow.keras.layersimportDense,DropoutimporttensorflowastfdefmnistModel():model=tf.keras.Sequential([tf.keras.layers.Flatten(input_shape=(28,28)),#对......
  • 使用Java和Google Guava简化开发
    使用Java和GoogleGuava简化开发大家好,我是微赚淘客系统3.0的小编,是个冬天不穿秋裤,天冷也要风度的程序猿!GoogleGuava是Google开发的一个Java开源库,它提供了许多工具和库来简化Java开发。Guava提供了从集合类到缓存、字符串处理、并发工具等多种功能。本篇文章将介绍如......
  • 简化Android数据管理:深入探索SQLite数据库
    SQLite数据库在Android中的使用SQLite是一种精巧的、轻量级的、无服务器的、零配置的、事务性SQL数据库引擎。相较于其他数据库系统,SQLite更适用于需要轻量级解决方案的移动应用场景。本文将详细介绍SQLite数据库在Android中的使用,包括数据库的创建、表的建立、数据的增删......
  • jQuery:简化DOM操作的利器
    ......
  • 使用Spring Data JPA实现持久化层的简化开发
    使用SpringDataJPA实现持久化层的简化开发大家好,我是微赚淘客系统3.0的小编,是个冬天不穿秋裤,天冷也要风度的程序猿!在现代的Java应用开发中,SpringDataJPA为我们提供了一种简单而强大的方式来操作数据库,本文将深入探讨如何利用SpringDataJPA简化持久化层的开发。一、Spring......
  • 汽车通用微控制器S32K324NHT1MPBIR、S32K324NHT1MMMSR、S32K314EHT1MMMSR可为汽车开发
    S32K3系列32位微控制器(MCU)提供基于Arm®Cortex®-M7的MCU,支持单核、双核和锁步内核配置。S32K3系列具有内核、内存和外设数量方面的可扩展性,能够实现高性能和功能安全,符合ISO26262标准,达到ASILD安全等级。S32K3系列提供全面的端到端解决方案,涵盖从开发到生产的各个环节。S32K......
  • Java性能优化-if-else简化技巧
    场景Java性能优化-switch-case和if-else速度性能对比,到底谁快?:https://blog.csdn.net/BADAO_LIUMANG_QIZHI/article/details/140376572如果单纯是做情景选择,建议使用switch,如果必须使用if-else,过多的if-else会让人看着很难受,可以使用如下几个小技巧来简化过多的if-else。注:......
  • python库(13):Tablib库简化数据处理
    1 Tablib简介数据处理是一个常见且重要的任务。无论是数据科学、机器学习,还是日常数据分析,都需要处理和管理大量的数据。然而,标准库中的工具有时显得不够直观和简便。这时,我们可以借助第三方库来简化数据处理流程。Tablib就是这样一个强大的数据处理库,它提供了一套简单易用......