Skip to content

Commit 19abc7b

Browse files
committed
Group aware GPU weighted sketching.
1 parent d6d1035 commit 19abc7b

File tree

5 files changed

+205
-110
lines changed

5 files changed

+205
-110
lines changed

src/common/hist_util.cu

Lines changed: 108 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -138,28 +138,62 @@ void GetColumnSizesScan(int device,
138138
* \param column_sizes_scan Describes the boundaries of column segments in
139139
* sorted data
140140
*/
141-
void ExtractCuts(int device, Span<SketchEntry> cuts,
142-
size_t num_cuts_per_feature, Span<Entry> sorted_data,
143-
Span<size_t> column_sizes_scan) {
144-
dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) {
141+
void ExtractCuts(int device,
142+
size_t num_cuts_per_feature,
143+
Span<Entry const> sorted_data,
144+
Span<size_t const> column_sizes_scan,
145+
Span<SketchEntry> out_cuts) {
146+
dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) {
145147
// Each thread is responsible for obtaining one cut from the sorted input
146148
size_t column_idx = idx / num_cuts_per_feature;
147149
size_t column_size =
148150
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
149151
size_t num_available_cuts =
150-
min(size_t(num_cuts_per_feature), column_size);
152+
min(static_cast<size_t>(num_cuts_per_feature), column_size);
151153
size_t cut_idx = idx % num_cuts_per_feature;
152154
if (cut_idx >= num_available_cuts) return;
153-
154-
Span<Entry> column_entries =
155+
Span<Entry const> column_entries =
155156
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
156-
157-
size_t rank = (column_entries.size() * cut_idx) / num_available_cuts;
158-
auto value = column_entries[rank].fvalue;
159-
cuts[idx] = SketchEntry(rank, rank + 1, 1, value);
157+
size_t rank = (column_entries.size() * cut_idx) /
158+
static_cast<float>(num_available_cuts);
159+
if (cut_idx == num_available_cuts - 1) {
160+
// Because the cut values represent upper bounds of bins, we need to place special
161+
// treatment for last cut. When a data equals to the greates value of cuts, last
162+
// cut value of the same column is retured.
163+
rank = column_entries.size() - 1;
164+
}
165+
out_cuts[idx] = WQSketch::Entry(rank, rank + 1, 1,
166+
column_entries[rank].fvalue);
160167
});
161168
}
162169

170+
void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
171+
SketchContainer* sketch_container, int num_cuts,
172+
size_t num_columns) {
173+
dh::XGBCachingDeviceAllocator<char> alloc;
174+
const auto& host_data = page.data.ConstHostVector();
175+
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
176+
host_data.begin() + end);
177+
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
178+
sorted_entries.end(), EntryCompareOp());
179+
180+
dh::caching_device_vector<size_t> column_sizes_scan;
181+
GetColumnSizesScan(device, &column_sizes_scan,
182+
{sorted_entries.data().get(), sorted_entries.size()},
183+
num_columns);
184+
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);
185+
186+
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts);
187+
ExtractCuts(device, num_cuts,
188+
dh::ToSpan(sorted_entries),
189+
dh::ToSpan(column_sizes_scan),
190+
dh::ToSpan(cuts));
191+
192+
// add cuts into sketches
193+
thrust::host_vector<SketchEntry> host_cuts(cuts);
194+
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
195+
}
196+
163197
/**
164198
* \brief Extracts the cuts from sorted data, considering weights.
165199
*
@@ -170,26 +204,27 @@ void ExtractCuts(int device, Span<SketchEntry> cuts,
170204
* \param weights_scan Inclusive scan of weights for each entry in sorted_data.
171205
* \param column_sizes_scan Describes the boundaries of column segments in sorted data.
172206
*/
173-
void ExtractWeightedCuts(int device, Span<SketchEntry> cuts,
174-
size_t num_cuts_per_feature, Span<Entry> sorted_data,
207+
void ExtractWeightedCuts(int device,
208+
size_t num_cuts_per_feature,
209+
Span<Entry> sorted_data,
175210
Span<float> weights_scan,
176-
Span<size_t> column_sizes_scan) {
211+
Span<size_t> column_sizes_scan,
212+
Span<SketchEntry> cuts) {
177213
dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) {
178214
// Each thread is responsible for obtaining one cut from the sorted input
179215
size_t column_idx = idx / num_cuts_per_feature;
180216
size_t column_size =
181217
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
182218
size_t num_available_cuts =
183-
min(size_t(num_cuts_per_feature), column_size);
219+
min(static_cast<size_t>(num_cuts_per_feature), column_size);
184220
size_t cut_idx = idx % num_cuts_per_feature;
185221
if (cut_idx >= num_available_cuts) return;
186-
187222
Span<Entry> column_entries =
188223
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
189-
Span<float> column_weights =
190-
weights_scan.subspan(column_sizes_scan[column_idx], column_size);
191224

192-
float total_column_weight = column_weights.back();
225+
Span<float> column_weights_scan =
226+
weights_scan.subspan(column_sizes_scan[column_idx], column_size);
227+
float total_column_weight = column_weights_scan.back();
193228
size_t sample_idx = 0;
194229
if (cut_idx == 0) {
195230
// First cut
@@ -204,70 +239,68 @@ void ExtractWeightedCuts(int device, Span<SketchEntry> cuts,
204239
} else {
205240
bst_float rank = (total_column_weight * cut_idx) /
206241
static_cast<float>(num_available_cuts);
207-
sample_idx = thrust::upper_bound(thrust::seq, column_weights.begin(),
208-
column_weights.end(), rank) -
209-
column_weights.begin() - 1;
242+
sample_idx = thrust::upper_bound(thrust::seq,
243+
column_weights_scan.begin(),
244+
column_weights_scan.end(),
245+
rank) -
246+
column_weights_scan.begin();
210247
sample_idx =
211-
max(size_t(0), min(sample_idx, column_entries.size() - 1));
248+
max(static_cast<size_t>(0),
249+
min(sample_idx, column_entries.size() - 1));
212250
}
213251
// repeated values will be filtered out on the CPU
214-
bst_float rmin = sample_idx > 0 ? column_weights[sample_idx - 1] : 0;
215-
bst_float rmax = column_weights[sample_idx];
252+
bst_float rmin = sample_idx > 0 ? column_weights_scan[sample_idx - 1] : 0;
253+
bst_float rmax = column_weights_scan[sample_idx];
216254
cuts[idx] = WQSketch::Entry(rmin, rmax, rmax - rmin,
217255
column_entries[sample_idx].fvalue);
218256
});
219257
}
220258

221-
void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
222-
SketchContainer* sketch_container, int num_cuts,
223-
size_t num_columns) {
224-
dh::XGBCachingDeviceAllocator<char> alloc;
225-
const auto& host_data = page.data.ConstHostVector();
226-
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
227-
host_data.begin() + end);
228-
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
229-
sorted_entries.end(), EntryCompareOp());
230-
231-
dh::caching_device_vector<size_t> column_sizes_scan;
232-
GetColumnSizesScan(device, &column_sizes_scan,
233-
{sorted_entries.data().get(), sorted_entries.size()},
234-
num_columns);
235-
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);
236-
237-
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts);
238-
ExtractCuts(device, {cuts.data().get(), cuts.size()}, num_cuts,
239-
{sorted_entries.data().get(), sorted_entries.size()},
240-
{column_sizes_scan.data().get(), column_sizes_scan.size()});
241-
242-
// add cuts into sketches
243-
thrust::host_vector<SketchEntry> host_cuts(cuts);
244-
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
245-
}
246-
247259
void ProcessWeightedBatch(int device, const SparsePage& page,
248260
Span<const float> weights, size_t begin, size_t end,
249261
SketchContainer* sketch_container, int num_cuts,
250-
size_t num_columns) {
262+
size_t num_columns,
263+
bool is_ranking, Span<bst_group_t const> d_group_ptr) {
251264
dh::XGBCachingDeviceAllocator<char> alloc;
252265
const auto& host_data = page.data.ConstHostVector();
253266
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
254-
host_data.begin() + end);
267+
host_data.begin() + end);
255268

256269
// Binary search to assign weights to each element
257270
dh::caching_device_vector<float> temp_weights(sorted_entries.size());
258271
auto d_temp_weights = temp_weights.data().get();
259272
page.offset.SetDevice(device);
260273
auto row_ptrs = page.offset.ConstDeviceSpan();
261274
size_t base_rowid = page.base_rowid;
262-
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
263-
size_t element_idx = idx + begin;
264-
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
265-
row_ptrs.end(), element_idx) -
266-
row_ptrs.begin() - 1;
267-
d_temp_weights[idx] = weights[ridx + base_rowid];
268-
});
275+
if (is_ranking) {
276+
CHECK_GE(d_group_ptr.size(), 1)
277+
<< "Must have at least 1 group for ranking.";
278+
CHECK_EQ(weights.size(), d_group_ptr.size() - 1)
279+
<< "Weight size should equal to number of groups.";
280+
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
281+
size_t element_idx = idx + begin;
282+
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
283+
row_ptrs.end(), element_idx) -
284+
row_ptrs.begin() - 1;
285+
auto it =
286+
thrust::upper_bound(thrust::seq,
287+
d_group_ptr.cbegin(), d_group_ptr.cend(),
288+
ridx + base_rowid) - 1;
289+
bst_group_t group = thrust::distance(d_group_ptr.cbegin(), it);
290+
d_temp_weights[idx] = weights[group];
291+
});
292+
} else {
293+
CHECK_EQ(weights.size(), page.offset.Size() - 1);
294+
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
295+
size_t element_idx = idx + begin;
296+
size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(),
297+
row_ptrs.end(), element_idx) -
298+
row_ptrs.begin() - 1;
299+
d_temp_weights[idx] = weights[ridx + base_rowid];
300+
});
301+
}
269302

270-
// Sort
303+
// Sort both entries and wegihts.
271304
thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries.begin(),
272305
sorted_entries.end(), temp_weights.begin(),
273306
EntryCompareOp());
@@ -288,11 +321,11 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
288321

289322
// Extract cuts
290323
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts);
291-
ExtractWeightedCuts(
292-
device, {cuts.data().get(), cuts.size()}, num_cuts,
293-
{sorted_entries.data().get(), sorted_entries.size()},
294-
{temp_weights.data().get(), temp_weights.size()},
295-
{column_sizes_scan.data().get(), column_sizes_scan.size()});
324+
ExtractWeightedCuts(device, num_cuts,
325+
dh::ToSpan(sorted_entries),
326+
dh::ToSpan(temp_weights),
327+
dh::ToSpan(column_sizes_scan),
328+
dh::ToSpan(cuts));
296329

297330
// add cuts into sketches
298331
thrust::host_vector<SketchEntry> host_cuts(cuts);
@@ -320,13 +353,17 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
320353
dmat->Info().weights_.SetDevice(device);
321354
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
322355
size_t batch_nnz = batch.data.Size();
356+
auto const& info = dmat->Info();
357+
dh::caching_device_vector<uint32_t> groups(info.group_ptr_.size());
358+
thrust::copy(info.group_ptr_.cbegin(), info.group_ptr_.cend(), groups.begin());
323359
for (auto begin = 0ull; begin < batch_nnz;
324360
begin += sketch_batch_num_elements) {
325361
size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements));
326362
if (has_weights) {
363+
bool is_ranking = CutsBuilder::UseGroup(dmat);
327364
ProcessWeightedBatch(
328365
device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end,
329-
&sketch_container, num_cuts, dmat->Info().num_col_);
366+
&sketch_container, num_cuts, dmat->Info().num_col_, is_ranking, dh::ToSpan(groups));
330367
} else {
331368
ProcessBatch(device, batch, begin, end, &sketch_container, num_cuts,
332369
dmat->Info().num_col_);
@@ -383,9 +420,10 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
383420

384421
// Extract the cuts from all columns concurrently
385422
dh::caching_device_vector<SketchEntry> cuts(adapter->NumColumns() * num_cuts);
386-
ExtractCuts(adapter->DeviceIdx(), {cuts.data().get(), cuts.size()}, num_cuts,
387-
{sorted_entries.data().get(), sorted_entries.size()},
388-
{column_sizes_scan.data().get(), column_sizes_scan.size()});
423+
ExtractCuts(adapter->DeviceIdx(), num_cuts,
424+
dh::ToSpan(sorted_entries),
425+
dh::ToSpan(column_sizes_scan),
426+
dh::ToSpan(cuts));
389427

390428
// Push cuts into sketches stored in host memory
391429
thrust::host_vector<SketchEntry> host_cuts(cuts);

src/common/hist_util.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,11 +127,11 @@ class HistogramCuts {
127127
class CutsBuilder {
128128
public:
129129
using WQSketch = common::WQuantileSketch<bst_float, bst_float>;
130+
/* \brief return whether group for ranking is used. */
131+
static bool UseGroup(DMatrix* dmat);
130132

131133
protected:
132134
HistogramCuts* p_cuts_;
133-
/* \brief return whether group for ranking is used. */
134-
static bool UseGroup(DMatrix* dmat);
135135

136136
public:
137137
explicit CutsBuilder(HistogramCuts* p_cuts) : p_cuts_{p_cuts} {}

tests/cpp/common/test_hist_util.cu

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,19 @@
33

44
#include <algorithm>
55
#include <cmath>
6-
7-
86
#include <thrust/device_vector.h>
97

10-
#include "xgboost/c_api.h"
8+
#include <xgboost/data.h>
9+
#include <xgboost/c_api.h>
1110

11+
#include "test_hist_util.h"
12+
#include "../helpers.h"
13+
#include "../data/test_array_interface.h"
1214
#include "../../../src/common/device_helpers.cuh"
1315
#include "../../../src/common/hist_util.h"
14-
15-
#include "../helpers.h"
16-
#include <xgboost/data.h>
1716
#include "../../../src/data/device_adapter.cuh"
18-
#include "../data/test_array_interface.h"
1917
#include "../../../src/common/math.h"
2018
#include "../../../src/data/simple_dmatrix.h"
21-
#include "test_hist_util.h"
2219
#include "../../../include/xgboost/logging.h"
2320

2421
namespace xgboost {
@@ -190,8 +187,7 @@ TEST(HistUtil, DeviceSketchMultipleColumnsExternal) {
190187
}
191188
}
192189

193-
TEST(HistUtil, AdapterDeviceSketch)
194-
{
190+
TEST(HistUtil, AdapterDeviceSketch) {
195191
int rows = 5;
196192
int cols = 1;
197193
int num_bins = 4;
@@ -235,7 +231,7 @@ TEST(HistUtil, AdapterDeviceSketchMemory) {
235231
bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant);
236232
}
237233

238-
TEST(HistUtil, AdapterDeviceSketchCategorical) {
234+
TEST(HistUtil, AdapterDeviceSketchCategorical) {
239235
int categorical_sizes[] = {2, 6, 8, 12};
240236
int num_bins = 256;
241237
int sizes[] = {25, 100, 1000};
@@ -268,6 +264,7 @@ TEST(HistUtil, AdapterDeviceSketchMultipleColumns) {
268264
}
269265
}
270266
}
267+
271268
TEST(HistUtil, AdapterDeviceSketchBatches) {
272269
int num_bins = 256;
273270
int num_rows = 5000;
@@ -305,7 +302,38 @@ TEST(HistUtil, SketchingEquivalent) {
305302
EXPECT_EQ(dmat_cuts.MinValues(), adapter_cuts.MinValues());
306303
}
307304
}
305+
}
306+
307+
TEST(HistUtil, DeviceSketchFromGroupWeights) {
308+
size_t constexpr kRows = 3000, kCols = 200, kBins = 256;
309+
size_t constexpr kGroups = 10;
310+
auto m = RandomDataGenerator {kRows, kCols, 0}.GenerateDMatrix();
311+
auto& h_weights = m->Info().weights_.HostVector();
312+
h_weights.resize(kRows);
313+
std::fill(h_weights.begin(), h_weights.end(), 1.0f);
314+
std::vector<bst_group_t> groups(kGroups);
315+
for (size_t i = 0; i < kGroups; ++i) {
316+
groups[i] = kRows / kGroups;
317+
}
318+
m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups);
319+
HistogramCuts weighted_cuts = DeviceSketch(0, m.get(), kBins, 0);
308320

321+
h_weights.clear();
322+
HistogramCuts cuts = DeviceSketch(0, m.get(), kBins, 0);
323+
324+
ASSERT_EQ(cuts.Values().size(), weighted_cuts.Values().size());
325+
ASSERT_EQ(cuts.MinValues().size(), weighted_cuts.MinValues().size());
326+
ASSERT_EQ(cuts.Ptrs().size(), weighted_cuts.Ptrs().size());
327+
328+
for (size_t i = 0; i < cuts.Values().size(); ++i) {
329+
EXPECT_EQ(cuts.Values()[i], weighted_cuts.Values()[i]) << "i:"<< i;
330+
}
331+
for (size_t i = 0; i < cuts.MinValues().size(); ++i) {
332+
ASSERT_EQ(cuts.MinValues()[i], weighted_cuts.MinValues()[i]);
333+
}
334+
for (size_t i = 0; i < cuts.Ptrs().size(); ++i) {
335+
ASSERT_EQ(cuts.Ptrs().at(i), weighted_cuts.Ptrs().at(i));
336+
}
309337
}
310338
} // namespace common
311339
} // namespace xgboost

0 commit comments

Comments
 (0)