首页 > 其他分享 >10.16

10.16

时间:2025-01-02 15:45:50浏览次数:11  
标签:score 训练 准确度 train 测试 test 10.16

实验五:BP 神经网络算法实现与测试

一、实验目的

深入理解 BP 神经网络的算法原理,能够使用 Python 语言实现 BP 神经网络的训练与测 试,并且使用五折交叉验证算法进行模型训练与评估。

 

二、实验内容

(1)从 scikit-learn 库中加载 iris 数据集,使用留出法留出 1/3 的样本作为测试集(注 意同分布取样); (2)使用训练集训练 BP 神经网络分类算法; (3)使用五折交叉验证对模型性能(准确度、精度、召回率和 F1 值)进行评估和选 择; (4)使用测试集,测试模型的性能,对测试结果进行分析,完成实验报告中实验五的部分。

 

三、算法步骤、代码、及结果

   1. 算法伪代码

算法名称:基于 BP 神经网络的鸢尾花数据集分类及性能评估

输入

 

鸢尾花数据集(包含特征数据以及对应的类别标签)

步骤



  1. 数据准备阶段
  • o 从 sklearn.datasets 库中加载鸢尾花数据集,将特征数据赋值给变量 X,类别标签赋值给变量 y。
  • o 使用留出法,按照测试集占总样本的 1/3 比例,将数据集划分为训练集(X_train、y_train)和测试集(X_test、y_test),设置随机种子为 42 并通过 stratify=y 保证训练集和测试集的类别分布与原始数据集相似。
  1. 模型构建阶段
  • o 创建 BP 神经网络分类器实例 bp_clf,设置隐藏层大小为包含 100 个神经元的一层(即 (100,)),最大迭代次数为 1000 次,随机数生成器的种子设为 42。
  1. 模型训练阶段
  • o 使用训练集数据(X_train、y_train)对 bp_clf 模型进行训练。
  1. 交叉验证阶段
  • o 运用五折交叉验证方法,针对 bp_clf 模型在训练集(X_train、y_train)上进行性能评估,评估指标设定为准确度(accuracy),将每次折叠得到的准确度分数存储在 cv_scores 变量中。
  • o 打印输出五折交叉验证的各次准确度(保留四位小数,即 np.round(cv_scores, 4) 的值)以及平均准确度(保留四位小数,即 np.round(np.mean(cv_scores), 4) 的值)。
  1. 训练集性能评估阶段
  • o 利用训练好的 bp_clf 模型对训练集(X_train、y_train)进行预测,得到预测结果 y_pred_train。
  • o 分别计算训练集的精度(precision_train)、召回率(recall_train)以及 F1 值(f1_train),计算时采用 'macro' 平均方式,对应使用 precision_score、recall_score、f1_score 函数进行计算。
  • o 分别打印输出训练集的精度(保留四位小数)、召回率(保留四位小数)以及 F1 值(保留四位小数)。
  1. 测试集性能评估阶段
  • o 利用训练好的 bp_clf 模型对测试集(X_test、y_test)进行预测,得到预测结果 y_pred_test。
  • o 计算测试集的准确度(accuracy_test),使用 accuracy_score 函数对比 y_test 与 y_pred_test 得到。
  • o 分别计算测试集的精度(precision_test)、召回率(recall_test)以及 F1 值(f1_test),计算时采用 'macro' 平均方式,对应使用 precision_score、recall_score、f1_score 函数进行计算。
  • o 分别打印输出测试集的准确度(保留四位小数)、精度(保留四位小数)、召回率(保留四位小数)以及 F1 值(保留四位小数)。
  1. 模型性能分析阶段
  • o 判断测试集准确度是否大于 0.8,若是,则打印 “模型性能良好,在测试集上有较高的准确度。”;若否,则打印 “模型性能有待提高,可尝试调整模型参数或增加训练数据。”

输出



  1. 五折交叉验证的各次准确度(保留四位小数)以及平均准确度(保留四位小数)。
  2. 训练集的精度(保留四位小数)、召回率(保留四位小数)以及 F1 值(保留四位小数)。
  3. 测试集的准确度(保留四位小数)、精度(保留四位小数)、召回率(保留四位小数)以及 F1 值(保留四位小数)。
  4. 根据测试集准确度给出的模型性能分析结论。

 

   2. 算法主要代码

完整源代码\调用库方法(函数参数说明)

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# 加载 iris 数据集
iris = load_iris()
X = iris.data
y = iris.target

# 留出法划分训练集和测试集,1/3 作为测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1/3, random_state=42, stratify=y)

# 定义 BP 神经网络分类器
bp_clf = MLPClassifier(hidden_layer_sizes=(100,), max_iter=1000, random_state=42)

# 使用训练集训练模型
bp_clf.fit(X_train, y_train)

# 五折交叉验证评估模型性能
cv_scores = cross_val_score(bp_clf, X_train, y_train, cv=5, scoring='accuracy')
print("五折交叉验证准确度:", np.round(cv_scores, 4))
print("平均准确度:", np.round(np.mean(cv_scores), 4))

y_pred_train = bp_clf.predict(X_train)
precision_train = precision_score(y_train, y_pred_train, average='macro')
recall_train = recall_score(y_train, y_pred_train, average='macro')
f1_train = f1_score(y_train, y_pred_train, average='macro')
print("训练集精度:", np.round(precision_train, 4))
print("训练集召回率:", np.round(recall_train, 4))
print("训练集 F1 值:", np.round(f1_train, 4))

# 使用测试集测试模型性能
y_pred_test = bp_clf.predict(X_test)
accuracy_test = accuracy_score(y_test, y_pred_test)
precision_test = precision_score(y_test, y_pred_test, average='macro')
recall_test = recall_score(y_test, y_pred_test, average='macro')
f1_test = f1_score(y_test, y_pred_test, average='macro')
print("测试集准确度:", np.round(accuracy_test, 4))
print("测试集精度:", np.round(precision_test, 4))
print("测试集召回率:", np.round(recall_test, 4))
print("测试集 F1 值:", np.round(f1_test, 4))

# 分析测试结果
if accuracy_test > 0.8:
    print("模型性能良好,在测试集上有较高的准确度。")
else:
    print("模型性能有待提高,可尝试调整模型参数或增加训练数据。")

 

 

   3. 训练结果截图(包括:准确率、精度(查准率)、召回率(查全率)、F1)

 

 

四、实验结果分析

1. 测试结果截图(包括:准确率、精度(查准率)、召回率(查全率)、F1)

 

 

 

 

2. 对比分析

模型在 iris 数据集上的表现非常优秀。五折交叉验证显示大多数折的准确率为1,平均准确度为0.98,说明模型在训练集上学习得很好。在训练集上,精度、召回率和 F1 值均接近0.99,表明模型几乎完美地预测了样本。测试集的准确度同样为0.98,其他评估指标也很高,显示出模型在未见数据上的良好泛化能力。尽管结果令人满意,但仍需注意过拟合的风险。总体而言,模型在这个数据集上的表现非常出色,值得进一步测试和应用。

五、心得体会

在实现与测试 BP 神经网络算法的过程中,深刻体会到其强大的学习能力。从数据预处理到网络架构设计,每一步都充满挑战。训练过程中见证了算法不断优化参数以适应数据的过程。测试结果展示了其在分类任务上的潜力,同时也让我认识到算法的局限性。这个过程提升了我的编程和问题解决能力,激发了对深度学习领域的更大热情。

 

标签:score,训练,准确度,train,测试,test,10.16
From: https://www.cnblogs.com/jais/p/18647892

相关文章

  • 10.16日报
    上午上了软件设计的课,进行了软件设计的实验实验4:抽象工厂模式本次实验属于模仿型实验,通过本次实验学生将掌握以下内容:1、理解抽象工厂模式的动机,掌握该模式的结构;2、能够利用抽象工厂模式解决实际问题。     [实验任务一]:人与肤色使用抽象工厂模式,完成下......
  • 10.16 总结
    T1赛时拿的30分暴力,没想到60分,但是预期:30pts,实际:30pts正解把一个人劈成四瓣,然后用树状数组维护不是\(i\)这个人以外的\(0,a_{(i,0)},a_{(i,1)},a_{(i,1)}+a_{(i,0)}\)以上的所有人的个数,最后除以\(16\),就行了。T2赛时时正解,然后因为没有写check然后就小样例......
  • 10.16学习日志
    一.Python函数1.定义一个函数什么是函数函数是可以重复执行的语句块,可以重复调用作用用于封装语句块,提高代码的重用性。函数是面向过程编程的最小单位1.1def语句作用用来定义(创建)函数语法说明函数代码块以def关键词开头,后接函数标识符名称和圆括......
  • 永久白嫖AWS云服务器,验证、注册指南【2024.10.16亲测可用】
    背景不知道你想不想拥有一台属于自己的云服务器呢,拥有一台自己的云服务器可以建站,可以在上面搭建个人博客,今天我就来教大家如何申请亚马逊AWS免费云服务器,这个云服务器可以长达12个月的免费。而且到期后可以继续换个账号继续白嫖。(不过呢在注册的时候是需要信用卡的,实测国......
  • 10.16
    今天我主要学习了Java中的异常处理知识。通过编写一个简单的程序,我了解了如何使用try-catch语句来处理异常,以及如何使用finally语句来确保资源的正确释放。此外,我还了解到使用二分法查找可以优化多次比较的算法,提高程序的运行效率。在实践中,我遇到了一些困难。例如,在Web界面中实......
  • 10.16
    在MySQL中,可以使用ALTERDATABASE来修改已经被创建或者存在的数据库的相关参数。修改数据库的语法格式为:ALTERDATABASE[数据库名]{[DEFAULT]CHARACTERSET<字符集名>|[DEFAULT]COLLATE<校对规则名>}语法说明如下:ALTERDATABASE用于更改数据库的全局特性。使用AL......
  • 【一周聚焦】联邦学习 10.9-10.16
    近期的联邦学习做了如下内容:大模型目前大模型是绝对的研究风口,而FL中为了降低传输开销的网络压缩技术也是可以服务于LLM的高效传输的。港科大+微众银行,10月16,FATE-LLM:AIndustrialGradeFederatedLearningFrameworkforLargeLanguageModels杨强团队一直在推FATE这个联......
  • 10.16
    编写一个方法,使用以上算法生成指定数目(比如1000个)的随机整数。源代码:importjava.util.Scanner;importjava.util.Random;publicclassMain{publicstaticvoidmain(String[]args){Scannersin=newScanner(System.in);System.out.println("请输入想......
  • 大二快乐日记10.16
    2.配置多个<url-pattern>子元素从Servlet2.5开始,<servlet-mapping>元素可以包含多个<url-pattern>子元素,每个<url-pattern>代表一个虚拟路径的映射规则。因此,通过在一个<servlet-mapping>元素中配置多个<url-pattern>子元素,也可以实现Servlet的多重映射。以ser......
  • 10.16 二分查找(加分项喔)
    上周一成功回答建民老师课上问题:对于不同分数对应的优秀程度,如何减少对比次数:二分查找(也叫折半查找算法):二分查找针对的是一个有序的数据集合时间复杂度:O(logn)但是二分查找的应用场景比较有限:底层必须依赖数组,并且要求数据有......