@@ -138,28 +138,62 @@ void GetColumnSizesScan(int device,
138138 * \param column_sizes_scan Describes the boundaries of column segments in
139139 * sorted data
140140 */
141- void ExtractCuts (int device, Span<SketchEntry> cuts,
142- size_t num_cuts_per_feature, Span<Entry> sorted_data,
143- Span<size_t > column_sizes_scan) {
144- dh::LaunchN (device, cuts.size (), [=] __device__ (size_t idx) {
141+ void ExtractCuts (int device,
142+ size_t num_cuts_per_feature,
143+ Span<Entry const > sorted_data,
144+ Span<size_t const > column_sizes_scan,
145+ Span<SketchEntry> out_cuts) {
146+ dh::LaunchN (device, out_cuts.size (), [=] __device__ (size_t idx) {
145147 // Each thread is responsible for obtaining one cut from the sorted input
146148 size_t column_idx = idx / num_cuts_per_feature;
147149 size_t column_size =
148150 column_sizes_scan[column_idx + 1 ] - column_sizes_scan[column_idx];
149151 size_t num_available_cuts =
150- min (size_t (num_cuts_per_feature), column_size);
152+ min (static_cast < size_t > (num_cuts_per_feature), column_size);
151153 size_t cut_idx = idx % num_cuts_per_feature;
152154 if (cut_idx >= num_available_cuts) return ;
153-
154- Span<Entry> column_entries =
155+ Span<Entry const > column_entries =
155156 sorted_data.subspan (column_sizes_scan[column_idx], column_size);
156-
157- size_t rank = (column_entries.size () * cut_idx) / num_available_cuts;
158- auto value = column_entries[rank].fvalue ;
159- cuts[idx] = SketchEntry (rank, rank + 1 , 1 , value);
157+ size_t rank = (column_entries.size () * cut_idx) /
158+ static_cast <float >(num_available_cuts);
159+ if (cut_idx == num_available_cuts - 1 ) {
160+ // Because the cut values represent upper bounds of bins, we need to place special
161+ // treatment for last cut. When a data equals to the greates value of cuts, last
162+ // cut value of the same column is retured.
163+ rank = column_entries.size () - 1 ;
164+ }
165+ out_cuts[idx] = WQSketch::Entry (rank, rank + 1 , 1 ,
166+ column_entries[rank].fvalue );
160167 });
161168}
162169
170+ void ProcessBatch (int device, const SparsePage& page, size_t begin, size_t end,
171+ SketchContainer* sketch_container, int num_cuts,
172+ size_t num_columns) {
173+ dh::XGBCachingDeviceAllocator<char > alloc;
174+ const auto & host_data = page.data .ConstHostVector ();
175+ dh::caching_device_vector<Entry> sorted_entries (host_data.begin () + begin,
176+ host_data.begin () + end);
177+ thrust::sort (thrust::cuda::par (alloc), sorted_entries.begin (),
178+ sorted_entries.end (), EntryCompareOp ());
179+
180+ dh::caching_device_vector<size_t > column_sizes_scan;
181+ GetColumnSizesScan (device, &column_sizes_scan,
182+ {sorted_entries.data ().get (), sorted_entries.size ()},
183+ num_columns);
184+ thrust::host_vector<size_t > host_column_sizes_scan (column_sizes_scan);
185+
186+ dh::caching_device_vector<SketchEntry> cuts (num_columns * num_cuts);
187+ ExtractCuts (device, num_cuts,
188+ dh::ToSpan (sorted_entries),
189+ dh::ToSpan (column_sizes_scan),
190+ dh::ToSpan (cuts));
191+
192+ // add cuts into sketches
193+ thrust::host_vector<SketchEntry> host_cuts (cuts);
194+ sketch_container->Push (num_cuts, host_cuts, host_column_sizes_scan);
195+ }
196+
163197/* *
164198 * \brief Extracts the cuts from sorted data, considering weights.
165199 *
@@ -170,26 +204,27 @@ void ExtractCuts(int device, Span<SketchEntry> cuts,
170204 * \param weights_scan Inclusive scan of weights for each entry in sorted_data.
171205 * \param column_sizes_scan Describes the boundaries of column segments in sorted data.
172206 */
173- void ExtractWeightedCuts (int device, Span<SketchEntry> cuts,
174- size_t num_cuts_per_feature, Span<Entry> sorted_data,
207+ void ExtractWeightedCuts (int device,
208+ size_t num_cuts_per_feature,
209+ Span<Entry> sorted_data,
175210 Span<float > weights_scan,
176- Span<size_t > column_sizes_scan) {
211+ Span<size_t > column_sizes_scan,
212+ Span<SketchEntry> cuts) {
177213 dh::LaunchN (device, cuts.size (), [=] __device__ (size_t idx) {
178214 // Each thread is responsible for obtaining one cut from the sorted input
179215 size_t column_idx = idx / num_cuts_per_feature;
180216 size_t column_size =
181217 column_sizes_scan[column_idx + 1 ] - column_sizes_scan[column_idx];
182218 size_t num_available_cuts =
183- min (size_t (num_cuts_per_feature), column_size);
219+ min (static_cast < size_t > (num_cuts_per_feature), column_size);
184220 size_t cut_idx = idx % num_cuts_per_feature;
185221 if (cut_idx >= num_available_cuts) return ;
186-
187222 Span<Entry> column_entries =
188223 sorted_data.subspan (column_sizes_scan[column_idx], column_size);
189- Span<float > column_weights =
190- weights_scan.subspan (column_sizes_scan[column_idx], column_size);
191224
192- float total_column_weight = column_weights.back ();
225+ Span<float > column_weights_scan =
226+ weights_scan.subspan (column_sizes_scan[column_idx], column_size);
227+ float total_column_weight = column_weights_scan.back ();
193228 size_t sample_idx = 0 ;
194229 if (cut_idx == 0 ) {
195230 // First cut
@@ -204,70 +239,68 @@ void ExtractWeightedCuts(int device, Span<SketchEntry> cuts,
204239 } else {
205240 bst_float rank = (total_column_weight * cut_idx) /
206241 static_cast <float >(num_available_cuts);
207- sample_idx = thrust::upper_bound (thrust::seq, column_weights.begin (),
208- column_weights.end (), rank) -
209- column_weights.begin () - 1 ;
242+ sample_idx = thrust::upper_bound (thrust::seq,
243+ column_weights_scan.begin (),
244+ column_weights_scan.end (),
245+ rank) -
246+ column_weights_scan.begin ();
210247 sample_idx =
211- max (size_t (0 ), min (sample_idx, column_entries.size () - 1 ));
248+ max (static_cast <size_t >(0 ),
249+ min (sample_idx, column_entries.size () - 1 ));
212250 }
213251 // repeated values will be filtered out on the CPU
214- bst_float rmin = sample_idx > 0 ? column_weights [sample_idx - 1 ] : 0 ;
215- bst_float rmax = column_weights [sample_idx];
252+ bst_float rmin = sample_idx > 0 ? column_weights_scan [sample_idx - 1 ] : 0 ;
253+ bst_float rmax = column_weights_scan [sample_idx];
216254 cuts[idx] = WQSketch::Entry (rmin, rmax, rmax - rmin,
217255 column_entries[sample_idx].fvalue );
218256 });
219257}
220258
221- void ProcessBatch (int device, const SparsePage& page, size_t begin, size_t end,
222- SketchContainer* sketch_container, int num_cuts,
223- size_t num_columns) {
224- dh::XGBCachingDeviceAllocator<char > alloc;
225- const auto & host_data = page.data .ConstHostVector ();
226- dh::caching_device_vector<Entry> sorted_entries (host_data.begin () + begin,
227- host_data.begin () + end);
228- thrust::sort (thrust::cuda::par (alloc), sorted_entries.begin (),
229- sorted_entries.end (), EntryCompareOp ());
230-
231- dh::caching_device_vector<size_t > column_sizes_scan;
232- GetColumnSizesScan (device, &column_sizes_scan,
233- {sorted_entries.data ().get (), sorted_entries.size ()},
234- num_columns);
235- thrust::host_vector<size_t > host_column_sizes_scan (column_sizes_scan);
236-
237- dh::caching_device_vector<SketchEntry> cuts (num_columns * num_cuts);
238- ExtractCuts (device, {cuts.data ().get (), cuts.size ()}, num_cuts,
239- {sorted_entries.data ().get (), sorted_entries.size ()},
240- {column_sizes_scan.data ().get (), column_sizes_scan.size ()});
241-
242- // add cuts into sketches
243- thrust::host_vector<SketchEntry> host_cuts (cuts);
244- sketch_container->Push (num_cuts, host_cuts, host_column_sizes_scan);
245- }
246-
247259void ProcessWeightedBatch (int device, const SparsePage& page,
248260 Span<const float > weights, size_t begin, size_t end,
249261 SketchContainer* sketch_container, int num_cuts,
250- size_t num_columns) {
262+ size_t num_columns,
263+ bool is_ranking, Span<bst_group_t const > d_group_ptr) {
251264 dh::XGBCachingDeviceAllocator<char > alloc;
252265 const auto & host_data = page.data .ConstHostVector ();
253266 dh::caching_device_vector<Entry> sorted_entries (host_data.begin () + begin,
254- host_data.begin () + end);
267+ host_data.begin () + end);
255268
256269 // Binary search to assign weights to each element
257270 dh::caching_device_vector<float > temp_weights (sorted_entries.size ());
258271 auto d_temp_weights = temp_weights.data ().get ();
259272 page.offset .SetDevice (device);
260273 auto row_ptrs = page.offset .ConstDeviceSpan ();
261274 size_t base_rowid = page.base_rowid ;
262- dh::LaunchN (device, temp_weights.size (), [=] __device__ (size_t idx) {
263- size_t element_idx = idx + begin;
264- size_t ridx = thrust::upper_bound (thrust::seq, row_ptrs.begin (),
265- row_ptrs.end (), element_idx) -
266- row_ptrs.begin () - 1 ;
267- d_temp_weights[idx] = weights[ridx + base_rowid];
268- });
275+ if (is_ranking) {
276+ CHECK_GE (d_group_ptr.size (), 1 )
277+ << " Must have at least 1 group for ranking." ;
278+ CHECK_EQ (weights.size (), d_group_ptr.size () - 1 )
279+ << " Weight size should equal to number of groups." ;
280+ dh::LaunchN (device, temp_weights.size (), [=] __device__ (size_t idx) {
281+ size_t element_idx = idx + begin;
282+ size_t ridx = thrust::upper_bound (thrust::seq, row_ptrs.begin (),
283+ row_ptrs.end (), element_idx) -
284+ row_ptrs.begin () - 1 ;
285+ auto it =
286+ thrust::upper_bound (thrust::seq,
287+ d_group_ptr.cbegin (), d_group_ptr.cend (),
288+ ridx + base_rowid) - 1 ;
289+ bst_group_t group = thrust::distance (d_group_ptr.cbegin (), it);
290+ d_temp_weights[idx] = weights[group];
291+ });
292+ } else {
293+ CHECK_EQ (weights.size (), page.offset .Size () - 1 );
294+ dh::LaunchN (device, temp_weights.size (), [=] __device__ (size_t idx) {
295+ size_t element_idx = idx + begin;
296+ size_t ridx = thrust::upper_bound (thrust::seq, row_ptrs.begin (),
297+ row_ptrs.end (), element_idx) -
298+ row_ptrs.begin () - 1 ;
299+ d_temp_weights[idx] = weights[ridx + base_rowid];
300+ });
301+ }
269302
270- // Sort
303+ // Sort both entries and wegihts.
271304 thrust::sort_by_key (thrust::cuda::par (alloc), sorted_entries.begin (),
272305 sorted_entries.end (), temp_weights.begin (),
273306 EntryCompareOp ());
@@ -288,11 +321,11 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
288321
289322 // Extract cuts
290323 dh::caching_device_vector<SketchEntry> cuts (num_columns * num_cuts);
291- ExtractWeightedCuts (
292- device, {cuts. data (). get (), cuts. size ()}, num_cuts ,
293- {sorted_entries. data (). get (), sorted_entries. size ()} ,
294- {temp_weights. data (). get (), temp_weights. size ()} ,
295- {column_sizes_scan. data (). get (), column_sizes_scan. size ()} );
324+ ExtractWeightedCuts (device, num_cuts,
325+ dh::ToSpan (sorted_entries) ,
326+ dh::ToSpan (temp_weights) ,
327+ dh::ToSpan (column_sizes_scan) ,
328+ dh::ToSpan (cuts) );
296329
297330 // add cuts into sketches
298331 thrust::host_vector<SketchEntry> host_cuts (cuts);
@@ -320,13 +353,17 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
320353 dmat->Info ().weights_ .SetDevice (device);
321354 for (const auto & batch : dmat->GetBatches <SparsePage>()) {
322355 size_t batch_nnz = batch.data .Size ();
356+ auto const & info = dmat->Info ();
357+ dh::caching_device_vector<uint32_t > groups (info.group_ptr_ .size ());
358+ thrust::copy (info.group_ptr_ .cbegin (), info.group_ptr_ .cend (), groups.begin ());
323359 for (auto begin = 0ull ; begin < batch_nnz;
324360 begin += sketch_batch_num_elements) {
325361 size_t end = std::min (batch_nnz, size_t (begin + sketch_batch_num_elements));
326362 if (has_weights) {
363+ bool is_ranking = CutsBuilder::UseGroup (dmat);
327364 ProcessWeightedBatch (
328365 device, batch, dmat->Info ().weights_ .ConstDeviceSpan (), begin, end,
329- &sketch_container, num_cuts, dmat->Info ().num_col_ );
366+ &sketch_container, num_cuts, dmat->Info ().num_col_ , is_ranking, dh::ToSpan (groups) );
330367 } else {
331368 ProcessBatch (device, batch, begin, end, &sketch_container, num_cuts,
332369 dmat->Info ().num_col_ );
@@ -383,9 +420,10 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing,
383420
384421 // Extract the cuts from all columns concurrently
385422 dh::caching_device_vector<SketchEntry> cuts (adapter->NumColumns () * num_cuts);
386- ExtractCuts (adapter->DeviceIdx (), {cuts.data ().get (), cuts.size ()}, num_cuts,
387- {sorted_entries.data ().get (), sorted_entries.size ()},
388- {column_sizes_scan.data ().get (), column_sizes_scan.size ()});
423+ ExtractCuts (adapter->DeviceIdx (), num_cuts,
424+ dh::ToSpan (sorted_entries),
425+ dh::ToSpan (column_sizes_scan),
426+ dh::ToSpan (cuts));
389427
390428 // Push cuts into sketches stored in host memory
391429 thrust::host_vector<SketchEntry> host_cuts (cuts);
0 commit comments