文章目录
前言
用chatgpt"实现基于深度学习的手写文本识别系统 | Python, PyTorch" :
设计并实现了基于卷积神经网络(CNN)的手写文字识别系统,支持数字(0-9)和15个常用汉字数字的识别,通过Tkinter构建交互界面,实现手写输入和实时识别。采用CNN架构进行特征提取和分类:利用两层卷积层提取图像特征,通过最大池化层降维,再经全连接层完成分类,在测试集上识别准确率达到90%
用户交互功能:支持鼠标手写输入,实时显示识别结果及置信度,提供清除和重绘功能
一、准备
chatgpt! 有了它再也不怕写不出代码了,你只需要头里有个大体框架,代码交给它就行,有报错告诉它让他去改。
Python, PyTorch
我是用电脑显卡跑的,电脑没有显卡cpu也可以跑,不过会慢点。相关环境安装教程链接:link
二、(0-9)数字识别模型代码
1.引入库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
from torchvision.datasets import MNIST
import numpy as np
import cv2
import matplotlib.pyplot as plt
2.读入数据
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(), # 转换为 Tensor,并自动归一化到 [0, 1]
transforms.Normalize((0.5,), (0.5,)) # 归一化到 [-1, 1]
])
# 加载 MNIST 数据集
train_data = MNIST(root='./data', train=True, download=True, transform=transform)
test_data = MNIST(root='./data', train=False, download=True, transform=transform)
# 创建 DataLoader
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)
3.模型训练
# 初始化模型、损失函数和优化器
class SimpleCNN(torch.nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = torch.nn.MaxPool2d(2, 2)
self.fc1 = torch.nn.Linear(64 * 7 * 7, 128)
self.fc2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 确定是否有 GPU 可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss() # 多类别分类损失函数
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
def train_model(model, train_loader, criterion, optimizer, epochs=5):
model.train() # 设置模型为训练模式
for epoch in range(epochs):
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device) # 将数据移到GPU(如果可用)
optimizer.zero_grad() # 清除梯度
outputs = model(images) # 模型预测
loss = criterion(outputs, labels) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
running_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")
# 训练模型
train_model(model, train_loader, criterion, optimizer, epochs=5)
4.模型测试
# 测试模型
def test_model(model, test_loader):
model.eval() # 设置模型为评估模式
correct = 0
total = 0
with torch.no_grad(): # 关闭梯度计算
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device) # 将数据移到GPU(如果可用)
outputs = model(images)
_, predicted = torch.max(outputs, 1) # 预测类别
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")
# 测试模型
test_model(model, test_loader)
5.模型权重保存(不用重复训练)
# 只保存模型权重
torch.save(model.state_dict(), 'model_weights.pth')
6.交互式界面
import tkinter as tk
from tkinter import Canvas
import numpy as np
import torch
from PIL import ImageGrab, Image, ImageTk
from torchvision import transforms
# 加载训练好的 CNN 模型
class SimpleCNN(torch.nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = torch.nn.MaxPool2d(2, 2)
self.fc1 = torch.nn.Linear(64 * 7 * 7, 128)
self.fc2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 加载模型
model = SimpleCNN()
model.load_state_dict(torch.load("model_weights.pth", map_location=torch.device("cpu")))
model.eval()
# Tkinter 界面
class DigitRecognizerApp:
def __init__(self, root):
self.root = root
self.root.title("手绘数字识别")
# 每个像素块的大小(加大尺寸)
self.pixel_size = 10 # 每个像素块的大小,增大方便手绘
self.canvas_size = 28 # 28x28 像素网格
self.canvas = Canvas(root, width=self.pixel_size * self.canvas_size, height=self.pixel_size * self.canvas_size, bg="black")
self.canvas.grid(row=0, column=0, padx=10, pady=10)
self.canvas.bind("<B1-Motion>", self.draw)
# 按钮
self.predict_button = tk.Button(root, text="识别", command=self.predict)
self.predict_button.grid(row=1, column=0, pady=0)
self.clear_button = tk.Button(root, text="清空", command=self.clear_canvas)
self.clear_button.grid(row=2, column=0, pady=0)
# 结果显示
self.result_label = tk.Label(root, text="预测结果: ", font=("Arial", 14))
self.result_label.grid(row=0, column=1, padx=10, pady=10)
self.prob_label = tk.Label(root, text="", font=("Arial", 10), justify="left")
self.prob_label.grid(row=1, column=1, padx=10, pady=10)
# 用于显示 28x28 图像
self.predicted_img_label = tk.Label(root)
self.predicted_img_label.grid(row=0, column=2, padx=10, pady=10)
# 创建空白的28x28网格
self.grid = np.zeros((self.canvas_size, self.canvas_size), dtype=int)
def draw(self, event):
# 获取鼠标的位置并计算在哪个像素块内
x, y = event.x, event.y
grid_x = x // self.pixel_size
grid_y = y // self.pixel_size
# 绘制白色方块
self.canvas.create_rectangle(grid_x * self.pixel_size, grid_y * self.pixel_size,
(grid_x + 1) * self.pixel_size, (grid_y + 1) * self.pixel_size,
fill="white", outline="white")
# 更新网格数据
self.grid[grid_y, grid_x] = 1 # 填充该像素点
def clear_canvas(self):
self.canvas.delete("all")
self.result_label.config(text="预测结果: ")
self.prob_label.config(text="")
self.predicted_img_label.config(image="")
self.grid = np.zeros((self.canvas_size, self.canvas_size), dtype=int) # 清空网格
def predict(self):
# 将绘制的图像转换为28x28图像
img = Image.fromarray(self.grid.astype(np.uint8) * 255) # 将网格转为黑白图像
img_resized = img.resize((28, 28)) # 确保大小为28x28
# 显示 28x28 图像
img_tk = ImageTk.PhotoImage(img_resized)
self.predicted_img_label.config(image=img_tk)
self.predicted_img_label.image = img_tk
# 预处理图像
img_transformed = transform(img_resized).unsqueeze(0)
# 模型预测
with torch.no_grad():
output = model(img_transformed)
probs = torch.nn.functional.softmax(output, dim=1).numpy()[0]
predicted_digit = np.argmax(probs)
# 显示预测结果
self.result_label.config(text=f"预测结果: {predicted_digit}")
# 显示预测概率
prob_text = "\n".join([f"{i}: {prob:.2%}" for i, prob in enumerate(probs)])
self.prob_label.config(text=prob_text)
# 启动应用
root = tk.Tk()
app = DigitRecognizerApp(root)
root.mainloop()
三、结果展示
5.1(0-9)数字识别
5.2汉字识别
四、jupyter代码下载
我用的是jupyter编辑,感觉它的交互式页面挺方便的,能分段运行代码,方便确定每一个步骤的代码对不对。代码基本都是chatgpt写出来的,中间有报错就问ai,然后修修改改最后跑通,只能说ai这个工具对于代码小白来说太友好了,只需要学习ptython基本语法然后就能做出一个模型来!代码下载链接放到文章里面了。
标签:torch,nn,self,识别系统,grid,手写,文本,model,size From: https://blog.csdn.net/woshiaoligei/article/details/145119701