@@ -29,7 +29,8 @@ class QuantileHistMock : public QuantileHistMaker {
2929 std::unique_ptr<SplitEvaluator> spliteval,
3030 FeatureInteractionConstraintHost int_constraint,
3131 DMatrix const * fmat)
32- : RealImpl(param, std::move(pruner), std::move(spliteval), std::move(int_constraint), fmat) {}
32+ : RealImpl(param, std::move(pruner), std::move(spliteval),
33+ std::move (int_constraint), fmat) {}
3334
3435 public:
3536 void TestInitData (const GHistIndexMatrix& gmat,
@@ -120,6 +121,147 @@ class QuantileHistMock : public QuantileHistMaker {
120121 omp_set_num_threads (nthreads);
121122 }
122123
124+ void TestAddHistRows (const GHistIndexMatrix& gmat,
125+ const std::vector<GradientPair>& gpair,
126+ DMatrix* p_fmat,
127+ RegTree* tree) {
128+ RealImpl::InitData (gmat, gpair, *p_fmat, *tree);
129+
130+ int starting_index = std::numeric_limits<int >::max ();
131+ int sync_count = 0 ;
132+ nodes_for_explicit_hist_build_.clear ();
133+ nodes_for_subtraction_trick_.clear ();
134+
135+ tree->ExpandNode (0 , 0 , 0 , false , 0 , 0 , 0 , 0 , 0 , 0 , 0 );
136+ tree->ExpandNode ((*tree)[0 ].LeftChild (), 0 , 0 , false , 0 , 0 , 0 , 0 , 0 , 0 , 0 );
137+ tree->ExpandNode ((*tree)[0 ].RightChild (), 0 , 0 , false , 0 , 0 , 0 , 0 , 0 , 0 , 0 );
138+ nodes_for_explicit_hist_build_.emplace_back (3 , 4 , tree->GetDepth (3 ), 0 .0f , 0 );
139+ nodes_for_explicit_hist_build_.emplace_back (4 , 3 , tree->GetDepth (4 ), 0 .0f , 0 );
140+ nodes_for_subtraction_trick_.emplace_back (5 , 6 , tree->GetDepth (5 ), 0 .0f , 0 );
141+ nodes_for_subtraction_trick_.emplace_back (6 , 5 , tree->GetDepth (6 ), 0 .0f , 0 );
142+
143+ hist_rows_adder_->AddHistRows (&starting_index, &sync_count, tree);
144+ ASSERT_EQ (sync_count, 2 );
145+ ASSERT_EQ (starting_index, 3 );
146+
147+ for (const ExpandEntry& node : nodes_for_explicit_hist_build_) {
148+ ASSERT_EQ (hist_.RowExists (node.nid ), true );
149+ }
150+ for (const ExpandEntry& node : nodes_for_subtraction_trick_) {
151+ ASSERT_EQ (hist_.RowExists (node.nid ), true );
152+ }
153+ }
154+
155+
156+ void TestSyncHistograms (const GHistIndexMatrix& gmat,
157+ const std::vector<GradientPair>& gpair,
158+ DMatrix* p_fmat,
159+ RegTree* tree) {
160+ // init
161+ RealImpl::InitData (gmat, gpair, *p_fmat, *tree);
162+
163+ int starting_index = std::numeric_limits<int >::max ();
164+ int sync_count = 0 ;
165+ nodes_for_explicit_hist_build_.clear ();
166+ nodes_for_subtraction_trick_.clear ();
167+ // level 0
168+ nodes_for_explicit_hist_build_.emplace_back (0 , -1 , tree->GetDepth (0 ), 0 .0f , 0 );
169+ hist_rows_adder_->AddHistRows (&starting_index, &sync_count, tree);
170+ tree->ExpandNode (0 , 0 , 0 , false , 0 , 0 , 0 , 0 , 0 , 0 , 0 );
171+
172+ nodes_for_explicit_hist_build_.clear ();
173+ nodes_for_subtraction_trick_.clear ();
174+ // level 1
175+ nodes_for_explicit_hist_build_.emplace_back ((*tree)[0 ].LeftChild (), (*tree)[0 ].RightChild (),
176+ tree->GetDepth (1 ), 0 .0f , 0 );
177+ nodes_for_subtraction_trick_.emplace_back ((*tree)[0 ].RightChild (), (*tree)[0 ].LeftChild (),
178+ tree->GetDepth (2 ), 0 .0f , 0 );
179+ hist_rows_adder_->AddHistRows (&starting_index, &sync_count, tree);
180+ tree->ExpandNode ((*tree)[0 ].LeftChild (), 0 , 0 , false , 0 , 0 , 0 , 0 , 0 , 0 , 0 );
181+ tree->ExpandNode ((*tree)[0 ].RightChild (), 0 , 0 , false , 0 , 0 , 0 , 0 , 0 , 0 , 0 );
182+
183+ nodes_for_explicit_hist_build_.clear ();
184+ nodes_for_subtraction_trick_.clear ();
185+ // level 2
186+ nodes_for_explicit_hist_build_.emplace_back (3 , 4 , tree->GetDepth (3 ), 0 .0f , 0 );
187+ nodes_for_subtraction_trick_.emplace_back (4 , 3 , tree->GetDepth (4 ), 0 .0f , 0 );
188+ nodes_for_explicit_hist_build_.emplace_back (5 , 6 , tree->GetDepth (5 ), 0 .0f , 0 );
189+ nodes_for_subtraction_trick_.emplace_back (6 , 5 , tree->GetDepth (6 ), 0 .0f , 0 );
190+ hist_rows_adder_->AddHistRows (&starting_index, &sync_count, tree);
191+
192+ const size_t n_nodes = nodes_for_explicit_hist_build_.size ();
193+ ASSERT_EQ (n_nodes, 2 );
194+ row_set_collection_.AddSplit (0 , (*tree)[0 ].LeftChild (),
195+ (*tree)[0 ].RightChild (), 4 , 4 );
196+ row_set_collection_.AddSplit (1 , (*tree)[1 ].LeftChild (),
197+ (*tree)[1 ].RightChild (), 2 , 2 );
198+ row_set_collection_.AddSplit (2 , (*tree)[2 ].LeftChild (),
199+ (*tree)[2 ].RightChild (), 2 , 2 );
200+
201+ common::BlockedSpace2d space (n_nodes, [&](size_t node) {
202+ const int32_t nid = nodes_for_explicit_hist_build_[node].nid ;
203+ return row_set_collection_[nid].Size ();
204+ }, 256 );
205+
206+ std::vector<GHistRow> target_hists (n_nodes);
207+ for (size_t i = 0 ; i < nodes_for_explicit_hist_build_.size (); ++i) {
208+ const int32_t nid = nodes_for_explicit_hist_build_[i].nid ;
209+ target_hists[i] = hist_[nid];
210+ }
211+
212+ const size_t nbins = hist_builder_.GetNumBins ();
213+ // set values to specific nodes hist
214+ std::vector<size_t > n_ids = {1 , 2 };
215+ for (size_t i : n_ids) {
216+ auto this_hist = hist_[i];
217+ using FPType = decltype (tree::GradStats::sum_grad);
218+ FPType* p_hist = reinterpret_cast <FPType*>(this_hist.data ());
219+ for (size_t bin_id = 0 ; bin_id < 2 *nbins; ++bin_id) {
220+ p_hist[bin_id] = 2 *bin_id;
221+ }
222+ }
223+ n_ids[0 ] = 3 ;
224+ n_ids[1 ] = 5 ;
225+ for (size_t i : n_ids) {
226+ auto this_hist = hist_[i];
227+ using FPType = decltype (tree::GradStats::sum_grad);
228+ FPType* p_hist = reinterpret_cast <FPType*>(this_hist.data ());
229+ for (size_t bin_id = 0 ; bin_id < 2 *nbins; ++bin_id) {
230+ p_hist[bin_id] = bin_id;
231+ }
232+ }
233+
234+ hist_buffer_.Reset (1 , n_nodes, space, target_hists);
235+ // sync hist
236+ hist_synchronizer_->SyncHistograms (starting_index, sync_count, tree);
237+
238+ auto check_hist = [] (const GHistRow parent, const GHistRow left,
239+ const GHistRow right, size_t begin, size_t end) {
240+ using FPType = decltype (tree::GradStats::sum_grad);
241+ const FPType* p_parent = reinterpret_cast <const FPType*>(parent.data ());
242+ const FPType* p_left = reinterpret_cast <const FPType*>(left.data ());
243+ const FPType* p_right = reinterpret_cast <const FPType*>(right.data ());
244+ for (size_t i = 2 * begin; i < 2 * end; ++i) {
245+ ASSERT_EQ (p_parent[i], p_left[i] + p_right[i]);
246+ }
247+ };
248+ for (const ExpandEntry& node : nodes_for_explicit_hist_build_) {
249+ auto this_hist = hist_[node.nid ];
250+ const size_t parent_id = (*tree)[node.nid ].Parent ();
251+ auto parent_hist = hist_[parent_id];
252+ auto sibling_hist = hist_[node.sibling_nid ];
253+
254+ check_hist (parent_hist, this_hist, sibling_hist, 0 , nbins);
255+ }
256+ for (const ExpandEntry& node : nodes_for_subtraction_trick_) {
257+ auto this_hist = hist_[node.nid ];
258+ const size_t parent_id = (*tree)[node.nid ].Parent ();
259+ auto parent_hist = hist_[parent_id];
260+ auto sibling_hist = hist_[node.sibling_nid ];
261+
262+ check_hist (parent_hist, this_hist, sibling_hist, 0 , nbins);
263+ }
264+ }
123265
124266 void TestBuildHist (int nid,
125267 const GHistIndexMatrix& gmat,
@@ -249,7 +391,6 @@ class QuantileHistMock : public QuantileHistMaker {
249391 TestEvaluateSplit (quantile_index_block, tree);
250392 omp_set_num_threads (1 );
251393 }
252-
253394 };
254395
255396 int static constexpr kNRows = 8 , kNCols = 16 ;
@@ -259,7 +400,7 @@ class QuantileHistMock : public QuantileHistMaker {
259400
260401 public:
261402 explicit QuantileHistMock (
262- const std::vector<std::pair<std::string, std::string> >& args) :
403+ const std::vector<std::pair<std::string, std::string> >& args, bool batch = true ) :
263404 cfg_{args} {
264405 QuantileHistMaker::Configure (args);
265406 spliteval_->Init (¶m_);
@@ -271,6 +412,13 @@ class QuantileHistMock : public QuantileHistMaker {
271412 std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone ()),
272413 int_constraint_,
273414 dmat_.get ()));
415+ if (batch) {
416+ builder_->SetHistSynchronizer (new BatchHistSynchronizer (builder_.get ()));
417+ builder_->SetHistRowsAdder (new BatchHistRowsAdder (builder_.get ()));
418+ } else {
419+ builder_->SetHistSynchronizer (new DistributedHistSynchronizer (builder_.get ()));
420+ builder_->SetHistRowsAdder (new DistributedHistRowsAdder (builder_.get ()));
421+ }
274422 }
275423 ~QuantileHistMock () override = default ;
276424
@@ -305,6 +453,34 @@ class QuantileHistMock : public QuantileHistMaker {
305453
306454 builder_->TestInitDataSampling (gmat, gpair, dmat_.get (), tree);
307455 }
456+
457+ void TestAddHistRows () {
458+ size_t constexpr kMaxBins = 4 ;
459+ common::GHistIndexMatrix gmat;
460+ gmat.Init (dmat_.get (), kMaxBins );
461+
462+ RegTree tree = RegTree ();
463+ tree.param .UpdateAllowUnknown (cfg_);
464+ std::vector<GradientPair> gpair =
465+ { {0 .23f , 0 .24f }, {0 .23f , 0 .24f }, {0 .23f , 0 .24f }, {0 .23f , 0 .24f },
466+ {0 .27f , 0 .29f }, {0 .27f , 0 .29f }, {0 .27f , 0 .29f }, {0 .27f , 0 .29f } };
467+ builder_->TestAddHistRows (gmat, gpair, dmat_.get (), &tree);
468+ }
469+
470+ void TestSyncHistograms () {
471+ size_t constexpr kMaxBins = 4 ;
472+ common::GHistIndexMatrix gmat;
473+ gmat.Init (dmat_.get (), kMaxBins );
474+
475+ RegTree tree = RegTree ();
476+ tree.param .UpdateAllowUnknown (cfg_);
477+ std::vector<GradientPair> gpair =
478+ { {0 .23f , 0 .24f }, {0 .23f , 0 .24f }, {0 .23f , 0 .24f }, {0 .23f , 0 .24f },
479+ {0 .27f , 0 .29f }, {0 .27f , 0 .29f }, {0 .27f , 0 .29f }, {0 .27f , 0 .29f } };
480+ builder_->TestSyncHistograms (gmat, gpair, dmat_.get (), &tree);
481+ }
482+
483+
308484 void TestBuildHist () {
309485 RegTree tree = RegTree ();
310486 tree.param .UpdateAllowUnknown (cfg_);
@@ -340,6 +516,34 @@ TEST(QuantileHist, InitDataSampling) {
340516 maker.TestInitDataSampling ();
341517}
342518
519+ TEST (QuantileHist, AddHistRows) {
520+ std::vector<std::pair<std::string, std::string>> cfg
521+ {{" num_feature" , std::to_string (QuantileHistMock::GetNumColumns ())}};
522+ QuantileHistMock maker (cfg);
523+ maker.TestAddHistRows ();
524+ }
525+
526+ TEST (QuantileHist, SyncHistograms) {
527+ std::vector<std::pair<std::string, std::string>> cfg
528+ {{" num_feature" , std::to_string (QuantileHistMock::GetNumColumns ())}};
529+ QuantileHistMock maker (cfg);
530+ maker.TestSyncHistograms ();
531+ }
532+
533+ TEST (QuantileHist, DistributedAddHistRows) {
534+ std::vector<std::pair<std::string, std::string>> cfg
535+ {{" num_feature" , std::to_string (QuantileHistMock::GetNumColumns ())}};
536+ QuantileHistMock maker (cfg, false );
537+ maker.TestAddHistRows ();
538+ }
539+
540+ TEST (QuantileHist, DistributedSyncHistograms) {
541+ std::vector<std::pair<std::string, std::string>> cfg
542+ {{" num_feature" , std::to_string (QuantileHistMock::GetNumColumns ())}};
543+ QuantileHistMock maker (cfg, false );
544+ maker.TestSyncHistograms ();
545+ }
546+
343547TEST (QuantileHist, BuildHist) {
344548 // Don't enable feature grouping
345549 std::vector<std::pair<std::string, std::string>> cfg
0 commit comments