@@ -338,7 +338,7 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
338338 }
339339}
340340
341- void MetaInfo::Validate () const {
341+ void MetaInfo::Validate (int32_t device ) const {
342342 if (group_ptr_.size () != 0 && weights_.Size () != 0 ) {
343343 CHECK_EQ (group_ptr_.size (), weights_.Size () + 1 )
344344 << " Size of weights must equal to number of groups when ranking "
@@ -350,30 +350,44 @@ void MetaInfo::Validate() const {
350350 << " Invalid group structure. Number of rows obtained from groups "
351351 " doesn't equal to actual number of rows given by data." ;
352352 }
353+ auto check_device = [device](HostDeviceVector<float > const &v) {
354+ CHECK (v.DeviceIdx () == GenericParameter::kCpuId ||
355+ device == GenericParameter::kCpuId ||
356+ v.DeviceIdx () == device)
357+ << " Data is resided on a different device than `gpu_id`. "
358+ << " Device that data is on: " << v.DeviceIdx () << " , "
359+ << " `gpu_id` for XGBoost: " << device;
360+ };
361+
353362 if (weights_.Size () != 0 ) {
354363 CHECK_EQ (weights_.Size (), num_row_)
355364 << " Size of weights must equal to number of rows." ;
365+ check_device (weights_);
356366 return ;
357367 }
358368 if (labels_.Size () != 0 ) {
359369 CHECK_EQ (labels_.Size (), num_row_)
360370 << " Size of labels must equal to number of rows." ;
371+ check_device (labels_);
361372 return ;
362373 }
363374 if (labels_lower_bound_.Size () != 0 ) {
364375 CHECK_EQ (labels_lower_bound_.Size (), num_row_)
365376 << " Size of label_lower_bound must equal to number of rows." ;
377+ check_device (labels_lower_bound_);
366378 return ;
367379 }
368380 if (labels_upper_bound_.Size () != 0 ) {
369381 CHECK_EQ (labels_upper_bound_.Size (), num_row_)
370382 << " Size of label_upper_bound must equal to number of rows." ;
383+ check_device (labels_upper_bound_);
371384 return ;
372385 }
373386 CHECK_LE (num_nonzero_, num_col_ * num_row_);
374387 if (base_margin_.Size () != 0 ) {
375388 CHECK_EQ (base_margin_.Size () % num_row_, 0 )
376389 << " Size of base margin must be a multiple of number of rows." ;
390+ check_device (base_margin_);
377391 }
378392}
379393
0 commit comments