实验五:BP 神经网络算法实现与测试
一、实验目的
深入理解 BP 神经网络的算法原理,能够使用 Python 语言实现 BP 神经网络的训练与测 试,并且使用五折交叉验证算法进行模型训练与评估。
二、实验内容
(1)从 scikit-learn 库中加载 iris 数据集,使用留出法留出 1/3 的样本作为测试集(注 意同分布取样); (2)使用训练集训练 BP 神经网络分类算法; (3)使用五折交叉验证对模型性能(准确度、精度、召回率和 F1 值)进行评估和选 择; (4)使用测试集,测试模型的性能,对测试结果进行分析,完成实验报告中实验五的部分。
三、算法步骤、代码、及结果
1. 算法伪代码
算法名称:基于 BP 神经网络的鸢尾花数据集分类及性能评估
输入
鸢尾花数据集(包含特征数据以及对应的类别标签)
步骤
- 数据准备阶段
- o 从 sklearn.datasets 库中加载鸢尾花数据集,将特征数据赋值给变量 X,类别标签赋值给变量 y。
- o 使用留出法,按照测试集占总样本的 1/3 比例,将数据集划分为训练集(X_train、y_train)和测试集(X_test、y_test),设置随机种子为 42 并通过 stratify=y 保证训练集和测试集的类别分布与原始数据集相似。
- 模型构建阶段
- o 创建 BP 神经网络分类器实例 bp_clf,设置隐藏层大小为包含 100 个神经元的一层(即 (100,)),最大迭代次数为 1000 次,随机数生成器的种子设为 42。
- 模型训练阶段
- o 使用训练集数据(X_train、y_train)对 bp_clf 模型进行训练。
- 交叉验证阶段
- o 运用五折交叉验证方法,针对 bp_clf 模型在训练集(X_train、y_train)上进行性能评估,评估指标设定为准确度(accuracy),将每次折叠得到的准确度分数存储在 cv_scores 变量中。
- o 打印输出五折交叉验证的各次准确度(保留四位小数,即 np.round(cv_scores, 4) 的值)以及平均准确度(保留四位小数,即 np.round(np.mean(cv_scores), 4) 的值)。
- 训练集性能评估阶段
- o 利用训练好的 bp_clf 模型对训练集(X_train、y_train)进行预测,得到预测结果 y_pred_train。
- o 分别计算训练集的精度(precision_train)、召回率(recall_train)以及 F1 值(f1_train),计算时采用 'macro' 平均方式,对应使用 precision_score、recall_score、f1_score 函数进行计算。
- o 分别打印输出训练集的精度(保留四位小数)、召回率(保留四位小数)以及 F1 值(保留四位小数)。
- 测试集性能评估阶段
- o 利用训练好的 bp_clf 模型对测试集(X_test、y_test)进行预测,得到预测结果 y_pred_test。
- o 计算测试集的准确度(accuracy_test),使用 accuracy_score 函数对比 y_test 与 y_pred_test 得到。
- o 分别计算测试集的精度(precision_test)、召回率(recall_test)以及 F1 值(f1_test),计算时采用 'macro' 平均方式,对应使用 precision_score、recall_score、f1_score 函数进行计算。
- o 分别打印输出测试集的准确度(保留四位小数)、精度(保留四位小数)、召回率(保留四位小数)以及 F1 值(保留四位小数)。
- 模型性能分析阶段
- o 判断测试集准确度是否大于 0.8,若是,则打印 “模型性能良好,在测试集上有较高的准确度。”;若否,则打印 “模型性能有待提高,可尝试调整模型参数或增加训练数据。”
输出
- 五折交叉验证的各次准确度(保留四位小数)以及平均准确度(保留四位小数)。
- 训练集的精度(保留四位小数)、召回率(保留四位小数)以及 F1 值(保留四位小数)。
- 测试集的准确度(保留四位小数)、精度(保留四位小数)、召回率(保留四位小数)以及 F1 值(保留四位小数)。
- 根据测试集准确度给出的模型性能分析结论。
2. 算法主要代码
完整源代码\调用库方法(函数参数说明)
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score,
f1_score
# 加载 iris 数据集
iris = load_iris()
X = iris.data
y = iris.target
# 留出法划分训练集和测试集,1/3 作为测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1/3,
random_state=42, stratify=y)
# 定义 BP 神经网络分类器
bp_clf = MLPClassifier(hidden_layer_sizes=(100,), max_iter=1000,
random_state=42)
# 使用训练集训练模型
bp_clf.fit(X_train, y_train)
# 五折交叉验证评估模型性能
cv_scores = cross_val_score(bp_clf, X_train, y_train, cv=5, scoring='accuracy')
print("五折交叉验证准确度:", np.round(cv_scores, 4))
print("平均准确度:", np.round(np.mean(cv_scores), 4))
y_pred_train = bp_clf.predict(X_train)
precision_train = precision_score(y_train, y_pred_train, average='macro')
recall_train = recall_score(y_train, y_pred_train, average='macro')
f1_train = f1_score(y_train, y_pred_train, average='macro')
print("训练集精度:", np.round(precision_train, 4))
print("训练集召回率:", np.round(recall_train, 4))
print("训练集 F1 值:", np.round(f1_train, 4))
# 使用测试集测试模型性能
y_pred_test = bp_clf.predict(X_test)
accuracy_test = accuracy_score(y_test, y_pred_test)
precision_test = precision_score(y_test, y_pred_test, average='macro')
recall_test = recall_score(y_test, y_pred_test, average='macro')
f1_test = f1_score(y_test, y_pred_test, average='macro')
print("测试集准确度:", np.round(accuracy_test, 4))
print("测试集精度:", np.round(precision_test, 4))
print("测试集召回率:", np.round(recall_test, 4))
print("测试集 F1 值:", np.round(f1_test, 4))
# 分析测试结果
if accuracy_test > 0.8:
print("模型性能良好,在测试集上有较高的准确度。")
else:
print("模型性能有待提高,可尝试调整模型参数或增加训练数据。")
3. 训练结果截图(包括:准确率、精度(查准率)、召回率(查全率)、F1)
四、实验结果分析
1. 测试结果截图(包括:准确率、精度(查准率)、召回率(查全率)、F1)
2. 对比分析
模型在 iris 数据集上的表现非常优秀。五折交叉验证显示大多数折的准确率为1,平均准确度为0.98,说明模型在训练集上学习得很好。在训练集上,精度、召回率和 F1 值均接近0.99,表明模型几乎完美地预测了样本。测试集的准确度同样为0.98,其他评估指标也很高,显示出模型在未见数据上的良好泛化能力。尽管结果令人满意,但仍需注意过拟合的风险。总体而言,模型在这个数据集上的表现非常出色,值得进一步测试和应用。
五、心得体会
在实现与测试 BP 神经网络算法的过程中,深刻体会到其强大的学习能力。从数据预处理到网络架构设计,每一步都充满挑战。训练过程中见证了算法不断优化参数以适应数据的过程。测试结果展示了其在分类任务上的潜力,同时也让我认识到算法的局限性。这个过程提升了我的编程和问题解决能力,激发了对深度学习领域的更大热情。
标签:score,训练,准确度,train,测试,test,10.16 From: https://www.cnblogs.com/jais/p/18647892