首页 > 编程语言 >xgboost源码阅读

xgboost源码阅读

时间:2024-01-22 16:59:27浏览次数:39  
标签:std const -- xgboost tree 源码 阅读 节点 gpair

xgboost分享

介绍

xgboost全称是extreme Gradient Boosting,可以看做是GBDT算法(Gradient Boosting Decision Tree)的高效实现。

1.1 Boosting思想

Boosting方法训练基分类器时采用串行的方式,各个基分类器之间有依赖。它的基本思路是将基分类器层层叠加,每一层在训练的时候,对前一层基分类器分错的样本,给予更高的权重。测试时,根据各层分类器的结果的加权得到最终结果。

Bagging与Boosting的串行训练方式不同,Bagging方法在训练过程中,各基分类器之间无强依赖,可以进行并行训练。

1.2 模型定义

1.3 目标函数

:第s轮的模型等于s-1轮的加上第s轮生成的,因此目标为找到每一轮的使得当前轮的Obj最小。

:损失函数,后面需要通过二阶泰勒逼近,所以必须可微且是凸函数(向下凸)

:正则项

:控制树复杂度,叶子节点数及其系数

:控制叶子结点的权重分数

1.4 目标函数近似

1.4.1 改写

1.4.2 Taylor展开,合并常数项,找最优权重

通过以上近似目标函数,每一轮的目标就是找到最小Obj。

由上可知,Obj是每个叶子节点的Obj的和,因此在分裂时可以计算单个叶子节点分裂后对总体Obj的影响,来确定分裂点。

xgboost在每轮分裂时使用贪心算法,计算分裂后的值是否比分裂前少

分裂前:

j 为叶子节点

1.5 总结

目标:树的目标函数(Obj)最小,Obj与叶子节点有关。

策略:贪心,通过找到每一轮使得Obj减少最多的特征与切分点,也就是最大的点。

1.6 实际计算过程

开始新一轮树的update

计算各样本的梯度

选择叶子结点

计算不同特征,不同切分点的

选择最大的特征与切分点

剪枝,如果算得的小于设置的阈值,则不进行分裂。

将最终的树加入模型

  1. 源码

2.1 代码结构

|--xgboost
|--include
|--xgboost //定义了 xgboost 相关的头文件
|--src
|--c_api
|--common //一些通用文件,如对配置文件的处理
|--data //使用的数据结构,如 DMatrix
|--gbm //定义分类器,如 gbtree 和 gblinear
|--metric //定义评价函数
|--objective //定义目标函数
|--tree //对树的一些列操作
|--collective //集合通讯,AllReduce实现

程序主入口为src中的cli_main.cc,其中调用了cli.Run(),run中主要内容如下


//cli_main.cc

|--main()

    |--Run()

        |--switch(param.task)

            {  
                Case kTrain: CLITrain();break;

                Case KDumpModel: CLIDumpModel();break;

                Case KPredict: CLIPredict();break;

            }

2.2 单机训练

单机训练以回归任务为例,分类与此相似。

2.2.1 时序图

2.2.2 类图

2.2.3 CLITrain的主要逻辑

Learner:训练的接口,成员变量有目标函数obj_,梯度提升模型gbm_,评估器metrics_,训练上下文ctx_

ObjFunction:目标函数接口,GetGradient

GradientBooster:梯度提示模型接口

TreeUpdater:树update接口,Update(),有多个子类。例如:适用于单机的精确贪心ColMaker,适用于分布式的全局近似逼近GlobalApproxUpdater

Builder:TreeUpdater中的成员变量,实际执行Update()

Metric:评估器接口

Context:训练上下文

Model:模型接口,定义了存储模型和读取模型的方法

Configurable:配置接口,定义了存储和读取配置的方法

dmlc::Serializable:确保可序列化

CLITrain的主要逻辑如下


|--CLITrain()

    |--DMatrix::Load() // 训练数据加载入内存

  	|--ResetLearner() // Learner初始化
  
    |--for (int iter = 0; iter < max_iter;   ++iter) //训练与评估

        {

                Learner::UpdateOneIter();

                Learner::EvalOneIter();

        }

数据加载

在CLITrain中,通过param中的train_path,将训练数据加载进std::shared_ptr dtrain,并存储进缓存中cache_mats。之后加载评价数据进eval_datasets与cache_mats中。

|--DMatrix
		// 各种元信息
    |--MetaInfo
    
    		|--std::vector<std::string> feature_names;

				|--...

Learner初始化

利用上面生成的cache_mats,使用ResetLearner函数对Learner进行初始化返回LearnerImpl,同时使用cache初始化LearnelIO。

继承路线为


  	|--ResetLearner() // Learner初始化
  			
        |--Learner::Create()

          	|--LearnerImpl
        
    		|--Learner::SetParams()

2.2.4 Learner::UpdateOneItera

|--Learner::UpdateOneIter()
		// 初始化Learner
    |--LearnerConfiguration::Configure()
  	// 计算输入数据的base_core,存储进mparam_.base_score变量
  	|--LearnerConfiguration::InitBaseScore()
  	// 检测训练数据是否有效,主要是检测训练数据特征数与learner模型参数中的特征数是否相同
  	// 以及训练数据是否是空集,如果训练数据行数为0直接error
  	|--LearnerImpl::ValidateDMatrix()
  	// 缓存训练数据以及gpu_id,生成后续训练的存储预测结果的输出数组predt
  	|--PredictionContainer::Cache()
  	// 预测训练数据结果,存储于predt中
  	|--LearnerImpl::PredictRaw()
		// 计算并记录初始的hess和gradient的值,用来查找下一轮分裂点,具体实现会根据集成的子类
    // 进行实现,最后结果会存储在gpair_中,其类型为GradientPairInternal的实现类GradientPair
    // 包含hess_,grad_并且有相应的Add()等函数
    |--ObjFunction::GetGradient()

    |--GradientBooster::DoBoost()

主要逻辑为两部分:

生成初始化的数据

进行树的生成

树的生成从DoBoost开始,GradientBooster的子类中,Class GBTree 用的比较多。 DoBoost() 函数执行的操作如下:

//gbtree.cc

|--GBTree::DoBoost()

    |--GBTree::BoostNewTrees()
  
  			|--TreeUpdater:Update()

  	|--GBTree::UpdateTreeLeaf	

DoBoost() 调用了 BoostNewTrees() 函数。在 BoostNewTrees() 中先初始化了 TreeUpdater 实例,在调用其 Update 函数生成一棵回归树。TreeUpdater 是一个抽象类,根据使用算法不同其派生出许多不同的 Updater,这些 Updater 都在 src/tree 目录下。

|--include
		|--tree_updater.h 
    /*
    定义了TreeUpdater接口
    */
|--src

    |--tree

				|--updater_approx.cc
        /*
        class GlobalApproxUpdater 使用的是直方图法;
        */
        
        |--updater_colmaker.cc
        /*
        class ColMaker 使用的是基本枚举贪婪搜索算法,通过
        枚举所有的特征来寻找最佳分裂点;
				*/

        |--updater_gpu_hist.cc
        /*
        class GPUHistMaker,GPUGlobalApproxMaker 基于gpu的直方图实现;
        */
        
        |--updater_prune.cc
        /*
        class TreePruner 是树的剪枝操作;
        */
        
        |--updater_quantile_hist.cc
        
        /*
        class QuantileHistMaker 使用的是直方图法;
        */
        
        |--updater_refresh.cc    
        /*
        class TreeRefresher 用于刷新数据集上树的统计信息和叶子值;
        */
        
        |--updater_sync.cc
        /*
        class TreeSyncher 是在分布式情况下保证树的一致性。
        */

2.2.5 ColMaker

使用精确贪心

在 Class ColMaker 定义了一些数据结构用于辅助树的构造。

struct ThreadEntry // 存储每个节点暂时的状态,数量为:线程 * 每个线程的节点
struct NodeEntry   // 同样存储信息,存储每个节点的root_gain, weight等
class Builder{  		 // 实际运行Updata的类
 	const TrainParam& param; //训练参数,即我们设置的一些超参数
	std::vector<int> position;  //当前样本实例在回归树结中对应点的索引
	std::vector<NodeEntry> snode; //回归树中的结点
	std::vector<int> qexpand_;  //保存将有可能分裂的节点的索引 
}

在 Class ColMaker 中定义了一个 Class Builder 类,树所有的构造过程都由这个类完成

具体内容如下

    virtual void Update(const std::vector<GradientPair>& gpair,
                        DMatrix* p_fmat,
                        RegTree* p_tree) {
      std::vector<int> newnodes;
      /*
      1. 初始化position_,有无效或者抽样没有抽到的数据,就将position_中对应的值取反
      2. 初始化列采样器 column_sampler_
      3. 初始化单个节点构造时的统计信息存储结构 stemp_,清空,大小设置为256,使用时通过读取对应线程id与
      节点id,来获取hess(hessian)和grad(gradian)数据,存储的是已经遍历过的特征数据
      4. 初始化单个节点统计信息存储结构 snode_,清空,大小设置为256
      5. 初始化待分裂节点队列qexpand_,清空,初始有一个0
      */
      this->InitData(gpair, *p_fmat);
      /*
      计算结点增益root_gain,也就是未分裂该节点的损失函数,将用于判断该点是否需要分裂。
      计算当前点的权值weigtht,最终模型输出就是叶子结点 weight 的线性组合。
      
      含有并行计算,通过stemp_记录每个线程相同节点的hess和grad,
      汇总到主线程后。计算当前节点的root_gain与weight
      */
      this->InitNewNode(qexpand_, gpair, *p_fmat, *p_tree);
      // We can check max_leaves too, but might break some grid searching pipelines.
      CHECK_GT(param_.max_depth, 0) << "exact tree method doesn't support unlimited depth.";
      // 按照深度开始生成树
      for (int depth = 0; depth < param_.max_depth; ++depth) {
        /* 
        寻找最佳分裂点
        */
        this->FindSplit(depth, qexpand_, gpair, p_fmat, p_tree);
        /*
        更新各个数据对应树种节点的位置,也就是更新position_
        */
        this->ResetPosition(qexpand_, p_fmat, *p_tree);
        /*
        更新待分裂节点,如果这个节点这一轮没有分裂,就去掉
        qexpand_
        */
        this->UpdateQueueExpand(*p_tree, qexpand_, &newnodes);
        /*
        更新新叶子节点的权值,增益
        */
        this->InitNewNode(newnodes, gpair, *p_fmat, *p_tree);
        for (auto nid : qexpand_) {
          if ((*p_tree)[nid].IsLeaf()) {
            continue;
          }
          int cleft = (*p_tree)[nid].LeftChild();
          int cright = (*p_tree)[nid].RightChild();
					/*
          给定的树节点上应用约束条件,通过调整节点的上下界来影响树的生长过程,
          从而优化模型的性能
          */
          tree_evaluator_.AddSplit(nid, cleft, cright, snode_[nid].best.SplitIndex(),
                                   snode_[cleft].weight, snode_[cright].weight);
          /*
          管理和应用特征交互约束。通过记录先前的分裂历史和处理交互约束,代码
          确保在新节点中允许使用的特征受到先前约束的影响。
          */
          interaction_constraints_.Split(nid, snode_[nid].best.SplitIndex(), cleft, cright);
        }
        qexpand_ = newnodes;
        // if nothing left to be expand, break
        if (qexpand_.size() == 0) break;
      }
      // set all the rest expanding nodes to leaf
      for (const int nid : qexpand_) {
        (*p_tree)[nid].SetLeaf(snode_[nid].weight * param_.learning_rate);
      }
      // remember auxiliary statistics in the tree node
      for (int nid = 0; nid < p_tree->NumNodes(); ++nid) {
        p_tree->Stat(nid).loss_chg = snode_[nid].best.loss_chg;
        p_tree->Stat(nid).base_weight = snode_[nid].weight;
        p_tree->Stat(nid).sum_hess = static_cast<float>(snode_[nid].stats.sum_hess);
      }
    }

2.2.6 FindSplit()

其中调用了UpdateSolution,对于训练数据batch(此时按照value排序,CSC存储模式,每一行都是一组排好序的feature value)查找时通过feat_set遍历所有的特征id,在batch中读取对应的特征值,之后通过EnumerateSplit并行计算不同特征的最佳分裂点。各线程的计算结果存储在对应stemp_[tid]中。计算完成后,通过stemp_中的数据,将结果更新到snode_中。

ColMaker属于单机执行,且规定其所有数据都会存储于一个block中。

在生成batch的过程中,可以看到一个batch的存储内容就是一个page。

BatchSet<SortedCSCPage> 
SimpleDMatrix::GetSortedColumnBatches(Context const* ctx) {
  ...
  auto begin_iter =
      BatchIterator<SortedCSCPage>(new SimpleBatchIteratorImpl<SortedCSCPage>(sorted_column_page_));
  ...
}
}  // namespace xgboost::data

page是DMatrix中的内容,通过读取数据后遍历读取的迭代器,全部存储入page_中,因此单机模式下,可以看出一个page中就会保存所有内容。

DMatrix* DMatrix::Load(const std::string& uri,
                       bool silent, DataSplitMode data_split_mode) {
  ...
    // 通过
    std::unique_ptr<dmlc::Parser<std::uint32_t>> parser(
    		// 生成解析器,单机模式下解析器partid = 1,naprt = 1,不分片。
        dmlc::Parser<std::uint32_t>::Create(fname.c_str(),
                                            partid, npart, "auto"));
        data::FileAdapter adapter(parser.get());
        dmat = DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(),
                               Context{}.Threads(),
                               cache_file,
                               data_split_mode);
  ...
}

template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread,
                             DataSplitMode data_split_mode) {
....
  while (adapter->Next()) {
....
  // 所有的数据全部写入page中
    auto batch_max_columns = sparse_page_->Push(batch, missing, ctx.Threads());
....
    }
  }

综上,单机模式下,FindSplit时一个线程会一次扫描单个特征的所有特征值。

2.2.7 剪枝

每一轮树生成后,都会进行一次剪枝。剪枝分布式情况下,剪枝完会同步信息。

  void Update(TrainParam const* param, HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat,
              common::Span<HostDeviceVector<bst_node_t>> out_position,
              const std::vector<RegTree*>& trees) override {
    pruner_monitor_.Start("PrunerUpdate");
    for (auto tree : trees) {
      this->DoPrune(param, tree);
    }
    syncher_->Update(param, gpair, p_fmat, out_position, trees);
    pruner_monitor_.Stop("PrunerUpdate");
  }

剪枝条件为:当前节点收益小于最小收益,或者深度超过最大深度。

 return loss_chg < this->min_split_loss || 
   (this->max_depth != 0 && depth > this->max_depth);
  1. xgboost on spark实现

spark下各个节点计算自己的数据,生成直方图,最后通过AllReduce进行汇总数据,生成总的直方图,之后分裂节点,通过AllReduce进行同步。

3.1 训练链路

3.2 数据生成

在spark端,生成batchIter,存储数据结构为LabeledPoint,大小为32k条数据。

  public DMatrix(Iterator<LabeledPoint> iter, String cacheInfo) throws XGBoostError {
    if (iter == null) {
      throw new NullPointerException("iter: null");
    }
    // 32k as batch size
    int batchSize = 32 << 10;
    Iterator<DataBatch> batchIter = new DataBatch.BatchIterator(iter, batchSize);
    long[] out = new long[1];
    XGBoostJNI.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter, cacheInfo, out));
    handle = out[0];
  }

xgboost中,通过jni接取到batch中,将batch的数据解析后存储到自身的DMatrix中,并且返回指向该DMatrix的指针,之后训练通过传入该指针进行训练数据的读取。

3.3 训练

相对于单机,在spark时已经进行了数据的生成,所以在调用时直接生成树,于是直接调用Doboost。

3.3.1 spark支持的booster与updater

gbtree的booster

  1. 分布式相关的更新器,后两种在2.0中可以使用

"grow_histmaker,prune","grow_quantile_histmaker","grow_gpu_hist"的updater。

3.3.2 GlobalApproxUpdater

除了gpu的updater,xgboost on spark的updater使用全局近似直方图。

全局近似直方图

直方图就是一定范围内数据的集合,根据切分点划分。近似算法首先按照特征取值的统计分布的一些百分位点确定一些候选分裂点,然后算法将连续的值映射到 buckets中,然后汇总统计数据,并根据聚合统计数据在候选节点中找到最佳节点。

每个节点都会算自己节点上数据的直方图,根据切分点进行划分,汇总对应结构上的梯度数据。最后计算好直方图后进行汇总,进行节点分裂,之后同步分裂数据。

直方图存储的结构

class GHistRow实际是里面的double数组。

直方图存储逻辑

叶子节点中一个特征就是一个直方图,不同的数据的不同特征通过预处理的GHistIndexMatrix结构对应到对应直方图的位置。

切分点存储结构是什么

GHistIndexMatrix 中的uint8_t数组。

构建直方图所需要的数组在什么时候初始化的

在构建直方图时,开始遍历DMatrix就会构建梯度索引:p_fmat->GetBatches。通过该索引,存储特征值的Entry可以索引到其对应的直方图,从而实现直方图的构建。

  /** Main entry point of this class, build histogram for tree nodes. */
  void BuildHist(std::size_t page_idx, common::BlockedSpace2d const &space,
                 GHistIndexMatrix const &gidx, common::RowSetCollection const &row_set_collection,
                 std::vector<bst_node_t> const &nodes_to_build,
                 linalg::VectorView<GradientPair const> gpair, bool force_read_by_column = false) {
    ...
    if (gidx.IsDense()) {
      this->BuildLocalHistograms<false>(space, gidx, nodes_to_build, row_set_collection,
                                        gpair.Values(), force_read_by_column);
    } else {
      this->BuildLocalHistograms<true>(space, gidx, nodes_to_build, row_set_collection,
                                       gpair.Values(), force_read_by_column);
    }
  }

template <bool do_prefetch, class BuildingManager>
void RowsWiseBuildHistKernel(Span<GradientPair const> gpair,
                             const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
                             GHistRow hist) {
	...
  // 梯度索引数组,存储当前特征值指向哪个直方图
  const BinIdxType *gradient_index = gmat.index.data<BinIdxType>();
	...

  for (std::size_t i = 0; i < size; ++i) {
    // 对每一行特征进行处理
    // 列的开始
    const size_t icol_start =
        kAnyMissing ? get_row_ptr(rid[i]) : get_rid(rid[i]) * n_features;
    const size_t icol_end =
        kAnyMissing ? get_row_ptr(rid[i] + 1) : icol_start + n_features;

    const size_t row_size = icol_end - icol_start;
    // row_index: 样本索引
    const size_t idx_gh = two * rid[i];

    if (do_prefetch) {
      const size_t icol_start_prefetch =
          kAnyMissing
              ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset])
              : get_rid(rid[i + Prefetch::kPrefetchOffset]) * n_features;
      const size_t icol_end_prefetch =
          kAnyMissing ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset] + 1)
                      : icol_start_prefetch + n_features;

      PREFETCH_READ_T0(p_gpair + two * rid[i + Prefetch::kPrefetchOffset]);
      for (size_t j = icol_start_prefetch; j < icol_end_prefetch;
           j += Prefetch::GetPrefetchStep<uint32_t>()) {
        PREFETCH_READ_T0(gradient_index + j);
      }
    }
    const BinIdxType *gr_index_local = gradient_index + icol_start;

    // The trick with pgh_t buffer helps the compiler to generate faster binary.
    const float pgh_t[] = {p_gpair[idx_gh], p_gpair[idx_gh + 1]};
    for (size_t j = 0; j < row_size; ++j) {
      // 特征对应直方图的索引,通过梯度数组[特征位于样本的第几个] + 后面的判断
      const uint32_t idx_bin =
          two * (static_cast<uint32_t>(gr_index_local[j]) + (kAnyMissing ? 0 : offsets[j]));
      // 聚合直方图对应特征的内容
      auto hist_local = hist_data + idx_bin;
      *(hist_local) += pgh_t[0];
      *(hist_local + 1) += pgh_t[1];
    }
  }
}

在实际训练时,通过BoostNewTrees()的调用,获取需要更新的树,如果参数为kDefault则是生成新的树,如果为kUpdate则是将已有的树再次更新。

void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fmat, int bst_group,
                           std::vector<HostDeviceVector<bst_node_t>>* out_position,
                           TreesOneGroup* ret) {
  std::vector<RegTree*> new_trees;
  ret->clear();
  // create the trees
  for (int i = 0; i < model_.param.num_parallel_tree; ++i) {
    // 新树的生成
    if (tparam_.process_type == TreeProcessType::kDefault) {
      CHECK(!updaters_.empty());
      CHECK(!updaters_.front()->CanModifyTree())
          << "Updater: `" << updaters_.front()->Name() << "` "
          << "can not be used to create new trees. "
          << "Set `process_type` to `update` if you want to update existing "
             "trees.";
      // create new tree
      std::unique_ptr<RegTree> ptr(new RegTree{this->model_.learner_model_param->LeafLength(),
                                               this->model_.learner_model_param->num_feature});
      new_trees.push_back(ptr.get());
      ret->push_back(std::move(ptr));
    } else if (tparam_.process_type == TreeProcessType::kUpdate) 
      for (auto const& up : updaters_) {
        CHECK(up->CanModifyTree())
            << "Updater: `" << up->Name() << "` "
            << "can not be used to modify existing trees. "
            << "Set `process_type` to `default` if you want to build new trees.";
      }
      CHECK_LT(model_.trees.size(), model_.trees_to_update.size())
          << "No more tree left for updating.  For updating existing trees, "
          << "boosting rounds can not exceed previous training rounds";
      // move an existing tree from trees_to_update
    	// 使用已有的树
      auto t = std::move(model_.trees_to_update[model_.trees.size() +
                                                bst_group * model_.param.num_parallel_tree + i]);
      new_trees.push_back(t.get());
      ret->push_back(std::move(t));
    }
  ...
    // 更新树
    for (auto& up : updaters_) {
    	up->Update(&tree_param_, gpair, p_fmat,
                 common::Span<HostDeviceVector<bst_node_t>>{*out_position}, new_trees);
  	}
  ...
  }

首先初始化数据,找到需要分裂的点以及生成节点对应的直方图。之后遍历需要分裂的点,看是否可以分裂。如果可以分裂则重新分配数据集,并且构建分裂后点的直方图并且加入队列中等待遍历,当树Update完毕后,同步新树的信息。

  void Update(TrainParam const *param, HostDeviceVector<GradientPair> *gpair, DMatrix *m,
              common::Span<HostDeviceVector<bst_node_t>> out_position,
              const std::vector<RegTree *> &trees) override {
    CHECK(hist_param_.GetInitialised());
    pimpl_ = std::make_unique<GloablApproxBuilder>(param, &hist_param_, m->Info(), ctx_,
                                                   column_sampler_, task_, &monitor_);

    linalg::Matrix<GradientPair> h_gpair;
    // Obtain the hessian values for weighted sketching
    InitData(*param, gpair, &h_gpair);
    std::vector<float> hess(h_gpair.Size());
    auto const &s_gpair = h_gpair.Data()->ConstHostVector();
    std::transform(s_gpair.begin(), s_gpair.end(), hess.begin(),
                   [](auto g) { return g.GetHess(); });

    cached_ = m;

    std::size_t t_idx = 0;
    for (auto p_tree : trees) {
      this->pimpl_->UpdateTree(m, s_gpair, hess, p_tree, &out_position[t_idx]);
      hist_param_.CheckTreesSynchronized(p_tree);
      ++t_idx;
    }
  }


  void UpdateTree(DMatrix *p_fmat, std::vector<GradientPair> const &gpair, common::Span<float> hess,
                  RegTree *p_tree, HostDeviceVector<bst_node_t> *p_out_position) {
    p_last_tree_ = p_tree;
    this->InitData(p_fmat, p_tree, hess);

    Driver<CPUExpandEntry> driver(*param_);
    auto &tree = *p_tree;
    driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)});
    auto expand_set = driver.Pop();

    /**
     * Note for update position
     * Root:
     *   Not applied: No need to update position as initialization has got all the rows ordered.
     *   Applied: Update position is run on applied nodes so the rows are partitioned.
     * Non-root:
     *   Not applied: That node is root of the subtree, same rule as root.
     *   Applied: Ditto
     */

    while (!expand_set.empty()) {
      // candidates that can be further splited.
      // 下一轮的子节点可以接着切分的点
      std::vector<CPUExpandEntry> valid_candidates;
      // candidates that can be applied.
      // 本轮可以切分的节点
      std::vector<CPUExpandEntry> applied;
      for (auto const &candidate : expand_set) {
        // 分裂节点
        evaluator_.ApplyTreeSplit(candidate, p_tree);
        applied.push_back(candidate);
        if (driver.IsChildValid(candidate)) {
          valid_candidates.emplace_back(candidate);
        }
      }

      monitor_->Start("UpdatePosition");
      size_t page_id = 0;
      for (auto const &page :
           p_fmat->GetBatches<GHistIndexMatrix>(ctx_, BatchSpec(*param_, hess))) {
        /**
         * 节点分裂后,数据集的重新分配和更新
        */
        partitioner_.at(page_id).UpdatePosition(ctx_, page, applied, p_tree);
        page_id++;
      }
      monitor_->Stop("UpdatePosition");

      std::vector<CPUExpandEntry> best_splits;
      if (!valid_candidates.empty()) {
        // 会基于左右子节点构筑直方图
        this->BuildHistogram(p_fmat, p_tree, valid_candidates, gpair, hess);
        for (auto const &candidate : valid_candidates) {
          int left_child_nidx = tree[candidate.nid].LeftChild();
          int right_child_nidx = tree[candidate.nid].RightChild();
          CPUExpandEntry l_best{left_child_nidx, tree.GetDepth(left_child_nidx)};
          CPUExpandEntry r_best{right_child_nidx, tree.GetDepth(right_child_nidx)};
          best_splits.push_back(l_best);
          best_splits.push_back(r_best);
        }
        auto const &histograms = histogram_builder_.Histogram(0);
        auto ft = p_fmat->Info().feature_types.ConstHostSpan();
        monitor_->Start("EvaluateSplits");
        evaluator_.EvaluateSplits(histograms, feature_values_, ft, *p_tree, &best_splits);
        monitor_->Stop("EvaluateSplits");
      }
      driver.Push(best_splits.begin(), best_splits.end());
      expand_set = driver.Pop();
    }

    auto &h_position = p_out_position->HostVector();
    this->LeafPartition(tree, hess, &h_position);
  }
};

3.4 分布式与单机区别

分布式相对于单机,树增长方法有了限制,只能使用近似直方图以及GPU相关算法。

在实现上主要是增加了Allreduce来统计汇总每一轮的相关信息。例如在生成直方图时会汇总各节点对于各树节点生成的直方图梯度信息

 void SyncHistogram(RegTree const *p_tree, std::vector<bst_node_t> const &nodes_to_build,
                     std::vector<bst_node_t> const &nodes_to_trick) {
    ...
    if (is_distributed_ && !is_col_split_) {
		...
      // 分布式
      collective::Allreduce<collective::Operation::kSum>(
          reinterpret_cast<double *>(this->hist_[first_nidx].data()), n);
    }
		...
    // 多线程
    common::BlockedSpace2d const &subspace =
        nodes_to_trick.size() == nodes_to_build.size()
            ? space
            : common::BlockedSpace2d{nodes_to_trick.size(),
                                     [&](std::size_t) { return n_total_bins; }, 1024};
    common::ParallelFor2d(
        subspace, this->n_threads_, [&](std::size_t nidx_in_set, common::Range1d r) {
          auto subtraction_nidx = nodes_to_trick[nidx_in_set];
          auto parent_id = p_tree->Parent(subtraction_nidx);
          auto sibling_nidx = p_tree->IsLeftChild(subtraction_nidx) ? p_tree->RightChild(parent_id)
                                                                    : p_tree->LeftChild(parent_id);
          auto sibling_hist = this->hist_[sibling_nidx];
          auto parent_hist = this->hist_[parent_id];
          auto subtract_hist = this->hist_[subtraction_nidx];
          common::SubtractionHist(subtract_hist, parent_hist, sibling_hist, r.begin(), r.end());
        });
  }

标签:std,const,--,xgboost,tree,源码,阅读,节点,gpair
From: https://www.cnblogs.com/yych0745/p/17980432

相关文章

  • FutureTask源码阅读
    目录简介例子代码分析成员变量方法参考链接本人的源码阅读主要聚焦于类的使用场景,一般只在java层面进行分析,没有深入到一些native方法的实现。并且由于知识储备不完整,很可能出现疏漏甚至是谬误,欢迎指出共同学习本文基于corretto-17.0.9源码,参考本文时请打开相应的源码对照,否则......
  • Linux下源码安装
    Linux下源码安装很多开源库都没有说明怎么安装,这里记录一下一般方法。步骤以wldgb为例:克隆下源码后,发现README中没有说怎么安装,观察文件:一般来说,autogen.sh是用来生成configure的,然后configure是用来生成makefile的。如果不确定,可以看一下这些文件中的内容,就知道大概是怎......
  • 【快速阅读三】使用泊松融合实现单幅图的无缝拼贴及消除两幅图片直接的拼接缝隙。
    泊松融合还可以创建一些很有意思的图片,比如一张图片任意规格平铺,使用泊松融合后,平铺的边界处过渡的很自然,另外,对于两张图片,由于局部亮度等等的影响,导致拼接在一起时不自然,也可以使用泊松融合予以解决。在【快速阅读二】从OpenCv的代码中扣取泊松融合算......
  • 【快速阅读二】从OpenCv的代码中扣取泊松融合算子(Poisson Image Editing)并稍作优化
    泊松融合是一种非常不错多图融合算法,在OpenCv的相关版本中也包含了该算子模块,作者尝试着从OpenCv的大仓库中扣取出该算子的全部代码,并分享了一些在扣取代码中的心得和收获。泊松融合我自己写的第一版程序大概是2016年在某个小房间里折腾出来的,当时是用......
  • zookeeper源码(06)ZooKeeperServer及子类
    ZooKeeperServer实现了单机版zookeeper服务端功能,子类实现了更加丰富的分布式集群功能:ZooKeeperServer|--QuorumZooKeeperServer|--LeaderZooKeeperServer|--LearnerZooKeeperServer|--FollowerZooKeeperServer|--ObserverZooKeeperServer|-......
  • 基于pytest搭建接口自动化测试框架,提供源码
     基于pytest搭建接口自动化测试框架 框架整体介绍和方法教程第三代框架使用教程,该框架比第二代这个完善了很多https://blog.csdn.net/aaaaaaaaanjjj/article/details/129597973新框架(第二代比这个功能多了很多,用例使用yaml编写)pytest+yaml设计接口自动化框架过程记录......
  • 华企盾DSC:外发文件设置编辑权限 阅读次数 阅后即焚 文件过期
    互联网时代,信息流通迅速,一份关键的内部文件一旦外泄,可能毁掉公司数月、甚至数年的努力。企业多次碰壁后终于发现,仅仅依靠员工层层审批、体系内控制,已难以防止数据泄密这一严重问题。更为糟糕的是,一旦文件发送出去,系统往往不能有效地控制未授权阅读的发生。痛定思痛,企业用户渴望有......
  • [转帖]MySQL多版本并发控制机制(MVCC)-源码浅析
    https://zhuanlan.zhihu.com/p/144682180 MySQL多版本并发控制机制(MVCC)-源码浅析前言作为一个数据库爱好者,自己动手写过简单的SQL解析器以及存储引擎,但感觉还是不够过瘾。<<事务处理-概念与技术>>诚然讲的非常透彻,但只能提纲挈领,不能让你玩转某个真正的数据库。感谢c......
  • 《Deep Long-Tailed Learning: A Survey》阅读笔记
    论文标题《DeepLong-TailedLearning:ASurvey》深度长尾学习:调查作者YifanZhang、BingyiKang、BryanHooi、ShuichengYan(IEEEFellow)和JiashiFeng来自新加坡国立大学计算机学院、字节跳动AILab和SEAAILab初读摘要长尾类别不平衡(long-tailedclassimbala......
  • Feign源码解析7:nacos loadbalancer不支持静态ip的负载均衡
    背景在feign中,一般是通过eureka、nacos等获取服务实例,但有时候调用一些服务时,人家给的是ip或域名,我们这时候还能用Feign这一套吗?可以的。有两种方式,一种是直接指定url:这种是服务端自己会保证高可用、负载均衡那些。但也可能对方给了多个url(一般不会这样,但是在app场景下,为了......