diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java
index f00993833321..3e49bf04ac89 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketizerExample.java
@@ -33,6 +33,13 @@
import org.apache.spark.sql.types.StructType;
// $example off$
+/**
+ * An example for Bucketizer.
+ * Run with
+ *
+ * bin/run-example ml.JavaBucketizerExample
+ *
+ */
public class JavaBucketizerExample {
public static void main(String[] args) {
SparkSession spark = SparkSession
@@ -68,6 +75,40 @@ public static void main(String[] args) {
bucketedData.show();
// $example off$
+ // $example on$
+ // Bucketize multiple columns at one pass.
+ double[][] splitsArray = {
+ {Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY},
+ {Double.NEGATIVE_INFINITY, -0.3, 0.0, 0.3, Double.POSITIVE_INFINITY}
+ };
+
+ List data2 = Arrays.asList(
+ RowFactory.create(-999.9, -999.9),
+ RowFactory.create(-0.5, -0.2),
+ RowFactory.create(-0.3, -0.1),
+ RowFactory.create(0.0, 0.0),
+ RowFactory.create(0.2, 0.4),
+ RowFactory.create(999.9, 999.9)
+ );
+ StructType schema2 = new StructType(new StructField[]{
+ new StructField("features1", DataTypes.DoubleType, false, Metadata.empty()),
+ new StructField("features2", DataTypes.DoubleType, false, Metadata.empty())
+ });
+ Dataset dataFrame2 = spark.createDataFrame(data2, schema2);
+
+ Bucketizer bucketizer2 = new Bucketizer()
+ .setInputCols(new String[] {"features1", "features2"})
+ .setOutputCols(new String[] {"bucketedFeatures1", "bucketedFeatures2"})
+ .setSplitsArray(splitsArray);
+ // Transform original data into its bucket index.
+ Dataset bucketedData2 = bucketizer2.transform(dataFrame2);
+
+ System.out.println("Bucketizer output with [" +
+ (bucketizer2.getSplitsArray()[0].length-1) + ", " +
+ (bucketizer2.getSplitsArray()[1].length-1) + "] buckets for each input column");
+ bucketedData2.show();
+ // $example off$
+
spark.stop();
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala
index 04e4eccd436e..7e65f9c88907 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala
@@ -22,7 +22,13 @@ package org.apache.spark.examples.ml
import org.apache.spark.ml.feature.Bucketizer
// $example off$
import org.apache.spark.sql.SparkSession
-
+/**
+ * An example for Bucketizer.
+ * Run with
+ * {{{
+ * bin/run-example ml.BucketizerExample
+ * }}}
+ */
object BucketizerExample {
def main(args: Array[String]): Unit = {
val spark = SparkSession
@@ -48,6 +54,34 @@ object BucketizerExample {
bucketedData.show()
// $example off$
+ // $example on$
+ val splitsArray = Array(
+ Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity),
+ Array(Double.NegativeInfinity, -0.3, 0.0, 0.3, Double.PositiveInfinity))
+
+ val data2 = Array(
+ (-999.9, -999.9),
+ (-0.5, -0.2),
+ (-0.3, -0.1),
+ (0.0, 0.0),
+ (0.2, 0.4),
+ (999.9, 999.9))
+ val dataFrame2 = spark.createDataFrame(data2).toDF("features1", "features2")
+
+ val bucketizer2 = new Bucketizer()
+ .setInputCols(Array("features1", "features2"))
+ .setOutputCols(Array("bucketedFeatures1", "bucketedFeatures2"))
+ .setSplitsArray(splitsArray)
+
+ // Transform original data into its bucket index.
+ val bucketedData2 = bucketizer2.transform(dataFrame2)
+
+ println(s"Bucketizer output with [" +
+ s"${bucketizer2.getSplitsArray(0).length-1}, " +
+ s"${bucketizer2.getSplitsArray(1).length-1}] buckets for each input column")
+ bucketedData2.show()
+ // $example off$
+
spark.stop()
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 6a11a75d1d56..e07f2a107bad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -24,7 +24,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.ml.Model
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol}
+import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.UserDefinedFunction
@@ -32,12 +32,16 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
/**
- * `Bucketizer` maps a column of continuous features to a column of feature buckets.
+ * `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0,
+ * `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that
+ * when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and
+ * only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is
+ * only used for single column usage, and `splitsArray` is for multiple columns.
*/
@Since("1.4.0")
final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Model[Bucketizer] with HasHandleInvalid with HasInputCol with HasOutputCol
- with DefaultParamsWritable {
+ with HasInputCols with HasOutputCols with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("bucketizer"))
@@ -81,7 +85,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
/**
* Param for how to handle invalid entries. Options are 'skip' (filter out rows with
* invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special
- * additional bucket).
+ * additional bucket). Note that in the multiple column case, the invalid handling is applied
+ * to all columns. That said for 'error' it will throw an error if any invalids are found in
+ * any column, for 'skip' it will skip rows with any invalids in any columns, etc.
* Default: "error"
* @group param
*/
@@ -96,9 +102,59 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, Bucketizer.ERROR_INVALID)
+ /**
+ * Parameter for specifying multiple splits parameters. Each element in this array can be used to
+ * map continuous features into buckets.
+ *
+ * @group param
+ */
+ @Since("2.3.0")
+ val splitsArray: DoubleArrayArrayParam = new DoubleArrayArrayParam(this, "splitsArray",
+ "The array of split points for mapping continuous features into buckets for multiple " +
+ "columns. For each input column, with n+1 splits, there are n buckets. A bucket defined by " +
+ "splits x,y holds values in the range [x,y) except the last bucket, which also includes y. " +
+ "The splits should be of length >= 3 and strictly increasing. Values at -inf, inf must be " +
+ "explicitly provided to cover all Double values; otherwise, values outside the splits " +
+ "specified will be treated as errors.",
+ Bucketizer.checkSplitsArray)
+
+ /** @group getParam */
+ @Since("2.3.0")
+ def getSplitsArray: Array[Array[Double]] = $(splitsArray)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setSplitsArray(value: Array[Array[Double]]): this.type = set(splitsArray, value)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setInputCols(value: Array[String]): this.type = set(inputCols, value)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
+
+ /**
+ * Determines whether this `Bucketizer` is going to map multiple columns. If and only if
+ * `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified
+ * by `inputCol`. A warning will be printed if both are set.
+ */
+ private[feature] def isBucketizeMultipleColumns(): Boolean = {
+ if (isSet(inputCols) && isSet(inputCol)) {
+ logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " +
+ "`Bucketizer` only map one column specified by `inputCol`")
+ false
+ } else if (isSet(inputCols)) {
+ true
+ } else {
+ false
+ }
+ }
+
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
- transformSchema(dataset.schema)
+ val transformedSchema = transformSchema(dataset.schema)
+
val (filteredDataset, keepInvalid) = {
if (getHandleInvalid == Bucketizer.SKIP_INVALID) {
// "skip" NaN option is set, will filter out NaN values in the dataset
@@ -108,26 +164,53 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
}
}
- val bucketizer: UserDefinedFunction = udf { (feature: Double) =>
- Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid)
- }.withName("bucketizer")
+ val seqOfSplits = if (isBucketizeMultipleColumns()) {
+ $(splitsArray).toSeq
+ } else {
+ Seq($(splits))
+ }
- val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType))
- val newField = prepOutputField(filteredDataset.schema)
- filteredDataset.withColumn($(outputCol), newCol, newField.metadata)
+ val bucketizers: Seq[UserDefinedFunction] = seqOfSplits.zipWithIndex.map { case (splits, idx) =>
+ udf { (feature: Double) =>
+ Bucketizer.binarySearchForBuckets(splits, feature, keepInvalid)
+ }.withName(s"bucketizer_$idx")
+ }
+
+ val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) {
+ ($(inputCols).toSeq, $(outputCols).toSeq)
+ } else {
+ (Seq($(inputCol)), Seq($(outputCol)))
+ }
+ val newCols = inputColumns.zipWithIndex.map { case (inputCol, idx) =>
+ bucketizers(idx)(filteredDataset(inputCol).cast(DoubleType))
+ }
+ val metadata = outputColumns.map { col =>
+ transformedSchema(col).metadata
+ }
+ filteredDataset.withColumns(outputColumns, newCols, metadata)
}
- private def prepOutputField(schema: StructType): StructField = {
- val buckets = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray
- val attr = new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true),
+ private def prepOutputField(splits: Array[Double], outputCol: String): StructField = {
+ val buckets = splits.sliding(2).map(bucket => bucket.mkString(", ")).toArray
+ val attr = new NominalAttribute(name = Some(outputCol), isOrdinal = Some(true),
values = Some(buckets))
attr.toStructField()
}
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
- SchemaUtils.checkNumericType(schema, $(inputCol))
- SchemaUtils.appendColumn(schema, prepOutputField(schema))
+ if (isBucketizeMultipleColumns()) {
+ var transformedSchema = schema
+ $(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) =>
+ SchemaUtils.checkNumericType(transformedSchema, inputCol)
+ transformedSchema = SchemaUtils.appendColumn(transformedSchema,
+ prepOutputField($(splitsArray)(idx), outputCol))
+ }
+ transformedSchema
+ } else {
+ SchemaUtils.checkNumericType(schema, $(inputCol))
+ SchemaUtils.appendColumn(schema, prepOutputField($(splits), $(outputCol)))
+ }
}
@Since("1.4.1")
@@ -163,6 +246,13 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
}
}
+ /**
+ * Check each splits in the splits array.
+ */
+ private[feature] def checkSplitsArray(splitsArray: Array[Array[Double]]): Boolean = {
+ splitsArray.forall(checkSplits(_))
+ }
+
/**
* Binary searching in several buckets to place each data point.
* @param splits array of split points
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index ac68b825af53..8985f2af90a9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -490,6 +490,45 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
}
}
+/**
+ * :: DeveloperApi ::
+ * Specialized version of `Param[Array[Array[Double]]]` for Java.
+ */
+@DeveloperApi
+class DoubleArrayArrayParam(
+ parent: Params,
+ name: String,
+ doc: String,
+ isValid: Array[Array[Double]] => Boolean)
+ extends Param[Array[Array[Double]]](parent, name, doc, isValid) {
+
+ def this(parent: Params, name: String, doc: String) =
+ this(parent, name, doc, ParamValidators.alwaysTrue)
+
+ /** Creates a param pair with a `java.util.List` of values (for Java and Python). */
+ def w(value: java.util.List[java.util.List[java.lang.Double]]): ParamPair[Array[Array[Double]]] =
+ w(value.asScala.map(_.asScala.map(_.asInstanceOf[Double]).toArray).toArray)
+
+ override def jsonEncode(value: Array[Array[Double]]): String = {
+ import org.json4s.JsonDSL._
+ compact(render(value.toSeq.map(_.toSeq.map(DoubleParam.jValueEncode))))
+ }
+
+ override def jsonDecode(json: String): Array[Array[Double]] = {
+ parse(json) match {
+ case JArray(values) =>
+ values.map {
+ case JArray(values) =>
+ values.map(DoubleParam.jValueDecode).toArray
+ case _ =>
+ throw new IllegalArgumentException(s"Cannot decode $json to Array[Array[Double]].")
+ }.toArray
+ case _ =>
+ throw new IllegalArgumentException(s"Cannot decode $json to Array[Array[Double]].")
+ }
+ }
+}
+
/**
* :: DeveloperApi ::
* Specialized version of `Param[Array[Int]]` for Java.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index a932d28fadbd..20a1db854e3a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -60,6 +60,7 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[String]("inputCol", "input column name"),
ParamDesc[Array[String]]("inputCols", "input column names"),
ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
+ ParamDesc[Array[String]]("outputCols", "output column names"),
ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " +
"disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " +
"every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index e6bdf5236e72..0d5fb28ae783 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -257,6 +257,23 @@ trait HasOutputCol extends Params {
final def getOutputCol: String = $(outputCol)
}
+/**
+ * Trait for shared param outputCols. This trait may be changed or
+ * removed between minor versions.
+ */
+@DeveloperApi
+trait HasOutputCols extends Params {
+
+ /**
+ * Param for output column names.
+ * @group param
+ */
+ final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", "output column names")
+
+ /** @group getParam */
+ final def getOutputCols: Array[String] = $(outputCols)
+}
+
/**
* Trait for shared param checkpointInterval. This trait may be changed or
* removed between minor versions.
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
index 87639380bdcf..e65265bf74a8 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
@@ -61,4 +61,39 @@ public void bucketizerTest() {
Assert.assertTrue((index >= 0) && (index <= 1));
}
}
+
+ @Test
+ public void bucketizerMultipleColumnsTest() {
+ double[][] splitsArray = {
+ {-0.5, 0.0, 0.5},
+ {-0.5, 0.0, 0.2, 0.5}
+ };
+
+ StructType schema = new StructType(new StructField[]{
+ new StructField("feature1", DataTypes.DoubleType, false, Metadata.empty()),
+ new StructField("feature2", DataTypes.DoubleType, false, Metadata.empty()),
+ });
+ Dataset dataset = spark.createDataFrame(
+ Arrays.asList(
+ RowFactory.create(-0.5, -0.5),
+ RowFactory.create(-0.3, -0.3),
+ RowFactory.create(0.0, 0.0),
+ RowFactory.create(0.2, 0.3)),
+ schema);
+
+ Bucketizer bucketizer = new Bucketizer()
+ .setInputCols(new String[] {"feature1", "feature2"})
+ .setOutputCols(new String[] {"result1", "result2"})
+ .setSplitsArray(splitsArray);
+
+ List result = bucketizer.transform(dataset).select("result1", "result2").collectAsList();
+
+ for (Row r : result) {
+ double index1 = r.getDouble(0);
+ Assert.assertTrue((index1 >= 0) && (index1 <= 1));
+
+ double index2 = r.getDouble(1);
+ Assert.assertTrue((index2 >= 0) && (index2 <= 2));
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index 420fb17ddce8..748dbd1b995d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -20,9 +20,10 @@ package org.apache.spark.ml.feature
import scala.util.Random
import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@@ -187,6 +188,220 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
}
}
}
+
+ test("multiple columns: Bucket continuous features, without -inf,inf") {
+ // Check a set of valid feature values.
+ val splits = Array(Array(-0.5, 0.0, 0.5), Array(-0.1, 0.3, 0.5))
+ val validData1 = Array(-0.5, -0.3, 0.0, 0.2)
+ val validData2 = Array(0.5, 0.3, 0.0, -0.1)
+ val expectedBuckets1 = Array(0.0, 0.0, 1.0, 1.0)
+ val expectedBuckets2 = Array(1.0, 1.0, 0.0, 0.0)
+
+ val data = (0 until validData1.length).map { idx =>
+ (validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx))
+ }
+ val dataFrame: DataFrame = data.toDF("feature1", "feature2", "expected1", "expected2")
+
+ val bucketizer1: Bucketizer = new Bucketizer()
+ .setInputCols(Array("feature1", "feature2"))
+ .setOutputCols(Array("result1", "result2"))
+ .setSplitsArray(splits)
+
+ assert(bucketizer1.isBucketizeMultipleColumns())
+
+ bucketizer1.transform(dataFrame).select("result1", "expected1", "result2", "expected2")
+ BucketizerSuite.checkBucketResults(bucketizer1.transform(dataFrame),
+ Seq("result1", "result2"),
+ Seq("expected1", "expected2"))
+
+ // Check for exceptions when using a set of invalid feature values.
+ val invalidData1 = Array(-0.9) ++ validData1
+ val invalidData2 = Array(0.51) ++ validData1
+ val badDF1 = invalidData1.zipWithIndex.toSeq.toDF("feature", "idx")
+
+ val bucketizer2: Bucketizer = new Bucketizer()
+ .setInputCols(Array("feature"))
+ .setOutputCols(Array("result"))
+ .setSplitsArray(Array(splits(0)))
+
+ assert(bucketizer2.isBucketizeMultipleColumns())
+
+ withClue("Invalid feature value -0.9 was not caught as an invalid feature!") {
+ intercept[SparkException] {
+ bucketizer2.transform(badDF1).collect()
+ }
+ }
+ val badDF2 = invalidData2.zipWithIndex.toSeq.toDF("feature", "idx")
+ withClue("Invalid feature value 0.51 was not caught as an invalid feature!") {
+ intercept[SparkException] {
+ bucketizer2.transform(badDF2).collect()
+ }
+ }
+ }
+
+ test("multiple columns: Bucket continuous features, with -inf,inf") {
+ val splits = Array(
+ Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity),
+ Array(Double.NegativeInfinity, -0.3, 0.2, 0.5, Double.PositiveInfinity))
+
+ val validData1 = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9)
+ val validData2 = Array(-0.1, -0.5, -0.2, 0.0, 0.1, 0.3, 0.5)
+ val expectedBuckets1 = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0)
+ val expectedBuckets2 = Array(1.0, 0.0, 1.0, 1.0, 1.0, 2.0, 3.0)
+
+ val data = (0 until validData1.length).map { idx =>
+ (validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx))
+ }
+ val dataFrame: DataFrame = data.toDF("feature1", "feature2", "expected1", "expected2")
+
+ val bucketizer: Bucketizer = new Bucketizer()
+ .setInputCols(Array("feature1", "feature2"))
+ .setOutputCols(Array("result1", "result2"))
+ .setSplitsArray(splits)
+
+ assert(bucketizer.isBucketizeMultipleColumns())
+
+ BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame),
+ Seq("result1", "result2"),
+ Seq("expected1", "expected2"))
+ }
+
+ test("multiple columns: Bucket continuous features, with NaN data but non-NaN splits") {
+ val splits = Array(
+ Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity),
+ Array(Double.NegativeInfinity, -0.1, 0.2, 0.6, Double.PositiveInfinity))
+
+ val validData1 = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN)
+ val validData2 = Array(0.2, -0.1, 0.3, 0.0, 0.1, 0.3, 0.5, 0.8, Double.NaN, Double.NaN)
+ val expectedBuckets1 = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 4.0)
+ val expectedBuckets2 = Array(2.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 3.0, 4.0, 4.0)
+
+ val data = (0 until validData1.length).map { idx =>
+ (validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx))
+ }
+ val dataFrame: DataFrame = data.toDF("feature1", "feature2", "expected1", "expected2")
+
+ val bucketizer: Bucketizer = new Bucketizer()
+ .setInputCols(Array("feature1", "feature2"))
+ .setOutputCols(Array("result1", "result2"))
+ .setSplitsArray(splits)
+
+ assert(bucketizer.isBucketizeMultipleColumns())
+
+ bucketizer.setHandleInvalid("keep")
+ BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame),
+ Seq("result1", "result2"),
+ Seq("expected1", "expected2"))
+
+ bucketizer.setHandleInvalid("skip")
+ val skipResults1: Array[Double] = bucketizer.transform(dataFrame)
+ .select("result1").as[Double].collect()
+ assert(skipResults1.length === 7)
+ assert(skipResults1.forall(_ !== 4.0))
+
+ val skipResults2: Array[Double] = bucketizer.transform(dataFrame)
+ .select("result2").as[Double].collect()
+ assert(skipResults2.length === 7)
+ assert(skipResults2.forall(_ !== 4.0))
+
+ bucketizer.setHandleInvalid("error")
+ withClue("Bucketizer should throw error when setHandleInvalid=error and given NaN values") {
+ intercept[SparkException] {
+ bucketizer.transform(dataFrame).collect()
+ }
+ }
+ }
+
+ test("multiple columns: Bucket continuous features, with NaN splits") {
+ val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN)
+ withClue("Invalid NaN split was not caught during Bucketizer initialization") {
+ intercept[IllegalArgumentException] {
+ new Bucketizer().setSplitsArray(Array(splits))
+ }
+ }
+ }
+
+ test("multiple columns: read/write") {
+ val t = new Bucketizer()
+ .setInputCols(Array("myInputCol"))
+ .setOutputCols(Array("myOutputCol"))
+ .setSplitsArray(Array(Array(0.1, 0.8, 0.9)))
+ assert(t.isBucketizeMultipleColumns())
+ testDefaultReadWrite(t)
+ }
+
+ test("Bucketizer in a pipeline") {
+ val df = Seq((0.5, 0.3, 1.0, 1.0), (0.5, -0.4, 1.0, 0.0))
+ .toDF("feature1", "feature2", "expected1", "expected2")
+
+ val bucket = new Bucketizer()
+ .setInputCols(Array("feature1", "feature2"))
+ .setOutputCols(Array("result1", "result2"))
+ .setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5)))
+
+ assert(bucket.isBucketizeMultipleColumns())
+
+ val pl = new Pipeline()
+ .setStages(Array(bucket))
+ .fit(df)
+ pl.transform(df).select("result1", "expected1", "result2", "expected2")
+
+ BucketizerSuite.checkBucketResults(pl.transform(df),
+ Seq("result1", "result2"), Seq("expected1", "expected2"))
+ }
+
+ test("Compare single/multiple column(s) Bucketizer in pipeline") {
+ val df = Seq((0.5, 0.3, 1.0, 1.0), (0.5, -0.4, 1.0, 0.0))
+ .toDF("feature1", "feature2", "expected1", "expected2")
+
+ val multiColsBucket = new Bucketizer()
+ .setInputCols(Array("feature1", "feature2"))
+ .setOutputCols(Array("result1", "result2"))
+ .setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5)))
+
+ val plForMultiCols = new Pipeline()
+ .setStages(Array(multiColsBucket))
+ .fit(df)
+
+ val bucketForCol1 = new Bucketizer()
+ .setInputCol("feature1")
+ .setOutputCol("result1")
+ .setSplits(Array(-0.5, 0.0, 0.5))
+ val bucketForCol2 = new Bucketizer()
+ .setInputCol("feature2")
+ .setOutputCol("result2")
+ .setSplits(Array(-0.5, 0.0, 0.5))
+
+ val plForSingleCol = new Pipeline()
+ .setStages(Array(bucketForCol1, bucketForCol2))
+ .fit(df)
+
+ val resultForSingleCol = plForSingleCol.transform(df)
+ .select("result1", "expected1", "result2", "expected2")
+ .collect()
+ val resultForMultiCols = plForMultiCols.transform(df)
+ .select("result1", "expected1", "result2", "expected2")
+ .collect()
+
+ resultForSingleCol.zip(resultForMultiCols).foreach {
+ case (rowForSingle, rowForMultiCols) =>
+ assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) &&
+ rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) &&
+ rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2) &&
+ rowForSingle.getDouble(3) == rowForMultiCols.getDouble(3))
+ }
+ }
+
+ test("Both inputCol and inputCols are set") {
+ val bucket = new Bucketizer()
+ .setInputCol("feature1")
+ .setOutputCol("result")
+ .setSplits(Array(-0.5, 0.0, 0.5))
+ .setInputCols(Array("feature1", "feature2"))
+
+ // When both are set, we ignore `inputCols` and just map the column specified by `inputCol`.
+ assert(bucket.isBucketizeMultipleColumns() == false)
+ }
}
private object BucketizerSuite extends SparkFunSuite {
@@ -220,4 +435,26 @@ private object BucketizerSuite extends SparkFunSuite {
i += 1
}
}
+
+ /** Checks if bucketized results match expected ones. */
+ def checkBucketResults(
+ bucketResult: DataFrame,
+ resultColumns: Seq[String],
+ expectedColumns: Seq[String]): Unit = {
+ assert(resultColumns.length == expectedColumns.length,
+ s"Given ${resultColumns.length} result columns doesn't match " +
+ s"${expectedColumns.length} expected columns.")
+ assert(resultColumns.length > 0, "At least one result and expected columns are needed.")
+
+ val allColumns = resultColumns ++ expectedColumns
+ bucketResult.select(allColumns.head, allColumns.tail: _*).collect().foreach {
+ case row =>
+ for (idx <- 0 until row.length / 2) {
+ val result = row.getDouble(idx)
+ val expected = row.getDouble(idx + row.length / 2)
+ assert(result === expected, "The feature value is not correct after bucketing. " +
+ s"Expected $expected but found $result.")
+ }
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 78a33e05e0e4..85198ad4c913 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -121,10 +121,10 @@ class ParamsSuite extends SparkFunSuite {
{ // DoubleArrayParam
val param = new DoubleArrayParam(dummy, "name", "doc")
val values: Seq[Array[Double]] = Seq(
- Array(),
- Array(1.0),
- Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0,
- Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity))
+ Array(),
+ Array(1.0),
+ Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0,
+ Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity))
for (value <- values) {
val json = param.jsonEncode(value)
val decoded = param.jsonDecode(json)
@@ -139,6 +139,36 @@ class ParamsSuite extends SparkFunSuite {
}
}
+ { // DoubleArrayArrayParam
+ val param = new DoubleArrayArrayParam(dummy, "name", "doc")
+ val values: Seq[Array[Array[Double]]] = Seq(
+ Array(Array()),
+ Array(Array(1.0)),
+ Array(Array(1.0), Array(2.0)),
+ Array(
+ Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0,
+ Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity),
+ Array(Double.MaxValue, Double.PositiveInfinity, Double.MinPositiveValue, 1.0,
+ Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0)
+ ))
+
+ for (value <- values) {
+ val json = param.jsonEncode(value)
+ val decoded = param.jsonDecode(json)
+ assert(decoded.length === value.length)
+ decoded.zip(value).foreach { case (actualArray, expectedArray) =>
+ assert(actualArray.length === expectedArray.length)
+ actualArray.zip(expectedArray).foreach { case (actual, expected) =>
+ if (expected.isNaN) {
+ assert(actual.isNaN)
+ } else {
+ assert(actual === expected)
+ }
+ }
+ }
+ }
+ }
+
{ // StringArrayParam
val param = new StringArrayParam(dummy, "name", "doc")
val values: Seq[Array[String]] = Seq(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index bd99ec52ce93..5eb2affa0bd8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2135,12 +2135,27 @@ class Dataset[T] private[sql](
}
/**
- * Returns a new Dataset by adding a column with metadata.
+ * Returns a new Dataset by adding columns with metadata.
*/
- private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = {
- withColumn(colName, col.as(colName, metadata))
+ private[spark] def withColumns(
+ colNames: Seq[String],
+ cols: Seq[Column],
+ metadata: Seq[Metadata]): DataFrame = {
+ require(colNames.size == metadata.size,
+ s"The size of column names: ${colNames.size} isn't equal to " +
+ s"the size of metadata elements: ${metadata.size}")
+ val newCols = colNames.zip(cols).zip(metadata).map { case ((colName, col), metadata) =>
+ col.as(colName, metadata)
+ }
+ withColumns(colNames, newCols)
}
+ /**
+ * Returns a new Dataset by adding a column with metadata.
+ */
+ private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame =
+ withColumns(Seq(colName), Seq(col), Seq(metadata))
+
/**
* Returns a new Dataset with a column renamed.
* This is a no-op if schema doesn't contain existingName.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 17c88b069080..31bfa77e7632 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -686,6 +686,34 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
}
+ test("withColumns: given metadata") {
+ def buildMetadata(num: Int): Seq[Metadata] = {
+ (0 until num).map { n =>
+ val builder = new MetadataBuilder
+ builder.putLong("key", n.toLong)
+ builder.build()
+ }
+ }
+
+ val df = testData.toDF().withColumns(
+ Seq("newCol1", "newCol2"),
+ Seq(col("key") + 1, col("key") + 2),
+ buildMetadata(2))
+
+ df.select("newCol1", "newCol2").schema.zipWithIndex.foreach { case (col, idx) =>
+ assert(col.metadata.getLong("key").toInt === idx)
+ }
+
+ val err = intercept[IllegalArgumentException] {
+ testData.toDF().withColumns(
+ Seq("newCol1", "newCol2"),
+ Seq(col("key") + 1, col("key") + 2),
+ buildMetadata(1))
+ }
+ assert(err.getMessage.contains(
+ "The size of column names: 2 isn't equal to the size of metadata elements: 1"))
+ }
+
test("replace column using withColumn") {
val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
val df3 = df2.withColumn("x", df2("x") + 1)