首页 > 其他分享 >构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类

构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类

时间:2024-09-07 20:53:30浏览次数:19  
标签:10 plt nn loss self torch CIFAR CNN

深度学习实践:构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类

引言

在计算机视觉领域中,CIFAR-10数据集是一个经典的基准数据集,广泛用于图像分类任务。本文将介绍如何使用PyTorch框架构建一个简单的卷积神经网络(CNN),并在CIFAR-10数据集上进行训练和评估。通过本文,您将了解到数据预处理、模型定义、训练过程及结果可视化的完整流程。
在这里插入图片描述

数据预处理

首先,我们需要加载并预处理CIFAR-10数据集。CIFAR-10包含60000张32x32的彩色图像,分为10个类别,每个类别有6000张图像。我们使用torchvision库来轻松加载这些数据,并应用一些基本的变换,如归一化。

import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化到[-1, 1]
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
模型定义

接下来,我们定义一个简单的卷积神经网络。该网络包含三个卷积层,两个池化层,以及两个全连接层。

import torch.nn as nn

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 64)  # 考虑到池化层后的尺寸
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = torch.relu(self.conv3(x))
        x = x.view(-1, 64 * 8 * 8)  # flatten
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = ConvNet()
训练过程

我们使用Adam优化器和交叉熵损失函数来训练模型,并将模型训练10个epoch。训练过程中,我们记录每个epoch的平均损失。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)

num_epochs = 10
loss_history = []  # 记录每个epoch的平均损失
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 100}')
            running_loss = 0.0

    epoch_loss = running_loss / len(trainloader)
    loss_history.append(epoch_loss)

print('Finished Training')
模型评估

训练完成后,我们在测试集上评估模型的性能,并计算准确率。

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

final_accuracy = 100 * correct / total

print(f'Accuracy of the network on the 10000 test images: {final_accuracy} %')
结果可视化

最后,我们将训练过程中的损失和最终的准确率进行可视化,以便更直观地了解模型的训练效果。

import matplotlib.pyplot as plt

# 可视化损失
plt.plot(range(1, num_epochs + 1), loss_history)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss History')
plt.show()

# 可视化准确率
plt.bar(1, final_accuracy, width=0.4, label='Final Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Final Accuracy on Test Set')
plt.legend()
plt.show()
结论

本文介绍了如何使用PyTorch构建并训练一个简单的卷积神经网络对CIFAR-10数据集进行分类。通过数据预处理、模型定义、训练及结果可视化,我们完整地展示了深度学习项目的流程。希望本文能为您提供一些有用的参考和启发,帮助您在自己的深度学习项目中取得更好的成果。

标签:10,plt,nn,loss,self,torch,CIFAR,CNN
From: https://blog.csdn.net/myTomorrow_better/article/details/141942111

相关文章

  • 数字IC验证笔面试常见100题【持续更新】
    【提要】收集整理了一些网络上和我自己在秋招、实习时遇到的题目,适合数字验证方向求职的同学进行差缺补漏或者应对八股时的速成。    对于时间比较充裕并且有条件的同学,还是强烈建议找个实习来提升自己的能力以及校招竞争性,独立完成了一两个真实项目后,能大大加深对验证......
  • LNGS1002 2024 Social Dialectology
    LNGS10022024Assignment2-SocialDialectologyRELEASED Wednesday28th AugustDUE Sunday8th September,11:59pmviaTurnitinToensureanonymousmarking,pleasedonotincludeyournameorSIDonthe assignment.WhenyousubmityourassignmentinCanv......
  • YOLOv10s训练代码解析7:TaskAlignedAssigner正负样本匹配
    本专栏会手把手带你从源码了解YOLOv10(后续会陆续介绍YOLOv8、RTDETR等模型),尽可能地完整介绍整个算法,这个专栏会持续创作与更新,大家如果想要本文PDF和思维导图,后台私信我即可(创作不易,不喜勿喷),大家如果发现任何错误和需要修改的地方都可以私信我,我会统一修改。注:训练batch为......
  • 【正点原子K210连载】第二十九章 音频录制实验 摘自【正点原子】DNK210使用指南-CanMV
    第二十九章音频录制实验本章将介绍CanMV下的音频录制通过CanMV提供的模块便能快速地实现音频录制。通过本章的学习,读者将学习到CanMV下控制I2S获取音频数和audio模块的使用。本章分为如下几个小节:29.1maix.I2S模块及audio模块介绍29.2硬件设计29.3程序设计29.4运行验证29......
  • 【正点原子K210连载】第三十章 照片拍摄实验 摘自【正点原子】DNK210使用指南-CanMV版
    第三十章照片拍摄实验在前面的章节中,已经了解了如何在CanMV下获取摄像头输出的图像数据并在LCD上进行显示,同时也了解了如何解码文件系统中的图像文件然后在LCD上进行显示,本章将通过照片拍摄实验,介绍如何通过CanMV将摄像头输出的图像数据进行图像编码保存到文件系统中。通过本章的......
  • 【Leetcode:LCR 101. 分割等和子集 + 递归 + 记忆化搜索 + dp】
    ......
  • Go - Web Application 10
    CreatingaunittestInGo,it’sstandardpracticetowriteyourtestsin*_test.gofileswhichlivedirectly alongsidethecodethatyou’retesting.So,inthiscase,thefirstthingthatwe’regoingtodo iscreateanewcmd/web/template_test.gofilet......
  • 编程技术开发105本经典书籍推荐分享
    最近整理了好多的技术书籍,对于提高自己能力来说还是很有用的,当然要有选择的看,不然估计退休了都不一定看得完,分享给需要的同学。编程技术开发105本经典书籍推荐:https://zhangfeidezhu.com/?p=753分享截图......
  • 分享10个免费的Python代码仓库,轻松实现办公自动化!
    为了帮助大家更好地利用Python实现自动化办公,我们精心挑选了10个免费的Python代码仓库。这些仓库不仅包含了实用的脚本和示例,还涵盖了从基础到进阶的各种自动化任务解决方案。无论你是Python编程的初学者,还是希望提升工作效率的职场人士,都能在这些仓库中找到适合自己的资......