首页 > 其他分享 >取相关系数大于0.3的决策树baseline

取相关系数大于0.3的决策树baseline

时间:2022-12-09 14:11:22浏览次数:45  
标签:index baseline 0.3 feature train np import out 决策树

模型在测试集的准确率为0.74提升了一些说明根据相关系数取模型是不错的选择。 import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns df = pd.read_csv('train.csv') df=df.drop(['ID'],axis=1) df=df.to_numpy() feature=np.abs(np.fft.fft(df[:,:-1])) feature=np.concatenate((feature,np.reshape(df[:,-1],(-1,1))),axis=1) train=pd.DataFrame(feature) heat=train.corr() fe=heat.index[abs(heat[240])>0.3] train=train.to_numpy() train=train[:,fe] from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score from sklearn import tree from sklearn.model_selection import cross_val_score from sklearn.model_selection import KFold kf=KFold(n_splits=5,shuffle=False) for k in range(30):     sum=0     sum1=0     i=0     for train_index,test_index in kf.split(train):         i=i+1         tfeature=train[train_index,:-1]         label=train[train_index,-1]         clf=tree.DecisionTreeClassifier(criterion='entropy',random_state=0,max_depth=k+1)             clf.fit(tfeature,label)         l=clf.predict(tfeature)         ttest=train[test_index,:-1]         testlabel=train[test_index,-1]         l1=clf.predict(ttest)         pr=accuracy_score(label, l)         pr1=accuracy_score(testlabel, l1)         sum=sum+pr         sum1=sum1+pr1     clf1=tree.DecisionTreeClassifier(criterion='entropy',random_state=0,max_depth=k+1)     scores = cross_val_score(clf1, train[:,:-1], train[:,-1], cv=5)     print(k,sum/i,sum1/i,scores.mean())     clf1=tree.DecisionTreeClassifier(criterion='entropy',random_state=0,max_depth=4+1)     clf1.fit(train[:,:-1],train[:,-1]) df1 = pd.read_csv('test.csv') df1=df1.drop(['ID'],axis=1) df1=df1.to_numpy() feature=np.abs(np.fft.fft(df1[:,:])) feature=feature[:,fe[:-1]] out=clf1.predict(feature) out=pd.DataFrame(out) out.columns = ['CLASS'] w=[] for k in range(out.shape[0]):     w.append(k+210) out['ID']=np.reshape(w,(-1,1)) out[['ID','CLASS']].to_csv('out3.csv',index=False)

标签:index,baseline,0.3,feature,train,np,import,out,决策树
From: https://www.cnblogs.com/hahaah/p/16968772.html

相关文章