From 49f51e1a61bae5ebdc3a36d6c74de3f990298e5c Mon Sep 17 00:00:00 2001 From: fis Date: Sun, 21 Jun 2020 07:00:36 +0800 Subject: [PATCH 1/2] [WIP] Remove WXQ and GK. * GK is not used. * WXQ is not documented and doesn't provide better performance. --- src/common/quantile.h | 269 ------------------------------- src/tree/updater_basemaker-inl.h | 6 +- src/tree/updater_histmaker.cc | 15 +- src/tree/updater_skmaker.cc | 27 ++-- 4 files changed, 24 insertions(+), 293 deletions(-) diff --git a/src/common/quantile.h b/src/common/quantile.h index 067b041bc4e9..014cca037a0e 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -341,256 +341,6 @@ struct WQSummary { } }; -/*! \brief try to do efficient pruning */ -template -struct WXQSummary : public WQSummary { - // redefine entry type - using Entry = typename WQSummary::Entry; - // constructor - WXQSummary(Entry *data, size_t size) - : WQSummary(data, size) {} - // check if the block is large chunk - inline static bool CheckLarge(const Entry &e, RType chunk) { - return e.RMinNext() > e.RMaxPrev() + chunk; - } - // set prune - inline void SetPrune(const WQSummary &src, size_t maxsize) { - if (src.size <= maxsize) { - this->CopyFrom(src); return; - } - RType begin = src.data[0].rmax; - // n is number of points exclude the min/max points - size_t n = maxsize - 2, nbig = 0; - // these is the range of data exclude the min/max point - RType range = src.data[src.size - 1].rmin - begin; - // prune off zero weights - if (range == 0.0f || maxsize <= 2) { - // special case, contain only two effective data pts - this->data[0] = src.data[0]; - this->data[1] = src.data[src.size - 1]; - this->size = 2; - return; - } else { - range = std::max(range, static_cast(1e-3f)); - } - // Get a big enough chunk size, bigger than range / n - // (multiply by 2 is a safe factor) - const RType chunk = 2 * range / n; - // minimized range - RType mrange = 0; - { - // first scan, grab all the big chunk - // moving block index, exclude the two ends. - size_t bid = 0; - for (size_t i = 1; i < src.size - 1; ++i) { - // detect big chunk data point in the middle - // always save these data points. - if (CheckLarge(src.data[i], chunk)) { - if (bid != i - 1) { - // accumulate the range of the rest points - mrange += src.data[i].RMaxPrev() - src.data[bid].RMinNext(); - } - bid = i; ++nbig; - } - } - if (bid != src.size - 2) { - mrange += src.data[src.size-1].RMaxPrev() - src.data[bid].RMinNext(); - } - } - // assert: there cannot be more than n big data points - if (nbig >= n) { - // see what was the case - LOG(INFO) << " check quantile stats, nbig=" << nbig << ", n=" << n; - LOG(INFO) << " srcsize=" << src.size << ", maxsize=" << maxsize - << ", range=" << range << ", chunk=" << chunk; - src.Print(); - CHECK(nbig < n) << "quantile: too many large chunk"; - } - this->data[0] = src.data[0]; - this->size = 1; - // The counter on the rest of points, to be selected equally from small chunks. - n = n - nbig; - // find the rest of point - size_t bid = 0, k = 1, lastidx = 0; - for (size_t end = 1; end < src.size; ++end) { - if (end == src.size - 1 || CheckLarge(src.data[end], chunk)) { - if (bid != end - 1) { - size_t i = bid; - RType maxdx2 = src.data[end].RMaxPrev() * 2; - for (; k < n; ++k) { - RType dx2 = 2 * ((k * mrange) / n + begin); - if (dx2 >= maxdx2) break; - while (i < end && - dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) ++i; - if (i == end) break; - if (dx2 < src.data[i].RMinNext() + src.data[i + 1].RMaxPrev()) { - if (i != lastidx) { - this->data[this->size++] = src.data[i]; lastidx = i; - } - } else { - if (i + 1 != lastidx) { - this->data[this->size++] = src.data[i + 1]; lastidx = i + 1; - } - } - } - } - if (lastidx != end) { - this->data[this->size++] = src.data[end]; - lastidx = end; - } - bid = end; - // shift base by the gap - begin += src.data[bid].RMinNext() - src.data[bid].RMaxPrev(); - } - } - } -}; -/*! - * \brief traditional GK summary - */ -template -struct GKSummary { - /*! \brief an entry in the sketch summary */ - struct Entry { - /*! \brief minimum rank */ - RType rmin; - /*! \brief maximum rank */ - RType rmax; - /*! \brief the value of data */ - DType value; - // constructor - Entry() = default; - // constructor - Entry(RType rmin, RType rmax, DType value) - : rmin(rmin), rmax(rmax), value(value) {} - }; - /*! \brief input data queue before entering the summary */ - struct Queue { - // the input queue - std::vector queue; - // end of the queue - size_t qtail; - // push data to the queue - inline void Push(DType x, RType w) { - queue[qtail++] = x; - } - inline void MakeSummary(GKSummary *out) { - std::sort(queue.begin(), queue.begin() + qtail); - out->size = qtail; - for (size_t i = 0; i < qtail; ++i) { - out->data[i] = Entry(i + 1, i + 1, queue[i]); - } - } - }; - /*! \brief data field */ - Entry *data; - /*! \brief number of elements in the summary */ - size_t size; - GKSummary(Entry *data, size_t size) - : data(data), size(size) {} - /*! \brief the maximum error of the summary */ - inline RType MaxError() const { - RType res = 0; - for (size_t i = 1; i < size; ++i) { - res = std::max(data[i].rmax - data[i-1].rmin, res); - } - return res; - } - /*! \return maximum rank in the summary */ - inline RType MaxRank() const { - return data[size - 1].rmax; - } - /*! - * \brief copy content from src - * \param src source sketch - */ - inline void CopyFrom(const GKSummary &src) { - size = src.size; - std::memcpy(data, src.data, sizeof(Entry) * size); - } - inline void CheckValid(RType eps) const { - // assume always valid - } - /*! \brief used for debug purpose, print the summary */ - inline void Print() const { - for (size_t i = 0; i < size; ++i) { - LOG(CONSOLE) << "x=" << data[i].value << "\t" - << "[" << data[i].rmin << "," << data[i].rmax << "]"; - } - } - /*! - * \brief set current summary to be pruned summary of src - * assume data field is already allocated to be at least maxsize - * \param src source summary - * \param maxsize size we can afford in the pruned sketch - */ - inline void SetPrune(const GKSummary &src, size_t maxsize) { - if (src.size <= maxsize) { - this->CopyFrom(src); return; - } - const RType max_rank = src.MaxRank(); - this->size = maxsize; - data[0] = src.data[0]; - size_t n = maxsize - 1; - RType top = 1; - for (size_t i = 1; i < n; ++i) { - RType k = (i * max_rank) / n; - while (k > src.data[top + 1].rmax) ++top; - // assert src.data[top].rmin <= k - // because k > src.data[top].rmax >= src.data[top].rmin - if ((k - src.data[top].rmin) < (src.data[top+1].rmax - k)) { - data[i] = src.data[top]; - } else { - data[i] = src.data[top + 1]; - } - } - data[n] = src.data[src.size - 1]; - } - inline void SetCombine(const GKSummary &sa, - const GKSummary &sb) { - if (sa.size == 0) { - this->CopyFrom(sb); return; - } - if (sb.size == 0) { - this->CopyFrom(sa); return; - } - CHECK(sa.size > 0 && sb.size > 0) << "invalid input for merge"; - const Entry *a = sa.data, *a_end = sa.data + sa.size; - const Entry *b = sb.data, *b_end = sb.data + sb.size; - this->size = sa.size + sb.size; - RType aprev_rmin = 0, bprev_rmin = 0; - Entry *dst = this->data; - while (a != a_end && b != b_end) { - if (a->value < b->value) { - *dst = Entry(bprev_rmin + a->rmin, - a->rmax + b->rmax - 1, a->value); - aprev_rmin = a->rmin; - ++dst; ++a; - } else { - *dst = Entry(aprev_rmin + b->rmin, - b->rmax + a->rmax - 1, b->value); - bprev_rmin = b->rmin; - ++dst; ++b; - } - } - if (a != a_end) { - RType bprev_rmax = (b_end - 1)->rmax; - do { - *dst = Entry(bprev_rmin + a->rmin, bprev_rmax + a->rmax, a->value); - ++dst; ++a; - } while (a != a_end); - } - if (b != b_end) { - RType aprev_rmax = (a_end - 1)->rmax; - do { - *dst = Entry(aprev_rmin + b->rmin, aprev_rmax + b->rmax, b->value); - ++dst; ++b; - } while (b != b_end); - } - CHECK(dst == data + size) << "bug in combine"; - } -}; - /*! * \brief template for all quantile sketch algorithm * that uses merge/prune scheme @@ -831,25 +581,6 @@ template class WQuantileSketch : public QuantileSketchTemplate > { }; - -/*! - * \brief Quantile sketch use WXQSummary - * \tparam DType type of data content - * \tparam RType type of rank - */ -template -class WXQuantileSketch : - public QuantileSketchTemplate > { -}; -/*! - * \brief Quantile sketch use WQSummary - * \tparam DType type of data content - * \tparam RType type of rank - */ -template -class GKQuantileSketch : - public QuantileSketchTemplate > { -}; } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_QUANTILE_H_ diff --git a/src/tree/updater_basemaker-inl.h b/src/tree/updater_basemaker-inl.h index 66ab91982407..18ce6e4890e3 100644 --- a/src/tree/updater_basemaker-inl.h +++ b/src/tree/updater_basemaker-inl.h @@ -376,7 +376,7 @@ class BaseMaker: public TreeUpdater { /*! \brief current size of sketch */ double next_goal; // pointer to the sketch to put things in - common::WXQuantileSketch *sketch; + common::WQuantileSketch *sketch; // initialize the space inline void Init(unsigned max_size) { next_goal = -1.0f; @@ -404,7 +404,7 @@ class BaseMaker: public TreeUpdater { last_fvalue > sketch->temp.data[sketch->temp.size-1].value) { // push to sketch sketch->temp.data[sketch->temp.size] = - common::WXQuantileSketch:: + common::WQuantileSketch:: Entry(static_cast(rmin), static_cast(rmax), static_cast(wmin), last_fvalue); @@ -442,7 +442,7 @@ class BaseMaker: public TreeUpdater { << ", stemp.size=" << sketch->temp.size; // push to sketch sketch->temp.data[sketch->temp.size] = - common::WXQuantileSketch:: + common::WQuantileSketch:: Entry(static_cast(rmin), static_cast(rmax), static_cast(wmin), last_fvalue); diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index c4fdbe3c0308..7639793b37d9 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -317,7 +317,7 @@ class CQHistMaker: public HistMaker { } }; // sketch type used for this - using WXQSketch = common::WXQuantileSketch; + using WQSketch = common::WQuantileSketch; // initialize the work set of tree void InitWorkSet(DMatrix *p_fmat, const RegTree &tree, @@ -439,14 +439,14 @@ class CQHistMaker: public HistMaker { } } for (size_t i = 0; i < sketchs_.size(); ++i) { - common::WXQuantileSketch::SummaryContainer out; + common::WQuantileSketch::SummaryContainer out; sketchs_[i].GetSummary(&out); summary_array_[i].SetPrune(out, max_size); } CHECK_EQ(summary_array_.size(), sketchs_.size()); } if (summary_array_.size() != 0) { - size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size); + size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(max_size); sreducer_.Allreduce(dmlc::BeginPtr(summary_array_), nbytes, summary_array_.size()); } // now we get the final result of sketch, setup the cut @@ -457,7 +457,8 @@ class CQHistMaker: public HistMaker { for (unsigned int i : fset) { int offset = feat2workindex_[i]; if (offset >= 0) { - const WXQSketch::Summary &a = summary_array_[wid * work_set_size + offset]; + const WQSketch::Summary &a = + summary_array_[wid * work_set_size + offset]; for (size_t i = 1; i < a.size; ++i) { bst_float cpt = a.data[i].value - kRtEps; if (i == 1 || cpt > this->wspace_.cut.back()) { @@ -630,11 +631,11 @@ class CQHistMaker: public HistMaker { // node statistics std::vector node_stats_; // summary array - std::vector summary_array_; + std::vector summary_array_; // reducer for summary - rabit::SerializeReducer sreducer_; + rabit::SerializeReducer sreducer_; // per node, per feature sketch - std::vector > sketchs_; + std::vector > sketchs_; }; // global proposal diff --git a/src/tree/updater_skmaker.cc b/src/tree/updater_skmaker.cc index 69cb4e58bbed..e70499c41ee6 100644 --- a/src/tree/updater_skmaker.cc +++ b/src/tree/updater_skmaker.cc @@ -75,7 +75,7 @@ class SketchMaker: public BaseMaker { } } // define the sketch we want to use - using WXQSketch = common::WXQuantileSketch; + using WQSketch = common::WQuantileSketch; private: // statistics needed in the gradient calculation @@ -152,12 +152,12 @@ class SketchMaker: public BaseMaker { // synchronize sketch summary_array_.resize(sketchs_.size()); for (size_t i = 0; i < sketchs_.size(); ++i) { - common::WXQuantileSketch::SummaryContainer out; + common::WQuantileSketch::SummaryContainer out; sketchs_[i].GetSummary(&out); summary_array_[i].Reserve(max_size); summary_array_[i].SetPrune(out, max_size); } - size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size); + size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(max_size); sketch_reducer_.Allreduce(dmlc::BeginPtr(summary_array_), nbytes, summary_array_.size()); } // update sketch information in column fid @@ -301,11 +301,10 @@ class SketchMaker: public BaseMaker { p_tree->Stat(nid).base_weight = static_cast(node_sum.CalcWeight(param_)); p_tree->Stat(nid).sum_hess = static_cast(node_sum.sum_hess); } - inline void EnumerateSplit(const WXQSketch::Summary &pos_grad, - const WXQSketch::Summary &neg_grad, - const WXQSketch::Summary &sum_hess, - const SKStats &node_sum, - bst_uint fid, + inline void EnumerateSplit(const WQSketch::Summary &pos_grad, + const WQSketch::Summary &neg_grad, + const WQSketch::Summary &sum_hess, + const SKStats &node_sum, bst_uint fid, SplitEntry *best) { if (sum_hess.size == 0) return; double root_gain = node_sum.CalcGain(param_); @@ -328,9 +327,9 @@ class SketchMaker: public BaseMaker { feat_sum.sum_hess = sum_hess.data[sum_hess.size - 1].rmax; size_t ipos = 0, ineg = 0, ihess = 0; for (size_t i = 1; i < fsplits.size(); ++i) { - WXQSketch::Entry pos = pos_grad.Query(fsplits[i], ipos); - WXQSketch::Entry neg = neg_grad.Query(fsplits[i], ineg); - WXQSketch::Entry hess = sum_hess.Query(fsplits[i], ihess); + WQSketch::Entry pos = pos_grad.Query(fsplits[i], ipos); + WQSketch::Entry neg = neg_grad.Query(fsplits[i], ineg); + WQSketch::Entry hess = sum_hess.Query(fsplits[i], ihess); SKStats s, c; s.pos_grad = 0.5f * (pos.rmin + pos.rmax - pos.wmin); s.neg_grad = 0.5f * (neg.rmin + neg.rmax - neg.wmin); @@ -379,13 +378,13 @@ class SketchMaker: public BaseMaker { // node statistics std::vector node_stats_; // summary array - std::vector summary_array_; + std::vector summary_array_; // reducer for summary rabit::Reducer stats_reducer_; // reducer for summary - rabit::SerializeReducer sketch_reducer_; + rabit::SerializeReducer sketch_reducer_; // per node, per feature sketch - std::vector > sketchs_; + std::vector > sketchs_; }; XGBOOST_REGISTER_TREE_UPDATER(SketchMaker, "grow_skmaker") From 92d485f52f72ed8efdc0d5e32e86fbeac1a795a3 Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 22 Jun 2020 18:19:12 +0800 Subject: [PATCH 2/2] Restore WXQ. --- src/common/quantile.h | 114 +++++++++++++++++++++++++++++++ src/tree/updater_basemaker-inl.h | 6 +- src/tree/updater_histmaker.cc | 15 ++-- src/tree/updater_skmaker.cc | 27 ++++---- 4 files changed, 138 insertions(+), 24 deletions(-) diff --git a/src/common/quantile.h b/src/common/quantile.h index 014cca037a0e..c0079ff8ebc8 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -341,6 +341,110 @@ struct WQSummary { } }; +/*! \brief try to do efficient pruning */ +template +struct WXQSummary : public WQSummary { + // redefine entry type + using Entry = typename WQSummary::Entry; + // constructor + WXQSummary(Entry *data, size_t size) + : WQSummary(data, size) {} + // check if the block is large chunk + inline static bool CheckLarge(const Entry &e, RType chunk) { + return e.RMinNext() > e.RMaxPrev() + chunk; + } + // set prune + inline void SetPrune(const WQSummary &src, size_t maxsize) { + if (src.size <= maxsize) { + this->CopyFrom(src); return; + } + RType begin = src.data[0].rmax; + // n is number of points exclude the min/max points + size_t n = maxsize - 2, nbig = 0; + // these is the range of data exclude the min/max point + RType range = src.data[src.size - 1].rmin - begin; + // prune off zero weights + if (range == 0.0f || maxsize <= 2) { + // special case, contain only two effective data pts + this->data[0] = src.data[0]; + this->data[1] = src.data[src.size - 1]; + this->size = 2; + return; + } else { + range = std::max(range, static_cast(1e-3f)); + } + // Get a big enough chunk size, bigger than range / n + // (multiply by 2 is a safe factor) + const RType chunk = 2 * range / n; + // minimized range + RType mrange = 0; + { + // first scan, grab all the big chunk + // moving block index, exclude the two ends. + size_t bid = 0; + for (size_t i = 1; i < src.size - 1; ++i) { + // detect big chunk data point in the middle + // always save these data points. + if (CheckLarge(src.data[i], chunk)) { + if (bid != i - 1) { + // accumulate the range of the rest points + mrange += src.data[i].RMaxPrev() - src.data[bid].RMinNext(); + } + bid = i; ++nbig; + } + } + if (bid != src.size - 2) { + mrange += src.data[src.size-1].RMaxPrev() - src.data[bid].RMinNext(); + } + } + // assert: there cannot be more than n big data points + if (nbig >= n) { + // see what was the case + LOG(INFO) << " check quantile stats, nbig=" << nbig << ", n=" << n; + LOG(INFO) << " srcsize=" << src.size << ", maxsize=" << maxsize + << ", range=" << range << ", chunk=" << chunk; + src.Print(); + CHECK(nbig < n) << "quantile: too many large chunk"; + } + this->data[0] = src.data[0]; + this->size = 1; + // The counter on the rest of points, to be selected equally from small chunks. + n = n - nbig; + // find the rest of point + size_t bid = 0, k = 1, lastidx = 0; + for (size_t end = 1; end < src.size; ++end) { + if (end == src.size - 1 || CheckLarge(src.data[end], chunk)) { + if (bid != end - 1) { + size_t i = bid; + RType maxdx2 = src.data[end].RMaxPrev() * 2; + for (; k < n; ++k) { + RType dx2 = 2 * ((k * mrange) / n + begin); + if (dx2 >= maxdx2) break; + while (i < end && + dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) ++i; + if (i == end) break; + if (dx2 < src.data[i].RMinNext() + src.data[i + 1].RMaxPrev()) { + if (i != lastidx) { + this->data[this->size++] = src.data[i]; lastidx = i; + } + } else { + if (i + 1 != lastidx) { + this->data[this->size++] = src.data[i + 1]; lastidx = i + 1; + } + } + } + } + if (lastidx != end) { + this->data[this->size++] = src.data[end]; + lastidx = end; + } + bid = end; + // shift base by the gap + begin += src.data[bid].RMinNext() - src.data[bid].RMaxPrev(); + } + } + } +}; /*! * \brief template for all quantile sketch algorithm * that uses merge/prune scheme @@ -581,6 +685,16 @@ template class WQuantileSketch : public QuantileSketchTemplate > { }; + +/*! + * \brief Quantile sketch use WXQSummary + * \tparam DType type of data content + * \tparam RType type of rank + */ +template +class WXQuantileSketch : + public QuantileSketchTemplate > { +}; } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_QUANTILE_H_ diff --git a/src/tree/updater_basemaker-inl.h b/src/tree/updater_basemaker-inl.h index 18ce6e4890e3..66ab91982407 100644 --- a/src/tree/updater_basemaker-inl.h +++ b/src/tree/updater_basemaker-inl.h @@ -376,7 +376,7 @@ class BaseMaker: public TreeUpdater { /*! \brief current size of sketch */ double next_goal; // pointer to the sketch to put things in - common::WQuantileSketch *sketch; + common::WXQuantileSketch *sketch; // initialize the space inline void Init(unsigned max_size) { next_goal = -1.0f; @@ -404,7 +404,7 @@ class BaseMaker: public TreeUpdater { last_fvalue > sketch->temp.data[sketch->temp.size-1].value) { // push to sketch sketch->temp.data[sketch->temp.size] = - common::WQuantileSketch:: + common::WXQuantileSketch:: Entry(static_cast(rmin), static_cast(rmax), static_cast(wmin), last_fvalue); @@ -442,7 +442,7 @@ class BaseMaker: public TreeUpdater { << ", stemp.size=" << sketch->temp.size; // push to sketch sketch->temp.data[sketch->temp.size] = - common::WQuantileSketch:: + common::WXQuantileSketch:: Entry(static_cast(rmin), static_cast(rmax), static_cast(wmin), last_fvalue); diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index 7639793b37d9..c4fdbe3c0308 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -317,7 +317,7 @@ class CQHistMaker: public HistMaker { } }; // sketch type used for this - using WQSketch = common::WQuantileSketch; + using WXQSketch = common::WXQuantileSketch; // initialize the work set of tree void InitWorkSet(DMatrix *p_fmat, const RegTree &tree, @@ -439,14 +439,14 @@ class CQHistMaker: public HistMaker { } } for (size_t i = 0; i < sketchs_.size(); ++i) { - common::WQuantileSketch::SummaryContainer out; + common::WXQuantileSketch::SummaryContainer out; sketchs_[i].GetSummary(&out); summary_array_[i].SetPrune(out, max_size); } CHECK_EQ(summary_array_.size(), sketchs_.size()); } if (summary_array_.size() != 0) { - size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(max_size); + size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size); sreducer_.Allreduce(dmlc::BeginPtr(summary_array_), nbytes, summary_array_.size()); } // now we get the final result of sketch, setup the cut @@ -457,8 +457,7 @@ class CQHistMaker: public HistMaker { for (unsigned int i : fset) { int offset = feat2workindex_[i]; if (offset >= 0) { - const WQSketch::Summary &a = - summary_array_[wid * work_set_size + offset]; + const WXQSketch::Summary &a = summary_array_[wid * work_set_size + offset]; for (size_t i = 1; i < a.size; ++i) { bst_float cpt = a.data[i].value - kRtEps; if (i == 1 || cpt > this->wspace_.cut.back()) { @@ -631,11 +630,11 @@ class CQHistMaker: public HistMaker { // node statistics std::vector node_stats_; // summary array - std::vector summary_array_; + std::vector summary_array_; // reducer for summary - rabit::SerializeReducer sreducer_; + rabit::SerializeReducer sreducer_; // per node, per feature sketch - std::vector > sketchs_; + std::vector > sketchs_; }; // global proposal diff --git a/src/tree/updater_skmaker.cc b/src/tree/updater_skmaker.cc index e70499c41ee6..69cb4e58bbed 100644 --- a/src/tree/updater_skmaker.cc +++ b/src/tree/updater_skmaker.cc @@ -75,7 +75,7 @@ class SketchMaker: public BaseMaker { } } // define the sketch we want to use - using WQSketch = common::WQuantileSketch; + using WXQSketch = common::WXQuantileSketch; private: // statistics needed in the gradient calculation @@ -152,12 +152,12 @@ class SketchMaker: public BaseMaker { // synchronize sketch summary_array_.resize(sketchs_.size()); for (size_t i = 0; i < sketchs_.size(); ++i) { - common::WQuantileSketch::SummaryContainer out; + common::WXQuantileSketch::SummaryContainer out; sketchs_[i].GetSummary(&out); summary_array_[i].Reserve(max_size); summary_array_[i].SetPrune(out, max_size); } - size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(max_size); + size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_size); sketch_reducer_.Allreduce(dmlc::BeginPtr(summary_array_), nbytes, summary_array_.size()); } // update sketch information in column fid @@ -301,10 +301,11 @@ class SketchMaker: public BaseMaker { p_tree->Stat(nid).base_weight = static_cast(node_sum.CalcWeight(param_)); p_tree->Stat(nid).sum_hess = static_cast(node_sum.sum_hess); } - inline void EnumerateSplit(const WQSketch::Summary &pos_grad, - const WQSketch::Summary &neg_grad, - const WQSketch::Summary &sum_hess, - const SKStats &node_sum, bst_uint fid, + inline void EnumerateSplit(const WXQSketch::Summary &pos_grad, + const WXQSketch::Summary &neg_grad, + const WXQSketch::Summary &sum_hess, + const SKStats &node_sum, + bst_uint fid, SplitEntry *best) { if (sum_hess.size == 0) return; double root_gain = node_sum.CalcGain(param_); @@ -327,9 +328,9 @@ class SketchMaker: public BaseMaker { feat_sum.sum_hess = sum_hess.data[sum_hess.size - 1].rmax; size_t ipos = 0, ineg = 0, ihess = 0; for (size_t i = 1; i < fsplits.size(); ++i) { - WQSketch::Entry pos = pos_grad.Query(fsplits[i], ipos); - WQSketch::Entry neg = neg_grad.Query(fsplits[i], ineg); - WQSketch::Entry hess = sum_hess.Query(fsplits[i], ihess); + WXQSketch::Entry pos = pos_grad.Query(fsplits[i], ipos); + WXQSketch::Entry neg = neg_grad.Query(fsplits[i], ineg); + WXQSketch::Entry hess = sum_hess.Query(fsplits[i], ihess); SKStats s, c; s.pos_grad = 0.5f * (pos.rmin + pos.rmax - pos.wmin); s.neg_grad = 0.5f * (neg.rmin + neg.rmax - neg.wmin); @@ -378,13 +379,13 @@ class SketchMaker: public BaseMaker { // node statistics std::vector node_stats_; // summary array - std::vector summary_array_; + std::vector summary_array_; // reducer for summary rabit::Reducer stats_reducer_; // reducer for summary - rabit::SerializeReducer sketch_reducer_; + rabit::SerializeReducer sketch_reducer_; // per node, per feature sketch - std::vector > sketchs_; + std::vector > sketchs_; }; XGBOOST_REGISTER_TREE_UPDATER(SketchMaker, "grow_skmaker")