首页 > 编程语言 >python实战(十五)——中文手写体数字图像CNN分类

python实战(十五)——中文手写体数字图像CNN分类

时间:2025-01-22 23:29:59浏览次数:3  
标签:python 数字图像 labels test train image CNN import self

一、任务背景

        本次python实战,我们使用来自Kaggle的数据集《Chinese MNIST》进行CNN分类建模,不同于经典的MNIST数据集,我们这次使用的数据集是汉字手写体数字。除了常规的汉字“零”到“九”之外还多了“十”、“百”、“千”、“万”、“亿”,共15种汉字数字

二、python建模

1、数据读取

        首先,读取jpg数据文件,可以看到总共有15000张图像数据。

import pandas as pd
import os

path = '/kaggle/input/chinese-mnist/data/data/'
files = os.listdir(path)
print('数据总量:', len(files))

        我们也可以打印一张图片出来看看。

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# 定义图片路径
image_path = path+files[3]

# 加载图片
image = mpimg.imread(image_path)

# 绘制图片
plt.figure(figsize=(3, 3))
plt.imshow(image)
plt.axis('off')  # 关闭坐标轴
plt.show()

2、数据集构建

        加载必要的库以便后续使用,再定义一些超参数。

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import precision_score, recall_score, f1_score

# 超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 5

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

        这里,我们看一看数据集介绍就会知道图片名称及其含义,需要从chinese_mnist.csv文件中根据图片名称中的几个数字来确定图片对应的标签。

# 获取所有图片文件的路径
all_images = [os.path.join(path, img) for img in os.listdir(path) if img.endswith('.jpg')]

# 读取索引-标签对应关系csv文件,并将'suite_id', 'sample_id', 'code'设置为索引列便于查找
index_df = pd.read_csv('/kaggle/input/chinese-mnist/chinese_mnist.csv')
index_df.set_index(['suite_id', 'sample_id', 'code'], inplace=True)

# 定义函数,根据各索引取值定位图片对应的数值标签value
def get_label_from_index(filename, index_df):
    suite_id, sample_id, code = map(int, filename.split('.')[0].split('_')[1:])
    return index_df.loc[(suite_id, sample_id, code), 'value']

# 构建value值对应的标签序号,用于模型训练
label_dic = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10, 100:11, 1000:12, 10000:13, 100000000:14}
# 获取所有图片的标签并转化为标签序号
all_labels = [get_label_from_index(os.path.basename(img), index_df) for img in all_images]
all_labels = [label_dic[li] for li in all_labels]

# 将图片路径和标签分成训练集和测试集
train_images, test_images, train_labels, test_labels = train_test_split(all_images, all_labels, test_size=0.2, random_state=2024)

        下面定义数据集类并完成数据的加载。

# 自定义数据集类
class CustomDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('L')  # 转换为灰度图像
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# 创建训练集和测试集数据集
train_dataset = CustomDataset(train_images, train_labels, transform=transform)
test_dataset = CustomDataset(test_images, test_labels, transform=transform)

# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# 打印一些信息
print(f'训练集样本数: {len(train_dataset)}')
print(f'测试集样本数: {len(test_dataset)}')

3、模型构建

        我们构建一个包含两层卷积层和池化层的CNN并且在池化层中使用最大池化的方式。

# 定义CNN模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 16 * 16, 128)
        self.fc2 = nn.Linear(128, 15)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 16 * 16)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

4、模型实例化及训练

        下面我们对模型进行实例化并定义criterion和optimizer。

# 初始化模型、损失函数和优化器
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

        定义训练的代码并调用代码训练模型。

from tqdm import tqdm
# 训练模型
def train(model, train_loader, criterion, optimizer, epochs):
    model.train()
    running_loss = 0.0
    for epoch in range(epochs):
        for data, target in tqdm(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch [{epoch + 1}], Loss: {running_loss / len(train_loader):.4f}')
        running_loss = 0.0

train(model, train_loader, criterion, optimizer, num_epochs)

5、测试模型

        定义模型测试代码,调用代码看指标可知我们所构建的CNN模型表现还不错。

# 测试模型
def test(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    precision = precision_score(all_targets, all_preds, average='macro')
    recall = recall_score(all_targets, all_preds, average='macro')
    f1 = f1_score(all_targets, all_preds, average='macro')
    print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')

test(model, test_loader, criterion)

三、完整代码

import pandas as pd
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import precision_score, recall_score, f1_score


path = '/kaggle/input/chinese-mnist/data/data/'
files = os.listdir(path)
print('数据总量:', len(files))


# 超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 5

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 获取所有图片文件的路径
all_images = [os.path.join(path, img) for img in os.listdir(path) if img.endswith('.jpg')]

# 读取索引-标签对应关系csv文件,并将'suite_id', 'sample_id', 'code'设置为索引列便于查找
index_df = pd.read_csv('/kaggle/input/chinese-mnist/chinese_mnist.csv')
index_df.set_index(['suite_id', 'sample_id', 'code'], inplace=True)

# 定义函数,根据各索引取值定位图片对应的数值标签value
def get_label_from_index(filename, index_df):
    suite_id, sample_id, code = map(int, filename.split('.')[0].split('_')[1:])
    return index_df.loc[(suite_id, sample_id, code), 'value']

# 构建value值对应的标签序号,用于模型训练
label_dic = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8, 9:9, 10:10, 100:11, 1000:12, 10000:13, 100000000:14}

# 获取所有图片的标签并转化为标签序号
all_labels = [get_label_from_index(os.path.basename(img), index_df) for img in all_images]
all_labels = [label_dic[li] for li in all_labels]

# 将图片路径和标签分成训练集和测试集
train_images, test_images, train_labels, test_labels = train_test_split(all_images, all_labels, test_size=0.2, random_state=2024)

# 自定义数据集类
class CustomDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('L')  # 转换为灰度图像
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# 创建训练集和测试集数据集
train_dataset = CustomDataset(train_images, train_labels, transform=transform)
test_dataset = CustomDataset(test_images, test_labels, transform=transform)

# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# 打印信息
print(f'训练集样本数: {len(train_dataset)}')
print(f'测试集样本数: {len(test_dataset)}')

# 定义CNN模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 16 * 16, 128)
        self.fc2 = nn.Linear(128, 15)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 16 * 16)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 初始化模型、损失函数和优化器
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

# 训练模型
def train(model, train_loader, criterion, optimizer, epochs):
    model.train()
    running_loss = 0.0
    for epoch in range(epochs):
        for data, target in tqdm(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch [{epoch + 1}], Loss: {running_loss / len(train_loader):.4f}')
        running_loss = 0.0

train(model, train_loader, criterion, optimizer, num_epochs)

# 测试模型
def test(model, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    precision = precision_score(all_targets, all_preds, average='macro')
    recall = recall_score(all_targets, all_preds, average='macro')
    f1 = f1_score(all_targets, all_preds, average='macro')
    print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}')

test(model, test_loader, criterion)

四、总结

        本文基于汉字手写体数字图像进行了CNN分类实战,CNN作为图像处理的经典模型,展现出了它强大的图像特征提取能力,结合更加复杂的模型框架CNN还可用于高精度人脸识别、物体识别等任务中。

标签:python,数字图像,labels,test,train,image,CNN,import,self
From: https://blog.csdn.net/ChaneMo/article/details/145285611

相关文章

  • 故障诊断 | DBO蜣螂优化算法LightGBM故障诊断(Matlab&Python)
    目录效果一览文章概述DBO蜣螂优化算法LightGBM故障诊断(Matlab&Python)DBO蜣螂优化算法LightGBM故障诊断研究一、引言1.1、研究背景及意义1.2、研究现状二、DBO蜣螂优化算法2.1、蜣螂优化算法的基本原理2.2、DBO算法的优化机制三、LightGBM模型......
  • 使用Python绘制混淆矩阵
    importnumpyasnpimportmatplotlib.pyplotaspltfromsklearn.metricsimportconfusion_matrix#模拟真实标签和预测标签(这里只是示例,实际中替换为真实数据)y_true=[0,1,0,1,1,0,1,0]y_pred=[0,1,1,1,0,0,1,1]#计算混淆矩阵cm=confusion_matr......
  • 人工智能学习(一)之python入门
    一、引言在当今的软件开发领域,面向对象编程(Object-OrientedProgramming,OOP)已经成为一种主流的编程范式。Python作为一门功能强大且简洁易读的编程语言,对面向对象编程提供了非常完善的支持。无论是开发大型项目、构建数据科学应用,还是进行自动化脚本编写,理解和掌握Python......
  • python生成随机字符串
    在Python中,可以使用random、secrets或uuid模块来生成随机字符串。以下是几种常见的方法:1.使用random生成随机字符串importrandomimportstringdefgenerate_random_string(length=10):characters=string.ascii_letters+string.digits#包含大小......
  • python中很常用的10个内置函数整理(初学必备)
    对于初学Python的小伙伴们来说,掌握内置常用函数是学好Python的重要一步。这些函数不仅能让你的代码更加简洁,还可以提高编程效率。本笔记将为大家整理62个Python中最常用的内置函数,并且给出了一些简单的示例,帮助大家更好地理解和运用这些函数。这些内置函数是Pyth......
  • python如何检查列表元素是否为零
    python检查列表元素是否为零的方法:1、使用for循环遍历列表中的每一个元素2、用if语句判断该元素是否为零;如果是则输出这个列表元素的下标完整代码如下:执行结果如下:......
  • Python 实现 macOS 系统代理的设置
    Python实现macOS系统代理的设置设置SOCKS代理在macOS系统中,可以通过networksetup工具来设置SOCKS代理。以下是Python实现的方法:使用networksetup设置SOCKS代理importsubprocessdefset_socks_proxy(server,port):"""设置macOS系统的SOCKS......
  • Python基础5-装饰器与推导式
    1.装饰器1.1引入装饰器的代码v=1v=2deffunc():passv=10v=fun#变量v指向了函数funcdefbase():print(1)defbar():print(2)bar=basebar()deffunc():definner():passreturninnerv=func()print(v)#inner函......
  • python操作mysql
    前言在Python3中,我们可以使用mysqlclient或者pymysql三方库来接入MySQL数据库并实现数据持久化操作。二者的用法完全相同,只是导入的模块名不一样。我们推荐大家使用纯Python的三方库pymysql,因为它更容易安装成功。下面我们仍然以之前创建的名为hrs的数据库为例,为大家......
  • 【Python】函数(一)
    函数是什么?编程中的函数和数学中的函数有一定的相似之处.数学上的函数,比如y=sinx,x取不同的值,y就会得到不同的结果.编程中的函数,是一段可以被重复使用的代码片段代码示例:求数列的和,不使用函数#1.求1-100的和sum=0foriinrange(1,101):......