首页 > 其他分享 >sklearn KMeans包结果无法复现

sklearn KMeans包结果无法复现

时间:2023-10-06 23:23:10浏览次数:47  
标签:random KMeans print seed 复现 np import numpy sklearn

随机问题

numpy等随机数

使用numpy等随机数时,一定要注意随机函数会首到上次随机结果的影响。e.g.

# 如果代码如下:
import numpy as np
np.random.seed(20)  # 为numpy设置随机种子
for i in range(10):
    x = np.random.random()
    print(x)

会发现每次随机的结果不一致,这是由于第二次随机时会收到第一次影响。

如果我们需要for里面的每次随机一致需要将种子设置放在for里面,e.g

import numpy as np
for i in range(10):
    np.random.seed(20)  # 为numpy设置随机种子
    x = np.random.random()
    print(x)

sklearn KMeans包结果无法复现问题

今天写代码有个问题困扰了我很久,今天分享出来,以及解决办法。其中出现问题代码片段简单化如下:

import os
import random

import numpy as np
import torch
# from kmeans_pytorch import kmeans

from sklearn.cluster import KMeans

def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)  # 为numpy设置随机种子
    torch.manual_seed(seed) # 为CPU设置随机种子
    torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子

setup_seed(20)
center = []
pred = []
data = np.random.random([500,512])
for i in range(10):
    print("k-means++: {}".format(i))
    k = KMeans(n_clusters=10,n_init=10,init='k-means++',random_state=0)
    k.fit_predict(data)
    center.append(k.cluster_centers_)
    pred.append(k.labels_)

for i in range(9):
   print("-------------------center------------")
   print(str(center[i] == center[i+1]).count("False"))
   print("-------------------pred------------")
   print(str(pred[i] == pred[i+1]).count("False"))

代码大概意思就是使用kmeans将数据重复聚类10次,并统计查看每次聚类结果中的聚类中心和聚类标签是否一致。这种需求是一般是为了实验的可复现。

当我们的数据量较小时,如上500个样本,每个512维度不会出现不一致情况。

Untitled

但是一旦我们将数据量提升至5000或者50000(对于一般数据来说不算多),就会发现每次聚类结果并不会一致。下面是5000

Untitled

对于仅仅需要label的算法而言,影响不大。但是像DEC(深度嵌入聚类)需要使用kmeans算法初始化质心而言就会使得每次算法结果不一致,使得参数无从下手调整。

问题:经过多次实验,初步认为是由于kmeans精度所导致的,每次Debug时,前几位数每次都是一致的,但是随着epoch的增加后面8,9位小数点的累加就会导致结果偏差。

解决:1. 可以使用np.round()限制小数点位数。2. 使用pytorch-kmeans(推荐)

不过其中需要注意:他聚类的是tensor,其次正如开始所述的随机问题,为了保证每次结果的一致我们需要每次for设定一次种子让pytorch-kmeans的质心每次随机初始化都一致。

import os
import random

import numpy as np
import torch
from kmeans_pytorch import kmeans

# from sklearn.cluster import KMeans

def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)  # 为numpy设置随机种子
    torch.manual_seed(seed) # 为CPU设置随机种子
    torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子

setup_seed(20)
center = []
pred = []
data = torch.from_numpy(np.random.random([5000,512]))
for i in range(10):
    print("k-means++: {}".format(i))
    np.random.seed(20)  # 为numpy设置随机种子
    cluster_ids_x, cluster_centers = kmeans(X=data, num_clusters=10, distance='euclidean', device=torch.device("cuda:0"))
    center.append(cluster_centers)
    pred.append(cluster_ids_x)

for i in range(9):
   print("-------------------center------------")
   print(str(center[i] == center[i+1]).count("False"))
   print("-------------------pred------------")
   print(str(pred[i] == pred[i+1]).count("False"))

结果:

Untitled

标签:random,KMeans,print,seed,复现,np,import,numpy,sklearn
From: https://www.cnblogs.com/12tree/p/17745262.html

相关文章

  • vulnhub_phpmyadmin_CVE-2016-5734漏洞复现
    漏洞复现练习漏洞简介phpMyAdmin是一套开源的、基于Web的MySQL数据库管理工具。在其查找并替换字符串功能中,将用户输入的信息拼接进preg_replace函数第一个参数中。在PHP5.4.7以前,preg_replace的第一个参数可以利用\0进行截断,并将正则模式修改为e。众所周知,e模式的正则支持......
  • sklearn.pipeline的作用 函数参数解释及代码举例
    sklearn.pipeline是一个实用的工具,可以将多个数据预处理步骤和机器学习模型组合成一个整体,从而简化了机器学习的流程。Pipeline类可以将多个算法模型串联起来,比如将特征提取、归一化、分类组织在一起形成一个典型的机器学习问题工作流。主要带来两点好处:1.直接调用fit和predict方......
  • pandas.get_dummies与sklearn.preprocessing.OneHotEncoder
    pandas.get_dummies、sklearn.preprocessing.OneHotEncoder.fit_transform和sklearn.preprocessing.OneHotEncoder都用于对分类变量进行独热编码,但它们在实现和使用上有一些区别。pandas.get_dummies:解释:pandas.get_dummies是pandas库中的一个函数,用于将分类变量转化为虚拟变......
  • vulhub thinphp 5.0.23 学习复现
     靶场是基于vulhub的靶场docker-composeup-d 启动靶场并且漏洞环境是thinkphp5.0.23 现在验证漏洞列出了当前目录文件 可以,验证完了,看到漏洞存在 现在想尝试弹一个shell,开启监听  ok!bp可以,但是,查看服务器并未得到shell不清楚是为什么,等下再看,先拿web......
  • 【CVE-2019-0708】远程桌面服务远程执行代码漏洞复现
    【CVE-2019-0708】远程桌面服务远程执行代码漏洞当未经身份验证的攻击者使用RDP连接到目标系统并发送经特殊设计的请求时,远程桌面服务中存在远程执行代码漏洞。此漏洞是预身份验证,无需用户交互。成功利用此漏洞的攻击者可以在目标系统上执行任意代码。靶机:windows7IP:192......
  • 从零开始复现CVE-2023-34644
    从零开始复现CVE-2023-34644说实话复现这个漏洞光调试我就调了一个星期,主要是逆向很难仿真启动脚本tarczfrootfs.tar.gz./rootfsscprootfs.tar.gzroot@192.168.192.135:/root/rootfscdrootfschmod-R777./mount-bind/procprocmount-bind/devdevchroot.......
  • R语言文本挖掘:kmeans聚类分析上海玛雅水公园景区五一假期评论词云可视化|附代码数据
    全文链接:http://tecdat.cn/?p=32307原文出处:拓端数据部落公众号互联网时代,大量的新闻信息、网络交互、舆情信息以文本形式存储在数据库中,如何利用数据分析和文本挖掘的算法,将海量文本的价值挖掘出来,成为我们团队近期的一个研究方向,本案例就是我们的一个尝试。文本聚类其实也就......
  • from sklearn.datasets.samples_generator import make_blobs
     fromsklearn.datasets.samples_generatorimportmake_blobsmake_blobs方法:sklearn.datasets.make_blobs(n_samples=100,n_features=2,centers=3,cluster_std=1.0,center_box=(-10.0,10.0),shuffle=True,random_state=None)make_blobs函数是为聚类或分类产生数据集,产生一......
  • 【漏洞复现】深信服 SG上网优化管理系统 catjs.php 任意文件读取漏洞
    1、简介2、漏洞描述深信服SG上网优化管理系统catjs.php存在任意文件读取漏洞,攻击者通过漏洞可以获取服务器上的敏感文件3、受影响版本深信服SG上网优化管理系统4、FOFA语句title==“SANGFOR上网优化管理”5、漏洞复现POCPOST/php/catjs.phpHTTP/1.1Host:User-A......
  • 【漏洞复现】亿赛通 电子文档安全管理系统 DecryptApplicationService2 任意文件上传
    1、简介亿赛通电子文档安全管理系统(简称:CDG)是一款电子文档安全加密软件,该系统利用驱动层透明加密技术,通过对电子文档的加密保护,防止内部员工泄密和外部人员非法窃取企业校心重要数据资产,对电子文档进行全生命周期防护,系统具有透明加密、主动加密、智能加密等多种加密方式,用户可......