From 01604c7e0ed4dc5c87ae980cc0720cf4858aa154 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Tue, 5 Dec 2017 19:49:50 +0800 Subject: [PATCH 1/2] create pr --- .../org/apache/spark/ml/feature/Bucketizer.scala | 14 ++++++++------ .../apache/spark/ml/feature/BucketizerSuite.scala | 9 +++++++++ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index e07f2a107badb..8299a3e95d822 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -155,10 +155,16 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema) + val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) { + ($(inputCols).toSeq, $(outputCols).toSeq) + } else { + (Seq($(inputCol)), Seq($(outputCol))) + } + val (filteredDataset, keepInvalid) = { if (getHandleInvalid == Bucketizer.SKIP_INVALID) { // "skip" NaN option is set, will filter out NaN values in the dataset - (dataset.na.drop().toDF(), false) + (dataset.na.drop(inputColumns).toDF(), false) } else { (dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID) } @@ -176,11 +182,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String }.withName(s"bucketizer_$idx") } - val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) { - ($(inputCols).toSeq, $(outputCols).toSeq) - } else { - (Seq($(inputCol)), Seq($(outputCol))) - } + val newCols = inputColumns.zipWithIndex.map { case (inputCol, idx) => bucketizers(idx)(filteredDataset(inputCol).cast(DoubleType)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 748dbd1b995d3..872f43e982a5f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -109,6 +109,15 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa s"The feature value is not correct after bucketing. Expected $y but found $x") } + test("Bucket should only drop NaN in input columns, with handleInvalid=skip") { + val df = spark.createDataFrame(Seq((2.3, 3.0), (Double.NaN, 3.0), (6.7, Double.NaN) + )).toDF("a", "b") + val splits = Array(Double.NegativeInfinity, 3.0, Double.PositiveInfinity) + val bucketizer = new Bucketizer().setInputCol("a").setOutputCol("x").setSplits(splits) + bucketizer.setHandleInvalid("skip") + assert(bucketizer.transform(df).count() == 2) + } + bucketizer.setHandleInvalid("skip") val skipResults: Array[Double] = bucketizer.transform(dataFrame) .select("result").as[Double].collect() From eaebedbaba514e90f1c463f1fdbc37e0b39b51da Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 5 Dec 2017 21:09:12 +0800 Subject: [PATCH 2/2] Update BucketizerSuite.scala --- .../spark/ml/feature/BucketizerSuite.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 872f43e982a5f..d9c97ae8067d3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -109,15 +109,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa s"The feature value is not correct after bucketing. Expected $y but found $x") } - test("Bucket should only drop NaN in input columns, with handleInvalid=skip") { - val df = spark.createDataFrame(Seq((2.3, 3.0), (Double.NaN, 3.0), (6.7, Double.NaN) - )).toDF("a", "b") - val splits = Array(Double.NegativeInfinity, 3.0, Double.PositiveInfinity) - val bucketizer = new Bucketizer().setInputCol("a").setOutputCol("x").setSplits(splits) - bucketizer.setHandleInvalid("skip") - assert(bucketizer.transform(df).count() == 2) - } - bucketizer.setHandleInvalid("skip") val skipResults: Array[Double] = bucketizer.transform(dataFrame) .select("result").as[Double].collect() @@ -132,6 +123,15 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } } + test("Bucketizer should only drop NaN in input columns, with handleInvalid=skip") { + val df = spark.createDataFrame(Seq((2.3, 3.0), (Double.NaN, 3.0), (6.7, Double.NaN))) + .toDF("a", "b") + val splits = Array(Double.NegativeInfinity, 3.0, Double.PositiveInfinity) + val bucketizer = new Bucketizer().setInputCol("a").setOutputCol("x").setSplits(splits) + bucketizer.setHandleInvalid("skip") + assert(bucketizer.transform(df).count() == 2) + } + test("Bucket continuous features, with NaN splits") { val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN) withClue("Invalid NaN split was not caught during Bucketizer initialization") {