diff --git a/include/xgboost/generic_parameters.h b/include/xgboost/generic_parameters.h index be068f74a0a0..7713d6b2b0f7 100644 --- a/include/xgboost/generic_parameters.h +++ b/include/xgboost/generic_parameters.h @@ -29,7 +29,6 @@ struct GenericParameter : public XGBoostParameter { size_t gpu_page_size; bool enable_experimental_json_serialization {false}; bool validate_parameters {false}; - bool validate_features {true}; void CheckDeprecated() { if (this->n_gpus != 0) { @@ -75,9 +74,6 @@ struct GenericParameter : public XGBoostParameter { DMLC_DECLARE_FIELD(validate_parameters) .set_default(false) .describe("Enable checking whether parameters are used or not."); - DMLC_DECLARE_FIELD(validate_features) - .set_default(false) - .describe("Enable validating input DMatrix."); DMLC_DECLARE_FIELD(n_gpus) .set_default(0) .set_range(0, 1) diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala index fa0d8b623445..c9aa1631f02e 100644 --- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala +++ b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoost.scala @@ -49,7 +49,7 @@ object XGBoost { Rabit.init(workerEnvs) val mapper = (x: LabeledVector) => { val (index, value) = x.vector.toSeq.unzip - LabeledPoint(x.label.toFloat, index.toArray, value.map(_.toFloat).toArray) + LabeledPoint(x.label.toFloat, x.vector.size, index.toArray, value.map(_.toFloat).toArray) } val dataIter = for (x <- it.iterator().asScala) yield mapper(x) val trainMat = new DMatrix(dataIter, null) diff --git a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoostModel.scala b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoostModel.scala index 46b5689a508c..71b376974dc0 100644 --- a/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-flink/src/main/scala/ml/dmlc/xgboost4j/scala/flink/XGBoostModel.scala @@ -56,7 +56,7 @@ class XGBoostModel (booster: Booster) extends Serializable { (it: Iterator[Vector]) => { val mapper = (x: Vector) => { val (index, value) = x.toSeq.unzip - LabeledPoint(0.0f, index.toArray, value.map(_.toFloat).toArray) + LabeledPoint(0.0f, x.size, index.toArray, value.map(_.toFloat).toArray) } val dataIter = for (x <- it) yield mapper(x) val dmat = new DMatrix(dataIter, null) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala index 36c76ce15a1c..df787d8eb8ab 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/DataUtils.scala @@ -38,15 +38,11 @@ object DataUtils extends Serializable { /** * Returns feature of the point as [[org.apache.spark.ml.linalg.Vector]]. - * - * If the point is sparse, the dimensionality of the resulting sparse - * vector would be [[Int.MaxValue]]. This is the only safe value, since - * XGBoost does not store the dimensionality explicitly. */ def features: Vector = if (labeledPoint.indices == null) { Vectors.dense(labeledPoint.values.map(_.toDouble)) } else { - Vectors.sparse(Int.MaxValue, labeledPoint.indices, labeledPoint.values.map(_.toDouble)) + Vectors.sparse(labeledPoint.size, labeledPoint.indices, labeledPoint.values.map(_.toDouble)) } } @@ -68,9 +64,9 @@ object DataUtils extends Serializable { */ def asXGB: XGBLabeledPoint = v match { case v: DenseVector => - XGBLabeledPoint(0.0f, null, v.values.map(_.toFloat)) + XGBLabeledPoint(0.0f, v.size, null, v.values.map(_.toFloat)) case v: SparseVector => - XGBLabeledPoint(0.0f, v.indices, v.values.map(_.toFloat)) + XGBLabeledPoint(0.0f, v.size, v.indices, v.values.map(_.toFloat)) } } @@ -162,18 +158,18 @@ object DataUtils extends Serializable { df => df.select(selectedColumns: _*).rdd.map { case row @ Row(label: Float, features: Vector, weight: Float, group: Int, baseMargin: Float) => - val (indices, values) = features match { - case v: SparseVector => (v.indices, v.values.map(_.toFloat)) - case v: DenseVector => (null, v.values.map(_.toFloat)) + val (size, indices, values) = features match { + case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat)) + case v: DenseVector => (v.size, null, v.values.map(_.toFloat)) } - val xgbLp = XGBLabeledPoint(label, indices, values, weight, group, baseMargin) + val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, group, baseMargin) attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp) case row @ Row(label: Float, features: Vector, weight: Float, baseMargin: Float) => - val (indices, values) = features match { - case v: SparseVector => (v.indices, v.values.map(_.toFloat)) - case v: DenseVector => (null, v.values.map(_.toFloat)) + val (size, indices, values) = features match { + case v: SparseVector => (v.size, v.indices, v.values.map(_.toFloat)) + case v: DenseVector => (v.size, null, v.values.map(_.toFloat)) } - val xgbLp = XGBLabeledPoint(label, indices, values, weight, baseMargin = baseMargin) + val xgbLp = XGBLabeledPoint(label, size, indices, values, weight, baseMargin = baseMargin) attachPartitionKey(row, deterministicPartition, numWorkers, xgbLp) } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala new file mode 100644 index 000000000000..f82927234490 --- /dev/null +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/FeatureSizeValidatingSuite.scala @@ -0,0 +1,71 @@ +/* + Copyright (c) 2014 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.scala.spark + +import ml.dmlc.xgboost4j.java.XGBoostError +import org.apache.spark.Partitioner +import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.sql.SparkSession +import org.scalatest.FunSuite + +import scala.util.Random + +class FeatureSizeValidatingSuite extends FunSuite with PerTest { + + test("transform throwing exception if feature size of dataset is different with model's") { + val modelPath = getClass.getResource("/model/0.82/model").getPath + val model = XGBoostClassificationModel.read.load(modelPath) + val r = new Random(0) + // 0.82/model was trained with 251 features. and transform will throw exception + // if feature size of data is not equal to 251 + val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))). + toDF("feature", "label") + val assembler = new VectorAssembler() + .setInputCols(df.columns.filter(!_.contains("label"))) + .setOutputCol("features") + val thrown = intercept[Exception] { + model.transform(assembler.transform(df)).show() + } + assert(thrown.getMessage.contains( + "Number of columns does not match number of features in booster")) + } + + test("train throwing exception if feature size of dataset is different on distributed train") { + val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", + "num_round" -> 5, "num_workers" -> 2, "use_external_memory" -> true, "missing" -> 0) + import DataUtils._ + val sparkSession = SparkSession.builder().getOrCreate() + import sparkSession.implicits._ + val repartitioned = sc.parallelize(Synthetic.trainWithDiffFeatureSize, 2) + .map(lp => (lp.label, lp)).partitionBy( + new Partitioner { + override def numPartitions: Int = 2 + + override def getPartition(key: Any): Int = key.asInstanceOf[Float].toInt + } + ).map(_._2).zipWithIndex().map { + case (lp, id) => + (id, lp.label, lp.features) + }.toDF("id", "label", "features") + val xgb = new XGBoostClassifier(paramMap) + intercept[XGBoostError] { + xgb.fit(repartitioned) + } + } + +} diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala index a3841772c796..ebe1d8546544 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/PersistenceSuite.scala @@ -19,13 +19,12 @@ package ml.dmlc.xgboost4j.scala.spark import java.io.File import java.util.Arrays -import scala.io.Source - import ml.dmlc.xgboost4j.scala.DMatrix -import scala.util.Random +import scala.util.Random import org.apache.spark.ml.feature._ import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.sql.functions._ import org.scalatest.FunSuite class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { @@ -138,12 +137,21 @@ class PersistenceSuite extends FunSuite with TmpFolderPerSuite with PerTest { val modelPath = getClass.getResource("/model/0.82/model").getPath val model = XGBoostClassificationModel.read.load(modelPath) val r = new Random(0) - val df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))). + var df = ss.createDataFrame(Seq.fill(100)(r.nextInt(2)).map(i => (i, i))). toDF("feature", "label") + // 0.82/model was trained with 251 features. and transform will throw exception + // if feature size of data is not equal to 251 + for (x <- 1 to 250) { + df = df.withColumn(s"feature_${x}", lit(1)) + } val assembler = new VectorAssembler() .setInputCols(df.columns.filter(!_.contains("label"))) .setOutputCol("features") - model.transform(assembler.transform(df)).show() + df = assembler.transform(df) + for (x <- 1 to 250) { + df = df.drop(s"feature_${x}") + } + model.transform(df).show() } } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TrainTestData.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TrainTestData.scala index 1a64e2d03529..fae241d8b990 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TrainTestData.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/TrainTestData.scala @@ -31,11 +31,12 @@ trait TrainTestData { Source.fromInputStream(is).getLines() } - protected def getLabeledPoints(resource: String, zeroBased: Boolean): Seq[XGBLabeledPoint] = { + protected def getLabeledPoints(resource: String, featureSize: Int, zeroBased: Boolean): + Seq[XGBLabeledPoint] = { getResourceLines(resource).map { line => val labelAndFeatures = line.split(" ") val label = labelAndFeatures.head.toFloat - val values = new Array[Float](126) + val values = new Array[Float](featureSize) for (feature <- labelAndFeatures.tail) { val idAndValue = feature.split(":") if (!zeroBased) { @@ -45,7 +46,7 @@ trait TrainTestData { } } - XGBLabeledPoint(label, null, values) + XGBLabeledPoint(label, featureSize, null, values) }.toList } @@ -56,14 +57,14 @@ trait TrainTestData { val label = original.head.toFloat val group = original.last.toInt val values = original.slice(1, length - 1).map(_.toFloat) - XGBLabeledPoint(label, null, values, 1f, group, Float.NaN) + XGBLabeledPoint(label, values.size, null, values, 1f, group, Float.NaN) }.toList } } object Classification extends TrainTestData { - val train: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.train", zeroBased = false) - val test: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.test", zeroBased = false) + val train: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.train", 126, zeroBased = false) + val test: Seq[XGBLabeledPoint] = getLabeledPoints("/agaricus.txt.test", 126, zeroBased = false) } object MultiClassification extends TrainTestData { @@ -80,19 +81,24 @@ object MultiClassification extends TrainTestData { values(i) = featuresAndLabel(i).toFloat } - XGBLabeledPoint(label, null, values.take(values.length - 1)) + XGBLabeledPoint(label, values.length - 1, null, values.take(values.length - 1)) }.toList } } object Regression extends TrainTestData { - val train: Seq[XGBLabeledPoint] = getLabeledPoints("/machine.txt.train", zeroBased = true) - val test: Seq[XGBLabeledPoint] = getLabeledPoints("/machine.txt.test", zeroBased = true) + val MACHINE_COL_NUM = 36 + val train: Seq[XGBLabeledPoint] = getLabeledPoints( + "/machine.txt.train", MACHINE_COL_NUM, zeroBased = true) + val test: Seq[XGBLabeledPoint] = getLabeledPoints( + "/machine.txt.test", MACHINE_COL_NUM, zeroBased = true) } object Ranking extends TrainTestData { + val RANK_COL_NUM = 3 val train: Seq[XGBLabeledPoint] = getLabeledPointsWithGroup("/rank.train.csv") - val test: Seq[XGBLabeledPoint] = getLabeledPoints("/rank.test.txt", zeroBased = false) + val test: Seq[XGBLabeledPoint] = getLabeledPoints( + "/rank.test.txt", RANK_COL_NUM, zeroBased = false) private def getGroups(resource: String): Seq[Int] = { getResourceLines(resource).map(_.toInt).toList @@ -100,10 +106,17 @@ object Ranking extends TrainTestData { } object Synthetic extends { + val TRAIN_COL_NUM = 3 + val TRAIN_WRONG_COL_NUM = 2 val train: Seq[XGBLabeledPoint] = Seq( - XGBLabeledPoint(1.0f, Array(0, 1), Array(1.0f, 2.0f)), - XGBLabeledPoint(0.0f, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)), - XGBLabeledPoint(0.0f, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)), - XGBLabeledPoint(1.0f, Array(0, 1), Array(1.0f, 2.0f)) + XGBLabeledPoint(1.0f, TRAIN_COL_NUM, Array(0, 1), Array(1.0f, 2.0f)), + XGBLabeledPoint(0.0f, TRAIN_COL_NUM, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)), + XGBLabeledPoint(0.0f, TRAIN_COL_NUM, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)), + XGBLabeledPoint(1.0f, TRAIN_COL_NUM, Array(0, 1), Array(1.0f, 2.0f)) + ) + + val trainWithDiffFeatureSize: Seq[XGBLabeledPoint] = Seq( + XGBLabeledPoint(1.0f, TRAIN_WRONG_COL_NUM, Array(0, 1), Array(1.0f, 2.0f)), + XGBLabeledPoint(0.0f, TRAIN_COL_NUM, Array(0, 1, 2), Array(1.0f, 2.0f, 3.0f)) ) } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala index f2ccdd44ee73..b1dda665b642 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifierSuite.scala @@ -17,12 +17,9 @@ package ml.dmlc.xgboost4j.scala.spark import ml.dmlc.xgboost4j.scala.{DMatrix, XGBoost => ScalaXGBoost} - import org.apache.spark.ml.linalg._ -import org.apache.spark.ml.param.ParamMap import org.apache.spark.sql._ import org.scalatest.FunSuite - import org.apache.spark.Partitioner class XGBoostClassifierSuite extends FunSuite with PerTest { @@ -308,4 +305,5 @@ class XGBoostClassifierSuite extends FunSuite with PerTest { val xgb = new XGBoostClassifier(paramMap) xgb.fit(repartitioned) } + } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index f3492b2e3296..3ee2b21f2a9a 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -16,19 +16,13 @@ package ml.dmlc.xgboost4j.scala.spark -import java.nio.file.Files - import scala.util.Random - import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.scala.DMatrix -import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} -import org.apache.hadoop.fs.{FileSystem, Path} - -import org.apache.spark.TaskContext +import org.apache.spark.{TaskContext} import org.scalatest.FunSuite - import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.sql.functions.lit class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest { @@ -350,12 +344,21 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest { val modelPath = getClass.getResource("/model/0.82/model").getPath val model = XGBoostClassificationModel.read.load(modelPath) val r = new Random(0) - val df = ss.createDataFrame(Seq.fill(100000)(1).map(i => (i, i))). + var df = ss.createDataFrame(Seq.fill(100000)(1).map(i => (i, i))). toDF("feature", "label").repartition(5) + // 0.82/model was trained with 251 features. and transform will throw exception + // if feature size of data is not equal to 251 + for (x <- 1 to 250) { + df = df.withColumn(s"feature_${x}", lit(1)) + } val assembler = new VectorAssembler() .setInputCols(df.columns.filter(!_.contains("label"))) .setOutputCol("features") - val df1 = model.transform(assembler.transform(df)).withColumnRenamed( + df = assembler.transform(df) + for (x <- 1 to 250) { + df = df.drop(s"feature_${x}") + } + val df1 = model.transform(df).withColumnRenamed( "prediction", "prediction1").withColumnRenamed( "rawPrediction", "rawPrediction1").withColumnRenamed( "probability", "probability1") @@ -363,4 +366,5 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest { df1.collect() df2.collect() } + } diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala index 12ba9366a8ad..d1b6ec0f9acc 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRabitRegressionSuite.scala @@ -69,8 +69,7 @@ class XGBoostRabitRegressionSuite extends FunSuite with PerTest { test("test regression prediction parity w/o ring reduce") { val training = buildDataFrame(Regression.train) - val testDM = new DMatrix(Regression.test.iterator, null) - val testDF = buildDataFrame(Classification.test) + val testDF = buildDataFrame(Regression.test) val xgbSettings = Map("eta" -> "1", "max_depth" -> "2", "verbosity" -> "1", "objective" -> "reg:squarederror", "num_round" -> 5, "num_workers" -> numWorkers) val model1 = new XGBoostRegressor(xgbSettings).fit(training) diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java index 8ebf781f1d97..f97aa3a39d4a 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/Booster.java @@ -49,7 +49,6 @@ public class Booster implements Serializable, KryoSerializable { */ Booster(Map params, DMatrix[] cacheMats) throws XGBoostError { init(cacheMats); - setParam("validate_features", "0"); setParams(params); } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java index 61db35764722..ad80d030a23b 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/DataBatch.java @@ -27,14 +27,17 @@ class DataBatch { final int[] featureIndex; /** value of each non-missing entry in the sparse matrix */ final float[] featureValue ; + /** feature columns */ + final int featureCols; DataBatch(long[] rowOffset, float[] weight, float[] label, int[] featureIndex, - float[] featureValue) { + float[] featureValue, int featureCols) { this.rowOffset = rowOffset; this.weight = weight; this.label = label; this.featureIndex = featureIndex; this.featureValue = featureValue; + this.featureCols = featureCols; } static class BatchIterator implements Iterator { @@ -56,9 +59,15 @@ public DataBatch next() { try { int numRows = 0; int numElem = 0; + int numCol = -1; List batch = new ArrayList<>(batchSize); while (base.hasNext() && batch.size() < batchSize) { LabeledPoint labeledPoint = base.next(); + if (numCol == -1) { + numCol = labeledPoint.size(); + } else if (numCol != labeledPoint.size()) { + throw new RuntimeException("Feature size is not the same"); + } batch.add(labeledPoint); numElem += labeledPoint.values().length; numRows++; @@ -91,7 +100,7 @@ public DataBatch next() { } rowOffset[batch.size()] = offset; - return new DataBatch(rowOffset, weight, label, featureIndex, featureValue); + return new DataBatch(rowOffset, weight, label, featureIndex, featureValue, numCol); } catch (RuntimeException runtimeError) { logger.error(runtimeError); return null; diff --git a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/LabeledPoint.scala b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/LabeledPoint.scala index 9a92d1b913ab..ccdedbaa3704 100644 --- a/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/LabeledPoint.scala +++ b/jvm-packages/xgboost4j/src/main/scala/ml/dmlc/xgboost4j/LabeledPoint.scala @@ -20,6 +20,7 @@ package ml.dmlc.xgboost4j * Labeled training data point. * * @param label Label of this point. + * @param size Feature dimensionality * @param indices Feature indices of this point or `null` if the data is dense. * @param values Feature values of this point. * @param weight Weight of this point. @@ -28,6 +29,7 @@ package ml.dmlc.xgboost4j */ case class LabeledPoint( label: Float, + size: Int, indices: Array[Int], values: Array[Float], weight: Float = 1f, @@ -36,8 +38,11 @@ case class LabeledPoint( require(indices == null || indices.length == values.length, "indices and values must have the same number of elements") - def this(label: Float, indices: Array[Int], values: Array[Float]) = { + require(indices == null || size >= indices.length, + "feature dimensionality must be greater equal than size of indices") + + def this(label: Float, size: Int, indices: Array[Int], values: Array[Float]) = { // [[weight]] default duplicated to disambiguate the constructor call. - this(label, indices, values, 1.0f) + this(label, size, indices, values, 1.0f) } } diff --git a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp index 5bc6f13b11a9..636db2fb4041 100644 --- a/jvm-packages/xgboost4j/src/native/xgboost4j.cpp +++ b/jvm-packages/xgboost4j/src/native/xgboost4j.cpp @@ -91,9 +91,11 @@ XGB_EXTERN_C int XGBoost4jCallbackDataIterNext( batch, jenv->GetFieldID(batchClass, "featureIndex", "[I")); jfloatArray jvalue = (jfloatArray)jenv->GetObjectField( batch, jenv->GetFieldID(batchClass, "featureValue", "[F")); + jint jcols = jenv->GetIntField( + batch, jenv->GetFieldID(batchClass, "featureCols", "I")); XGBoostBatchCSR cbatch; cbatch.size = jenv->GetArrayLength(joffset) - 1; - cbatch.columns = std::numeric_limits::max(); + cbatch.columns = jcols; cbatch.offset = reinterpret_cast( jenv->GetLongArrayElements(joffset, 0)); if (jlabel != nullptr) { diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java index 75f52d877eac..3f6435b9e9b1 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/DMatrixTest.java @@ -45,7 +45,7 @@ public void testCreateFromDataIterator() throws XGBoostError { java.util.List blist = new java.util.LinkedList(); for (int i = 0; i < nrep; ++i) { LabeledPoint p = new LabeledPoint( - 0.1f + i, new int[]{0, 2, 3}, new float[]{3, 4, 5}); + 0.1f + i, 4, new int[]{0, 2, 3}, new float[]{3, 4, 5}); blist.add(p); labelall.add(p.label()); } @@ -57,6 +57,33 @@ public void testCreateFromDataIterator() throws XGBoostError { } } + @Test + public void testCreateFromDataIteratorWithDiffFeatureSize() throws XGBoostError { + //create DMatrix from DataIterator + + java.util.ArrayList labelall = new java.util.ArrayList(); + int nrep = 3000; + java.util.List blist = new java.util.LinkedList(); + int featureSize = 4; + for (int i = 0; i < nrep; ++i) { + // set some rows with wrong feature size + if (i % 10 == 1) { + featureSize = 5; + } + LabeledPoint p = new LabeledPoint( + 0.1f + i, featureSize, new int[]{0, 2, 3}, new float[]{3, 4, 5}); + blist.add(p); + labelall.add(p.label()); + } + boolean success = true; + try { + DMatrix dmat = new DMatrix(blist.iterator(), null); + } catch (XGBoostError e) { + success = false; + } + TestCase.assertTrue(success == false); + } + @Test public void testCreateFromFile() throws XGBoostError { //create DMatrix from file diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 64637695cd53..d0c1396dd2a6 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -125,6 +125,7 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { } else { info_.num_col_ = adapter->NumColumns(); } + // Synchronise worker columns rabit::Allreduce(&info_.num_col_, 1); diff --git a/src/learner.cc b/src/learner.cc index 339f8aa7eb54..1d1acd2aec99 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1063,19 +1063,9 @@ class LearnerImpl : public LearnerIO { return tparam_.dsplit == DataSplitMode::kRow || tparam_.dsplit == DataSplitMode::kAuto; }; - bool const valid_features = - !row_based_split() || - (learner_model_param_.num_feature == p_fmat->Info().num_col_); - std::string const msg { - "Number of columns does not match number of features in booster." - }; - if (generic_parameters_.validate_features) { - CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_) << msg; - } else if (!valid_features) { - // Remove this and make the equality check fatal once spark can fix all failing tests. - LOG(WARNING) << msg << " " - << "Columns: " << p_fmat->Info().num_col_ << " " - << "Features: " << learner_model_param_.num_feature; + if (row_based_split()) { + CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_) + << "Number of columns does not match number of features in booster."; } }