diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java index f00993833321..3e49bf04ac89 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java @@ -33,6 +33,13 @@ import org.apache.spark.sql.types.StructType; // $example off$ +/** + * An example for Bucketizer. + * Run with + *
+ * bin/run-example ml.JavaBucketizerExample
+ * 
+ */ public class JavaBucketizerExample { public static void main(String[] args) { SparkSession spark = SparkSession @@ -68,6 +75,40 @@ public static void main(String[] args) { bucketedData.show(); // $example off$ + // $example on$ + // Bucketize multiple columns at one pass. + double[][] splitsArray = { + {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY}, + {Double.NEGATIVE_INFINITY, -0.3, 0.0, 0.3, Double.POSITIVE_INFINITY} + }; + + List data2 = Arrays.asList( + RowFactory.create(-999.9, -999.9), + RowFactory.create(-0.5, -0.2), + RowFactory.create(-0.3, -0.1), + RowFactory.create(0.0, 0.0), + RowFactory.create(0.2, 0.4), + RowFactory.create(999.9, 999.9) + ); + StructType schema2 = new StructType(new StructField[]{ + new StructField("features1", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("features2", DataTypes.DoubleType, false, Metadata.empty()) + }); + Dataset dataFrame2 = spark.createDataFrame(data2, schema2); + + Bucketizer bucketizer2 = new Bucketizer() + .setInputCols(new String[] {"features1", "features2"}) + .setOutputCols(new String[] {"bucketedFeatures1", "bucketedFeatures2"}) + .setSplitsArray(splitsArray); + // Transform original data into its bucket index. + Dataset bucketedData2 = bucketizer2.transform(dataFrame2); + + System.out.println("Bucketizer output with [" + + (bucketizer2.getSplitsArray()[0].length-1) + ", " + + (bucketizer2.getSplitsArray()[1].length-1) + "] buckets for each input column"); + bucketedData2.show(); + // $example off$ + spark.stop(); } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala index 04e4eccd436e..7e65f9c88907 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala @@ -22,7 +22,13 @@ package org.apache.spark.examples.ml import org.apache.spark.ml.feature.Bucketizer // $example off$ import org.apache.spark.sql.SparkSession - +/** + * An example for Bucketizer. + * Run with + * {{{ + * bin/run-example ml.BucketizerExample + * }}} + */ object BucketizerExample { def main(args: Array[String]): Unit = { val spark = SparkSession @@ -48,6 +54,34 @@ object BucketizerExample { bucketedData.show() // $example off$ + // $example on$ + val splitsArray = Array( + Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity), + Array(Double.NegativeInfinity, -0.3, 0.0, 0.3, Double.PositiveInfinity)) + + val data2 = Array( + (-999.9, -999.9), + (-0.5, -0.2), + (-0.3, -0.1), + (0.0, 0.0), + (0.2, 0.4), + (999.9, 999.9)) + val dataFrame2 = spark.createDataFrame(data2).toDF("features1", "features2") + + val bucketizer2 = new Bucketizer() + .setInputCols(Array("features1", "features2")) + .setOutputCols(Array("bucketedFeatures1", "bucketedFeatures2")) + .setSplitsArray(splitsArray) + + // Transform original data into its bucket index. + val bucketedData2 = bucketizer2.transform(dataFrame2) + + println(s"Bucketizer output with [" + + s"${bucketizer2.getSplitsArray(0).length-1}, " + + s"${bucketizer2.getSplitsArray(1).length-1}] buckets for each input column") + bucketedData2.show() + // $example off$ + spark.stop() } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 6a11a75d1d56..e07f2a107bad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.ml.Model import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.expressions.UserDefinedFunction @@ -32,12 +32,16 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** - * `Bucketizer` maps a column of continuous features to a column of feature buckets. + * `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0, + * `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that + * when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and + * only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is + * only used for single column usage, and `splitsArray` is for multiple columns. */ @Since("1.4.0") final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Model[Bucketizer] with HasHandleInvalid with HasInputCol with HasOutputCol - with DefaultParamsWritable { + with HasInputCols with HasOutputCols with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("bucketizer")) @@ -81,7 +85,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String /** * Param for how to handle invalid entries. Options are 'skip' (filter out rows with * invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special - * additional bucket). + * additional bucket). Note that in the multiple column case, the invalid handling is applied + * to all columns. That said for 'error' it will throw an error if any invalids are found in + * any column, for 'skip' it will skip rows with any invalids in any columns, etc. * Default: "error" * @group param */ @@ -96,9 +102,59 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String def setHandleInvalid(value: String): this.type = set(handleInvalid, value) setDefault(handleInvalid, Bucketizer.ERROR_INVALID) + /** + * Parameter for specifying multiple splits parameters. Each element in this array can be used to + * map continuous features into buckets. + * + * @group param + */ + @Since("2.3.0") + val splitsArray: DoubleArrayArrayParam = new DoubleArrayArrayParam(this, "splitsArray", + "The array of split points for mapping continuous features into buckets for multiple " + + "columns. For each input column, with n+1 splits, there are n buckets. A bucket defined by " + + "splits x,y holds values in the range [x,y) except the last bucket, which also includes y. " + + "The splits should be of length >= 3 and strictly increasing. Values at -inf, inf must be " + + "explicitly provided to cover all Double values; otherwise, values outside the splits " + + "specified will be treated as errors.", + Bucketizer.checkSplitsArray) + + /** @group getParam */ + @Since("2.3.0") + def getSplitsArray: Array[Array[Double]] = $(splitsArray) + + /** @group setParam */ + @Since("2.3.0") + def setSplitsArray(value: Array[Array[Double]]): this.type = set(splitsArray, value) + + /** @group setParam */ + @Since("2.3.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + /** + * Determines whether this `Bucketizer` is going to map multiple columns. If and only if + * `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified + * by `inputCol`. A warning will be printed if both are set. + */ + private[feature] def isBucketizeMultipleColumns(): Boolean = { + if (isSet(inputCols) && isSet(inputCol)) { + logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " + + "`Bucketizer` only map one column specified by `inputCol`") + false + } else if (isSet(inputCols)) { + true + } else { + false + } + } + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema) + val transformedSchema = transformSchema(dataset.schema) + val (filteredDataset, keepInvalid) = { if (getHandleInvalid == Bucketizer.SKIP_INVALID) { // "skip" NaN option is set, will filter out NaN values in the dataset @@ -108,26 +164,53 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String } } - val bucketizer: UserDefinedFunction = udf { (feature: Double) => - Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) - }.withName("bucketizer") + val seqOfSplits = if (isBucketizeMultipleColumns()) { + $(splitsArray).toSeq + } else { + Seq($(splits)) + } - val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType)) - val newField = prepOutputField(filteredDataset.schema) - filteredDataset.withColumn($(outputCol), newCol, newField.metadata) + val bucketizers: Seq[UserDefinedFunction] = seqOfSplits.zipWithIndex.map { case (splits, idx) => + udf { (feature: Double) => + Bucketizer.binarySearchForBuckets(splits, feature, keepInvalid) + }.withName(s"bucketizer_$idx") + } + + val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) { + ($(inputCols).toSeq, $(outputCols).toSeq) + } else { + (Seq($(inputCol)), Seq($(outputCol))) + } + val newCols = inputColumns.zipWithIndex.map { case (inputCol, idx) => + bucketizers(idx)(filteredDataset(inputCol).cast(DoubleType)) + } + val metadata = outputColumns.map { col => + transformedSchema(col).metadata + } + filteredDataset.withColumns(outputColumns, newCols, metadata) } - private def prepOutputField(schema: StructType): StructField = { - val buckets = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray - val attr = new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true), + private def prepOutputField(splits: Array[Double], outputCol: String): StructField = { + val buckets = splits.sliding(2).map(bucket => bucket.mkString(", ")).toArray + val attr = new NominalAttribute(name = Some(outputCol), isOrdinal = Some(true), values = Some(buckets)) attr.toStructField() } @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkNumericType(schema, $(inputCol)) - SchemaUtils.appendColumn(schema, prepOutputField(schema)) + if (isBucketizeMultipleColumns()) { + var transformedSchema = schema + $(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) => + SchemaUtils.checkNumericType(transformedSchema, inputCol) + transformedSchema = SchemaUtils.appendColumn(transformedSchema, + prepOutputField($(splitsArray)(idx), outputCol)) + } + transformedSchema + } else { + SchemaUtils.checkNumericType(schema, $(inputCol)) + SchemaUtils.appendColumn(schema, prepOutputField($(splits), $(outputCol))) + } } @Since("1.4.1") @@ -163,6 +246,13 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { } } + /** + * Check each splits in the splits array. + */ + private[feature] def checkSplitsArray(splitsArray: Array[Array[Double]]): Boolean = { + splitsArray.forall(checkSplits(_)) + } + /** * Binary searching in several buckets to place each data point. * @param splits array of split points diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ac68b825af53..8985f2af90a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -490,6 +490,45 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array } } +/** + * :: DeveloperApi :: + * Specialized version of `Param[Array[Array[Double]]]` for Java. + */ +@DeveloperApi +class DoubleArrayArrayParam( + parent: Params, + name: String, + doc: String, + isValid: Array[Array[Double]] => Boolean) + extends Param[Array[Array[Double]]](parent, name, doc, isValid) { + + def this(parent: Params, name: String, doc: String) = + this(parent, name, doc, ParamValidators.alwaysTrue) + + /** Creates a param pair with a `java.util.List` of values (for Java and Python). */ + def w(value: java.util.List[java.util.List[java.lang.Double]]): ParamPair[Array[Array[Double]]] = + w(value.asScala.map(_.asScala.map(_.asInstanceOf[Double]).toArray).toArray) + + override def jsonEncode(value: Array[Array[Double]]): String = { + import org.json4s.JsonDSL._ + compact(render(value.toSeq.map(_.toSeq.map(DoubleParam.jValueEncode)))) + } + + override def jsonDecode(json: String): Array[Array[Double]] = { + parse(json) match { + case JArray(values) => + values.map { + case JArray(values) => + values.map(DoubleParam.jValueDecode).toArray + case _ => + throw new IllegalArgumentException(s"Cannot decode $json to Array[Array[Double]].") + }.toArray + case _ => + throw new IllegalArgumentException(s"Cannot decode $json to Array[Array[Double]].") + } + } +} + /** * :: DeveloperApi :: * Specialized version of `Param[Array[Int]]` for Java. diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index a932d28fadbd..20a1db854e3a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -60,6 +60,7 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), + ParamDesc[Array[String]]("outputCols", "output column names"), ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " + "disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " + "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index e6bdf5236e72..0d5fb28ae783 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -257,6 +257,23 @@ trait HasOutputCol extends Params { final def getOutputCol: String = $(outputCol) } +/** + * Trait for shared param outputCols. This trait may be changed or + * removed between minor versions. + */ +@DeveloperApi +trait HasOutputCols extends Params { + + /** + * Param for output column names. + * @group param + */ + final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", "output column names") + + /** @group getParam */ + final def getOutputCols: Array[String] = $(outputCols) +} + /** * Trait for shared param checkpointInterval. This trait may be changed or * removed between minor versions. diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java index 87639380bdcf..e65265bf74a8 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -61,4 +61,39 @@ public void bucketizerTest() { Assert.assertTrue((index >= 0) && (index <= 1)); } } + + @Test + public void bucketizerMultipleColumnsTest() { + double[][] splitsArray = { + {-0.5, 0.0, 0.5}, + {-0.5, 0.0, 0.2, 0.5} + }; + + StructType schema = new StructType(new StructField[]{ + new StructField("feature1", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("feature2", DataTypes.DoubleType, false, Metadata.empty()), + }); + Dataset dataset = spark.createDataFrame( + Arrays.asList( + RowFactory.create(-0.5, -0.5), + RowFactory.create(-0.3, -0.3), + RowFactory.create(0.0, 0.0), + RowFactory.create(0.2, 0.3)), + schema); + + Bucketizer bucketizer = new Bucketizer() + .setInputCols(new String[] {"feature1", "feature2"}) + .setOutputCols(new String[] {"result1", "result2"}) + .setSplitsArray(splitsArray); + + List result = bucketizer.transform(dataset).select("result1", "result2").collectAsList(); + + for (Row r : result) { + double index1 = r.getDouble(0); + Assert.assertTrue((index1 >= 0) && (index1 <= 1)); + + double index2 = r.getDouble(1); + Assert.assertTrue((index2 >= 0) && (index2 <= 2)); + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 420fb17ddce8..748dbd1b995d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -20,9 +20,10 @@ package org.apache.spark.ml.feature import scala.util.Random import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.Pipeline import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -187,6 +188,220 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } } } + + test("multiple columns: Bucket continuous features, without -inf,inf") { + // Check a set of valid feature values. + val splits = Array(Array(-0.5, 0.0, 0.5), Array(-0.1, 0.3, 0.5)) + val validData1 = Array(-0.5, -0.3, 0.0, 0.2) + val validData2 = Array(0.5, 0.3, 0.0, -0.1) + val expectedBuckets1 = Array(0.0, 0.0, 1.0, 1.0) + val expectedBuckets2 = Array(1.0, 1.0, 0.0, 0.0) + + val data = (0 until validData1.length).map { idx => + (validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx)) + } + val dataFrame: DataFrame = data.toDF("feature1", "feature2", "expected1", "expected2") + + val bucketizer1: Bucketizer = new Bucketizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setSplitsArray(splits) + + assert(bucketizer1.isBucketizeMultipleColumns()) + + bucketizer1.transform(dataFrame).select("result1", "expected1", "result2", "expected2") + BucketizerSuite.checkBucketResults(bucketizer1.transform(dataFrame), + Seq("result1", "result2"), + Seq("expected1", "expected2")) + + // Check for exceptions when using a set of invalid feature values. + val invalidData1 = Array(-0.9) ++ validData1 + val invalidData2 = Array(0.51) ++ validData1 + val badDF1 = invalidData1.zipWithIndex.toSeq.toDF("feature", "idx") + + val bucketizer2: Bucketizer = new Bucketizer() + .setInputCols(Array("feature")) + .setOutputCols(Array("result")) + .setSplitsArray(Array(splits(0))) + + assert(bucketizer2.isBucketizeMultipleColumns()) + + withClue("Invalid feature value -0.9 was not caught as an invalid feature!") { + intercept[SparkException] { + bucketizer2.transform(badDF1).collect() + } + } + val badDF2 = invalidData2.zipWithIndex.toSeq.toDF("feature", "idx") + withClue("Invalid feature value 0.51 was not caught as an invalid feature!") { + intercept[SparkException] { + bucketizer2.transform(badDF2).collect() + } + } + } + + test("multiple columns: Bucket continuous features, with -inf,inf") { + val splits = Array( + Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity), + Array(Double.NegativeInfinity, -0.3, 0.2, 0.5, Double.PositiveInfinity)) + + val validData1 = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9) + val validData2 = Array(-0.1, -0.5, -0.2, 0.0, 0.1, 0.3, 0.5) + val expectedBuckets1 = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0) + val expectedBuckets2 = Array(1.0, 0.0, 1.0, 1.0, 1.0, 2.0, 3.0) + + val data = (0 until validData1.length).map { idx => + (validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx)) + } + val dataFrame: DataFrame = data.toDF("feature1", "feature2", "expected1", "expected2") + + val bucketizer: Bucketizer = new Bucketizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setSplitsArray(splits) + + assert(bucketizer.isBucketizeMultipleColumns()) + + BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame), + Seq("result1", "result2"), + Seq("expected1", "expected2")) + } + + test("multiple columns: Bucket continuous features, with NaN data but non-NaN splits") { + val splits = Array( + Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity), + Array(Double.NegativeInfinity, -0.1, 0.2, 0.6, Double.PositiveInfinity)) + + val validData1 = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN) + val validData2 = Array(0.2, -0.1, 0.3, 0.0, 0.1, 0.3, 0.5, 0.8, Double.NaN, Double.NaN) + val expectedBuckets1 = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 4.0) + val expectedBuckets2 = Array(2.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 3.0, 4.0, 4.0) + + val data = (0 until validData1.length).map { idx => + (validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx)) + } + val dataFrame: DataFrame = data.toDF("feature1", "feature2", "expected1", "expected2") + + val bucketizer: Bucketizer = new Bucketizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setSplitsArray(splits) + + assert(bucketizer.isBucketizeMultipleColumns()) + + bucketizer.setHandleInvalid("keep") + BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame), + Seq("result1", "result2"), + Seq("expected1", "expected2")) + + bucketizer.setHandleInvalid("skip") + val skipResults1: Array[Double] = bucketizer.transform(dataFrame) + .select("result1").as[Double].collect() + assert(skipResults1.length === 7) + assert(skipResults1.forall(_ !== 4.0)) + + val skipResults2: Array[Double] = bucketizer.transform(dataFrame) + .select("result2").as[Double].collect() + assert(skipResults2.length === 7) + assert(skipResults2.forall(_ !== 4.0)) + + bucketizer.setHandleInvalid("error") + withClue("Bucketizer should throw error when setHandleInvalid=error and given NaN values") { + intercept[SparkException] { + bucketizer.transform(dataFrame).collect() + } + } + } + + test("multiple columns: Bucket continuous features, with NaN splits") { + val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN) + withClue("Invalid NaN split was not caught during Bucketizer initialization") { + intercept[IllegalArgumentException] { + new Bucketizer().setSplitsArray(Array(splits)) + } + } + } + + test("multiple columns: read/write") { + val t = new Bucketizer() + .setInputCols(Array("myInputCol")) + .setOutputCols(Array("myOutputCol")) + .setSplitsArray(Array(Array(0.1, 0.8, 0.9))) + assert(t.isBucketizeMultipleColumns()) + testDefaultReadWrite(t) + } + + test("Bucketizer in a pipeline") { + val df = Seq((0.5, 0.3, 1.0, 1.0), (0.5, -0.4, 1.0, 0.0)) + .toDF("feature1", "feature2", "expected1", "expected2") + + val bucket = new Bucketizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5))) + + assert(bucket.isBucketizeMultipleColumns()) + + val pl = new Pipeline() + .setStages(Array(bucket)) + .fit(df) + pl.transform(df).select("result1", "expected1", "result2", "expected2") + + BucketizerSuite.checkBucketResults(pl.transform(df), + Seq("result1", "result2"), Seq("expected1", "expected2")) + } + + test("Compare single/multiple column(s) Bucketizer in pipeline") { + val df = Seq((0.5, 0.3, 1.0, 1.0), (0.5, -0.4, 1.0, 0.0)) + .toDF("feature1", "feature2", "expected1", "expected2") + + val multiColsBucket = new Bucketizer() + .setInputCols(Array("feature1", "feature2")) + .setOutputCols(Array("result1", "result2")) + .setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5))) + + val plForMultiCols = new Pipeline() + .setStages(Array(multiColsBucket)) + .fit(df) + + val bucketForCol1 = new Bucketizer() + .setInputCol("feature1") + .setOutputCol("result1") + .setSplits(Array(-0.5, 0.0, 0.5)) + val bucketForCol2 = new Bucketizer() + .setInputCol("feature2") + .setOutputCol("result2") + .setSplits(Array(-0.5, 0.0, 0.5)) + + val plForSingleCol = new Pipeline() + .setStages(Array(bucketForCol1, bucketForCol2)) + .fit(df) + + val resultForSingleCol = plForSingleCol.transform(df) + .select("result1", "expected1", "result2", "expected2") + .collect() + val resultForMultiCols = plForMultiCols.transform(df) + .select("result1", "expected1", "result2", "expected2") + .collect() + + resultForSingleCol.zip(resultForMultiCols).foreach { + case (rowForSingle, rowForMultiCols) => + assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) && + rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) && + rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2) && + rowForSingle.getDouble(3) == rowForMultiCols.getDouble(3)) + } + } + + test("Both inputCol and inputCols are set") { + val bucket = new Bucketizer() + .setInputCol("feature1") + .setOutputCol("result") + .setSplits(Array(-0.5, 0.0, 0.5)) + .setInputCols(Array("feature1", "feature2")) + + // When both are set, we ignore `inputCols` and just map the column specified by `inputCol`. + assert(bucket.isBucketizeMultipleColumns() == false) + } } private object BucketizerSuite extends SparkFunSuite { @@ -220,4 +435,26 @@ private object BucketizerSuite extends SparkFunSuite { i += 1 } } + + /** Checks if bucketized results match expected ones. */ + def checkBucketResults( + bucketResult: DataFrame, + resultColumns: Seq[String], + expectedColumns: Seq[String]): Unit = { + assert(resultColumns.length == expectedColumns.length, + s"Given ${resultColumns.length} result columns doesn't match " + + s"${expectedColumns.length} expected columns.") + assert(resultColumns.length > 0, "At least one result and expected columns are needed.") + + val allColumns = resultColumns ++ expectedColumns + bucketResult.select(allColumns.head, allColumns.tail: _*).collect().foreach { + case row => + for (idx <- 0 until row.length / 2) { + val result = row.getDouble(idx) + val expected = row.getDouble(idx + row.length / 2) + assert(result === expected, "The feature value is not correct after bucketing. " + + s"Expected $expected but found $result.") + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 78a33e05e0e4..85198ad4c913 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -121,10 +121,10 @@ class ParamsSuite extends SparkFunSuite { { // DoubleArrayParam val param = new DoubleArrayParam(dummy, "name", "doc") val values: Seq[Array[Double]] = Seq( - Array(), - Array(1.0), - Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0, - Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity)) + Array(), + Array(1.0), + Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0, + Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity)) for (value <- values) { val json = param.jsonEncode(value) val decoded = param.jsonDecode(json) @@ -139,6 +139,36 @@ class ParamsSuite extends SparkFunSuite { } } + { // DoubleArrayArrayParam + val param = new DoubleArrayArrayParam(dummy, "name", "doc") + val values: Seq[Array[Array[Double]]] = Seq( + Array(Array()), + Array(Array(1.0)), + Array(Array(1.0), Array(2.0)), + Array( + Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0, + Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity), + Array(Double.MaxValue, Double.PositiveInfinity, Double.MinPositiveValue, 1.0, + Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0) + )) + + for (value <- values) { + val json = param.jsonEncode(value) + val decoded = param.jsonDecode(json) + assert(decoded.length === value.length) + decoded.zip(value).foreach { case (actualArray, expectedArray) => + assert(actualArray.length === expectedArray.length) + actualArray.zip(expectedArray).foreach { case (actual, expected) => + if (expected.isNaN) { + assert(actual.isNaN) + } else { + assert(actual === expected) + } + } + } + } + } + { // StringArrayParam val param = new StringArrayParam(dummy, "name", "doc") val values: Seq[Array[String]] = Seq( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index bd99ec52ce93..5eb2affa0bd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2135,12 +2135,27 @@ class Dataset[T] private[sql]( } /** - * Returns a new Dataset by adding a column with metadata. + * Returns a new Dataset by adding columns with metadata. */ - private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { - withColumn(colName, col.as(colName, metadata)) + private[spark] def withColumns( + colNames: Seq[String], + cols: Seq[Column], + metadata: Seq[Metadata]): DataFrame = { + require(colNames.size == metadata.size, + s"The size of column names: ${colNames.size} isn't equal to " + + s"the size of metadata elements: ${metadata.size}") + val newCols = colNames.zip(cols).zip(metadata).map { case ((colName, col), metadata) => + col.as(colName, metadata) + } + withColumns(colNames, newCols) } + /** + * Returns a new Dataset by adding a column with metadata. + */ + private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = + withColumns(Seq(colName), Seq(col), Seq(metadata)) + /** * Returns a new Dataset with a column renamed. * This is a no-op if schema doesn't contain existingName. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 17c88b069080..31bfa77e7632 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -686,6 +686,34 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } } + test("withColumns: given metadata") { + def buildMetadata(num: Int): Seq[Metadata] = { + (0 until num).map { n => + val builder = new MetadataBuilder + builder.putLong("key", n.toLong) + builder.build() + } + } + + val df = testData.toDF().withColumns( + Seq("newCol1", "newCol2"), + Seq(col("key") + 1, col("key") + 2), + buildMetadata(2)) + + df.select("newCol1", "newCol2").schema.zipWithIndex.foreach { case (col, idx) => + assert(col.metadata.getLong("key").toInt === idx) + } + + val err = intercept[IllegalArgumentException] { + testData.toDF().withColumns( + Seq("newCol1", "newCol2"), + Seq(col("key") + 1, col("key") + 2), + buildMetadata(1)) + } + assert(err.getMessage.contains( + "The size of column names: 2 isn't equal to the size of metadata elements: 1")) + } + test("replace column using withColumn") { val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1)