首页 > 其他分享 >卷积神经网络CNN实战:MINST手写数字识别——数据集下载与网络训练

卷积神经网络CNN实战:MINST手写数字识别——数据集下载与网络训练

时间:2024-07-22 15:29:05浏览次数:13  
标签:卷积 image MINST transform train transforms CNN import net

数据集下载

这一部分比较简单,就不过多赘述了,把代码粘贴到自己的项目文件里,运行一下就可以下载了。

from torchvision import datasets, transforms

# 定义数据转换,将数据转换为张量并进行标准化
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化
])

# 下载和加载训练集
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# 下载和加载测试集
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

该代码运行效果如下图:

下载好的数据集可以将其中的图片保存,这里给出两个代码,分别采用matplotlib库和opencv库进行可视化和保存

# matplotlib
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import os

# 创建保存图片的文件夹
os.makedirs('mnist_images', exist_ok=True)

# 定义数据转换(转换为Tensor)
transform = transforms.Compose([
    transforms.ToTensor()
])

# 下载 MNIST 数据集
dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)

# 获取前100张图片
for i in range(100):
    image, _ = dataset[i]
    image = image.squeeze()  # 去掉单通道维度

    plt.imshow(image, cmap='gray')
    plt.axis('off')  # 不显示坐标轴
    plt.savefig(f'mnist_images/image_{i+1}.png', bbox_inches='tight', pad_inches=0)

print("前 100 张图片已保存为 PNG 文件")
# opencv
import cv2
import numpy as np
from torchvision import datasets, transforms
import os

# 创建保存图片的文件夹
os.makedirs('mnist_images', exist_ok=True)

# 定义数据转换(转换为Tensor)
transform = transforms.Compose([
    transforms.ToTensor()
])

# 下载 MNIST 数据集
dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)

# 获取前100张图片
for i in range(100):
    image, _ = dataset[i]
    image = image.squeeze().numpy()  # 去掉单通道维度,并转换为 numpy 数组

    # OpenCV 需要图像的范围在 0 到 255 之间
    image = (image * 255).astype(np.uint8)

    # 保存图像
    cv2.imwrite(f'mnist_images/image_{i+1}.png', image)

# 可选:显示图像
cv2.imshow('image_1', image)
cv2.waitKey(0)
cv2.destroyAllWindows()

# 如果你启用了显示图像的功能,记得在最后调用以下代码:
cv2.destroyAllWindows()

网络训练

该代码运行效果如下图

import torch

'''=============== 数据集部分 ==============='''
# 定义数据转换
import torchvision.transforms as transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 打开已经下载的训练集和测试集
from torchvision.datasets import MNIST
train_dataset = MNIST(root='./data', train=True, download=False, transform=transform)
test_dataset = MNIST(root='./data', train=False, download=False, transform=transform)

# 创建数据加载器
batch_size = 256
from torch.utils.data import random_split
from torch.utils.data import DataLoader

# 将数据集分割为6000和剩余的数据
train_size = 6000
train_subset, _ = random_split(train_dataset, [train_size, len(train_dataset) - train_size])

train_loader = DataLoader(dataset=train_subset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

'''=============== 网络定义 ==============='''
# 初始化网络
from net import CNN 
net = CNN()

# 初始化优化器、学习率调整器、评价函数
import torch.nn as nn
from torch import optim
learning_rate = 0.001 # 0.05 ~ 1e-6
weight_decay = 1e-4 # 1e-2 ~ 1e-8
momentum = 0.8 # 0.3~0.9
optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=momentum)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)
criterion = nn.CrossEntropyLoss()

# GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net.to(device=device)
    
'''=============== 模型信息管理 ==============='''
model_path = None

if model_path is not None:
    net.load_state_dict(torch.load(model_path, map_location=device))

'''=============== 网络训练 ==============='''
epochs = 50

def train(net, device, optimizer, scheduler, criterion):
    net.train() 
    
    for epoch in range(epochs):
        epoch_loss = 0      # 集损失置0
        
        for images, labels in train_loader:
            ''' ========== 数据获取和转移 ========== '''
            images = images.to(device=device, dtype=torch.float32)
            labels = labels.to(device=device, dtype=torch.long)
            
            ''' ========== 数据操作 ========== '''
            outputs = net(images)
            # net.forward()
            loss = criterion(outputs, labels)
            epoch_loss += loss.detach().item()

            ''' ========== 反向传播 ========== '''
            optimizer.zero_grad()
            loss.requires_grad_(True)
            loss.backward() 
            
            # 梯度裁剪
            for param in net.parameters():
                if param.grad is not None and param.grad.nelement() > 0:
                    nn.utils.clip_grad_value_([param], clip_value=0.1)
                    
            optimizer.step()

        epoch_loss /= len(train_loader)
    
        # 输出每个 epoch 的平均损失
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss}')
        
train(net, device, optimizer, scheduler, criterion)

'''=============== 网络保存 ==============='''
from datetime import datetime

# 获取当前时间
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
model_path = f'./output/final_model_{current_time}.pth'

# 保存模型
torch.save(net.state_dict(), model_path)

标签:卷积,image,MINST,transform,train,transforms,CNN,import,net
From: https://www.cnblogs.com/SXWisON/p/18314373

相关文章

  • Fast R-CNN网络结构、框架原理详解
    一、FastR-CNN简介FastR-CNN是一种基于区域卷积网络(Region-basedConvolutionalNetwork)的快速目标检测方法。是R-CNN作者RossGirshick继R-CNN之后的又一力作,原文链接。与R-CNN相同,FastR-CNN同样使用VGG16作为网络的backbone,FastR-CNN训练非常深的VGG16网络比R-CN......
  • Maskrcnn学习笔记--个人向
    论文名称:MaskR-CNN论文下载地址:https://arxiv.org/abs/1703.06870在阅读本篇博文之前需要掌握FasterR-CNN、FPN以及FCN相关知识。FasterR-CNN视频讲解:FasterRCNN_哔哩哔哩_bilibiliFPN视频讲解:1.1.2FPN结构详解_哔哩哔哩_bilibiliFCN视频讲解:FCN网络结构详解(语义分割......
  • 基于卷积神经网络(CNNs)的无监督多模态子空间聚类方法
    基于卷积神经网络(CNNs)的无监督多模态子空间聚类方法引言基于卷积神经网络(CNNs)的无监督多模态子空间聚类方法是一种前沿技术,专门设计用于处理来自不同模态(如图像、文本、音频等)的高维数据,旨在自动学习表示并聚类这些数据,而无需任何标记信息。这种方法利用CNNs的特征提取能......
  • 【YOLOv5/v7改进系列】引入SAConv——即插即用的卷积块
    一、导言《DetectoRS:使用递归特征金字塔和可切换空洞卷积进行物体检测》这篇文章提出了一种用于物体检测的新方法,结合了递归特征金字塔(RecursiveFeaturePyramid,RFP)和可切换空洞卷积(SwitchableAtrousConvolution,SAC)。以下是对该研究的优缺点分析:优点:机制灵感来源于人......
  • YOLOv10有效涨点专栏目录 | 包含卷积、主干、检测头、注意力机制、Neck、二次创新、独
     ......
  • 基于 CNN(二维卷积Conv2D)+LSTM 实现股票多变量时间序列预测(PyTorch版)
    前言系列专栏:【深度学习:算法项目实战】✨︎涉及医疗健康、财经金融、商业零售、食品饮料、运动健身、交通运输、环境科学、社交媒体以及文本和图像处理等诸多领域,讨论了各种复杂的深度神经网络思想,如卷积神经网络、循环神经网络、生成对抗网络、门控循环单元、长短期记忆......
  • 卷积神经网络【CNN】--卷积层的原理详细解读
    卷积层(ConvolutionalLayer)是卷积神经网络(ConvolutionalNeuralNetwork,CNN)中的核心组件,它通过卷积运算对输入数据进行特征提取。以下是对卷积层的相关概述:一、基本概念定义:卷积层由多个卷积单元组成,每个卷积单元的参数通过反向传播算法优化得到。卷积运算的目的是提取输入......
  • 在Python中使用SWCNN去除水印
    在Python中使用SWCNN去除水印说明首次发表日期:2024-07-17SWCNNGithub官方仓库:https://github.com/hellloxiaotian/SWCNNSWCNN论文链接:https://arxiv.org/abs/2403.05807准备运行环境首先创建一个conda环境,安装SWCNN官方建议的库:condacreate-npy39torchpython=3.......
  • 北京交通大学《深度学习》专业课,实验3卷积、空洞卷积、残差神经网络实验
    一、实验要求1.二维卷积实验(平台课与专业课要求相同)⚫手写二维卷积的实现,并在至少一个数据集上进行实验,从训练时间、预测精度、Loss变化等角度分析实验结果(最好使用图表展示)⚫使用torch.nn实现二维卷积,并在至少一个数据集上进行实验,从训练时间、预测精度、Loss变化等角......
  • 算法金 | 秒懂 AI - 深度学习五大模型:RNN、CNN、Transformer、BERT、GPT 简介
    1.RNN(RecurrentNeuralNetwork)时间轴1986年,RNN模型首次由DavidRumelhart等人提出,旨在处理序列数据。关键技术循环结构序列处理长短时记忆网络(LSTM)和门控循环单元(GRU)核心原理RNN通过循环结构让网络记住以前的输入信息,使其能够处理序列数据。每个节点不仅接收当前......