首页 > 其他分享 >20241120

20241120

时间:2024-11-20 11:09:04浏览次数:1  
标签:剪枝 20241120 tree train test model 节点

一、实验目的

深入理解决策树、预剪枝和后剪枝的算法原理,能够使用 Python 语言实现带有预剪枝 和后剪枝的决策树算法 C4.5 算法的训练与测试,并且使用五折交叉验证算法进行模型训练 与评估。

二、实验内容

(1)从 scikit-learn 库中加载 iris 数据集,使用留出法留出 1/3 的样本作为测试集(注意同分布取样);

(2)使用训练集训练分类带有预剪枝和后剪枝的 C4.5 算法;

(3)使用五折交叉验证对模型性能(准确度、精度、召回率和 F1 值)进行评估和选 择;

(4)使用测试集,测试模型的性能,对测试结果进行分析,完成实验报告中实验三的部分。

 

三、算法步骤、代码、及结果

   1. 算法伪代码

输入: 数据集 D, 属性集 A, 剪枝策略 P(预剪枝或后剪枝)

输出: 决策树 T

 

1. 若 D 中所有样本同属一个类别,返回叶节点;

2. 若 A 为空集,返回叶节点,其类别为 D 中样本最多的类别;

3. 对 A 中每个属性,计算其信息增益率,选择增益率最高的属性作为分裂节点;

4. 根据分裂属性将 D 分成子集,递归调用构建子树;

5. 预剪枝策略:若分裂导致验证集性能下降,则停止分裂;

6. 后剪枝策略:

   a. 对生成的树进行修剪;

   b. 比较剪枝前后的性能,若剪枝后性能无显著下降,则保留剪枝;

7. 返回构建的决策树 T。

 

   2. 算法主要代码

完整源代码\调用库方法(函数参数说明)

import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report, make_scorer, precision_score, recall_score, f1_score

# 数据加载与分割
def load_and_split_data():
    iris = load_iris()
    X, y = iris.data, iris.target
    # 留出法:1/3 的数据作为测试集
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=1/3, random_state=42, stratify=y
    )
    return X_train, X_test, y_train, y_test


# 带预剪枝的决策树
def train_pre_pruned_tree(X_train, y_train, max_depth=3, min_samples_split=5):
    model = DecisionTreeClassifier(
        criterion='entropy',  # 基于信息增益率
        max_depth=max_depth,  # 最大深度限制
        min_samples_split=min_samples_split,  # 节点最小样本数
        random_state=42  # 随机种子
    )
    model.fit(X_train, y_train)
    return model


# 后剪枝实现
def post_prune_tree(model, X_val, y_val):
    tree = model.tree_

    # 递归修剪函数
    def prune(node_id):
        # 如果是叶子节点,直接返回
        if tree.children_left[node_id] == -1 and tree.children_right[node_id] == -1:
            return

        # 递归修剪子节点
        if tree.children_left[node_id] != -1:
            prune(tree.children_left[node_id])
        if tree.children_right[node_id] != -1:
            prune(tree.children_right[node_id])

        # 检查是否可以合并子节点
        if (
            tree.children_left[node_id] != -1
            and tree.children_right[node_id] != -1
        ):
            # 模拟剪枝:移除子节点
            left_child = tree.children_left[node_id]
            right_child = tree.children_right[node_id]

            # 保存当前节点的类别和样本权重
            original_class = tree.value[node_id]
            tree.children_left[node_id] = -1
            tree.children_right[node_id] = -1

            # 重新评估性能
            y_pred = model.predict(X_val)
            acc_after_prune = np.mean(y_pred == y_val)

            # 如果剪枝后性能下降,恢复子节点
            if acc_after_prune < np.mean(model.predict(X_val) == y_val):
                tree.children_left[node_id] = left_child
                tree.children_right[node_id] = right_child

    # 从根节点开始修剪
    prune(0)
    return model


# 五折交叉验证
def cross_validate_model(model, X_train, y_train):
    # 定义评分指标
    scorers = {
        'accuracy': 'accuracy',
        'precision': make_scorer(precision_score, average='macro'),
        'recall': make_scorer(recall_score, average='macro'),
        'f1': make_scorer(f1_score, average='macro')
    }

    # 交叉验证
    results = {}
    for metric, scorer in scorers.items():
        scores = cross_val_score(model, X_train, y_train, cv=5, scoring=scorer)
        results[metric] = scores.mean()
        print(f"{metric.capitalize()} 平均值: {scores.mean():.4f}")
    return results


# 测试模型性能
def evaluate_model(model, X_test, y_test):
    y_pred = model.predict(X_test)
    report = classification_report(y_test, y_pred, output_dict=True, zero_division=0)
    print("测试集性能:")
    print(classification_report(y_test, y_pred, zero_division=0))
    return report


# 主函数
if __name__ == "__main__":
    # 数据加载与分割
    X_train, X_test, y_train, y_test = load_and_split_data()

    # 带预剪枝的决策树训练
    pre_pruned_model = train_pre_pruned_tree(X_train, y_train)

    # 五折交叉验证评估
    print("\n五折交叉验证评估:")
    cross_validate_model(pre_pruned_model, X_train, y_train)

    # 后剪枝
    print("\n开始后剪枝:")
    pruned_model = post_prune_tree(pre_pruned_model, X_test, y_test)

    # 测试集评估
    print("\n测试集评估:")
    evaluate_model(pruned_model, X_test, y_test)

 

 

 

1.train_test_split

来源: sklearn.model_selection.train_test_split

功能: 将数据集划分为训练集和测试集。

参数说明:

X: 特征数据集。

y: 目标标签数据集。

test_size: 测试集所占比例(0-1 之间的浮点数)。

train_size: 训练集所占比例(可选,与 test_size 互斥)。

random_state: 随机种子,用于结果复现。

stratify: 如果为 y,则按目标变量比例分层抽样。

 


 

2. DecisionTreeClassifier

来源: sklearn.tree.DecisionTreeClassifier

功能: 使用决策树对数据进行分类。

参数说明:

criterion: 划分标准,默认值为 "gini";本代码中使用 "entropy"(基于信息增益率)。

max_depth: 决策树的最大深度,用于防止过拟合。

min_samples_split: 每个节点划分所需的最小样本数。

class_weight: 类别权重,可以设置为 "balanced",根据数据集自动调整类别权重。

random_state: 随机种子,用于结果复现。

 


 

3. cross_val_score

来源: sklearn.model_selection.cross_val_score

功能: 使用交叉验证评估模型性能。

参数说明:

estimator: 需要评估的模型。

X: 特征数据集。

y: 目标标签数据集。

cv: 交叉验证的折数,默认为 5。

scoring: 评分方法,可选 "accuracy", "precision", "recall", "f1" 等。

 


 

4. classification_report

来源: sklearn.metrics.classification_report

功能: 生成分类任务的详细评估报告,包括准确率、精度、召回率和 F1 值。

参数说明:

y_true: 真实标签。

y_pred: 模型预测结果。

output_dict: 如果为 True,返回字典格式的报告。

zero_division: 默认值为 "warn"。当分母为零时,控制返回的值(0 或 1)。

 


 

5. make_scorer

来源: sklearn.metrics.make_scorer

功能: 将自定义的评分方法转换为可用于 cross_val_score 的格式。

参数说明:

score_func: 自定义评分函数,如 precision_score, recall_score, f1_score 等。

greater_is_better: 是否更高的分数更好,默认 True。

average: 当用于多分类问题时,控制指标计算方式,例如 "macro" 或 "weighted"。

 


 

6. np.unique

来源: numpy.unique

功能: 查找数组中的唯一值,并返回每个值的计数。

参数说明:

return_counts: 如果为 True,同时返回每个唯一值的计数。

 


 

7. model.tree_ 属性

来源: sklearn.tree.DecisionTreeClassifier.tree_

功能: 获取决策树的底层数据结构。

主要属性:

children_left: 每个节点的左子节点索引,叶节点为 -1。

children_right: 每个节点的右子节点索引,叶节点为 -1。

feature: 每个节点分裂时使用的特征索引。

threshold: 每个节点分裂时使用的阈值。

value: 每个节点的类别分布(数量)。

标签:剪枝,20241120,tree,train,test,model,节点
From: https://www.cnblogs.com/gyg1222/p/18556478

相关文章

  • 20241120 校内模拟赛 T3 题解
    题目描述给定一个数列\(A\),数列的元素取值范围为\([1,m]\)。请计算有多少个非空子区间满足以下条件:该区间内每个元素的出现次数都相同(没有出现的元素视为出现\(0\)次)。例如,当\(m=3\)时,\([1,2,3]\)和\([1,1,3,2,3,2]\)是满足条件的区间,而\([1,2,2,3]\)和\([1,1,3,3]......