首页 > 其他分享 >深度学习入门之手写数字识别

深度学习入门之手写数字识别

时间:2025-01-14 12:10:37浏览次数:1  
标签:loss 入门 nn loader device train test 手写 识别

模型定义

我们使用 CNN 和 MLP 来定义模型:

import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        """
        定义模型结构

        输入维度为 1 * 28 * 28 (C, H, W)
        """
        super(Model, self).__init__()

        # 卷积层 1
        self.conv1 = nn.Sequential(
            # 二维卷积层,输入通道数为 1,输出通道数为 16,卷积核大小为 5,填充为 2
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            # ReLU 激活函数
            nn.ReLU(),
            # 最大池化层,池化窗口大小为 2
            nn.MaxPool2d(kernel_size=2)
            # 输出维度为 16 * 14 * 14 (C, H/2, W/2)
        )

        # 卷积层 2
        self.conv2 = nn.Sequential(
            # 二维卷积层,输入通道数为 16,输出通道数为 32,卷积核大小为 5,填充为 2
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            # ReLU 激活函数
            nn.ReLU(),
            # 最大池化层,池化窗口大小为 2
            nn.MaxPool2d(kernel_size=2)
            # 输出维度为 32 * 7 * 7 (C, H/4, W/4)
        )

        # 全连接层,输入维度为 32 * 7 * 7,输出维度为 10
        self.fc = nn.Linear(32 * 7 * 7, 10)

    def forward(self, x):
        """
        前向传播函数,由 torch 自动调用
        """
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

训练和测试函数

import torch

def train(model, train_loader, criterion, optimizer, device):
    # 设置模型为训练模式
    model.train()

    total_loss = 0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        # 前向传播
        output = model(data)  # 输出维度为 (batch_size, 10)
        # 计算损失
        loss = criterion(output, target)
        # 计算预测结果
        _, predicted = output.max(1)
        # 反向传播和优化
        optimizer.zero_grad()  # 清空梯度
        loss.backward()        # 反向传播
        optimizer.step()       # 更新参数

        # 统计
        total_loss += loss.item()
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

        # 打印进度
        if (batch_idx + 1) % 100 == 0:  # 每 100 个 batch 打印一次
            print(f'Batch: {batch_idx + 1}/{len(train_loader)}, '
                  f'Loss: {loss.item():.4f}, '
                  f'Accuracy: {100. * correct / total:.2f}%')

        # 记录训练数据
        writer.add_scalar('Training Loss/Step',
                            loss.item(),
                            epoch * len(train_loader) + batch_idx)

    # 计算平均损失和准确率
    avg_loss = total_loss / len(train_loader)
    accuracy = 100. * correct / total

    # 计算平均损失和准确率
    return avg_loss, accuracy


def test(model, test_loader, criterion, device):
    # 设置模型为评估模式
    model.eval()

    total_loss = 0
    correct = 0
    total = 0

    # 不计算梯度
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            # 前向传播
            output = model(data)  # 输出维度为 (batch_size, 10)
            # 计算预测结果
            _, predicted = output.max(1)  # 从维度为 1 的维度上取最大值
            # 计算损失
            loss = criterion(output, target)
            # 统计
            total_loss += loss.item()
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

    # 计算平均损失和准确率
    avg_loss = total_loss / len(test_loader)
    accuracy = 100. * correct / total

    return avg_loss, accuracy

主程序

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from model import Model
from train import train, test

# 定义超参数
BATCH_SIZE = 64
EPOCHS = 10
LEARNING_RATE = 0.001

def load_data():
    """加载数据"""

    # 定义数据预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # 加载训练集和测试集
    train_dataset = torchvision.datasets.MNIST(
        root='./data', 
        train=True,
        transform=transform,
        download=True
    )

    test_dataset = torchvision.datasets.MNIST(
        root='./data',
        train=False,
        transform=transform,
        download=True
    )

    # 创建数据加载器
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True
    )

    test_loader = DataLoader(
        dataset=test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False
    )

    return train_loader, test_loader

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    train_loader, test_loader = load_data()

    # 定义模型
    model = Model().to(device)
    # 定义损失函数
    criterion = nn.CrossEntropyLoss()
    # 定义优化器
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    best_accuracy = 0
    for epoch in range(EPOCHS):
        print(f'\nEpoch: {epoch + 1}/{EPOCHS}')

        # 训练阶段
        train_loss, train_acc = train(model, train_loader, criterion, optimizer, device, epoch, writer)
        print(f'Training - Average Loss: {train_loss:.4f}, Accuracy: {train_acc:.2f}%')

        # 测试阶段
        test_loss, test_acc = test(model, test_loader, criterion, device, epoch, writer)
        print(f'Testing - Average Loss: {test_loss:.4f}, Accuracy: {test_acc:.2f}%')

        # 保存最佳模型
        if test_acc > best_accuracy:
            best_accuracy = test_acc
            torch.save(model.state_dict(), 'mnist_model.pth')

    print(f'\nBest Test Accuracy: {best_accuracy:.2f}%')

if __name__ == '__main__':
    main()

TensorBoard 可视化

安装依赖:

pip install tensorboard torch_tb_profiler

修改程序,写入训练日志:

timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter(f'runs/mnist_{timestamp}')

sample_images, _ = next(iter(train_loader))
writer.add_graph(model, sample_images.to(device))

writer.add_scalar('Testing Loss/Epoch', avg_loss, epoch)
writer.add_scalar('Testing Accuracy/Epoch', accuracy, epoch)

if epoch == 0:
    images, labels = next(iter(test_loader))
    img_grid = torchvision.utils.make_grid(images[:25])
    writer.add_image('mnist_images', img_grid)

在 VS Code 中,可以使用下面的命令启动 TensorBoard:

> Python: Launch TensorBoard

image

标签:loss,入门,nn,loader,device,train,test,手写,识别
From: https://www.cnblogs.com/Undefined443/p/18670506

相关文章

  • LLM大模型入门必读免费白皮书《从头训练大模型最佳实践》免费pdf分享
    本书介绍《CurrentBestPracticesforTrainingLLMsfromScratch》是由Weights&Biases(W&B)提供的一份关于从头开始训练大型语言模型(LLMs)的权威指南。这份白皮书深入剖析了LLMs训练的最佳实践,内容覆盖了从数据收集与处理、模型架构选择、训练技巧与优化策略,到模型评估......
  • 信息学奥赛考试大纲之CSP-J信息学奥赛考试大纲(入门级)
    信息学奥赛考试大纲CSP-J信息学奥赛考试大纲(入门级)1.计算机基础与编程环境2.C++程序设计2.1程序基本概念2.2基本数据类型2.3程序基本语句2.4基本运算2.5数学库常用函数2.6结构化程序设计2.7数组2.8字符串的处理2.9函数与递归2.10结构体与联合体2.11指针类型2.12文件及基本......
  • Python 和 Tesseract OCR 识别复杂验证码
    ​安装依赖首先,确保已安装所需的工具和库。安装Tesseract在Windows上,下载安装包并进行安装:TesseractGitHub。在Linux上,你可以通过以下命令安装:bash更多内容访问ttocr.com或联系1436423940sudoapt-getinstalltesseract-ocr安装Python库使用pip安装Python......
  • nvidia gpu结构简介和cuda编程入门
    0.前言最近本人在写硕士大论文,需要写一些GPU相关的内容作为引言,所以在此总结一下。1.NVIDIAGPU线程管理CUDA的线程模型如上图,在调用一个CUDA函数时,需要定义grid和block的形状:func<<<grid,block>>>();在程序里定义的grid和block都是dim3类型的变量。当调用一个函数时,该函......
  • 线段树入门讲解
    有一段时间没有更新了,前面比较忙,所以知识上会有一些跳跃,后面看看有没有时间去补一下吧,没有就算了那现在就开始说一下线段树线段树是一种数据结构,他主要是用于实现快速的区间修改和区间求和这两个功能,同时,有别于树状数组,线段树还有更多的是在于其功能的强大和灵活性上,就比如说,树......
  • 数位 dp 入门
    如果余生有浪漫的星河,日子慢慢过。《如果我们在余生相遇》现在发现我们理解一件事情,其实是基于我们的认知去理解的。所以一些以前理解不了的事情,或者说当时难以理解的事情,过一段时间,再去看,可能恍然大悟。#include<iostream>#include<algorithm>#include<vector>usin......
  • SQL刷题快速入门(二)
    其他章节:SQL刷题快速入门(一)承接上一章节,本章主要讲SQL的运算符、聚合函数、SQL保留小数的几种方式三个部分运算符SQL支持多种运算符,用于执行各种操作,如算术运算、比较、赋值、逻辑运算等。以下是一些常见的SQL运算符类型及其示例:算术运算符+(加)-(减)*(乘)/(除)%(取模)SELECT......
  • Kali高手都在用的环境变量技巧,学会这些就能实现隐蔽渗透!黑客技术零基础入门到精通教程
    大家好,我们今天继续更新《黑客视角下的KaliLinux的基础与网络管理》中的管理用户环境变量。为了充分利用我们的黑客操作系统KaliLinux,我们需要理解和善于使用环境变量,这样会使我们的工具更具便利,甚至具有一定的隐蔽性。1.环境变量基础概念1.1什么是变量?变量在计算机......
  • 基于java的停车场车牌识别系统
    一、系统背景与意义随着城市化进程的加速,停车场管理面临着越来越大的挑战。传统的手工记录车牌号方式不仅费时费力,还容易出错。而基于Java的停车场车牌识别系统的出现,则有效地解决了这一问题。该系统能够自动识别进出停车场的车辆车牌号,实现快速、准确的车辆管理,提高了停车......
  • 1130: 【入门】简单a+b(字符串式子a+b)
    看到了吗,不是正常的输入a和b,然后直接相加,而是一个式子,没关系,一个字符串对于电脑而言奥秘多多,给电脑一个式子,他会反应吗?是不是不会。诶,但是让他去提取,那就是“怎么看都看不够”,嘿嘿,开个玩笑,就是提取字符串里的信息可以解决不少问题,这题就是这样。下面是代码:#include<bits/stdc......