markdown
代码解读
复制代码
# 手写阿拉伯数字识别 本项目将使用 TensorFlow 和 Keras 构建一个卷积神经网络(CNN)模型来识别手写阿拉伯数字。以下是各个步骤的详细说明。
python
代码解读
复制代码
# 导入必要的库 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from tensorflow.keras.preprocessing.image import ImageDataGenerator from skimage import io from skimage.transform import resize # 设置中文字体 plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False
markdown
代码解读
复制代码
## 步骤1:加载 MNIST 手写阿拉伯数字数据 我们将使用 Keras 提供的 MNIST 数据集,它包含了 60000 个训练样本和 10000 个测试样本,每个样本是一个 28x28 像素的灰度图像。
python
代码解读
复制代码
mnist = tf.keras.datasets.mnist (x_train, y_train), (x_test, y_test) = mnist.load_data()
markdown
代码解读
复制代码
## 步骤2:数据清理 此步骤无需进行,因为 MNIST 数据集已经过清理和处理。
markdown
代码解读
复制代码
## 步骤3:特征工程 将图像数据的像素值缩放到 (0, 1) 之间,以提高模型的训练效果。
python
代码解读
复制代码
x_train_norm, x_test_norm = x_train / 255.0, x_test / 255.0
markdown
代码解读
复制代码
## 步骤4:数据分割 加载 MNIST 数据时已经完成了数据分割。
markdown
代码解读
复制代码
## 步骤5:建立改进的模型结构 我们将构建一个卷积神经网络(CNN),包含多个卷积层、池化层和全连接层。
python
代码解读
复制代码
model = tf.keras.models.Sequential([ tf.keras.layers.Reshape((28, 28, 1), input_shape=(28, 28)), tf.keras.layers.Conv2D(32, (3, 3), activation='relu'), tf.keras.layers.MaxPooling2D((2, 2)), tf.keras.layers.Conv2D(64, (3, 3), activation='relu'), tf.keras.layers.MaxPooling2D((2, 2)), tf.keras.layers.Conv2D(64, (3, 3), activation='relu'), tf.keras.layers.Flatten(), tf.keras.layers.Dense(256, activation='relu'), tf.keras.layers.BatchNormalization(), tf.keras.layers.Dropout(0.3), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.BatchNormalization(), tf.keras.layers.Dropout(0.3), tf.keras.layers.Dense(10, activation='softmax') ])
markdown
代码解读
复制代码
## 编译模型 我们使用 Adam 优化器和稀疏分类交叉熵损失函数来编译模型,并选择准确率作为评估指标。
python
代码解读
复制代码
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
markdown
代码解读
复制代码
## 步骤6:使用数据增强进行模型训练 我们将使用图像数据生成器对训练数据进行数据增强,以提高模型的泛化能力。
python
代码解读
复制代码
datagen = ImageDataGenerator( rotation_range=10, width_shift_range=0.1, height_shift_range=0.1, zoom_range=0.1, shear_range=0.1 ) # 训练模型 history = model.fit(datagen.flow(x_train_norm.reshape(-1, 28, 28, 1), y_train, batch_size=128), steps_per_epoch=len(x_train_norm) // 128, epochs=20, validation_data=(x_test_norm.reshape(-1, 28, 28, 1), y_test))
markdown
代码解读
复制代码
## 绘制训练过程的准确率和损失 我们将绘制模型在训练过程中的准确率和损失,以便于观察模型的训练情况。
python
代码解读
复制代码
plt.figure(figsize=(12, 4)) plt.subplot(1, 2, 1) plt.plot(history.history['accuracy'], 'r', label='训练准确率') plt.plot(history.history['val_accuracy'], 'g', label='验证准确率') plt.title('模型准确率') plt.legend() plt.subplot(1, 2, 2) plt.plot(history.history['loss'], 'r', label='训练损失') plt.plot(history.history['val_loss'], 'g', label='验证损失') plt.title('模型损失') plt.legend() plt.show()
markdown
代码解读
复制代码
## 步骤7:评分 我们将使用测试数据集对模型进行评分,以评估其性能。
python
代码解读
复制代码
score = model.evaluate(x_test_norm.reshape(-1, 28, 28, 1), y_test, verbose=0) print(f'测试准确率: {score[1]:.4f}')
markdown
代码解读
复制代码
## 步骤8:评估 我们将使用测试数据集对模型进行预测,并显示前 20 个测试样本的实际值和预测值。
python
代码解读
复制代码
predictions = np.argmax(model.predict(x_test_norm.reshape(-1, 28, 28, 1)), axis=-1) print('实际值 :', y_test[:20]) print('预测值 :', predictions[:20])
markdown
代码解读
复制代码
## 步骤9:模型部署 将训练好的模型保存到文件,以便在后续的预测任务中使用。
python
代码解读
复制代码
model.save('data/DigitSense_model_improved.keras')
markdown
代码解读
复制代码
## 步骤10:新数据预测 定义一个函数,用于对新图像进行预测。
python
代码解读
复制代码
def predict_digit(file_path): # 读取图像并转为单色 image = io.imread(file_path, as_gray=True) # 缩放为 (28, 28) 大小的图像 image_resized = resize(image, (28, 28), anti_aliasing=True) # 反转颜色并reshape X = np.abs(1 - image_resized).reshape(1, 28, 28, 1) # 预测 prediction = np.argmax(model.predict(X), axis=-1) # 显示图像和预测结果 plt.imshow(image, cmap='gray') plt.title(f'预测结果: {prediction[0]}') plt.axis('off') plt.show() return prediction[0]
markdown
代码解读
复制代码
## 测试自定义图像 使用自定义图像进行测试,并显示预测结果。
python
代码解读
复制代码
for i in range(10): file_path = f'./data/images/{i}.png' predicted_digit = predict_digit(file_path) print(f'图像 {i}.png 的预测结果: {predicted_digit}')
markdown
代码解读
复制代码
## 显示模型汇总信息 显示模型的结构和参数数量。
python
代码解读
复制代码
model.summary()
标签:plt,keras,阿拉伯数字,代码,28,解读,复制,手写,识别 From: https://blog.csdn.net/weixin_47588164/article/details/142134629原文链接:https://juejin.cn/post/7412899060827111475