Skip to content

Commit f36e5cf

Browse files
author
SHVETS, KIRILL
committed
tests were added and found bug was fixed(min->max)
1 parent a4d6e96 commit f36e5cf

File tree

3 files changed

+220
-10
lines changed

3 files changed

+220
-10
lines changed

src/tree/updater_quantile_hist.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ void DistributedHistRowsAdder::AddHistRows(int *starting_index, int *sync_count,
234234
builder_->hist_local_worker_.AddHistRow(nid);
235235
}
236236
}
237-
(*sync_count) = std::min(1, n_left);
237+
(*sync_count) = std::max(1, n_left);
238238
builder_->builder_monitor_.Stop("AddHistRows");
239239
}
240240

src/tree/updater_quantile_hist.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,9 @@ class HistSynchronizer {
400400
virtual void SyncHistograms(int starting_index,
401401
int sync_count,
402402
RegTree *p_tree) = 0;
403+
virtual ~HistSynchronizer() {
404+
builder_ = nullptr;
405+
}
403406

404407
protected:
405408
QuantileHistMaker::Builder* builder_;
@@ -409,19 +412,19 @@ class BatchHistSynchronizer: public HistSynchronizer {
409412
public:
410413
explicit BatchHistSynchronizer(QuantileHistMaker::Builder* builder): HistSynchronizer(builder) {}
411414

412-
virtual void SyncHistograms(int starting_index,
415+
void SyncHistograms(int starting_index,
413416
int sync_count,
414-
RegTree *p_tree);
417+
RegTree *p_tree) override;
415418
};
416419

417420
class DistributedHistSynchronizer: public HistSynchronizer {
418421
public:
419422
explicit DistributedHistSynchronizer(QuantileHistMaker::Builder* builder):
420423
HistSynchronizer(builder) {}
421424

422-
virtual void SyncHistograms(int starting_index,
425+
void SyncHistograms(int starting_index,
423426
int sync_count,
424-
RegTree *p_tree);
427+
RegTree *p_tree) override;
425428
void ParallelSubtractionHist(const common::BlockedSpace2d& space,
426429
const std::vector<QuantileHistMaker::Builder::ExpandEntry>& nodes,
427430
const RegTree * p_tree);
@@ -431,6 +434,9 @@ class HistRowsAdder {
431434
public:
432435
explicit HistRowsAdder(QuantileHistMaker::Builder* builder) : builder_(builder) {}
433436
virtual void AddHistRows(int *starting_index, int *sync_count, RegTree *p_tree) = 0;
437+
virtual ~HistRowsAdder() {
438+
builder_ = nullptr;
439+
}
434440

435441
protected:
436442
QuantileHistMaker::Builder* builder_;
@@ -440,14 +446,14 @@ class BatchHistRowsAdder: public HistRowsAdder {
440446
public:
441447
explicit BatchHistRowsAdder(QuantileHistMaker::Builder* builder) : HistRowsAdder(builder) {}
442448

443-
void AddHistRows(int *starting_index, int *sync_count, RegTree *p_tree);
449+
void AddHistRows(int *starting_index, int *sync_count, RegTree *p_tree) override;
444450
};
445451

446452
class DistributedHistRowsAdder: public HistRowsAdder {
447453
public:
448454
explicit DistributedHistRowsAdder(QuantileHistMaker::Builder* builder) : HistRowsAdder(builder) {}
449455

450-
void AddHistRows(int *starting_index, int *sync_count, RegTree *p_tree);
456+
void AddHistRows(int *starting_index, int *sync_count, RegTree *p_tree) override;
451457
};
452458

453459

tests/cpp/tree/test_quantile_hist.cc

Lines changed: 207 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ class QuantileHistMock : public QuantileHistMaker {
2929
std::unique_ptr<SplitEvaluator> spliteval,
3030
FeatureInteractionConstraintHost int_constraint,
3131
DMatrix const* fmat)
32-
: RealImpl(param, std::move(pruner), std::move(spliteval), std::move(int_constraint), fmat) {}
32+
: RealImpl(param, std::move(pruner), std::move(spliteval),
33+
std::move(int_constraint), fmat) {}
3334

3435
public:
3536
void TestInitData(const GHistIndexMatrix& gmat,
@@ -120,6 +121,147 @@ class QuantileHistMock : public QuantileHistMaker {
120121
omp_set_num_threads(nthreads);
121122
}
122123

124+
void TestAddHistRows(const GHistIndexMatrix& gmat,
125+
const std::vector<GradientPair>& gpair,
126+
DMatrix* p_fmat,
127+
RegTree* tree) {
128+
RealImpl::InitData(gmat, gpair, *p_fmat, *tree);
129+
130+
int starting_index = std::numeric_limits<int>::max();
131+
int sync_count = 0;
132+
nodes_for_explicit_hist_build_.clear();
133+
nodes_for_subtraction_trick_.clear();
134+
135+
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
136+
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
137+
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
138+
nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
139+
nodes_for_explicit_hist_build_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
140+
nodes_for_subtraction_trick_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
141+
nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
142+
143+
hist_rows_adder_->AddHistRows(&starting_index, &sync_count, tree);
144+
ASSERT_EQ(sync_count, 2);
145+
ASSERT_EQ(starting_index, 3);
146+
147+
for (const ExpandEntry& node : nodes_for_explicit_hist_build_) {
148+
ASSERT_EQ(hist_.RowExists(node.nid), true);
149+
}
150+
for (const ExpandEntry& node : nodes_for_subtraction_trick_) {
151+
ASSERT_EQ(hist_.RowExists(node.nid), true);
152+
}
153+
}
154+
155+
156+
void TestSyncHistograms(const GHistIndexMatrix& gmat,
157+
const std::vector<GradientPair>& gpair,
158+
DMatrix* p_fmat,
159+
RegTree* tree) {
160+
// init
161+
RealImpl::InitData(gmat, gpair, *p_fmat, *tree);
162+
163+
int starting_index = std::numeric_limits<int>::max();
164+
int sync_count = 0;
165+
nodes_for_explicit_hist_build_.clear();
166+
nodes_for_subtraction_trick_.clear();
167+
// level 0
168+
nodes_for_explicit_hist_build_.emplace_back(0, -1, tree->GetDepth(0), 0.0f, 0);
169+
hist_rows_adder_->AddHistRows(&starting_index, &sync_count, tree);
170+
tree->ExpandNode(0, 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
171+
172+
nodes_for_explicit_hist_build_.clear();
173+
nodes_for_subtraction_trick_.clear();
174+
// level 1
175+
nodes_for_explicit_hist_build_.emplace_back((*tree)[0].LeftChild(), (*tree)[0].RightChild(),
176+
tree->GetDepth(1), 0.0f, 0);
177+
nodes_for_subtraction_trick_.emplace_back((*tree)[0].RightChild(), (*tree)[0].LeftChild(),
178+
tree->GetDepth(2), 0.0f, 0);
179+
hist_rows_adder_->AddHistRows(&starting_index, &sync_count, tree);
180+
tree->ExpandNode((*tree)[0].LeftChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
181+
tree->ExpandNode((*tree)[0].RightChild(), 0, 0, false, 0, 0, 0, 0, 0, 0, 0);
182+
183+
nodes_for_explicit_hist_build_.clear();
184+
nodes_for_subtraction_trick_.clear();
185+
// level 2
186+
nodes_for_explicit_hist_build_.emplace_back(3, 4, tree->GetDepth(3), 0.0f, 0);
187+
nodes_for_subtraction_trick_.emplace_back(4, 3, tree->GetDepth(4), 0.0f, 0);
188+
nodes_for_explicit_hist_build_.emplace_back(5, 6, tree->GetDepth(5), 0.0f, 0);
189+
nodes_for_subtraction_trick_.emplace_back(6, 5, tree->GetDepth(6), 0.0f, 0);
190+
hist_rows_adder_->AddHistRows(&starting_index, &sync_count, tree);
191+
192+
const size_t n_nodes = nodes_for_explicit_hist_build_.size();
193+
ASSERT_EQ(n_nodes, 2);
194+
row_set_collection_.AddSplit(0, (*tree)[0].LeftChild(),
195+
(*tree)[0].RightChild(), 4, 4);
196+
row_set_collection_.AddSplit(1, (*tree)[1].LeftChild(),
197+
(*tree)[1].RightChild(), 2, 2);
198+
row_set_collection_.AddSplit(2, (*tree)[2].LeftChild(),
199+
(*tree)[2].RightChild(), 2, 2);
200+
201+
common::BlockedSpace2d space(n_nodes, [&](size_t node) {
202+
const int32_t nid = nodes_for_explicit_hist_build_[node].nid;
203+
return row_set_collection_[nid].Size();
204+
}, 256);
205+
206+
std::vector<GHistRow> target_hists(n_nodes);
207+
for (size_t i = 0; i < nodes_for_explicit_hist_build_.size(); ++i) {
208+
const int32_t nid = nodes_for_explicit_hist_build_[i].nid;
209+
target_hists[i] = hist_[nid];
210+
}
211+
212+
const size_t nbins = hist_builder_.GetNumBins();
213+
// set values to specific nodes hist
214+
std::vector<size_t> n_ids = {1, 2};
215+
for (size_t i : n_ids) {
216+
auto this_hist = hist_[i];
217+
using FPType = decltype(tree::GradStats::sum_grad);
218+
FPType* p_hist = reinterpret_cast<FPType*>(this_hist.data());
219+
for (size_t bin_id = 0; bin_id < 2*nbins; ++bin_id) {
220+
p_hist[bin_id] = 2*bin_id;
221+
}
222+
}
223+
n_ids[0] = 3;
224+
n_ids[1] = 5;
225+
for (size_t i : n_ids) {
226+
auto this_hist = hist_[i];
227+
using FPType = decltype(tree::GradStats::sum_grad);
228+
FPType* p_hist = reinterpret_cast<FPType*>(this_hist.data());
229+
for (size_t bin_id = 0; bin_id < 2*nbins; ++bin_id) {
230+
p_hist[bin_id] = bin_id;
231+
}
232+
}
233+
234+
hist_buffer_.Reset(1, n_nodes, space, target_hists);
235+
// sync hist
236+
hist_synchronizer_->SyncHistograms(starting_index, sync_count, tree);
237+
238+
auto check_hist = [] (const GHistRow parent, const GHistRow left,
239+
const GHistRow right, size_t begin, size_t end) {
240+
using FPType = decltype(tree::GradStats::sum_grad);
241+
const FPType* p_parent = reinterpret_cast<const FPType*>(parent.data());
242+
const FPType* p_left = reinterpret_cast<const FPType*>(left.data());
243+
const FPType* p_right = reinterpret_cast<const FPType*>(right.data());
244+
for (size_t i = 2 * begin; i < 2 * end; ++i) {
245+
ASSERT_EQ(p_parent[i], p_left[i] + p_right[i]);
246+
}
247+
};
248+
for (const ExpandEntry& node : nodes_for_explicit_hist_build_) {
249+
auto this_hist = hist_[node.nid];
250+
const size_t parent_id = (*tree)[node.nid].Parent();
251+
auto parent_hist = hist_[parent_id];
252+
auto sibling_hist = hist_[node.sibling_nid];
253+
254+
check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
255+
}
256+
for (const ExpandEntry& node : nodes_for_subtraction_trick_) {
257+
auto this_hist = hist_[node.nid];
258+
const size_t parent_id = (*tree)[node.nid].Parent();
259+
auto parent_hist = hist_[parent_id];
260+
auto sibling_hist = hist_[node.sibling_nid];
261+
262+
check_hist(parent_hist, this_hist, sibling_hist, 0, nbins);
263+
}
264+
}
123265

124266
void TestBuildHist(int nid,
125267
const GHistIndexMatrix& gmat,
@@ -249,7 +391,6 @@ class QuantileHistMock : public QuantileHistMaker {
249391
TestEvaluateSplit(quantile_index_block, tree);
250392
omp_set_num_threads(1);
251393
}
252-
253394
};
254395

255396
int static constexpr kNRows = 8, kNCols = 16;
@@ -259,7 +400,7 @@ class QuantileHistMock : public QuantileHistMaker {
259400

260401
public:
261402
explicit QuantileHistMock(
262-
const std::vector<std::pair<std::string, std::string> >& args) :
403+
const std::vector<std::pair<std::string, std::string> >& args, bool batch = true) :
263404
cfg_{args} {
264405
QuantileHistMaker::Configure(args);
265406
spliteval_->Init(&param_);
@@ -271,6 +412,13 @@ class QuantileHistMock : public QuantileHistMaker {
271412
std::unique_ptr<SplitEvaluator>(spliteval_->GetHostClone()),
272413
int_constraint_,
273414
dmat_.get()));
415+
if (batch) {
416+
builder_->SetHistSynchronizer(new BatchHistSynchronizer(builder_.get()));
417+
builder_->SetHistRowsAdder(new BatchHistRowsAdder(builder_.get()));
418+
} else {
419+
builder_->SetHistSynchronizer(new DistributedHistSynchronizer(builder_.get()));
420+
builder_->SetHistRowsAdder(new DistributedHistRowsAdder(builder_.get()));
421+
}
274422
}
275423
~QuantileHistMock() override = default;
276424

@@ -305,6 +453,34 @@ class QuantileHistMock : public QuantileHistMaker {
305453

306454
builder_->TestInitDataSampling(gmat, gpair, dmat_.get(), tree);
307455
}
456+
457+
void TestAddHistRows() {
458+
size_t constexpr kMaxBins = 4;
459+
common::GHistIndexMatrix gmat;
460+
gmat.Init(dmat_.get(), kMaxBins);
461+
462+
RegTree tree = RegTree();
463+
tree.param.UpdateAllowUnknown(cfg_);
464+
std::vector<GradientPair> gpair =
465+
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
466+
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
467+
builder_->TestAddHistRows(gmat, gpair, dmat_.get(), &tree);
468+
}
469+
470+
void TestSyncHistograms() {
471+
size_t constexpr kMaxBins = 4;
472+
common::GHistIndexMatrix gmat;
473+
gmat.Init(dmat_.get(), kMaxBins);
474+
475+
RegTree tree = RegTree();
476+
tree.param.UpdateAllowUnknown(cfg_);
477+
std::vector<GradientPair> gpair =
478+
{ {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f},
479+
{0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} };
480+
builder_->TestSyncHistograms(gmat, gpair, dmat_.get(), &tree);
481+
}
482+
483+
308484
void TestBuildHist() {
309485
RegTree tree = RegTree();
310486
tree.param.UpdateAllowUnknown(cfg_);
@@ -340,6 +516,34 @@ TEST(QuantileHist, InitDataSampling) {
340516
maker.TestInitDataSampling();
341517
}
342518

519+
TEST(QuantileHist, AddHistRows) {
520+
std::vector<std::pair<std::string, std::string>> cfg
521+
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
522+
QuantileHistMock maker(cfg);
523+
maker.TestAddHistRows();
524+
}
525+
526+
TEST(QuantileHist, SyncHistograms) {
527+
std::vector<std::pair<std::string, std::string>> cfg
528+
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
529+
QuantileHistMock maker(cfg);
530+
maker.TestSyncHistograms();
531+
}
532+
533+
TEST(QuantileHist, DistributedAddHistRows) {
534+
std::vector<std::pair<std::string, std::string>> cfg
535+
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
536+
QuantileHistMock maker(cfg, false);
537+
maker.TestAddHistRows();
538+
}
539+
540+
TEST(QuantileHist, DistributedSyncHistograms) {
541+
std::vector<std::pair<std::string, std::string>> cfg
542+
{{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}};
543+
QuantileHistMock maker(cfg, false);
544+
maker.TestSyncHistograms();
545+
}
546+
343547
TEST(QuantileHist, BuildHist) {
344548
// Don't enable feature grouping
345549
std::vector<std::pair<std::string, std::string>> cfg

0 commit comments

Comments
 (0)