-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-20542][ML][SQL] Add an API to Bucketizer that can bin multiple columns #17819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e8f5d89
38dce8b
6ff9c79
8386d1e
f8dedd1
7c38b77
92ef9bd
60d3ba1
f70fc2a
2abca6b
000844a
1889995
bb19708
a970723
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,20 +24,24 @@ import org.apache.spark.annotation.Since | |
| import org.apache.spark.ml.Model | ||
| import org.apache.spark.ml.attribute.NominalAttribute | ||
| import org.apache.spark.ml.param._ | ||
| import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasOutputCol} | ||
| import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} | ||
| import org.apache.spark.ml.util._ | ||
| import org.apache.spark.sql._ | ||
| import org.apache.spark.sql.expressions.UserDefinedFunction | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.types.{DoubleType, StructField, StructType} | ||
|
|
||
| /** | ||
| * `Bucketizer` maps a column of continuous features to a column of feature buckets. | ||
| * `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0, | ||
| * `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that | ||
| * when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and | ||
| * only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is | ||
| * only used for single column usage, and `splitsArray` is for multiple columns. | ||
| */ | ||
| @Since("1.4.0") | ||
| final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) | ||
| extends Model[Bucketizer] with HasHandleInvalid with HasInputCol with HasOutputCol | ||
| with DefaultParamsWritable { | ||
| with HasInputCols with HasOutputCols with DefaultParamsWritable { | ||
|
|
||
| @Since("1.4.0") | ||
| def this() = this(Identifiable.randomUID("bucketizer")) | ||
|
|
@@ -81,7 +85,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String | |
| /** | ||
| * Param for how to handle invalid entries. Options are 'skip' (filter out rows with | ||
| * invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special | ||
| * additional bucket). | ||
| * additional bucket). Note that in the multiple column case, the invalid handling is applied | ||
| * to all columns. That said for 'error' it will throw an error if any invalids are found in | ||
| * any column, for 'skip' it will skip rows with any invalids in any columns, etc. | ||
| * Default: "error" | ||
| * @group param | ||
| */ | ||
|
|
@@ -96,9 +102,59 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String | |
| def setHandleInvalid(value: String): this.type = set(handleInvalid, value) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should make it clear that in the multi column case, the invalid handling is applied to all columns (so for |
||
| setDefault(handleInvalid, Bucketizer.ERROR_INVALID) | ||
|
|
||
| /** | ||
| * Parameter for specifying multiple splits parameters. Each element in this array can be used to | ||
| * map continuous features into buckets. | ||
| * | ||
| * @group param | ||
| */ | ||
| @Since("2.3.0") | ||
| val splitsArray: DoubleArrayArrayParam = new DoubleArrayArrayParam(this, "splitsArray", | ||
| "The array of split points for mapping continuous features into buckets for multiple " + | ||
| "columns. For each input column, with n+1 splits, there are n buckets. A bucket defined by " + | ||
| "splits x,y holds values in the range [x,y) except the last bucket, which also includes y. " + | ||
| "The splits should be of length >= 3 and strictly increasing. Values at -inf, inf must be " + | ||
| "explicitly provided to cover all Double values; otherwise, values outside the splits " + | ||
| "specified will be treated as errors.", | ||
| Bucketizer.checkSplitsArray) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.3.0") | ||
| def getSplitsArray: Array[Array[Double]] = $(splitsArray) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setSplitsArray(value: Array[Array[Double]]): this.type = set(splitsArray, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setInputCols(value: Array[String]): this.type = set(inputCols, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.3.0") | ||
| def setOutputCols(value: Array[String]): this.type = set(outputCols, value) | ||
|
|
||
| /** | ||
| * Determines whether this `Bucketizer` is going to map multiple columns. If and only if | ||
| * `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified | ||
| * by `inputCol`. A warning will be printed if both are set. | ||
| */ | ||
| private[feature] def isBucketizeMultipleColumns(): Boolean = { | ||
| if (isSet(inputCols) && isSet(inputCol)) { | ||
| logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " + | ||
| "`Bucketizer` only map one column specified by `inputCol`") | ||
| false | ||
| } else if (isSet(inputCols)) { | ||
| true | ||
| } else { | ||
| false | ||
| } | ||
| } | ||
|
|
||
| @Since("2.0.0") | ||
| override def transform(dataset: Dataset[_]): DataFrame = { | ||
| transformSchema(dataset.schema) | ||
| val transformedSchema = transformSchema(dataset.schema) | ||
|
|
||
| val (filteredDataset, keepInvalid) = { | ||
| if (getHandleInvalid == Bucketizer.SKIP_INVALID) { | ||
| // "skip" NaN option is set, will filter out NaN values in the dataset | ||
|
|
@@ -108,26 +164,53 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String | |
| } | ||
| } | ||
|
|
||
| val bucketizer: UserDefinedFunction = udf { (feature: Double) => | ||
| Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) | ||
| }.withName("bucketizer") | ||
| val seqOfSplits = if (isBucketizeMultipleColumns()) { | ||
| $(splitsArray).toSeq | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am interested in the difference between |
||
| } else { | ||
| Seq($(splits)) | ||
| } | ||
|
|
||
| val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType)) | ||
| val newField = prepOutputField(filteredDataset.schema) | ||
| filteredDataset.withColumn($(outputCol), newCol, newField.metadata) | ||
| val bucketizers: Seq[UserDefinedFunction] = seqOfSplits.zipWithIndex.map { case (splits, idx) => | ||
| udf { (feature: Double) => | ||
| Bucketizer.binarySearchForBuckets(splits, feature, keepInvalid) | ||
| }.withName(s"bucketizer_$idx") | ||
| } | ||
|
|
||
| val (inputColumns, outputColumns) = if (isBucketizeMultipleColumns()) { | ||
| ($(inputCols).toSeq, $(outputCols).toSeq) | ||
| } else { | ||
| (Seq($(inputCol)), Seq($(outputCol))) | ||
| } | ||
| val newCols = inputColumns.zipWithIndex.map { case (inputCol, idx) => | ||
| bucketizers(idx)(filteredDataset(inputCol).cast(DoubleType)) | ||
| } | ||
| val metadata = outputColumns.map { col => | ||
| transformedSchema(col).metadata | ||
| } | ||
| filteredDataset.withColumns(outputColumns, newCols, metadata) | ||
| } | ||
|
|
||
| private def prepOutputField(schema: StructType): StructField = { | ||
| val buckets = $(splits).sliding(2).map(bucket => bucket.mkString(", ")).toArray | ||
| val attr = new NominalAttribute(name = Some($(outputCol)), isOrdinal = Some(true), | ||
| private def prepOutputField(splits: Array[Double], outputCol: String): StructField = { | ||
| val buckets = splits.sliding(2).map(bucket => bucket.mkString(", ")).toArray | ||
| val attr = new NominalAttribute(name = Some(outputCol), isOrdinal = Some(true), | ||
| values = Some(buckets)) | ||
| attr.toStructField() | ||
| } | ||
|
|
||
| @Since("1.4.0") | ||
| override def transformSchema(schema: StructType): StructType = { | ||
| SchemaUtils.checkNumericType(schema, $(inputCol)) | ||
| SchemaUtils.appendColumn(schema, prepOutputField(schema)) | ||
| if (isBucketizeMultipleColumns()) { | ||
| var transformedSchema = schema | ||
| $(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) => | ||
| SchemaUtils.checkNumericType(transformedSchema, inputCol) | ||
| transformedSchema = SchemaUtils.appendColumn(transformedSchema, | ||
| prepOutputField($(splitsArray)(idx), outputCol)) | ||
| } | ||
| transformedSchema | ||
| } else { | ||
| SchemaUtils.checkNumericType(schema, $(inputCol)) | ||
| SchemaUtils.appendColumn(schema, prepOutputField($(splits), $(outputCol))) | ||
| } | ||
| } | ||
|
|
||
| @Since("1.4.1") | ||
|
|
@@ -163,6 +246,13 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Check each splits in the splits array. | ||
| */ | ||
| private[feature] def checkSplitsArray(splitsArray: Array[Array[Double]]): Boolean = { | ||
| splitsArray.forall(checkSplits(_)) | ||
| } | ||
|
|
||
| /** | ||
| * Binary searching in several buckets to place each data point. | ||
| * @param splits array of split points | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No Scala example?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a Scala example.