首页 > 其他分享 >Kmeans2D数据类别划分

Kmeans2D数据类别划分

时间:2024-10-12 19:46:54浏览次数:3  
标签:loc plt V2 V1 划分 Kmeans2D 类别 data scatter

读取数据:

import pandas as pd
import numpy as np
data = pd.read_csv('data_2D.csv')
data.head()

读取输入及标签(标签用于后期对比观察模型结果):

X = data.drop(['labels'],axis=1)
y = data.loc[:,'labels']
y.head()

 观察标签数量打印图像

#观察类数
pd.value_counts(y)
#原始数据
from matplotlib import pyplot as plt
fig1=plt.figure()
plt.scatter(data.loc[:,'V1'],data.loc[:,'V2'])
plt.title('un-label data')
plt.xlabel('V1')
plt.ylabel('V2')
plt.show()

原始数据分类显示:

#根据标签显示原始数据
fig2=plt.figure()
label0 = plt.scatter(data.loc[:,'V1'][y==0],data.loc[:,'V2'][y==0])
label1 = plt.scatter(data.loc[:,'V1'][y==1],data.loc[:,'V2'][y==1])
label2 = plt.scatter(data.loc[:,'V1'][y==2],data.loc[:,'V2'][y==2])
plt.title('un-label data')
plt.xlabel('V1')
plt.ylabel('V2')
plt.legend((label0,label1,label2),('label0','label1','label2'))
plt.show()

打印维度建立模型并绘图 :

print(X.shape,y.shape)
#建立kmeans模型
from sklearn.cluster import KMeans
KM = KMeans(n_clusters=3,random_state=0)#3类,random_state=0每次训练结果一致
KM.fit(X)
centers = KM.cluster_centers_#中心点
print(type(centers),len(centers),centers.shape)
#根据标签显示原始数据
fig3=plt.figure()
label0 = plt.scatter(data.loc[:,'V1'][y==0],data.loc[:,'V2'][y==0])
label1 = plt.scatter(data.loc[:,'V1'][y==1],data.loc[:,'V2'][y==1])
label2 = plt.scatter(data.loc[:,'V1'][y==2],data.loc[:,'V2'][y==2])
plt.title('un-label data')
plt.xlabel('V1')
plt.ylabel('V2')
plt.legend((label0,label1,label2),('label0','label1','label2'))
#显示中心点
plt.scatter(centers[:,0],centers[:,1])
plt.show()
#预测
y_predict_test = KM.predict([[40,60]])
print(y_predict_test)
(3000, 2) (3000,)

预测:

#预测
y_predict_test = KM.predict([[40,60]])
print(y_predict_test)
y_predict = KM.predict(X)
print(pd.value_counts(y_predict),pd.value_counts(y))
[2]
0    1149
1     952
2     899
Name: count, dtype: int64 labels
2    1156
1     954
0     890
Name: count, dtype: int64

模型评估原图对比:

from sklearn.metrics import accuracy_score
accuary = accuracy_score(y,y_predict)
print(accuary)
fig4=plt.subplot(121)
label0 = plt.scatter(data.loc[:,'V1'][y_predict==0],data.loc[:,'V2'][y_predict==0])
label1 = plt.scatter(data.loc[:,'V1'][y_predict==1],data.loc[:,'V2'][y_predict==1])
label2 = plt.scatter(data.loc[:,'V1'][y_predict==2],data.loc[:,'V2'][y_predict==2])
plt.title('y_predict data')
plt.xlabel('V1')
plt.ylabel('V2')
plt.legend((label0,label1,label2),('label0','label1','label2'))
#显示中心点
plt.scatter(centers[:,0],centers[:,1])


fig5=plt.subplot(122)
label0 = plt.scatter(data.loc[:,'V1'][y==0],data.loc[:,'V2'][y==0])
label1 = plt.scatter(data.loc[:,'V1'][y==1],data.loc[:,'V2'][y==1])
label2 = plt.scatter(data.loc[:,'V1'][y==2],data.loc[:,'V2'][y==2])
plt.title('y data')
plt.xlabel('V1')
plt.ylabel('V2')
plt.legend((label0,label1,label2),('label0','label1','label2'))
#显示中心点
plt.scatter(centers[:,0],centers[:,1])

plt.show()
0.31966666666666665

分成三类但未指定标签故而对应关系出现错误,矫正:

y_corrected = []
for i in y_predict:
    if i==0:
        y_corrected.append(2)
    elif i==2:
        y_corrected.append(0)
    else:
        y_corrected.append(1)
print(pd.value_counts(y_corrected),pd.value_counts(y))
2    1149
1     952
0     899
Name: count, dtype: int64 labels
2    1156
1     954
0     890
Name: count, dtype: int64

矫正后系数评估:

print(accuracy_score(y,y_corrected))
y_corrected = np.array(y_corrected)
0.997

矫正后数据对比(矫正与原始数据):

fig6=plt.subplot(121)
label0 = plt.scatter(data.loc[:,'V1'][y_corrected==0],data.loc[:,'V2'][y_corrected==0])
label1 = plt.scatter(data.loc[:,'V1'][y_corrected==1],data.loc[:,'V2'][y_corrected==1])
label2 = plt.scatter(data.loc[:,'V1'][y_corrected==2],data.loc[:,'V2'][y_corrected==2])
plt.title('y_corrected data')
plt.xlabel('V1')
plt.ylabel('V2')
plt.legend((label0,label1,label2),('label0','label1','label2'))
#显示中心点
plt.scatter(centers[:,0],centers[:,1])


fig7=plt.subplot(122)
label0 = plt.scatter(data.loc[:,'V1'][y==0],data.loc[:,'V2'][y==0])
label1 = plt.scatter(data.loc[:,'V1'][y==1],data.loc[:,'V2'][y==1])
label2 = plt.scatter(data.loc[:,'V1'][y==2],data.loc[:,'V2'][y==2])
plt.title('y data')
plt.xlabel('V1')
plt.ylabel('V2')
plt.legend((label0,label1,label2),('label0','label1','label2'))
#显示中心点
plt.scatter(centers[:,0],centers[:,1])

plt.show()

标签:loc,plt,V2,V1,划分,Kmeans2D,类别,data,scatter
From: https://blog.csdn.net/chu_kuang_/article/details/142859182

相关文章