diff --git a/include/xgboost/data.h b/include/xgboost/data.h index c2a80576c395..2eb036d287f4 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -87,11 +87,23 @@ class MetaInfo { this->weights_.Resize(that.weights_.Size()); this->weights_.Copy(that.weights_); + this->base_margin_.Resize(that.base_margin_.Size()); this->base_margin_.Copy(that.base_margin_); + + this->labels_lower_bound_.Resize(that.labels_lower_bound_.Size()); + this->labels_lower_bound_.Copy(that.labels_lower_bound_); + + this->labels_upper_bound_.Resize(that.labels_upper_bound_.Size()); + this->labels_upper_bound_.Copy(that.labels_upper_bound_); return *this; } + /*! + * \brief Validate all metainfo. + */ + void Validate() const; + MetaInfo Slice(common::Span ridxs) const; /*! * \brief Get weight of each instances. diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 93d02d0ef18a..ca07f83cca18 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -138,25 +138,26 @@ void GetColumnSizesScan(int device, * \param column_sizes_scan Describes the boundaries of column segments in * sorted data */ -void ExtractCuts(int device, Span cuts, - size_t num_cuts_per_feature, Span sorted_data, - Span column_sizes_scan) { - dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) { +void ExtractCuts(int device, + size_t num_cuts_per_feature, + Span sorted_data, + Span column_sizes_scan, + Span out_cuts) { + dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) { // Each thread is responsible for obtaining one cut from the sorted input size_t column_idx = idx / num_cuts_per_feature; size_t column_size = column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx]; size_t num_available_cuts = - min(size_t(num_cuts_per_feature), column_size); + min(static_cast(num_cuts_per_feature), column_size); size_t cut_idx = idx % num_cuts_per_feature; if (cut_idx >= num_available_cuts) return; - - Span column_entries = + Span column_entries = sorted_data.subspan(column_sizes_scan[column_idx], column_size); - - size_t rank = (column_entries.size() * cut_idx) / num_available_cuts; - auto value = column_entries[rank].fvalue; - cuts[idx] = SketchEntry(rank, rank + 1, 1, value); + size_t rank = (column_entries.size() * cut_idx) / + static_cast(num_available_cuts); + out_cuts[idx] = WQSketch::Entry(rank, rank + 1, 1, + column_entries[rank].fvalue); }); } @@ -170,31 +171,32 @@ void ExtractCuts(int device, Span cuts, * \param weights_scan Inclusive scan of weights for each entry in sorted_data. * \param column_sizes_scan Describes the boundaries of column segments in sorted data. */ -void ExtractWeightedCuts(int device, Span cuts, - size_t num_cuts_per_feature, Span sorted_data, +void ExtractWeightedCuts(int device, + size_t num_cuts_per_feature, + Span sorted_data, Span weights_scan, - Span column_sizes_scan) { + Span column_sizes_scan, + Span cuts) { dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) { // Each thread is responsible for obtaining one cut from the sorted input size_t column_idx = idx / num_cuts_per_feature; size_t column_size = column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx]; size_t num_available_cuts = - min(size_t(num_cuts_per_feature), column_size); + min(static_cast(num_cuts_per_feature), column_size); size_t cut_idx = idx % num_cuts_per_feature; if (cut_idx >= num_available_cuts) return; - Span column_entries = sorted_data.subspan(column_sizes_scan[column_idx], column_size); - Span column_weights = - weights_scan.subspan(column_sizes_scan[column_idx], column_size); - float total_column_weight = column_weights.back(); + Span column_weights_scan = + weights_scan.subspan(column_sizes_scan[column_idx], column_size); + float total_column_weight = column_weights_scan.back(); size_t sample_idx = 0; if (cut_idx == 0) { // First cut sample_idx = 0; - } else if (cut_idx == num_available_cuts - 1) { + } else if (cut_idx == num_available_cuts) { // Last cut sample_idx = column_entries.size() - 1; } else if (num_available_cuts == column_size) { @@ -204,15 +206,18 @@ void ExtractWeightedCuts(int device, Span cuts, } else { bst_float rank = (total_column_weight * cut_idx) / static_cast(num_available_cuts); - sample_idx = thrust::upper_bound(thrust::seq, column_weights.begin(), - column_weights.end(), rank) - - column_weights.begin() - 1; + sample_idx = thrust::upper_bound(thrust::seq, + column_weights_scan.begin(), + column_weights_scan.end(), + rank) - + column_weights_scan.begin(); sample_idx = - max(size_t(0), min(sample_idx, column_entries.size() - 1)); + max(static_cast(0), + min(sample_idx, column_entries.size() - 1)); } // repeated values will be filtered out on the CPU - bst_float rmin = sample_idx > 0 ? column_weights[sample_idx - 1] : 0; - bst_float rmax = column_weights[sample_idx]; + bst_float rmin = sample_idx > 0 ? column_weights_scan[sample_idx - 1] : 0.0f; + bst_float rmax = column_weights_scan[sample_idx]; cuts[idx] = WQSketch::Entry(rmin, rmax, rmax - rmin, column_entries[sample_idx].fvalue); }); @@ -224,7 +229,7 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end, dh::XGBCachingDeviceAllocator alloc; const auto& host_data = page.data.ConstHostVector(); dh::caching_device_vector sorted_entries(host_data.begin() + begin, - host_data.begin() + end); + host_data.begin() + end); thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), sorted_entries.end(), EntryCompareOp()); @@ -235,9 +240,10 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end, thrust::host_vector host_column_sizes_scan(column_sizes_scan); dh::caching_device_vector cuts(num_columns * num_cuts); - ExtractCuts(device, {cuts.data().get(), cuts.size()}, num_cuts, - {sorted_entries.data().get(), sorted_entries.size()}, - {column_sizes_scan.data().get(), column_sizes_scan.size()}); + ExtractCuts(device, num_cuts, + dh::ToSpan(sorted_entries), + dh::ToSpan(column_sizes_scan), + dh::ToSpan(cuts)); // add cuts into sketches thrust::host_vector host_cuts(cuts); @@ -246,12 +252,13 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end, void ProcessWeightedBatch(int device, const SparsePage& page, Span weights, size_t begin, size_t end, - SketchContainer* sketch_container, int num_cuts, - size_t num_columns) { + SketchContainer* sketch_container, int num_cuts_per_feature, + size_t num_columns, + bool is_ranking, Span d_group_ptr) { dh::XGBCachingDeviceAllocator alloc; const auto& host_data = page.data.ConstHostVector(); dh::caching_device_vector sorted_entries(host_data.begin() + begin, - host_data.begin() + end); + host_data.begin() + end); // Binary search to assign weights to each element dh::caching_device_vector temp_weights(sorted_entries.size()); @@ -259,15 +266,35 @@ void ProcessWeightedBatch(int device, const SparsePage& page, page.offset.SetDevice(device); auto row_ptrs = page.offset.ConstDeviceSpan(); size_t base_rowid = page.base_rowid; - dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) { - size_t element_idx = idx + begin; - size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(), - row_ptrs.end(), element_idx) - - row_ptrs.begin() - 1; - d_temp_weights[idx] = weights[ridx + base_rowid]; - }); + if (is_ranking) { + CHECK_GE(d_group_ptr.size(), 2) + << "Must have at least 1 group for ranking."; + CHECK_EQ(weights.size(), d_group_ptr.size() - 1) + << "Weight size should equal to number of groups."; + dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) { + size_t element_idx = idx + begin; + size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(), + row_ptrs.end(), element_idx) - + row_ptrs.begin() - 1; + auto it = + thrust::upper_bound(thrust::seq, + d_group_ptr.cbegin(), d_group_ptr.cend(), + ridx + base_rowid) - 1; + bst_group_t group = thrust::distance(d_group_ptr.cbegin(), it); + d_temp_weights[idx] = weights[group]; + }); + } else { + CHECK_EQ(weights.size(), page.offset.Size() - 1); + dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) { + size_t element_idx = idx + begin; + size_t ridx = thrust::upper_bound(thrust::seq, row_ptrs.begin(), + row_ptrs.end(), element_idx) - + row_ptrs.begin() - 1; + d_temp_weights[idx] = weights[ridx + base_rowid]; + }); + } - // Sort + // Sort both entries and wegihts. thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries.begin(), sorted_entries.end(), temp_weights.begin(), EntryCompareOp()); @@ -287,26 +314,26 @@ void ProcessWeightedBatch(int device, const SparsePage& page, thrust::host_vector host_column_sizes_scan(column_sizes_scan); // Extract cuts - dh::caching_device_vector cuts(num_columns * num_cuts); - ExtractWeightedCuts( - device, {cuts.data().get(), cuts.size()}, num_cuts, - {sorted_entries.data().get(), sorted_entries.size()}, - {temp_weights.data().get(), temp_weights.size()}, - {column_sizes_scan.data().get(), column_sizes_scan.size()}); + dh::caching_device_vector cuts(num_columns * num_cuts_per_feature); + ExtractWeightedCuts(device, num_cuts_per_feature, + dh::ToSpan(sorted_entries), + dh::ToSpan(temp_weights), + dh::ToSpan(column_sizes_scan), + dh::ToSpan(cuts)); // add cuts into sketches thrust::host_vector host_cuts(cuts); - sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan); + sketch_container->Push(num_cuts_per_feature, host_cuts, host_column_sizes_scan); } HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, size_t sketch_batch_num_elements) { // Configure batch size based on available memory bool has_weights = dmat->Info().weights_.Size() > 0; - size_t num_cuts = RequiredSampleCuts(max_bins, dmat->Info().num_row_); + size_t num_cuts_per_feature = RequiredSampleCuts(max_bins, dmat->Info().num_row_); if (sketch_batch_num_elements == 0) { int bytes_per_element = has_weights ? 24 : 16; - size_t bytes_cuts = num_cuts * dmat->Info().num_col_ * sizeof(SketchEntry); + size_t bytes_cuts = num_cuts_per_feature * dmat->Info().num_col_ * sizeof(SketchEntry); // use up to 80% of available space sketch_batch_num_elements = (dh::AvailableMemory(device) - bytes_cuts) * 0.8 / bytes_per_element; @@ -320,15 +347,21 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, dmat->Info().weights_.SetDevice(device); for (const auto& batch : dmat->GetBatches()) { size_t batch_nnz = batch.data.Size(); - for (auto begin = 0ull; begin < batch_nnz; - begin += sketch_batch_num_elements) { + auto const& info = dmat->Info(); + dh::caching_device_vector groups(info.group_ptr_.cbegin(), + info.group_ptr_.cend()); + for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) { size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements)); if (has_weights) { + bool is_ranking = CutsBuilder::UseGroup(dmat); ProcessWeightedBatch( device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end, - &sketch_container, num_cuts, dmat->Info().num_col_); + &sketch_container, + num_cuts_per_feature, + dmat->Info().num_col_, + is_ranking, dh::ToSpan(groups)); } else { - ProcessBatch(device, batch, begin, end, &sketch_container, num_cuts, + ProcessBatch(device, batch, begin, end, &sketch_container, num_cuts_per_feature, dmat->Info().num_col_); } } @@ -383,9 +416,10 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing, // Extract the cuts from all columns concurrently dh::caching_device_vector cuts(adapter->NumColumns() * num_cuts); - ExtractCuts(adapter->DeviceIdx(), {cuts.data().get(), cuts.size()}, num_cuts, - {sorted_entries.data().get(), sorted_entries.size()}, - {column_sizes_scan.data().get(), column_sizes_scan.size()}); + ExtractCuts(adapter->DeviceIdx(), num_cuts, + dh::ToSpan(sorted_entries), + dh::ToSpan(column_sizes_scan), + dh::ToSpan(cuts)); // Push cuts into sketches stored in host memory thrust::host_vector host_cuts(cuts); diff --git a/src/common/hist_util.h b/src/common/hist_util.h index d94067289f1b..5e4797659789 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -127,11 +127,11 @@ class HistogramCuts { class CutsBuilder { public: using WQSketch = common::WQuantileSketch; + /* \brief return whether group for ranking is used. */ + static bool UseGroup(DMatrix* dmat); protected: HistogramCuts* p_cuts_; - /* \brief return whether group for ranking is used. */ - static bool UseGroup(DMatrix* dmat); public: explicit CutsBuilder(HistogramCuts* p_cuts) : p_cuts_{p_cuts} {} diff --git a/src/data/data.cc b/src/data/data.cc index 3af47e2b6600..6052fc0f0310 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -338,6 +338,45 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t } } +void MetaInfo::Validate() const { + if (group_ptr_.size() != 0 && weights_.Size() != 0) { + CHECK_EQ(group_ptr_.size(), weights_.Size() + 1) + << "Size of weights must equal to number of groups when ranking " + "group is used."; + return; + } + if (group_ptr_.size() != 0) { + CHECK_EQ(group_ptr_.back(), num_row_) + << "Invalid group structure. Number of rows obtained from groups " + "doesn't equal to actual number of rows given by data."; + } + if (weights_.Size() != 0) { + CHECK_EQ(weights_.Size(), num_row_) + << "Size of weights must equal to number of rows."; + return; + } + if (labels_.Size() != 0) { + CHECK_EQ(labels_.Size(), num_row_) + << "Size of labels must equal to number of rows."; + return; + } + if (labels_lower_bound_.Size() != 0) { + CHECK_EQ(labels_lower_bound_.Size(), num_row_) + << "Size of label_lower_bound must equal to number of rows."; + return; + } + if (labels_upper_bound_.Size() != 0) { + CHECK_EQ(labels_upper_bound_.Size(), num_row_) + << "Size of label_upper_bound must equal to number of rows."; + return; + } + CHECK_LE(num_nonzero_, num_col_ * num_row_); + if (base_margin_.Size() != 0) { + CHECK_EQ(base_margin_.Size() % num_row_, 0) + << "Size of base margin must be a multiple of number of rows."; + } +} + #if !defined(XGBOOST_USE_CUDA) void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) { common::AssertGPUSupport(); diff --git a/src/learner.cc b/src/learner.cc index 835ea4ad2fe0..b2f962894aba 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1048,15 +1048,7 @@ class LearnerImpl : public LearnerIO { void ValidateDMatrix(DMatrix* p_fmat) const { MetaInfo const& info = p_fmat->Info(); - auto const& weights = info.weights_; - if (info.group_ptr_.size() != 0 && weights.Size() != 0) { - CHECK(weights.Size() == info.group_ptr_.size() - 1) - << "\n" - << "weights size: " << weights.Size() << ", " - << "groups size: " << info.group_ptr_.size() -1 << ", " - << "num rows: " << p_fmat->Info().num_row_ << "\n" - << "Number of weights should be equal to number of groups in ranking task."; - } + info.Validate(); auto const row_based_split = [this]() { return tparam_.dsplit == DataSplitMode::kRow || diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index b728acb684d9..746602d9ee61 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -3,22 +3,19 @@ #include #include - - #include -#include "xgboost/c_api.h" +#include +#include +#include "test_hist_util.h" +#include "../helpers.h" +#include "../data/test_array_interface.h" #include "../../../src/common/device_helpers.cuh" #include "../../../src/common/hist_util.h" - -#include "../helpers.h" -#include #include "../../../src/data/device_adapter.cuh" -#include "../data/test_array_interface.h" #include "../../../src/common/math.h" #include "../../../src/data/simple_dmatrix.h" -#include "test_hist_util.h" #include "../../../include/xgboost/logging.h" namespace xgboost { @@ -143,7 +140,6 @@ TEST(HistUtil, DeviceSketchMultipleColumns) { ValidateCuts(cuts, dmat.get(), num_bins); } } - } TEST(HistUtil, DeviceSketchMultipleColumnsWeights) { @@ -161,6 +157,29 @@ TEST(HistUtil, DeviceSketchMultipleColumnsWeights) { } } +TEST(HistUitl, DeviceSketchWeights) { + int bin_sizes[] = {2, 16, 256, 512}; + int sizes[] = {100, 1000, 1500}; + int num_columns = 5; + for (auto num_rows : sizes) { + auto x = GenerateRandom(num_rows, num_columns); + auto dmat = GetDMatrixFromData(x, num_rows, num_columns); + auto weighted_dmat = GetDMatrixFromData(x, num_rows, num_columns); + auto& h_weights = weighted_dmat->Info().weights_.HostVector(); + h_weights.resize(num_rows); + std::fill(h_weights.begin(), h_weights.end(), 1.0f); + for (auto num_bins : bin_sizes) { + auto cuts = DeviceSketch(0, dmat.get(), num_bins); + auto wcuts = DeviceSketch(0, weighted_dmat.get(), num_bins); + ASSERT_EQ(cuts.MinValues(), wcuts.MinValues()); + ASSERT_EQ(cuts.Ptrs(), wcuts.Ptrs()); + ASSERT_EQ(cuts.Values(), wcuts.Values()); + ValidateCuts(cuts, dmat.get(), num_bins); + ValidateCuts(wcuts, weighted_dmat.get(), num_bins); + } + } +} + TEST(HistUtil, DeviceSketchBatches) { int num_bins = 256; int num_rows = 5000; @@ -190,8 +209,7 @@ TEST(HistUtil, DeviceSketchMultipleColumnsExternal) { } } -TEST(HistUtil, AdapterDeviceSketch) -{ +TEST(HistUtil, AdapterDeviceSketch) { int rows = 5; int cols = 1; int num_bins = 4; @@ -235,7 +253,7 @@ TEST(HistUtil, AdapterDeviceSketchMemory) { bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant); } - TEST(HistUtil, AdapterDeviceSketchCategorical) { +TEST(HistUtil, AdapterDeviceSketchCategorical) { int categorical_sizes[] = {2, 6, 8, 12}; int num_bins = 256; int sizes[] = {25, 100, 1000}; @@ -268,6 +286,7 @@ TEST(HistUtil, AdapterDeviceSketchMultipleColumns) { } } } + TEST(HistUtil, AdapterDeviceSketchBatches) { int num_bins = 256; int num_rows = 5000; @@ -305,7 +324,38 @@ TEST(HistUtil, SketchingEquivalent) { EXPECT_EQ(dmat_cuts.MinValues(), adapter_cuts.MinValues()); } } +} + +TEST(HistUtil, DeviceSketchFromGroupWeights) { + size_t constexpr kRows = 3000, kCols = 200, kBins = 256; + size_t constexpr kGroups = 10; + auto m = RandomDataGenerator {kRows, kCols, 0}.GenerateDMatrix(); + auto& h_weights = m->Info().weights_.HostVector(); + h_weights.resize(kRows); + std::fill(h_weights.begin(), h_weights.end(), 1.0f); + std::vector groups(kGroups); + for (size_t i = 0; i < kGroups; ++i) { + groups[i] = kRows / kGroups; + } + m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups); + HistogramCuts weighted_cuts = DeviceSketch(0, m.get(), kBins, 0); + + h_weights.clear(); + HistogramCuts cuts = DeviceSketch(0, m.get(), kBins, 0); + ASSERT_EQ(cuts.Values().size(), weighted_cuts.Values().size()); + ASSERT_EQ(cuts.MinValues().size(), weighted_cuts.MinValues().size()); + ASSERT_EQ(cuts.Ptrs().size(), weighted_cuts.Ptrs().size()); + + for (size_t i = 0; i < cuts.Values().size(); ++i) { + EXPECT_EQ(cuts.Values()[i], weighted_cuts.Values()[i]) << "i:"<< i; + } + for (size_t i = 0; i < cuts.MinValues().size(); ++i) { + ASSERT_EQ(cuts.MinValues()[i], weighted_cuts.MinValues()[i]); + } + for (size_t i = 0; i < cuts.Ptrs().size(); ++i) { + ASSERT_EQ(cuts.Ptrs().at(i), weighted_cuts.Ptrs().at(i)); + } } } // namespace common } // namespace xgboost diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index bb17704ceb0e..ec55f89d7a76 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -9,6 +9,11 @@ #include "../../../src/data/simple_dmatrix.h" #include "../../../src/data/adapter.h" +#ifdef __CUDACC__ +#include +#include "../../../src/data/device_adapter.cuh" +#endif // __CUDACC__ + // Some helper functions used to test both GPU and CPU algorithms // namespace xgboost { @@ -69,11 +74,11 @@ inline std::vector GenerateRandomCategoricalSingleColumn(int n, return x; } -inline std::shared_ptr GetDMatrixFromData(const std::vector& x, int num_rows, int num_columns) { +inline std::shared_ptr +GetDMatrixFromData(const std::vector &x, int num_rows, int num_columns) { data::DenseAdapter adapter(x.data(), num_rows, num_columns); return std::shared_ptr(new data::SimpleDMatrix( - &adapter, std::numeric_limits::quiet_NaN(), - 1)); + &adapter, std::numeric_limits::quiet_NaN(), 1)); } inline std::shared_ptr GetExternalMemoryDMatrixFromData( @@ -96,8 +101,9 @@ inline std::shared_ptr GetExternalMemoryDMatrixFromData( } // Test that elements are approximately equally distributed among bins -inline void TestBinDistribution(const HistogramCuts& cuts, int column_idx, - const std::vector& sorted_column,const std::vector&sorted_weights, +inline void TestBinDistribution(const HistogramCuts &cuts, int column_idx, + const std::vector &sorted_column, + const std::vector &sorted_weights, int num_bins) { std::map bin_weights; for (auto i = 0ull; i < sorted_column.size(); i++) { @@ -113,29 +119,29 @@ inline void TestBinDistribution(const HistogramCuts& cuts, int column_idx, // First and last bin can have smaller for (auto& kv : bin_weights) { EXPECT_LE(std::abs(bin_weights[kv.first] - expected_bin_weight), - allowable_error ); + allowable_error); } } - // Test sketch quantiles against the real quantiles - // Not a very strict test -inline void TestRank(const std::vector& cuts, - const std::vector& sorted_x, - const std::vector& sorted_weights) { +// Test sketch quantiles against the real quantiles Not a very strict +// test +inline void TestRank(const std::vector &column_cuts, + const std::vector &sorted_x, + const std::vector &sorted_weights) { double eps = 0.05; auto total_weight = std::accumulate(sorted_weights.begin(), sorted_weights.end(), 0.0); // Ignore the last cut, its special double sum_weight = 0.0; size_t j = 0; - for (size_t i = 0; i < cuts.size() - 1; i++) { - while (cuts[i] > sorted_x[j]) { + for (size_t i = 0; i < column_cuts.size() - 1; i++) { + while (column_cuts[i] > sorted_x[j]) { sum_weight += sorted_weights[j]; j++; } - double expected_rank = ((i + 1) * total_weight) / cuts.size(); - double acceptable_error = std::max(2.0, total_weight * eps); - ASSERT_LE(std::abs(expected_rank - sum_weight), acceptable_error); + double expected_rank = ((i + 1) * total_weight) / column_cuts.size(); + double acceptable_error = std::max(2.9, total_weight * eps); + EXPECT_LE(std::abs(expected_rank - sum_weight), acceptable_error); } } @@ -167,15 +173,14 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx, ASSERT_EQ(cuts.SearchBin(v, column_idx), cuts.Ptrs()[column_idx] + i); i++; } - } - else { + } else { int num_cuts_column = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx]; std::vector column_cuts(num_cuts_column); std::copy(cuts.Values().begin() + cuts.Ptrs()[column_idx], cuts.Values().begin() + cuts.Ptrs()[column_idx + 1], column_cuts.begin()); - TestBinDistribution(cuts, column_idx, sorted_column,sorted_weights, num_bins); - TestRank(column_cuts, sorted_column,sorted_weights); + TestBinDistribution(cuts, column_idx, sorted_column, sorted_weights, num_bins); + TestRank(column_cuts, sorted_column, sorted_weights); } } @@ -196,10 +201,8 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, const auto& w = dmat->Info().weights_.HostVector(); std::vector index(col.size()); std::iota(index.begin(), index.end(), 0); - std::sort(index.begin(), index.end(),[=](size_t a,size_t b) - { - return col[a] < col[b]; - }); + std::sort(index.begin(), index.end(), + [=](size_t a, size_t b) { return col[a] < col[b]; }); std::vector sorted_column(col.size()); std::vector sorted_weights(col.size(), 1.0); diff --git a/tests/cpp/data/test_metainfo.cc b/tests/cpp/data/test_metainfo.cc index ec4f0ca33847..64f432f35ad6 100644 --- a/tests/cpp/data/test_metainfo.cc +++ b/tests/cpp/data/test_metainfo.cc @@ -141,3 +141,17 @@ TEST(MetaInfo, LoadQid) { CHECK(batch.data.HostVector() == expected_data); } } + +TEST(MetaInfo, Validate) { + xgboost::MetaInfo info; + info.num_row_ = 10; + info.num_nonzero_ = 12; + info.num_col_ = 3; + std::vector groups (11); + info.SetInfo("group", groups.data(), xgboost::DataType::kUInt32, 11); + EXPECT_THROW(info.Validate(), dmlc::Error); + + std::vector labels(info.num_row_ + 1); + info.SetInfo("label", labels.data(), xgboost::DataType::kFloat32, info.num_row_ + 1); + EXPECT_THROW(info.Validate(), dmlc::Error); +} diff --git a/tests/python-gpu/test_gpu_ranking.py b/tests/python-gpu/test_gpu_ranking.py index 58c2fd78da87..06791e972005 100644 --- a/tests/python-gpu/test_gpu_ranking.py +++ b/tests/python-gpu/test_gpu_ranking.py @@ -1,14 +1,13 @@ import numpy as np -from scipy.sparse import csr_matrix import xgboost import os -import math import unittest import itertools import shutil import urllib.request import zipfile + class TestRanking(unittest.TestCase): @classmethod def setUpClass(cls): @@ -22,7 +21,7 @@ def setUpClass(cls): target = cls.dpath + '/MQ2008.zip' if os.path.exists(cls.dpath) and os.path.exists(target): - print ("Skipping dataset download...") + print("Skipping dataset download...") else: urllib.request.urlretrieve(url=src, filename=target) with zipfile.ZipFile(target, 'r') as f: @@ -50,17 +49,30 @@ def setUpClass(cls): cls.qid_test = qid_test cls.qid_valid = qid_valid + def setup_weighted(x, y, groups): + # Setup weighted data + data = xgboost.DMatrix(x, y) + groups_segment = [len(list(items)) + for _key, items in itertools.groupby(groups)] + data.set_group(groups_segment) + n_groups = len(groups_segment) + weights = np.ones((n_groups,)) + data.set_weight(weights) + return data + + cls.dtrain_w = setup_weighted(x_train, y_train, qid_train) + cls.dtest_w = setup_weighted(x_test, y_test, qid_test) + cls.dvalid_w = setup_weighted(x_valid, y_valid, qid_valid) + # model training parameters cls.params = {'booster': 'gbtree', 'tree_method': 'gpu_hist', 'gpu_id': 0, - 'predictor': 'gpu_predictor' - } + 'predictor': 'gpu_predictor'} cls.cpu_params = {'booster': 'gbtree', 'tree_method': 'hist', 'gpu_id': -1, - 'predictor': 'cpu_predictor' - } + 'predictor': 'cpu_predictor'} @classmethod def tearDownClass(cls): @@ -81,30 +93,46 @@ def __test_training_with_rank_objective(cls, rank_objective, metric_name, tolera # specify validations set to watch performance watchlist = [(cls.dtest, 'eval'), (cls.dtrain, 'train')] - num_trees=2500 - check_metric_improvement_rounds=10 + num_trees = 2500 + check_metric_improvement_rounds = 10 evals_result = {} cls.params['objective'] = rank_objective cls.params['eval_metric'] = metric_name - bst = xgboost.train(cls.params, cls.dtrain, num_boost_round=num_trees, - early_stopping_rounds=check_metric_improvement_rounds, - evals=watchlist, evals_result=evals_result) + bst = xgboost.train( + cls.params, cls.dtrain, num_boost_round=num_trees, + early_stopping_rounds=check_metric_improvement_rounds, + evals=watchlist, evals_result=evals_result) gpu_map_metric = evals_result['train'][metric_name][-1] evals_result = {} cls.cpu_params['objective'] = rank_objective cls.cpu_params['eval_metric'] = metric_name - bstc = xgboost.train(cls.cpu_params, cls.dtrain, num_boost_round=num_trees, - early_stopping_rounds=check_metric_improvement_rounds, - evals=watchlist, evals_result=evals_result) + bstc = xgboost.train( + cls.cpu_params, cls.dtrain, num_boost_round=num_trees, + early_stopping_rounds=check_metric_improvement_rounds, + evals=watchlist, evals_result=evals_result) cpu_map_metric = evals_result['train'][metric_name][-1] - print("{0} gpu {1} metric {2}".format(rank_objective, metric_name, gpu_map_metric)) - print("{0} cpu {1} metric {2}".format(rank_objective, metric_name, cpu_map_metric)) - print("gpu best score {0} cpu best score {1}".format(bst.best_score, bstc.best_score)) - assert np.allclose(gpu_map_metric, cpu_map_metric, tolerance, tolerance) - assert np.allclose(bst.best_score, bstc.best_score, tolerance, tolerance) + assert np.allclose(gpu_map_metric, cpu_map_metric, tolerance, + tolerance) + assert np.allclose(bst.best_score, bstc.best_score, tolerance, + tolerance) + + evals_result_weighted = {} + watchlist = [(cls.dtest_w, 'eval'), (cls.dtrain_w, 'train')] + bst_w = xgboost.train( + cls.params, cls.dtrain_w, num_boost_round=num_trees, + early_stopping_rounds=check_metric_improvement_rounds, + evals=watchlist, evals_result=evals_result_weighted) + weighted_metric = evals_result_weighted['train'][metric_name][-1] + # GPU Ranking is not deterministic due to `AtomicAddGpair`, + # remove tolerance once the issue is resolved. + # https://github.com/dmlc/xgboost/issues/5561 + assert np.allclose(bst_w.best_score, bst.best_score, + tolerance, tolerance) + assert np.allclose(weighted_metric, gpu_map_metric, + tolerance, tolerance) def test_training_rank_pairwise_map_metric(self): """