@@ -77,6 +77,13 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
7777 std::move (pruner_),
7878 std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone ()),
7979 int_constraint_, dmat));
80+ if (rabit::IsDistributed ()) {
81+ builder_->SetHistSynchronizer (new DistributedHistSynchronizer (builder_.get ()));
82+ builder_->SetHistRowsAdder (new DistributedHistRowsAdder (builder_.get ()));
83+ } else {
84+ builder_->SetHistSynchronizer (new BatchHistSynchronizer (builder_.get ()));
85+ builder_->SetHistRowsAdder (new BatchHistRowsAdder (builder_.get ()));
86+ }
8087 }
8188 for (auto tree : trees) {
8289 builder_->Update (gmat_, gmatb_, column_matrix_, gpair, dmat, tree);
@@ -97,72 +104,146 @@ bool QuantileHistMaker::UpdatePredictionCache(
97104 }
98105}
99106
100- void QuantileHistMaker::Builder::ParallelSubtractionHist (const common::BlockedSpace2d& space,
101- const std::vector<ExpandEntry>& nodes,
102- const RegTree * p_tree) {
103- common::ParallelFor2d (space, this ->nthread_ , [&](size_t node, common::Range1d r) {
104- const auto entry = nodes[node];
105- if (!((*p_tree)[entry.nid ].IsLeftChild ())) {
106- auto this_hist = hist_[entry.nid ];
107+ void BatchHistSynchronizer::SyncHistograms (int starting_index,
108+ int sync_count,
109+ RegTree *p_tree) {
110+ builder_->builder_monitor_ .Start (" SyncHistograms" );
111+ const size_t nbins = builder_->hist_builder_ .GetNumBins ();
112+ common::BlockedSpace2d space (builder_->nodes_for_explicit_hist_build_ .size (), [&](size_t node) {
113+ return nbins;
114+ }, 1024 );
107115
108- if (!(*p_tree)[entry.nid ].IsRoot () && entry.sibling_nid > -1 ) {
109- auto parent_hist = hist_[(*p_tree)[entry.nid ].Parent ()];
110- auto sibling_hist = hist_[entry.sibling_nid ];
111- SubtractionHist (this_hist, parent_hist, sibling_hist, r.begin (), r.end ());
112- }
116+ common::ParallelFor2d (space, builder_->nthread_ , [&](size_t node, common::Range1d r) {
117+ const auto entry = builder_->nodes_for_explicit_hist_build_ [node];
118+ auto this_hist = builder_->hist_ [entry.nid ];
119+ // Merging histograms from each thread into once
120+ builder_->hist_buffer_ .ReduceHist (node, r.begin (), r.end ());
121+
122+ if (!(*p_tree)[entry.nid ].IsRoot () && entry.sibling_nid > -1 ) {
123+ const size_t parent_id = (*p_tree)[entry.nid ].Parent ();
124+ auto parent_hist = builder_->hist_ [parent_id];
125+ auto sibling_hist = builder_->hist_ [entry.sibling_nid ];
126+ SubtractionHist (sibling_hist, parent_hist, this_hist, r.begin (), r.end ());
113127 }
114128 });
129+ builder_->builder_monitor_ .Stop (" SyncHistograms" );
115130}
116131
117- void QuantileHistMaker::Builder::SyncHistograms (
118- int starting_index,
119- int sync_count,
120- RegTree *p_tree) {
121- builder_monitor_.Start (" SyncHistograms" );
122-
123- const bool isDistributed = rabit::IsDistributed ();
124- const size_t nbins = hist_builder_.GetNumBins ();
125- common::BlockedSpace2d space (nodes_for_explicit_hist_build_.size (), [&](size_t node) {
132+ void DistributedHistSynchronizer::SyncHistograms (int starting_index,
133+ int sync_count,
134+ RegTree *p_tree) {
135+ builder_->builder_monitor_ .Start (" SyncHistograms" );
136+ const size_t nbins = builder_->hist_builder_ .GetNumBins ();
137+ common::BlockedSpace2d space (builder_->nodes_for_explicit_hist_build_ .size (), [&](size_t node) {
126138 return nbins;
127139 }, 1024 );
128- common::ParallelFor2d (space, this ->nthread_ , [&](size_t node, common::Range1d r) {
129- const auto entry = nodes_for_explicit_hist_build_[node];
130- auto this_hist = hist_[entry.nid ];
140+ common::ParallelFor2d (space, builder_ ->nthread_ , [&](size_t node, common::Range1d r) {
141+ const auto entry = builder_-> nodes_for_explicit_hist_build_ [node];
142+ auto this_hist = builder_-> hist_ [entry.nid ];
131143 // Merging histograms from each thread into once
132- hist_buffer_.ReduceHist (node, r.begin (), r.end ());
133- if (isDistributed) {
134- // Store posible parent node
135- auto this_local = hist_local_worker_[entry.nid ];
136- CopyHist (this_local, this_hist, r.begin (), r.end ());
137- }
144+ builder_->hist_buffer_ .ReduceHist (node, r.begin (), r.end ());
145+ // Store posible parent node
146+ auto this_local = builder_->hist_local_worker_ [entry.nid ];
147+ CopyHist (this_local, this_hist, r.begin (), r.end ());
138148
139149 if (!(*p_tree)[entry.nid ].IsRoot () && entry.sibling_nid > -1 ) {
140150 const size_t parent_id = (*p_tree)[entry.nid ].Parent ();
141- auto parent_hist = isDistributed ? hist_local_worker_[parent_id] : hist_ [parent_id];
142- auto sibling_hist = hist_[entry.sibling_nid ];
151+ auto parent_hist = builder_-> hist_local_worker_ [parent_id];
152+ auto sibling_hist = builder_-> hist_ [entry.sibling_nid ];
143153 SubtractionHist (sibling_hist, parent_hist, this_hist, r.begin (), r.end ());
144- if (isDistributed) {
145- // Store posible parent node
146- auto sibling_local = hist_local_worker_[entry.sibling_nid ];
147- CopyHist (sibling_local, sibling_hist, r.begin (), r.end ());
154+ // Store posible parent node
155+ auto sibling_local = builder_->hist_local_worker_ [entry.sibling_nid ];
156+ CopyHist (sibling_local, sibling_hist, r.begin (), r.end ());
157+ }
158+ });
159+ builder_->builder_monitor_ .Start (" SyncHistogramsAllreduce" );
160+ this ->builder_ ->histred_ .Allreduce (builder_->hist_ [starting_index].data (),
161+ builder_->hist_builder_ .GetNumBins () * sync_count);
162+ builder_->builder_monitor_ .Stop (" SyncHistogramsAllreduce" );
163+
164+ ParallelSubtractionHist (space, builder_->nodes_for_explicit_hist_build_ , p_tree);
165+
166+ common::BlockedSpace2d space2 (builder_->nodes_for_subtraction_trick_ .size (), [&](size_t node) {
167+ return nbins;
168+ }, 1024 );
169+ ParallelSubtractionHist (space2, builder_->nodes_for_subtraction_trick_ , p_tree);
170+ builder_->builder_monitor_ .Stop (" SyncHistograms" );
171+ }
172+
173+ void DistributedHistSynchronizer::ParallelSubtractionHist (const common::BlockedSpace2d& space,
174+ const std::vector<QuantileHistMaker::Builder::ExpandEntry>& nodes,
175+ const RegTree * p_tree) {
176+ common::ParallelFor2d (space, builder_->nthread_ , [&](size_t node, common::Range1d r) {
177+ const auto entry = nodes[node];
178+ if (!((*p_tree)[entry.nid ].IsLeftChild ())) {
179+ auto this_hist = builder_->hist_ [entry.nid ];
180+
181+ if (!(*p_tree)[entry.nid ].IsRoot () && entry.sibling_nid > -1 ) {
182+ auto parent_hist = builder_->hist_ [(*p_tree)[entry.nid ].Parent ()];
183+ auto sibling_hist = builder_->hist_ [entry.sibling_nid ];
184+ SubtractionHist (this_hist, parent_hist, sibling_hist, r.begin (), r.end ());
148185 }
149186 }
150187 });
188+ }
189+
190+ void BatchHistRowsAdder::AddHistRows (int *starting_index, int *sync_count,
191+ RegTree *p_tree) {
192+ builder_->builder_monitor_ .Start (" AddHistRows" );
193+
194+ for (auto const & entry : builder_->nodes_for_explicit_hist_build_ ) {
195+ int nid = entry.nid ;
196+ builder_->hist_ .AddHistRow (nid);
197+ (*starting_index) = std::min (nid, (*starting_index));
198+ }
199+ (*sync_count) = builder_->nodes_for_explicit_hist_build_ .size ();
151200
152- if (isDistributed) {
153- builder_monitor_.Start (" SyncHistogramsAllreduce" );
154- this ->histred_ .Allreduce (hist_[starting_index].data (), hist_builder_.GetNumBins () * sync_count);
155- builder_monitor_.Stop (" SyncHistogramsAllreduce" );
201+ for (auto const & node : builder_->nodes_for_subtraction_trick_ ) {
202+ builder_->hist_ .AddHistRow (node.nid );
203+ }
156204
157- ParallelSubtractionHist (space, nodes_for_explicit_hist_build_, p_tree);
205+ builder_->builder_monitor_ .Stop (" AddHistRows" );
206+ }
158207
159- common::BlockedSpace2d space2 (nodes_for_subtraction_trick_.size (), [&](size_t node) {
160- return nbins;
161- }, 1024 );
162- ParallelSubtractionHist (space2, nodes_for_subtraction_trick_, p_tree);
208+ void DistributedHistRowsAdder::AddHistRows (int *starting_index, int *sync_count,
209+ RegTree *p_tree) {
210+ builder_->builder_monitor_ .Start (" AddHistRows" );
211+ const size_t explicit_size = builder_->nodes_for_explicit_hist_build_ .size ();
212+ const size_t subtaction_size = builder_->nodes_for_subtraction_trick_ .size ();
213+ std::vector<int > merged_node_ids (explicit_size + subtaction_size);
214+ for (size_t i = 0 ; i < explicit_size; ++i) {
215+ merged_node_ids[i] = builder_->nodes_for_explicit_hist_build_ [i].nid ;
216+ }
217+ for (size_t i = 0 ; i < subtaction_size; ++i) {
218+ merged_node_ids[explicit_size + i] =
219+ builder_->nodes_for_subtraction_trick_ [i].nid ;
220+ }
221+ std::sort (merged_node_ids.begin (), merged_node_ids.end ());
222+ int n_left = 0 ;
223+ for (auto const & nid : merged_node_ids) {
224+ if ((*p_tree)[nid].IsLeftChild ()) {
225+ builder_->hist_ .AddHistRow (nid);
226+ (*starting_index) = std::min (nid, (*starting_index));
227+ n_left++;
228+ builder_->hist_local_worker_ .AddHistRow (nid);
229+ }
163230 }
231+ for (auto const & nid : merged_node_ids) {
232+ if (!((*p_tree)[nid].IsLeftChild ())) {
233+ builder_->hist_ .AddHistRow (nid);
234+ builder_->hist_local_worker_ .AddHistRow (nid);
235+ }
236+ }
237+ (*sync_count) = std::min (1 , n_left);
238+ builder_->builder_monitor_ .Stop (" AddHistRows" );
239+ }
240+
241+ void QuantileHistMaker::Builder::SetHistSynchronizer (HistSynchronizer* sync) {
242+ hist_synchronizer_.reset (sync);
243+ }
164244
165- builder_monitor_.Stop (" SyncHistograms" );
245+ void QuantileHistMaker::Builder::SetHistRowsAdder (HistRowsAdder* adder) {
246+ hist_rows_adder_.reset (adder);
166247}
167248
168249void QuantileHistMaker::Builder::BuildHistogramsLossGuide (
@@ -183,56 +264,11 @@ void QuantileHistMaker::Builder::BuildHistogramsLossGuide(
183264 int starting_index = std::numeric_limits<int >::max ();
184265 int sync_count = 0 ;
185266
186- AddHistRows (&starting_index, &sync_count, p_tree);
267+ hist_rows_adder_-> AddHistRows (&starting_index, &sync_count, p_tree);
187268 BuildLocalHistograms (gmat, gmatb, p_tree, gpair_h);
188- SyncHistograms (starting_index, sync_count, p_tree);
189- }
190-
191-
192- void QuantileHistMaker::Builder::AddHistRows (int *starting_index, int *sync_count,
193- RegTree *p_tree) {
194- builder_monitor_.Start (" AddHistRows" );
195-
196- std::vector<int > merged_hist (nodes_for_explicit_hist_build_.size () +
197- nodes_for_subtraction_trick_.size ());
198- for (size_t i = 0 ; i < nodes_for_explicit_hist_build_.size (); ++i) {
199- merged_hist[i] = nodes_for_explicit_hist_build_[i].nid ;
200- }
201- for (size_t i = 0 ; i < nodes_for_subtraction_trick_.size (); ++i) {
202- merged_hist[nodes_for_explicit_hist_build_.size () + i] =
203- nodes_for_subtraction_trick_[i].nid ;
204- }
205- std::sort (merged_hist.begin (), merged_hist.end ());
206- int n_left = 0 ;
207- for (auto const & nid : merged_hist) {
208- if ((*p_tree)[nid].IsLeftChild ()) {
209- hist_.AddHistRow (nid);
210- (*starting_index) = std::min (nid, (*starting_index));
211- n_left++;
212- if (rabit::IsDistributed ()) {
213- hist_local_worker_.AddHistRow (nid);
214- }
215- }
216- }
217- for (auto const & nid : merged_hist) {
218- if (!((*p_tree)[nid].IsLeftChild ())) {
219- hist_.AddHistRow (nid);
220- if (rabit::IsDistributed ()) {
221- hist_local_worker_.AddHistRow (nid);
222- }
223- }
224- }
225-
226- if (n_left == 0 ) {
227- (*sync_count) = 1 ;
228- } else {
229- (*sync_count) = n_left;
230- }
231-
232- builder_monitor_.Stop (" AddHistRows" );
269+ hist_synchronizer_->SyncHistograms (starting_index, sync_count, p_tree);
233270}
234271
235-
236272void QuantileHistMaker::Builder::BuildLocalHistograms (
237273 const GHistIndexMatrix &gmat,
238274 const GHistIndexBlockMatrix &gmatb,
@@ -407,10 +443,9 @@ void QuantileHistMaker::Builder::ExpandWithDepthWise(
407443 std::vector<ExpandEntry> temp_qexpand_depth;
408444 SplitSiblings (qexpand_depth_wise_, &nodes_for_explicit_hist_build_,
409445 &nodes_for_subtraction_trick_, p_tree);
410- AddHistRows (&starting_index, &sync_count, p_tree);
411-
446+ hist_rows_adder_->AddHistRows (&starting_index, &sync_count, p_tree);
412447 BuildLocalHistograms (gmat, gmatb, p_tree, gpair_h);
413- SyncHistograms (starting_index, sync_count, p_tree);
448+ hist_synchronizer_-> SyncHistograms (starting_index, sync_count, p_tree);
414449 BuildNodeStats (gmat, p_fmat, p_tree, gpair_h);
415450
416451 EvaluateAndApplySplits (gmat, column_matrix, p_tree, &num_leaves, depth, ×tamp,
0 commit comments