首页 > 其他分享 >基于深度学习的手写文本识别系统

基于深度学习的手写文本识别系统

时间:2025-01-14 16:57:40浏览次数:3  
标签:torch nn self 识别系统 grid 手写 文本 model size

文章目录


前言

用chatgpt"实现基于深度学习的手写文本识别系统 | Python, PyTorch" :

设计并实现了基于卷积神经网络(CNN)的手写文字识别系统,支持数字(0-9)和15个常用汉字数字的识别,通过Tkinter构建交互界面,实现手写输入和实时识别。采用CNN架构进行特征提取和分类:利用两层卷积层提取图像特征,通过最大池化层降维,再经全连接层完成分类,在测试集上识别准确率达到90%

用户交互功能:支持鼠标手写输入,实时显示识别结果及置信度,提供清除和重绘功能


一、准备

chatgpt! 有了它再也不怕写不出代码了,你只需要头里有个大体框架,代码交给它就行,有报错告诉它让他去改。
Python, PyTorch

我是用电脑显卡跑的,电脑没有显卡cpu也可以跑,不过会慢点。相关环境安装教程链接:link


二、(0-9)数字识别模型代码

1.引入库

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
from torchvision.datasets import MNIST
import numpy as np
import cv2
import matplotlib.pyplot as plt

2.读入数据

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为 Tensor,并自动归一化到 [0, 1]
    transforms.Normalize((0.5,), (0.5,))  # 归一化到 [-1, 1]
])
# 加载 MNIST 数据集
train_data = MNIST(root='./data', train=True, download=True, transform=transform)
test_data = MNIST(root='./data', train=False, download=True, transform=transform)
# 创建 DataLoader
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

3.模型训练

# 初始化模型、损失函数和优化器
class SimpleCNN(torch.nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.fc1 = torch.nn.Linear(64 * 7 * 7, 128)
        self.fc2 = torch.nn.Linear(128, 10)
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
# 确定是否有 GPU 可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()  # 多类别分类损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
def train_model(model, train_loader, criterion, optimizer, epochs=5):
    model.train()  # 设置模型为训练模式
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)  # 将数据移到GPU(如果可用)
            
            optimizer.zero_grad()  # 清除梯度
            outputs = model(images)  # 模型预测
            loss = criterion(outputs, labels)  # 计算损失
            loss.backward()  # 反向传播
            optimizer.step()  # 更新参数
            
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

# 训练模型
train_model(model, train_loader, criterion, optimizer, epochs=5)

4.模型测试

# 测试模型
def test_model(model, test_loader):
    model.eval()  # 设置模型为评估模式
    correct = 0
    total = 0
    with torch.no_grad():  # 关闭梯度计算
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)  # 将数据移到GPU(如果可用)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)  # 预测类别
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    print(f"Test Accuracy: {accuracy * 100:.2f}%")

# 测试模型
test_model(model, test_loader)

5.模型权重保存(不用重复训练)

# 只保存模型权重
torch.save(model.state_dict(), 'model_weights.pth')

6.交互式界面

import tkinter as tk
from tkinter import Canvas
import numpy as np
import torch
from PIL import ImageGrab, Image, ImageTk
from torchvision import transforms
# 加载训练好的 CNN 模型
class SimpleCNN(torch.nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = torch.nn.MaxPool2d(2, 2)
        self.fc1 = torch.nn.Linear(64 * 7 * 7, 128)
        self.fc2 = torch.nn.Linear(128, 10)
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
# 加载模型
model = SimpleCNN()
model.load_state_dict(torch.load("model_weights.pth", map_location=torch.device("cpu")))
model.eval()

# Tkinter 界面
class DigitRecognizerApp:
    def __init__(self, root):
        self.root = root
        self.root.title("手绘数字识别")
        # 每个像素块的大小(加大尺寸)
        self.pixel_size = 10  # 每个像素块的大小,增大方便手绘
        self.canvas_size = 28  # 28x28 像素网格
        self.canvas = Canvas(root, width=self.pixel_size * self.canvas_size, height=self.pixel_size * self.canvas_size, bg="black")
        self.canvas.grid(row=0, column=0, padx=10, pady=10)
        self.canvas.bind("<B1-Motion>", self.draw)
        # 按钮
        self.predict_button = tk.Button(root, text="识别", command=self.predict)
        self.predict_button.grid(row=1, column=0, pady=0)
        self.clear_button = tk.Button(root, text="清空", command=self.clear_canvas)
        self.clear_button.grid(row=2, column=0, pady=0)
        # 结果显示
        self.result_label = tk.Label(root, text="预测结果: ", font=("Arial", 14))
        self.result_label.grid(row=0, column=1, padx=10, pady=10)
        self.prob_label = tk.Label(root, text="", font=("Arial", 10), justify="left")
        self.prob_label.grid(row=1, column=1, padx=10, pady=10)
        # 用于显示 28x28 图像
        self.predicted_img_label = tk.Label(root)
        self.predicted_img_label.grid(row=0, column=2, padx=10, pady=10)
        # 创建空白的28x28网格
        self.grid = np.zeros((self.canvas_size, self.canvas_size), dtype=int)
    def draw(self, event):
        # 获取鼠标的位置并计算在哪个像素块内
        x, y = event.x, event.y
        grid_x = x // self.pixel_size
        grid_y = y // self.pixel_size
        # 绘制白色方块
        self.canvas.create_rectangle(grid_x * self.pixel_size, grid_y * self.pixel_size,
                                     (grid_x + 1) * self.pixel_size, (grid_y + 1) * self.pixel_size,
                                     fill="white", outline="white")
        # 更新网格数据
        self.grid[grid_y, grid_x] = 1  # 填充该像素点
    def clear_canvas(self):
        self.canvas.delete("all")
        self.result_label.config(text="预测结果: ")
        self.prob_label.config(text="")
        self.predicted_img_label.config(image="")
        self.grid = np.zeros((self.canvas_size, self.canvas_size), dtype=int)  # 清空网格
    def predict(self):
        # 将绘制的图像转换为28x28图像
        img = Image.fromarray(self.grid.astype(np.uint8) * 255)  # 将网格转为黑白图像
        img_resized = img.resize((28, 28))  # 确保大小为28x28
        # 显示 28x28 图像
        img_tk = ImageTk.PhotoImage(img_resized)
        self.predicted_img_label.config(image=img_tk)
        self.predicted_img_label.image = img_tk
        # 预处理图像
        img_transformed = transform(img_resized).unsqueeze(0)
        # 模型预测
        with torch.no_grad():
            output = model(img_transformed)
            probs = torch.nn.functional.softmax(output, dim=1).numpy()[0]
            predicted_digit = np.argmax(probs)
        # 显示预测结果
        self.result_label.config(text=f"预测结果: {predicted_digit}")      
        # 显示预测概率
        prob_text = "\n".join([f"{i}: {prob:.2%}" for i, prob in enumerate(probs)])
        self.prob_label.config(text=prob_text)
# 启动应用
root = tk.Tk()
app = DigitRecognizerApp(root)
root.mainloop()

三、结果展示

5.1(0-9)数字识别
在这里插入图片描述在这里插入图片描述

5.2汉字识别
在这里插入图片描述在这里插入图片描述

四、jupyter代码下载

我用的是jupyter编辑,感觉它的交互式页面挺方便的,能分段运行代码,方便确定每一个步骤的代码对不对。代码基本都是chatgpt写出来的,中间有报错就问ai,然后修修改改最后跑通,只能说ai这个工具对于代码小白来说太友好了,只需要学习ptython基本语法然后就能做出一个模型来!代码下载链接放到文章里面了。

标签:torch,nn,self,识别系统,grid,手写,文本,model,size
From: https://blog.csdn.net/woshiaoligei/article/details/145119701

相关文章

  • NLP 进阶:BERT + CNN 结合打造高效文本分类模型!
    引言:在自然语言处理(NLP)中,文本分类任务是一个核心问题,涵盖了情感分析、新闻分类、垃圾邮件检测等多个领域。传统的深度学习方法虽然取得了一定的成效,但随着BERT(BidirectionalEncoderRepresentationsfromTransformers)和CNN(ConvolutionalNeuralNetworks)技术的出现,文本分......
  • 深度学习入门之手写数字识别
    模型定义我们使用CNN和MLP来定义模型:importtorch.nnasnnclassModel(nn.Module):def__init__(self):"""定义模型结构输入维度为1*28*28(C,H,W)"""super(Model,self).__init__()#卷积......
  • 使用VoyageAI进行高效文本嵌入与重新排序
    在自然语言处理(NLP)任务中,文本嵌入和重新排序是两项关键技术。VoyageAI提供了针对特定领域和公司的定制化嵌入模型,以提高检索质量。本文将详细讲解如何使用VoyageAI进行文本嵌入和重新排序。技术背景介绍文本嵌入是一种将文本转换为数值向量的方法,使其能够在机器学习模型......
  • 文本预处理是指在将文本数据用于模型训练或分析之前,对其进行的一系列清洗、转换和处理
    文本预处理是指在将文本数据用于模型训练或分析之前,对其进行的一系列清洗、转换和处理操作。这些操作旨在消除文本中的噪声和不必要的信息,并将其转化为适合后续处理的格式。以下是文本预处理的一些常见方法:一、文本清洗去除HTML标记和特殊字符:移除文本中的HTML标签(如、等)......
  • 使用OpenAI API进行文本生成的实践指南
    在AI技术日新月异的发展中,文本生成已经成为一项重要应用。通过使用OpenAI的API,开发者可以轻松地实现复杂的文本生成任务。在本文中,我们将深入探讨如何使用OpenAIAPI进行文本生成,从技术背景、核心原理到实际代码实现,并结合应用场景提供实践建议。技术背景介绍文本生成是自......
  • 基于java的停车场车牌识别系统
    一、系统背景与意义随着城市化进程的加速,停车场管理面临着越来越大的挑战。传统的手工记录车牌号方式不仅费时费力,还容易出错。而基于Java的停车场车牌识别系统的出现,则有效地解决了这一问题。该系统能够自动识别进出停车场的车辆车牌号,实现快速、准确的车辆管理,提高了停车......
  • 每天一个优秀提示词学习收藏 - 文本选题篇(三)
    ......
  • 基于YOLOv8与CGNet的鸟类智能识别系统 深度学习图像分类 鸟类目标检测与分类 图像特征
    博主介绍:  ✌我是阿龙,一名专注于Java技术领域的程序员,全网拥有10W+粉丝。作为CSDN特邀作者、博客专家、新星计划导师,我在计算机毕业设计开发方面积累了丰富的经验。同时,我也是掘金、华为云、阿里云、InfoQ等平台的优质作者。通过长期分享和实战指导,我致力于帮助更多学生......
  • 基于YOLOv5的手语识别系统:深度学习应用与实现
    手语是聋人和听力障碍者与他人交流的主要方式之一。随着社会的进步,手语的识别技术逐渐成为研究的热点,尤其在智能助残设备和多模态人机交互中,手语识别的应用越来越广泛。尽管手语是一种自然语言,但其表达方式非常丰富,包括了不同的手势、姿势、动作轨迹和面部表情等。为了能够......
  • 实现单行文本居中和多行文本左对齐并超出显示"..."
    在前端开发中,你可以使用CSS来实现单行文本居中和多行文本左对齐并超出显示"..."的效果。以下是一个示例:<!DOCTYPEhtml><html><head><style>.single-line{text-align:center;white-space:nowrap;overflow:hidden;text-overflow:ellipsis;}.multi-line{......