首页 > 其他分享 >【深度学习】KMeans中自动K值的确认方法

【深度学习】KMeans中自动K值的确认方法

时间:2023-12-16 12:01:04浏览次数:32  
标签:得分 系数 silhouette 深度 确认 KMeans 聚类 轮廓


1  前言

聚类常用于数据探索或挖掘前期,在没有做先验经验的背景下做的探索性分析,也适用于样本量较大情况下的数据预处理等方面工作。例如针对企业整体用户特征,在未得到相关知识或经验之前先根据数据本身特点进行用户分群,然后再针对不同群体做进一步分析;例如将连续数据做离散化,便于做后续分类分析应用。

KMeans是聚类方法中非常常用的方法,并且在正确确定K的情况下,KMeans对类别的划分跟分类算法的差异性非常小,这也意味着KMeans是一个准确率非常接近实际分类的算法。本文将讨论如下基于自动化的方法确立K值。

本案例是《Python数据分析与数据化运营》中“7.11案例-基于自动K值的KMeans广告效果聚类分析”的一部分,其ad_performance.txt和源代码chapter7_code2.py位于“附件-chapter7”中,该附件可以在可从http://www.dataivy.cn/book/python_book.zip或https://pan.baidu.com/s/1kUUBWNX下载。

2  实现思路

K值的确定一直是KMeans算法的关键,而由于KMeans是一个非监督式学习,因此没有所谓的“最佳”K值。但是,从数据本身的特征来讲,最佳K值对应的类别下应该是类内距离最小化并且类间距离最大化。有多个指标可以用来评估这种特征,比如平均轮廓系数、类内距离/类间距离等都可以做此类评估。基于这种思路,我们可以通过枚举法计每个K下的平均轮廓系数值,然后选出平均轮廓系数最大下的K值。

3  核心过程

假设我们已经拥有一份预处理过的数据集,其中的异常值、缺失值、数据标准化等前期工作都已经完成。下面是完成自动K值确定的核心流程:


1. 
2.  
3.  
4. score_list = list()  # 用来存储每个K下模型的平局轮廓系数
5. silhouette_int = -1  # 初始化的平均轮廓系数阀值
6. for n_clusters in range(2, 10): # 遍历从2到10几个有限组
7.     model_kmeans =KMeans(n_clusters=n_clusters, random_state=0) # 建立聚类模型对象
8.     cluster_labels_tmp =model_kmeans.fit_predict(X)  # 训练聚类模型
9.     silhouette_tmp =metrics.silhouette_score(X, cluster_labels_tmp) # 得到每个K下的平均轮廓系数
10.     if silhouette_tmp >silhouette_int:  # 如果平均轮廓系数更高
11.         best_k =n_clusters  # 将最好的K存储下来
12.         silhouette_int =silhouette_tmp  # 将最好的平均轮廓得分存储下来
13.         best_kmeans =model_kmeans  # 将最好的模型存储下来
14.         cluster_labels_k =cluster_labels_tmp  # 将最好的聚类标签存储下来
15.    score_list.append([n_clusters, silhouette_tmp])  # 将每次K及其得分追加到列表
16. print ('{:^60}'.format(‘K value and silhouette summary:’))
17. print (np.array(score_list)) # 打印输出所有K下的详细得分
18. print (‘Best K is:{0} with average silhouette of{1}’.format(best_k, silhouette_int.round(4)))


该步骤的主要实现过程如下:

定义初始变量score_list和silhouette_int。score_list用来存储每个K下模型的平局轮廓系数,方便在最终打印输出详细计算结果;silhouette_int的初始值设置为-1,每个K下计算得到的平均轮廓系数如果比该值大,则将其值赋值给silhouette_int。

提示:对于平均轮廓系数而言,其值域分布式[-1,1]。因此silhouette_int的初始值可以设置为-1或比-1更小的值。

使用for循环遍历每个K值,这里的K的范围确定为从2-10.一般而言,用于聚类分析的K值的确定不会太大。如果值太大,那么聚类效果可能不明显,因为大量信息的都会被分散到各个小类之中,会导致数据的碎片化。

通过KMeans(n_clusters=n_clusters, random_state=0)建立KMeans模型对象model_kmeans,设置聚类数为循环中得到的K值,设置固定的初始状态。

对model_kmeans使用fit_predict得到其训练集的聚类标签。该步骤其实无需通过predict获得标签,可以先使用fit方法对模型做训练,然后使用模型对象model_kmeans的label_属性获得其训练集的标签分类。

使用metrics.silhouette_score方法对数据集做平均轮廓系数得分检验,将其得分赋值给silhouette_tmp,输入参数有两个:

  • X:为原始输入的数组或矩阵
  • cluster_labels:训练集对应的聚类标签

接下来做判断,如果计算后的得分大于初始化变量的得分,那么:

  • 将最佳K值存储下来,便于后续输出展示
  • 将最好的平均轮廓得分存储下来,便于跟其他后续得分做比较以及输出展示
  • 将最好的模型存储下来,这样省去了后续再做最优模型下fit(训练)的工作
  • 将最好的聚类标签存储下来,这样方便下面将原始训练集与最终标签合并

每次循环结束后,将当次循环的K值以及对应的评论轮廓得分使用append方法追加到列表。

最后打印输出每个K值下详细信息以及最后K值和最优评论轮廓得分,返回数据如下:




K value and silhouette summary:
[[ 2.          0.46692821]
 [ 3.          0.54904646]
 [ 4.          0.56968547]
 [ 5.          0.48186604]
 [ 6.          0.45477667]
 [ 7.          0.48204261]
 [ 8.          0.50447223]
 [ 9.          0.52697493]]
Best K is:4 with average silhouette of 0.5697


上述结果显示了不同K下的平均轮廓得分。就经验看,如果平均轮廓得分值小于0,意味着聚类效果不佳;如果值大约0且小于0.5,那么说明聚类效果一般;如果值大于0.5,则说明聚类效果比较好。本案例在K=4时,得分为0.5697,说明效果较好。

对于上述得到的结果,将最优K值下得到的KMeans模型的结果,可以通过各类别的类内、外数据的对比以及配合雷达图或极坐标图做分析解释。

【深度学习】KMeans中自动K值的确认方法_人工智能

各聚类类别显著特征对比

4  引申思考

注意,即使在数据上聚类特征最明显,也并不意味着聚类结果就是有效的,因为这里的聚类结果用来分析使用,不同类别间需要具有明显的差异性特征并且类别间的样本量需要大体分布均衡。而确定最佳K值时却没有考虑到这些“业务性”因素。

案例中通过平均轮廓系数的方法得到的最佳K值不一定在业务上具有明显的解读和应用价值。如果最佳K值的解读无效怎么办?有两种思路:

  • 扩大K值范围,例如将K的范围调整为[2,12],然后再次运算看更大范围内得到的K值是否更加有效并且能符合业务解读和应用需求。
  • 得到平均轮廓系数“次要好”(而不是最好)的K值,再对其结果做分析。

对于不同类别的典型特征的对比,除了使用雷达图直观的显示外,还可以使用多个柱形图的形式,将每个类别对应特征的值做柱形图统计,这样也是一个非常直观的对比方法。具体参考下图:

【深度学习】KMeans中自动K值的确认方法_深度学习_02


标签:得分,系数,silhouette,深度,确认,KMeans,聚类,轮廓
From: https://blog.51cto.com/u_14682436/8851329

相关文章

  • 与开发者深度对话、携手共进,鲸鸿动能广告联盟城市系列沙龙完美收官
    12月14日,以“聚势增长,合作共赢”为主题的鲸鸿动能广告联盟城市沙龙,在北京成功举办,作为年度城市系列沙龙的最后一站,北京站汇聚了当地各个行业的领先开发者,近百位参会嘉宾围绕商业变现方案和策略玩法交流分享,积极应对市场和用户变化,期待与鲸鸿动能携手共同探索商业化未来,共创增长新......
  • 【Python】人工智能-机器学习——不调库手撕深度网络分类问题
    1.作业内容描述1.1背景数据集大小150该数据有4个属性,分别如下Sepal.Length:花萼长度(cm)Sepal.Width:花萼宽度单位(cm)Petal.Length:花瓣长度(cm)Petal.Width:花瓣宽度(cm)category:类别(IrisSetosa\IrisVersicolour\IrisVirginica)1.2要求在不调用机器学习库的情......
  • 深度学习3D网络---PointNet++
    PointNet++地址:http://stanford.edu/~rqi/pointnet2/1.两者主要不同点考虑到PointNet特征提取时只考虑单点,不能很好的表示局部结构==>PointNet++引入了sampling&grouping,考虑局部领域特征PointNet中globalfeature直接由maxpool得到,容易造成信息丢失==>PointNet++采......
  • 深度学习笔记4:在卷积基上添加数据增强代码块和分类器
    特征提取的另一种方式是将原有模型与一个新的密集分类器相连接,以构建一个新的模型,然后对整个模型进行端到端的训练。这种方法在输入数据上进行整体训练,使模型能够更好地适应数据特性并提取更有效的特征。通过这种方式,模型的性能可以得到进一步提高,同时也能更好地捕捉到数据中......
  • 基于AlexNet深度学习网络的智能垃圾分类系统matlab仿真
    1.算法运行效果图预览  2.算法运行软件版本matlab2022a 3.算法理论概述       基于AlexNet深度学习网络的智能垃圾分类系统是一种利用深度学习技术,对垃圾图像进行分类的系统。下面将详细介绍这种系统的原理和数学公式。 3.1、基于AlexNet深度学习网络的......
  • 羚通视频智能分析平台烟火检测与烟火识别算法的深度解析
    随着科技的不断发展,人工智能技术在各个领域的应用越来越广泛。其中,视频智能分析技术以其高效、准确的特点,被广泛应用于安全防护、环境监测等多个领域。今天,我们将重点介绍羚通视频智能分析平台中的烟火检测、烟火识别算法。一、烟火检测的重要性烟火检测是公共安全的重要环节,它能......
  • OpenGL的深度缓冲
      如果我们想要在三维空间里画两个正方形:一个红色的,一个绿色的,而且从人眼的观察角度看,绿色正方形在红色正方形的后面。最后看上去应该是这样的:要点在于,从观察者的角度看,绿色正方形在红色正方形的后面,因此绿色正方形的一部分被红色正方形遮挡。  然而,在启用深度测试前,正方形......
  • 拆解全景,解锁未来——深度分析大模型六大领域及五大应用解决方案
    在本篇文章中,我们将带您首先通过解读LLM的全景图,深入探讨了LLM的六个关键领域,随后提出五种主要方案以解决企业在这一技术领域面临的挑战。从商业模型到开源模型、微调、自定义构建,再到与AI提供商的合作,本文将引领您深入了解LLM的技术脉络,为探索和应用这一技术提供一些思考......
  • 12.12邻接表存储实现图的深度优先遍历(c++)
    今天学习了数据结构中的邻接表存储实现图的深度优先遍历,其中让我受益匪浅,以下是我的解题思路。编写程序,实现由邻接表存储实现无向图的深度优先搜索遍历的功能。顶点为字符型。输入格式:第一行输入顶点个数及边的个数,第二行依次输入各顶点,第三行开始依次输入边的两个顶点,用空格......
  • 聊天记录年度报告一览无余:轻松多格式导出永久保存,深度智能分析
    聊天记录年度报告一览无余:轻松多格式导出永久保存,深度智能分析1.功能简介效果展示一个用于提取微信聊天记录的工具,支持将聊天记录导出成HTML、Word、CSV文档,以实现永久保存。此外,该工具还具有对聊天记录进行分析的功能,可以生成年度聊天报告,帮助用户更好地了解和回顾与他人的沟通......