diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index e07f2a107badb..8299a3e95d822 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -155,10 +155,16 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema) + val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) { + ($(inputCols).toSeq, $(outputCols).toSeq) + } else { + (Seq($(inputCol)), Seq($(outputCol))) + } + val (filteredDataset, keepInvalid) = { if (getHandleInvalid == Bucketizer.SKIP_INVALID) { // "skip" NaN option is set, will filter out NaN values in the dataset - (dataset.na.drop().toDF(), false) + (dataset.na.drop(inputColumns).toDF(), false) } else { (dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID) } @@ -176,11 +182,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String }.withName(s"bucketizer_$idx") } - val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) { - ($(inputCols).toSeq, $(outputCols).toSeq) - } else { - (Seq($(inputCol)), Seq($(outputCol))) - } + val newCols = inputColumns.zipWithIndex.map { case (inputCol, idx) => bucketizers(idx)(filteredDataset(inputCol).cast(DoubleType)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 748dbd1b995d3..d9c97ae8067d3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -123,6 +123,15 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } } + test("Bucketizer should only drop NaN in input columns, with handleInvalid=skip") { + val df = spark.createDataFrame(Seq((2.3, 3.0), (Double.NaN, 3.0), (6.7, Double.NaN))) + .toDF("a", "b") + val splits = Array(Double.NegativeInfinity, 3.0, Double.PositiveInfinity) + val bucketizer = new Bucketizer().setInputCol("a").setOutputCol("x").setSplits(splits) + bucketizer.setHandleInvalid("skip") + assert(bucketizer.transform(df).count() == 2) + } + test("Bucket continuous features, with NaN splits") { val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN) withClue("Invalid NaN split was not caught during Bucketizer initialization") {