首页 > 其他分享 >图像分类实战:深度学习在CIFAR-10数据集上的应用

图像分类实战:深度学习在CIFAR-10数据集上的应用

时间:2024-03-30 09:32:18浏览次数:22  
标签:10 plt self CIFAR test 集上 model history

1.前言

        图像分类是计算机视觉领域的一个核心任务,算法能够自动识别图像中的物体或场景,并将其归类到预定义的类别中。近年来,深度学习技术的发展极大地推动了图像分类领域的进步。CIFAR-10数据集作为计算机视觉领域的一个经典小型数据集,为研究者提供了一个理想的实验平台,用于验证和比较不同的图像分类算法。本文将介绍CIFAR-10数据集的基本情况和加载方法,并展示如何构建与训练一个卷积神经网络(CNN)模型来进行图像分类,最后对模型的性能进行评估与可视化。

2.数据集介绍与加载

        CIFAR-10数据集由加拿大高等研究院(Canadian Institute for Advanced Research, CIFAR)发布,是计算机视觉领域广泛使用的基准数据集之一。它包含了10个类别(飞机、汽车、鸟类、猫、鹿、狗、青蛙、船、卡车、马)的彩色图像,每类有6,000张图像,共计60,000张。所有图像尺寸统一为32x32像素,且已进行标准化处理,其色彩模式为RGB。数据集被划分为50,000张训练图像和10,000张测试图像,保证了训练集与测试集的均衡分布。

        数据加载

        使用Python的tensorflow.keras.datasets模块加载CIFAR-10数据集,同时进行必要的预处理,如归一化和标签转换。

import tensorflow as tf

# 加载CIFAR-10数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# 数据归一化
x_train, x_test = x_train / 255.0, x_test / 255.0

# 将标签转换为one-hot编码
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)

3.构建与训练CNN模型

        ResNet(Residual Neural Network)是一种深度残差学习网络,通过引入残差块解决了深度神经网络训练过程中的梯度消失和爆炸问题,从而能够构建和训练极深的模型,显著提升模型的性能和泛化能力。

        关于CNN模型的更多介绍,请看这篇文章:

卷积神经网络(CNN):图像识别的强大工具-CSDN博客文章浏览阅读795次,点赞9次,收藏18次。卷积神经网络是一种强大的图像识别工具,它能够自动学习图像的特征,并在各种图像识别任务中取得出色的效果。通过使用深度学习框架和大量的训练数据,我们可以构建出高效准确的卷积神经网络模型,实现对图像的分类、识别等任务。希望这篇文章能够帮助你更好地理解卷积神经网络在图像识别中的应用。如果你有任何问题或需要进一步的帮助,请随时提问。https://blog.csdn.net/meijinbo/article/details/137015665

3.1.构建模型

        使用Keras构建一个适用于CIFAR-10数据集的小型ResNet模型。

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Add, MaxPooling2D, GlobalAveragePooling2D, Dense

def residual_block(input_tensor, filters, strides=1, use_projection=False):
    shortcut = input_tensor
    if use_projection:
        shortcut = Conv2D(filters, kernel_size=1, strides=strides, padding='valid')(shortcut)
        shortcut = BatchNormalization()(shortcut)

    x = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(input_tensor)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters, kernel_size=3, strides=1, padding='same')(x)
    x = BatchNormalization()(x)

    if strides != 1 or input_tensor.shape[-1] != filters:
        shortcut = Conv2D(filters, kernel_size=1, strides=strides, padding='valid')(shortcut)
        shortcut = BatchNormalization()(shortcut)

    x = Add()([shortcut, x])
    x = Activation('relu')(x)

    return x


def build_resnet():
    model = Sequential()
    model.add(Conv2D(16, kernel_size=3, padding='same', input_shape=(32, 32, 3)))
    model.add(BatchNormalization())
    model.add(Activation('relu'))

    for _ in range(2):
        model.add(residual_block(model.output, 16))

    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(residual_block(model.output, 32, strides=2, use_projection=True))

    for _ in range(2):
        model.add(residual_block(model.output, 32))

    model.add(GlobalAveragePooling2D())
    model.add(Dense(10, activation='softmax'))

    return model

resnet_model = build_resnet()
resnet_model.summary()

3.2.模型训练

        配置模型训练参数,启动训练过程,并监控训练进度。

resnet_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

history = resnet_model.fit(x_train, y_train,
                          batch_size=128,
                          epochs=100,
                          validation_data=(x_test, y_test),
                          verbose=1)

4.模型性能评估与可视化

4.1.性能评估

        评估模型在测试集上的最终性能指标。

test_loss, test_acc = resnet_model.evaluate(x_test, y_test, verbose=2)
print(f'Test accuracy: {test_acc:.4f}')

 4.2.可视化

        绘制训练过程中损失和准确率曲线,以直观了解模型收敛情况与过拟合风险。

import matplotlib.pyplot as plt

def plot_history(history):
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.show()

plot_history(history)  # 显示训练过程中的准确率与损失曲线

        以下是基于PyTorch的实现:

import torch.nn as nn  
import torch.nn.functional as F  
  
class SimpleCNN(nn.Module):  
    def __init__(self):  
        super(SimpleCNN, self).__init__()  
        self.conv1 = nn.Conv2d(3, 6, 5)  
        self.pool = nn.MaxPool2d(2, 2)  
        self.conv2 = nn.Conv2d(6, 16, 5)  
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  
        self.fc2 = nn.Linear(120, 84)  
        self.fc3 = nn.Linear(84, 10)  
  
    def forward(self, x):  
        x = self.pool(F.relu(self.conv1(x)))  
        x = self.pool(F.relu(self.conv2(x)))  
        x = x.view(-1, 16 * 5 * 5)  
        x = F.relu(self.fc1(x))  
        x = F.relu(self.fc2(x))  
        x = self.fc3(x)  
        return x  
  
# 实例化模型、定义损失函数和优化器  
model = SimpleCNN()  
criterion = nn.CrossEntropyLoss()  
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)  
  
# 训练模型  
for epoch in range(2):  # 假设我们训练两个epoch  
    running_loss = 0.0  
    for i, data in enumerate(trainloader, 0):  
        inputs, labels = data  
        optimizer.zero_grad()  
        outputs = model(inputs)  
        loss = criterion(outputs, labels)  
        loss.backward()  
        optimizer.step()  
        running_loss += loss.item()  
        if i % 2000 == 1999:  # 每2

 5.总结

        通过以上步骤,我们已经完成了在CIFAR-10数据集上使用深度学习进行图像分类的全过程。从数据集的介绍与加载,到构建并训练ResNet模型,再到模型性能的评估与可视化,这一系列操作展示了如何将理论知识应用于实际问题,揭示了深度学习在图像分类任务中的强大能力。实践中,可根据具体需求调整模型结构、优化策略等参数以进一步提升模型性能。

标签:10,plt,self,CIFAR,test,集上,model,history
From: https://blog.csdn.net/meijinbo/article/details/137151822

相关文章

  • 【洛谷】 P1006 [NOIP2008提高组]传纸条
    题目描述小渊和小轩是好朋友也是同班同学,他们在一起总有谈不完的话题。一次素质拓展活动中,班上同学安排坐成一个 m 行 n 列的矩阵,而小渊和小轩被安排在矩阵对角线的两端,因此,他们就无法直接交谈了。幸运的是,他们可以通过传纸条来进行交流。纸条要经由许多同学传到对方手里,......
  • 3121002754 刘栋 《需求规格说明书》
    这个作业属于哪个课程<软件工程2024-双学位>这个作业要求在哪里<团队作业2——需求说明文档>这个作业的目标完成需求文档目录团队作业2-需求说明文档需求说明面向用户分析功能性需求预期用户数量系统价值gitcode链接时间安排原安排表校正后安排感想团队作业2-......
  • 第16期 Double Commander 开源免费的Total Commander替代型【体验100款文件管理工具】
     体验背景:我们正在做一款文件版本管理软件,追光几何(追光几何),期待以最无感的方式,解决新一代工程师文件管理的问题,让大家有更多时间去做快乐和有成就感的事情。所以打算体验100款文件管理软件,来取长补短。真实1h体验DoubleCommander是一款开源的跨平台文件管理软件,灵感来源......
  • 代码随想录算法训练营第二十三天(二叉树9)|669. 修剪二叉搜索树、108. 将有序数组转换为
    文章目录669.修剪二叉搜索树解题思路源码108.将有序数组转换为二叉搜索树解题思路源码538.把二叉搜索树转换为累加树解题思路源码669.修剪二叉搜索树给你二叉搜索树的根节点root,同时给定最小边界low和最大边界high。通过修剪二叉搜索树,使得所有节点的值......
  • PCB的10条布线原则
    目录电气连接原则:连线精简:避免直角:差分走线:蛇形线等长:圆滑走线:数字模拟分开:3W原则:20H原则:安全载流原则:铜箔承载电流:过孔承载电流:电气连接原则:连线精简:避免直角:差分走线:蛇形线等长:圆滑走线:数字模拟分开:3W原则:20H原则:安全载流原则:铜箔承载电流......
  • 2024年1000个计算机毕业设计项目推荐(源码+论文【万字】)
    2024年最新计算机毕业设计题目推荐,项目汇总!本科、专科。项目设计、项目定制、辅导、万字文档哈喽,大家好,大四的同学马上要开始做毕业设计了,大家做好准备了吗?博主给大家详细整理了计算机毕业设计最新项目,对项目有任何疑问,都可以问博主哦~技术栈包括但不限于:Java、JavaWeb......
  • 我做【网创导师训练营】项目,从0开始做,3个月就做到月入10万+,现在手把手复制给你
    一、开门见山,先公开一下我自己操作的收益大家好!我是如今,【如今笔记】的主理人。今天给大家带来的项目是:网创导师训练营。我是从0开始做的这个项目,三个月的时间,我做到了月收益10万以上,而且现在每月的收益都还在增加。 这个项目比较简单,不需要你有很多的设备,一台手机+一台电......
  • ZCMU-1038
    其实感觉不太难,读懂题意就行,我一开始没有仔细去读感觉就很懵。其题目意思就是一段字符串含有数字和'<'或者'>',一开始从左开始遍历,遇到'>'这类东西换方向,如果有多次遇到就删之前那一个;遇到数字就记下,并减去,一直减到0,就删掉思路:无非用一个int类型的数组存放数字打印个数,以及模拟......
  • P2107
    小Z的AK计划题目描述在小Z的家乡,有机房一条街,街上有很多机房。每个机房里都有一万个人在切题。小Z刚刷完CodeChef,准备出来逛逛。机房一条街有$n$个机房,第$i$个机房的坐标为$x_i$,小Z的家坐标为$0$。小Z在街上移动的速度为$1$,即从$x_1$到$x_2$所耗费的时间......
  • 《引流108招》第2招:众筹兴风作浪
    有好项目却找不到意向客户,只能看着别人赚钱?平台引流总是违规,账号老被封?《引流108招》:108个最新最有效的,0基础小白都能学会的引流获客秘籍,帮你掌握流量密码,实现简单粗暴每日引流100+,疯狂收钱你好,我是独立开发者小黑今天我们跟大家讲一招微信体系内的引流玩法,我给它取名为"......