Skip to content

Commit a4d6e96

Browse files
author
SHVETS, KIRILL
committed
strategy pattern was applied
1 parent 4d4b11e commit a4d6e96

File tree

2 files changed

+206
-103
lines changed

2 files changed

+206
-103
lines changed

src/tree/updater_quantile_hist.cc

Lines changed: 131 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
7777
std::move(pruner_),
7878
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
7979
int_constraint_, dmat));
80+
if (rabit::IsDistributed()) {
81+
builder_->SetHistSynchronizer(new DistributedHistSynchronizer(builder_.get()));
82+
builder_->SetHistRowsAdder(new DistributedHistRowsAdder(builder_.get()));
83+
} else {
84+
builder_->SetHistSynchronizer(new BatchHistSynchronizer(builder_.get()));
85+
builder_->SetHistRowsAdder(new BatchHistRowsAdder(builder_.get()));
86+
}
8087
}
8188
for (auto tree : trees) {
8289
builder_->Update(gmat_, gmatb_, column_matrix_, gpair, dmat, tree);
@@ -97,72 +104,146 @@ bool QuantileHistMaker::UpdatePredictionCache(
97104
}
98105
}
99106

100-
void QuantileHistMaker::Builder::ParallelSubtractionHist(const common::BlockedSpace2d& space,
101-
const std::vector<ExpandEntry>& nodes,
102-
const RegTree * p_tree) {
103-
common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) {
104-
const auto entry = nodes[node];
105-
if (!((*p_tree)[entry.nid].IsLeftChild())) {
106-
auto this_hist = hist_[entry.nid];
107+
void BatchHistSynchronizer::SyncHistograms(int starting_index,
108+
int sync_count,
109+
RegTree *p_tree) {
110+
builder_->builder_monitor_.Start("SyncHistograms");
111+
const size_t nbins = builder_->hist_builder_.GetNumBins();
112+
common::BlockedSpace2d space(builder_->nodes_for_explicit_hist_build_.size(), [&](size_t node) {
113+
return nbins;
114+
}, 1024);
107115

108-
if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) {
109-
auto parent_hist = hist_[(*p_tree)[entry.nid].Parent()];
110-
auto sibling_hist = hist_[entry.sibling_nid];
111-
SubtractionHist(this_hist, parent_hist, sibling_hist, r.begin(), r.end());
112-
}
116+
common::ParallelFor2d(space, builder_->nthread_, [&](size_t node, common::Range1d r) {
117+
const auto entry = builder_->nodes_for_explicit_hist_build_[node];
118+
auto this_hist = builder_->hist_[entry.nid];
119+
// Merging histograms from each thread into once
120+
builder_->hist_buffer_.ReduceHist(node, r.begin(), r.end());
121+
122+
if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) {
123+
const size_t parent_id = (*p_tree)[entry.nid].Parent();
124+
auto parent_hist = builder_->hist_[parent_id];
125+
auto sibling_hist = builder_->hist_[entry.sibling_nid];
126+
SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
113127
}
114128
});
129+
builder_->builder_monitor_.Stop("SyncHistograms");
115130
}
116131

117-
void QuantileHistMaker::Builder::SyncHistograms(
118-
int starting_index,
119-
int sync_count,
120-
RegTree *p_tree) {
121-
builder_monitor_.Start("SyncHistograms");
122-
123-
const bool isDistributed = rabit::IsDistributed();
124-
const size_t nbins = hist_builder_.GetNumBins();
125-
common::BlockedSpace2d space(nodes_for_explicit_hist_build_.size(), [&](size_t node) {
132+
void DistributedHistSynchronizer::SyncHistograms(int starting_index,
133+
int sync_count,
134+
RegTree *p_tree) {
135+
builder_->builder_monitor_.Start("SyncHistograms");
136+
const size_t nbins = builder_->hist_builder_.GetNumBins();
137+
common::BlockedSpace2d space(builder_->nodes_for_explicit_hist_build_.size(), [&](size_t node) {
126138
return nbins;
127139
}, 1024);
128-
common::ParallelFor2d(space, this->nthread_, [&](size_t node, common::Range1d r) {
129-
const auto entry = nodes_for_explicit_hist_build_[node];
130-
auto this_hist = hist_[entry.nid];
140+
common::ParallelFor2d(space, builder_->nthread_, [&](size_t node, common::Range1d r) {
141+
const auto entry = builder_->nodes_for_explicit_hist_build_[node];
142+
auto this_hist = builder_->hist_[entry.nid];
131143
// Merging histograms from each thread into once
132-
hist_buffer_.ReduceHist(node, r.begin(), r.end());
133-
if (isDistributed) {
134-
// Store posible parent node
135-
auto this_local = hist_local_worker_[entry.nid];
136-
CopyHist(this_local, this_hist, r.begin(), r.end());
137-
}
144+
builder_->hist_buffer_.ReduceHist(node, r.begin(), r.end());
145+
// Store posible parent node
146+
auto this_local = builder_->hist_local_worker_[entry.nid];
147+
CopyHist(this_local, this_hist, r.begin(), r.end());
138148

139149
if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) {
140150
const size_t parent_id = (*p_tree)[entry.nid].Parent();
141-
auto parent_hist = isDistributed ? hist_local_worker_[parent_id] : hist_[parent_id];
142-
auto sibling_hist = hist_[entry.sibling_nid];
151+
auto parent_hist = builder_->hist_local_worker_[parent_id];
152+
auto sibling_hist = builder_->hist_[entry.sibling_nid];
143153
SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
144-
if (isDistributed) {
145-
// Store posible parent node
146-
auto sibling_local = hist_local_worker_[entry.sibling_nid];
147-
CopyHist(sibling_local, sibling_hist, r.begin(), r.end());
154+
// Store posible parent node
155+
auto sibling_local = builder_->hist_local_worker_[entry.sibling_nid];
156+
CopyHist(sibling_local, sibling_hist, r.begin(), r.end());
157+
}
158+
});
159+
builder_->builder_monitor_.Start("SyncHistogramsAllreduce");
160+
this->builder_->histred_.Allreduce(builder_->hist_[starting_index].data(),
161+
builder_->hist_builder_.GetNumBins() * sync_count);
162+
builder_->builder_monitor_.Stop("SyncHistogramsAllreduce");
163+
164+
ParallelSubtractionHist(space, builder_->nodes_for_explicit_hist_build_, p_tree);
165+
166+
common::BlockedSpace2d space2(builder_->nodes_for_subtraction_trick_.size(), [&](size_t node) {
167+
return nbins;
168+
}, 1024);
169+
ParallelSubtractionHist(space2, builder_->nodes_for_subtraction_trick_, p_tree);
170+
builder_->builder_monitor_.Stop("SyncHistograms");
171+
}
172+
173+
void DistributedHistSynchronizer::ParallelSubtractionHist(const common::BlockedSpace2d& space,
174+
const std::vector<QuantileHistMaker::Builder::ExpandEntry>& nodes,
175+
const RegTree * p_tree) {
176+
common::ParallelFor2d(space, builder_->nthread_, [&](size_t node, common::Range1d r) {
177+
const auto entry = nodes[node];
178+
if (!((*p_tree)[entry.nid].IsLeftChild())) {
179+
auto this_hist = builder_->hist_[entry.nid];
180+
181+
if (!(*p_tree)[entry.nid].IsRoot() && entry.sibling_nid > -1) {
182+
auto parent_hist = builder_->hist_[(*p_tree)[entry.nid].Parent()];
183+
auto sibling_hist = builder_->hist_[entry.sibling_nid];
184+
SubtractionHist(this_hist, parent_hist, sibling_hist, r.begin(), r.end());
148185
}
149186
}
150187
});
188+
}
189+
190+
void BatchHistRowsAdder::AddHistRows(int *starting_index, int *sync_count,
191+
RegTree *p_tree) {
192+
builder_->builder_monitor_.Start("AddHistRows");
193+
194+
for (auto const& entry : builder_->nodes_for_explicit_hist_build_) {
195+
int nid = entry.nid;
196+
builder_->hist_.AddHistRow(nid);
197+
(*starting_index) = std::min(nid, (*starting_index));
198+
}
199+
(*sync_count) = builder_->nodes_for_explicit_hist_build_.size();
151200

152-
if (isDistributed) {
153-
builder_monitor_.Start("SyncHistogramsAllreduce");
154-
this->histred_.Allreduce(hist_[starting_index].data(), hist_builder_.GetNumBins() * sync_count);
155-
builder_monitor_.Stop("SyncHistogramsAllreduce");
201+
for (auto const& node : builder_->nodes_for_subtraction_trick_) {
202+
builder_->hist_.AddHistRow(node.nid);
203+
}
156204

157-
ParallelSubtractionHist(space, nodes_for_explicit_hist_build_, p_tree);
205+
builder_->builder_monitor_.Stop("AddHistRows");
206+
}
158207

159-
common::BlockedSpace2d space2(nodes_for_subtraction_trick_.size(), [&](size_t node) {
160-
return nbins;
161-
}, 1024);
162-
ParallelSubtractionHist(space2, nodes_for_subtraction_trick_, p_tree);
208+
void DistributedHistRowsAdder::AddHistRows(int *starting_index, int *sync_count,
209+
RegTree *p_tree) {
210+
builder_->builder_monitor_.Start("AddHistRows");
211+
const size_t explicit_size = builder_->nodes_for_explicit_hist_build_.size();
212+
const size_t subtaction_size = builder_->nodes_for_subtraction_trick_.size();
213+
std::vector<int> merged_node_ids(explicit_size + subtaction_size);
214+
for (size_t i = 0; i < explicit_size; ++i) {
215+
merged_node_ids[i] = builder_->nodes_for_explicit_hist_build_[i].nid;
216+
}
217+
for (size_t i = 0; i < subtaction_size; ++i) {
218+
merged_node_ids[explicit_size + i] =
219+
builder_->nodes_for_subtraction_trick_[i].nid;
220+
}
221+
std::sort(merged_node_ids.begin(), merged_node_ids.end());
222+
int n_left = 0;
223+
for (auto const& nid : merged_node_ids) {
224+
if ((*p_tree)[nid].IsLeftChild()) {
225+
builder_->hist_.AddHistRow(nid);
226+
(*starting_index) = std::min(nid, (*starting_index));
227+
n_left++;
228+
builder_->hist_local_worker_.AddHistRow(nid);
229+
}
163230
}
231+
for (auto const& nid : merged_node_ids) {
232+
if (!((*p_tree)[nid].IsLeftChild())) {
233+
builder_->hist_.AddHistRow(nid);
234+
builder_->hist_local_worker_.AddHistRow(nid);
235+
}
236+
}
237+
(*sync_count) = std::min(1, n_left);
238+
builder_->builder_monitor_.Stop("AddHistRows");
239+
}
240+
241+
void QuantileHistMaker::Builder::SetHistSynchronizer(HistSynchronizer* sync) {
242+
hist_synchronizer_.reset(sync);
243+
}
164244

165-
builder_monitor_.Stop("SyncHistograms");
245+
void QuantileHistMaker::Builder::SetHistRowsAdder(HistRowsAdder* adder) {
246+
hist_rows_adder_.reset(adder);
166247
}
167248

168249
void QuantileHistMaker::Builder::BuildHistogramsLossGuide(
@@ -183,56 +264,11 @@ void QuantileHistMaker::Builder::BuildHistogramsLossGuide(
183264
int starting_index = std::numeric_limits<int>::max();
184265
int sync_count = 0;
185266

186-
AddHistRows(&starting_index, &sync_count, p_tree);
267+
hist_rows_adder_->AddHistRows(&starting_index, &sync_count, p_tree);
187268
BuildLocalHistograms(gmat, gmatb, p_tree, gpair_h);
188-
SyncHistograms(starting_index, sync_count, p_tree);
189-
}
190-
191-
192-
void QuantileHistMaker::Builder::AddHistRows(int *starting_index, int *sync_count,
193-
RegTree *p_tree) {
194-
builder_monitor_.Start("AddHistRows");
195-
196-
std::vector<int> merged_hist(nodes_for_explicit_hist_build_.size() +
197-
nodes_for_subtraction_trick_.size());
198-
for (size_t i = 0; i < nodes_for_explicit_hist_build_.size(); ++i) {
199-
merged_hist[i] = nodes_for_explicit_hist_build_[i].nid;
200-
}
201-
for (size_t i = 0; i < nodes_for_subtraction_trick_.size(); ++i) {
202-
merged_hist[nodes_for_explicit_hist_build_.size() + i] =
203-
nodes_for_subtraction_trick_[i].nid;
204-
}
205-
std::sort(merged_hist.begin(), merged_hist.end());
206-
int n_left = 0;
207-
for (auto const& nid : merged_hist) {
208-
if ((*p_tree)[nid].IsLeftChild()) {
209-
hist_.AddHistRow(nid);
210-
(*starting_index) = std::min(nid, (*starting_index));
211-
n_left++;
212-
if (rabit::IsDistributed()) {
213-
hist_local_worker_.AddHistRow(nid);
214-
}
215-
}
216-
}
217-
for (auto const& nid : merged_hist) {
218-
if (!((*p_tree)[nid].IsLeftChild())) {
219-
hist_.AddHistRow(nid);
220-
if (rabit::IsDistributed()) {
221-
hist_local_worker_.AddHistRow(nid);
222-
}
223-
}
224-
}
225-
226-
if (n_left == 0) {
227-
(*sync_count) = 1;
228-
} else {
229-
(*sync_count) = n_left;
230-
}
231-
232-
builder_monitor_.Stop("AddHistRows");
269+
hist_synchronizer_->SyncHistograms(starting_index, sync_count, p_tree);
233270
}
234271

235-
236272
void QuantileHistMaker::Builder::BuildLocalHistograms(
237273
const GHistIndexMatrix &gmat,
238274
const GHistIndexBlockMatrix &gmatb,
@@ -407,10 +443,9 @@ void QuantileHistMaker::Builder::ExpandWithDepthWise(
407443
std::vector<ExpandEntry> temp_qexpand_depth;
408444
SplitSiblings(qexpand_depth_wise_, &nodes_for_explicit_hist_build_,
409445
&nodes_for_subtraction_trick_, p_tree);
410-
AddHistRows(&starting_index, &sync_count, p_tree);
411-
446+
hist_rows_adder_->AddHistRows(&starting_index, &sync_count, p_tree);
412447
BuildLocalHistograms(gmat, gmatb, p_tree, gpair_h);
413-
SyncHistograms(starting_index, sync_count, p_tree);
448+
hist_synchronizer_->SyncHistograms(starting_index, sync_count, p_tree);
414449
BuildNodeStats(gmat, p_fmat, p_tree, gpair_h);
415450

416451
EvaluateAndApplySplits(gmat, column_matrix, p_tree, &num_leaves, depth, &timestamp,

0 commit comments

Comments
 (0)