具体步骤:
1、导入相关扩展包
from sklearn.model_selection import train_test_split # 划分数据集
from sklearn.feature_extraction import DictVectorizer #字典特征值提取
from sklearn.tree import DecisionTreeClassifier # 决策树
from sklearn.tree import export_graphviz # 决策树可视化
import pandas as pd
2、获取数据
titanic=pd.read_csv("./train.csv")
3、筛选特征值和目标值
x=titanic[["Pclass","Age","Sex"]] #特征值
y=titanic["Survived"] #目标值
特征值:
目标值:
4、转化为字典
x=x.to_dict(orient="records")
转化结果:
5、字典特征值抽取
transfer=DictVectorizer()
x_train=transfer.fit_transform(x_train)
x_test=transfer.transform(x_test)
6、决策树预估器(estimator)
estimator = DecisionTreeClassifier(criterion="entropy") # criterion默认为'gini'系数,也可选择信息增益熵'entropy'
estimator.fit(x_train, y_train) # 调用fit()方法进行训练,()内为训练集的特征值与目标值
7、模型评估
方法一:直接对比真实值和预测值
y_predict = estimator.predict(x_test) # 传入测试集特征值,预测所给测试集的目标值
print("y_predict:\n", y_predict)
print("直接对比真实值和预测值:\n", y_test == y_predict)
方法二:计算准确率
score = estimator.score(x_test, y_test) # 传入测试集的特征值和目标值
8、决策树可视化
export_graphviz(estimator, out_file="titanic_tree.dot", feature_names=transfer.get_feature_names())
主要代码:
def titanic_demo():
# 1.获取数据
titanic=pd.read_csv("./train.csv")
# 2.筛选特征值和目标值
x=titanic[["Pclass","Age","Sex"]] #特征值
y=titanic["Survived"] #目标值
# print(x.head())
# print(y.head())
# 3.数据处理(缺失值处理,特征值——>字典类型)
#缺失值处理
x["Age"].fillna(x["Age"].mean(),inplace=True)
# print(x)
#转换为字典
x=x.to_dict(orient="records")
# print(x)
# 4.划分数据集
x_train,x_test,y_train,y_test=train_test_split(x,y,random_state=22)
# 5.字典特征抽取
transfer=DictVectorizer()
x_train=transfer.fit_transform(x_train)
x_test=transfer.transform(x_test)
# 6.决策树预估器(estimator)
estimator = DecisionTreeClassifier(criterion="entropy") # criterion默认为'gini'系数,也可选择信息增益熵'entropy'
estimator.fit(x_train, y_train) # 调用fit()方法进行训练,()内为训练集的特征值与目标值
# 7.模型评估
# 方法一:直接对比真实值和预测值
y_predict = estimator.predict(x_test) # 传入测试集特征值,预测所给测试集的目标值
print("y_predict:\n", y_predict)
print("直接对比真实值和预测值:\n", y_test == y_predict)
# 方法二:计算准确率
score = estimator.score(x_test, y_test) # 传入测试集的特征值和目标值
print("准确率为:\n", score)
# 8.决策树可视化
export_graphviz(estimator, out_file="titanic_tree.dot", feature_names=transfer.get_feature_names())
return None
运行结果:
可视化结果(因图规模过大导致截图展示不完整):