Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
import org.apache.spark.sql.types.StructType;
// $example off$

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No Scala example?

Copy link
Member Author

@viirya viirya Oct 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a Scala example.

/**
* An example for Bucketizer.
* Run with
* <pre>
* bin/run-example ml.JavaBucketizerExample
* </pre>
*/
public class JavaBucketizerExample {
public static void main(String[] args) {
SparkSession spark = SparkSession
Expand Down Expand Up @@ -68,6 +75,40 @@ public static void main(String[] args) {
bucketedData.show();
// $example off$

// $example on$
// Bucketize multiple columns at one pass.
double[][] splitsArray = {
{Double.NEGATIVE_INFINITY, -0.5, 0.0, 0.5, Double.POSITIVE_INFINITY},
{Double.NEGATIVE_INFINITY, -0.3, 0.0, 0.3, Double.POSITIVE_INFINITY}
};

List<Row> data2 = Arrays.asList(
RowFactory.create(-999.9, -999.9),
RowFactory.create(-0.5, -0.2),
RowFactory.create(-0.3, -0.1),
RowFactory.create(0.0, 0.0),
RowFactory.create(0.2, 0.4),
RowFactory.create(999.9, 999.9)
);
StructType schema2 = new StructType(new StructField[]{
new StructField("features1", DataTypes.DoubleType, false, Metadata.empty()),
new StructField("features2", DataTypes.DoubleType, false, Metadata.empty())
});
Dataset<Row> dataFrame2 = spark.createDataFrame(data2, schema2);

Bucketizer bucketizer2 = new Bucketizer()
.setInputCols(new String[] {"features1", "features2"})
.setOutputCols(new String[] {"bucketedFeatures1", "bucketedFeatures2"})
.setSplitsArray(splitsArray);
// Transform original data into its bucket index.
Dataset<Row> bucketedData2 = bucketizer2.transform(dataFrame2);

System.out.println("Bucketizer output with [" +
(bucketizer2.getSplitsArray()[0].length-1) + ", " +
(bucketizer2.getSplitsArray()[1].length-1) + "] buckets for each input column");
bucketedData2.show();
// $example off$

spark.stop();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@ package org.apache.spark.examples.ml
import org.apache.spark.ml.feature.Bucketizer
// $example off$
import org.apache.spark.sql.SparkSession

/**
* An example for Bucketizer.
* Run with
* {{{
* bin/run-example ml.BucketizerExample
* }}}
*/
object BucketizerExample {
def main(args: Array[String]): Unit = {
val spark = SparkSession
Expand All @@ -48,6 +54,34 @@ object BucketizerExample {
bucketedData.show()
// $example off$

// $example on$
val splitsArray = Array(
Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity),
Array(Double.NegativeInfinity, -0.3, 0.0, 0.3, Double.PositiveInfinity))

val data2 = Array(
(-999.9, -999.9),
(-0.5, -0.2),
(-0.3, -0.1),
(0.0, 0.0),
(0.2, 0.4),
(999.9, 999.9))
val dataFrame2 = spark.createDataFrame(data2).toDF("features1", "features2")

val bucketizer2 = new Bucketizer()
.setInputCols(Array("features1", "features2"))
.setOutputCols(Array("bucketedFeatures1", "bucketedFeatures2"))
.setSplitsArray(splitsArray)

// Transform original data into its bucket index.
val bucketedData2 = bucketizer2.transform(dataFrame2)

println(s"Bucketizer output with [" +
s"${bucketizer2.getSplitsArray(0).length-1}, " +
s"${bucketizer2.getSplitsArray(1).length-1}] buckets for each input column")
bucketedData2.show()
// $example off$

spark.stop()
}
}
Expand Down
122 changes: 106 additions & 16 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,24 @@ import org.apache.spark.annotation.Since
import org.apache.spark.ml.Model
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

/**
* `Bucketizer` maps a column of continuous features to a column of feature buckets.
* `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0,
* `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that
* when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and
* only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is
* only used for single column usage, and `splitsArray` is for multiple columns.
*/
@Since("1.4.0")
final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Model[Bucketizer] with HasHandleInvalid with HasInputCol with HasOutputCol
with DefaultParamsWritable {
with HasInputCols with HasOutputCols with DefaultParamsWritable {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("bucketizer"))
Expand Down Expand Up @@ -81,7 +85,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
/**
* Param for how to handle invalid entries. Options are 'skip' (filter out rows with
* invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special
* additional bucket).
* additional bucket). Note that in the multiple column case, the invalid handling is applied
* to all columns. That said for 'error' it will throw an error if any invalids are found in
* any column, for 'skip' it will skip rows with any invalids in any columns, etc.
* Default: "error"
* @group param
*/
Expand All @@ -96,9 +102,59 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make it clear that in the multi column case, the invalid handling is applied to all columns (so for error it will throw the error if any invalids are found in any column, for skip it will skip rows with any invalids in any column, etc)

setDefault(handleInvalid, Bucketizer.ERROR_INVALID)

/**
* Parameter for specifying multiple splits parameters. Each element in this array can be used to
* map continuous features into buckets.
*
* @group param
*/
@Since("2.3.0")
val splitsArray: DoubleArrayArrayParam = new DoubleArrayArrayParam(this, "splitsArray",
"The array of split points for mapping continuous features into buckets for multiple " +
"columns. For each input column, with n+1 splits, there are n buckets. A bucket defined by " +
"splits x,y holds values in the range [x,y) except the last bucket, which also includes y. " +
"The splits should be of length >= 3 and strictly increasing. Values at -inf, inf must be " +
"explicitly provided to cover all Double values; otherwise, values outside the splits " +
"specified will be treated as errors.",
Bucketizer.checkSplitsArray)

/** @group getParam */
@Since("2.3.0")
def getSplitsArray: Array[Array[Double]] = $(splitsArray)

/** @group setParam */
@Since("2.3.0")
def setSplitsArray(value: Array[Array[Double]]): this.type = set(splitsArray, value)

/** @group setParam */
@Since("2.3.0")
def setInputCols(value: Array[String]): this.type = set(inputCols, value)

/** @group setParam */
@Since("2.3.0")
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

/**
* Determines whether this `Bucketizer` is going to map multiple columns. If and only if
* `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified
* by `inputCol`. A warning will be printed if both are set.
*/
private[feature] def isBucketizeMultipleColumns(): Boolean = {
if (isSet(inputCols) && isSet(inputCol)) {
logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " +
"`Bucketizer` only map one column specified by `inputCol`")
false
} else if (isSet(inputCols)) {
true
} else {
false
}
}

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
val transformedSchema = transformSchema(dataset.schema)

val (filteredDataset, keepInvalid) = {
if (getHandleInvalid == Bucketizer.SKIP_INVALID) {
// "skip" NaN option is set, will filter out NaN values in the dataset
Expand All @@ -108,26 +164,53 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
}
}

val bucketizer: UserDefinedFunction = udf { (feature: Double) =>
Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid)
}.withName("bucketizer")
val seqOfSplits = if (isBucketizeMultipleColumns()) {
$(splitsArray).toSeq
Copy link
Contributor

@tengpeng tengpeng Nov 24, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am interested in the difference between .toSeq and Seq().

} else {
Seq($(splits))
}

val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType))
val newField = prepOutputField(filteredDataset.schema)
filteredDataset.withColumn($(outputCol), newCol, newField.metadata)
val bucketizers: Seq[UserDefinedFunction] = seqOfSplits.zipWithIndex.map { case (splits, idx) =>
udf { (feature: Double) =>
Bucketizer.binarySearchForBuckets(splits, feature, keepInvalid)
}.withName(s"bucketizer_$idx")
}

val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) {
($(inputCols).toSeq, $(outputCols).toSeq)
} else {
(Seq($(inputCol)), Seq($(outputCol)))
}
val newCols = inputColumns.zipWithIndex.map { case (inputCol, idx) =>
bucketizers(idx)(filteredDataset(inputCol).cast(DoubleType))
}
val metadata = outputColumns.map { col =>
transformedSchema(col).metadata
}
filteredDataset.withColumns(outputColumns, newCols, metadata)
}

private def prepOutputField(schema: StructType): StructField = {
val buckets = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray
val attr = new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true),
private def prepOutputField(splits: Array[Double], outputCol: String): StructField = {
val buckets = splits.sliding(2).map(bucket => bucket.mkString(", ")).toArray
val attr = new NominalAttribute(name = Some(outputCol), isOrdinal = Some(true),
values = Some(buckets))
attr.toStructField()
}

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkNumericType(schema, $(inputCol))
SchemaUtils.appendColumn(schema, prepOutputField(schema))
if (isBucketizeMultipleColumns()) {
var transformedSchema = schema
$(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) =>
SchemaUtils.checkNumericType(transformedSchema, inputCol)
transformedSchema = SchemaUtils.appendColumn(transformedSchema,
prepOutputField($(splitsArray)(idx), outputCol))
}
transformedSchema
} else {
SchemaUtils.checkNumericType(schema, $(inputCol))
SchemaUtils.appendColumn(schema, prepOutputField($(splits), $(outputCol)))
}
}

@Since("1.4.1")
Expand Down Expand Up @@ -163,6 +246,13 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
}
}

/**
* Check each splits in the splits array.
*/
private[feature] def checkSplitsArray(splitsArray: Array[Array[Double]]): Boolean = {
splitsArray.forall(checkSplits(_))
}

/**
* Binary searching in several buckets to place each data point.
* @param splits array of split points
Expand Down
39 changes: 39 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,45 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
}
}

/**
* :: DeveloperApi ::
* Specialized version of `Param[Array[Array[Double]]]` for Java.
*/
@DeveloperApi
class DoubleArrayArrayParam(
parent: Params,
name: String,
doc: String,
isValid: Array[Array[Double]] => Boolean)
extends Param[Array[Array[Double]]](parent, name, doc, isValid) {

def this(parent: Params, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)

/** Creates a param pair with a `java.util.List` of values (for Java and Python). */
def w(value: java.util.List[java.util.List[java.lang.Double]]): ParamPair[Array[Array[Double]]] =
w(value.asScala.map(_.asScala.map(_.asInstanceOf[Double]).toArray).toArray)

override def jsonEncode(value: Array[Array[Double]]): String = {
import org.json4s.JsonDSL._
compact(render(value.toSeq.map(_.toSeq.map(DoubleParam.jValueEncode))))
}

override def jsonDecode(json: String): Array[Array[Double]] = {
parse(json) match {
case JArray(values) =>
values.map {
case JArray(values) =>
values.map(DoubleParam.jValueDecode).toArray
case _ =>
throw new IllegalArgumentException(s"Cannot decode $json to Array[Array[Double]].")
}.toArray
case _ =>
throw new IllegalArgumentException(s"Cannot decode $json to Array[Array[Double]].")
}
}
}

/**
* :: DeveloperApi ::
* Specialized version of `Param[Array[Int]]` for Java.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[String]("inputCol", "input column name"),
ParamDesc[Array[String]]("inputCols", "input column names"),
ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
ParamDesc[Array[String]]("outputCols", "output column names"),
ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " +
"disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " +
"every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,23 @@ trait HasOutputCol extends Params {
final def getOutputCol: String = $(outputCol)
}

/**
* Trait for shared param outputCols. This trait may be changed or
* removed between minor versions.
*/
@DeveloperApi
trait HasOutputCols extends Params {

/**
* Param for output column names.
* @group param
*/
final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", "output column names")

/** @group getParam */
final def getOutputCols: Array[String] = $(outputCols)
}

/**
* Trait for shared param checkpointInterval. This trait may be changed or
* removed between minor versions.
Expand Down
Loading