Skip to content

Commit e8f5d89

Browse files
committed
Add a Bucketizer that can bin multiple columns.
1 parent 2eaf4f3 commit e8f5d89

File tree

6 files changed

+291
-6
lines changed

6 files changed

+291
-6
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.annotation.Since
2424
import org.apache.spark.ml.Model
2525
import org.apache.spark.ml.attribute.NominalAttribute
2626
import org.apache.spark.ml.param._
27-
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
27+
import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol}
2828
import org.apache.spark.ml.util._
2929
import org.apache.spark.sql._
3030
import org.apache.spark.sql.expressions.UserDefinedFunction
@@ -140,6 +140,139 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
140140
}
141141
}
142142

143+
/**
144+
* `MultipleBucketizer` maps columns of continuous features to columns of feature buckets.
145+
*/
146+
@Since("2.3.0")
147+
final class MultipleBucketizer @Since("2.3.0") (@Since("2.3.0") override val uid: String)
148+
extends Model[MultipleBucketizer] with HasInputCols with DefaultParamsWritable {
149+
150+
@Since("2.3.0")
151+
def this() = this(Identifiable.randomUID("multipleBucketizer"))
152+
153+
/**
154+
* Parameter for mapping continuous features into buckets. With n+1 splits, there are n buckets.
155+
* A bucket defined by splits x,y holds values in the range [x,y) except the last bucket, which
156+
* also includes y. Splits should be of length greater than or equal to 3 and strictly increasing.
157+
* Values at -inf, inf must be explicitly provided to cover all Double values;
158+
* otherwise, values outside the splits specified will be treated as errors.
159+
*
160+
* See also [[handleInvalid]], which can optionally create an additional bucket for NaN values.
161+
*
162+
* @group param
163+
*/
164+
@Since("2.3.0")
165+
val splitsArray: DoubleArrayArrayParam = new DoubleArrayArrayParam(this, "splitsArray",
166+
"The array of split points for mapping continuous features into buckets for multiple " +
167+
"columns. For each input column, with n+1 splits, there are n buckets. A bucket defined by " +
168+
"splits x,y holds values in the range [x,y) except the last bucket, which also includes y. " +
169+
"The splits should be of length >= 3 and strictly increasing. Values at -inf, inf must be " +
170+
"explicitly provided to cover all Double values; otherwise, values outside the splits " +
171+
"specified will be treated as errors.",
172+
Bucketizer.checkSplitsArray)
173+
174+
/**
175+
* Param for output column names.
176+
* @group param
177+
*/
178+
@Since("2.3.0")
179+
final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols",
180+
"output column names")
181+
182+
/** @group getParam */
183+
@Since("2.3.0")
184+
def getSplitsArray: Array[Array[Double]] = $(splitsArray)
185+
186+
/** @group getParam */
187+
@Since("2.3.0")
188+
final def getOutputCols: Array[String] = $(outputCols)
189+
190+
/** @group setParam */
191+
@Since("2.3.0")
192+
def setSplitsArray(value: Array[Array[Double]]): this.type = set(splitsArray, value)
193+
194+
/** @group setParam */
195+
@Since("2.3.0")
196+
def setInputCols(value: Array[String]): this.type = set(inputCols, value)
197+
198+
/** @group setParam */
199+
@Since("2.3.0")
200+
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
201+
202+
/**
203+
* Param for how to handle invalid entries. Options are 'skip' (filter out rows with
204+
* invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special
205+
* additional bucket).
206+
* Default: "error"
207+
* @group param
208+
*/
209+
// TODO: Make MultipleBucketizer inherit from HasHandleInvalid.
210+
@Since("2.3.0")
211+
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
212+
"invalid entries. Options are skip (filter out rows with invalid values), " +
213+
"error (throw an error), or keep (keep invalid values in a special additional bucket).",
214+
ParamValidators.inArray(Bucketizer.supportedHandleInvalids))
215+
216+
/** @group getParam */
217+
@Since("2.3.0")
218+
def getHandleInvalid: String = $(handleInvalid)
219+
220+
/** @group setParam */
221+
@Since("2.3.0")
222+
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
223+
setDefault(handleInvalid, Bucketizer.ERROR_INVALID)
224+
225+
@Since("2.3.0")
226+
override def transform(dataset: Dataset[_]): DataFrame = {
227+
transformSchema(dataset.schema)
228+
val (filteredDataset, keepInvalid) = {
229+
if (getHandleInvalid == Bucketizer.SKIP_INVALID) {
230+
// "skip" NaN option is set, will filter out NaN values in the dataset
231+
(dataset.na.drop().toDF(), false)
232+
} else {
233+
(dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID)
234+
}
235+
}
236+
237+
val bucketizers: Seq[UserDefinedFunction] = $(splitsArray).map { splits =>
238+
udf { (feature: Double) =>
239+
Bucketizer.binarySearchForBuckets(splits, feature, keepInvalid)
240+
}
241+
}
242+
243+
val newCols = $(inputCols).zipWithIndex.map { case (inputCol, idx) =>
244+
bucketizers(idx)(filteredDataset(inputCol))
245+
}
246+
val newFields = $(outputCols).zipWithIndex.map { case (outputCol, idx) =>
247+
prepOutputField(idx, outputCol)
248+
}
249+
filteredDataset.withColumns($(outputCols), newCols, newFields.map(_.metadata))
250+
}
251+
252+
private def prepOutputField(idx: Int, outputCol: String): StructField = {
253+
val buckets = $(splitsArray)(idx).sliding(2).map(bucket => bucket.mkString(", ")).toArray
254+
val attr = new NominalAttribute(name = Some(outputCol), isOrdinal = Some(true),
255+
values = Some(buckets))
256+
attr.toStructField()
257+
}
258+
259+
@Since("2.3.0")
260+
override def transformSchema(schema: StructType): StructType = {
261+
var transformedSchema = schema
262+
$(inputCols).zip($(outputCols)).zipWithIndex.map { case ((inputCol, outputCol), idx) =>
263+
SchemaUtils.checkColumnType(transformedSchema, inputCol, DoubleType)
264+
transformedSchema = SchemaUtils.appendColumn(transformedSchema,
265+
prepOutputField(idx, outputCol))
266+
}
267+
transformedSchema
268+
}
269+
270+
@Since("2.3.0")
271+
override def copy(extra: ParamMap): MultipleBucketizer = {
272+
defaultCopy[MultipleBucketizer](extra).setParent(parent)
273+
}
274+
}
275+
143276
@Since("1.6.0")
144277
object Bucketizer extends DefaultParamsReadable[Bucketizer] {
145278

@@ -167,6 +300,13 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
167300
}
168301
}
169302

303+
/**
304+
* Check each splits in the splits array.
305+
*/
306+
private[feature] def checkSplitsArray(splitsArray: Array[Array[Double]]): Boolean = {
307+
splitsArray.forall(checkSplits(_))
308+
}
309+
170310
/**
171311
* Binary searching in several buckets to place each data point.
172312
* @param splits array of split points
@@ -211,3 +351,9 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
211351
@Since("1.6.0")
212352
override def load(path: String): Bucketizer = super.load(path)
213353
}
354+
355+
@Since("2.3.0")
356+
object MultipleBucketizer extends DefaultParamsReadable[MultipleBucketizer] {
357+
@Since("2.3.0")
358+
override def load(path: String): MultipleBucketizer = super.load(path)
359+
}

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,45 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
490490
}
491491
}
492492

493+
/**
494+
* :: DeveloperApi ::
495+
* Specialized version of `Param[Array[Array[Double]]]` for Java.
496+
*/
497+
@DeveloperApi
498+
class DoubleArrayArrayParam(
499+
parent: Params,
500+
name: String,
501+
doc: String,
502+
isValid: Array[Array[Double]] => Boolean)
503+
extends Param[Array[Array[Double]]](parent, name, doc, isValid) {
504+
505+
def this(parent: Params, name: String, doc: String) =
506+
this(parent, name, doc, ParamValidators.alwaysTrue)
507+
508+
/** Creates a param pair with a `java.util.List` of values (for Java and Python). */
509+
def w(value: java.util.List[java.util.List[java.lang.Double]]): ParamPair[Array[Array[Double]]] =
510+
w(value.asScala.map(_.asScala.map(_.asInstanceOf[Double]).toArray).toArray)
511+
512+
override def jsonEncode(value: Array[Array[Double]]): String = {
513+
import org.json4s.JsonDSL._
514+
compact(render(value.toSeq.map(_.toSeq.map(DoubleParam.jValueEncode))))
515+
}
516+
517+
override def jsonDecode(json: String): Array[Array[Double]] = {
518+
parse(json) match {
519+
case JArray(values) =>
520+
values.map {
521+
case JArray(values) =>
522+
values.map(DoubleParam.jValueDecode).toArray
523+
case _ =>
524+
throw new IllegalArgumentException(s"Cannot decode $json to Array[Array[Double]].")
525+
}.toArray
526+
case _ =>
527+
throw new IllegalArgumentException(s"Cannot decode $json to Array[Array[Double]].")
528+
}
529+
}
530+
}
531+
493532
/**
494533
* :: DeveloperApi ::
495534
* Specialized version of `Param[Array[Int]]` for Java.

mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import scala.util.Random
2222
import org.apache.spark.{SparkException, SparkFunSuite}
2323
import org.apache.spark.ml.linalg.Vectors
2424
import org.apache.spark.ml.param.ParamsSuite
25-
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
25+
import org.apache.spark.ml.util.DefaultReadWriteTest
2626
import org.apache.spark.ml.util.TestingUtils._
2727
import org.apache.spark.mllib.util.MLlibTestSparkContext
2828
import org.apache.spark.sql.{DataFrame, Row}

mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,10 @@ class ParamsSuite extends SparkFunSuite {
121121
{ // DoubleArrayParam
122122
val param = new DoubleArrayParam(dummy, "name", "doc")
123123
val values: Seq[Array[Double]] = Seq(
124-
Array(),
125-
Array(1.0),
126-
Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0,
127-
Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity))
124+
Array(),
125+
Array(1.0),
126+
Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0,
127+
Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity))
128128
for (value <- values) {
129129
val json = param.jsonEncode(value)
130130
val decoded = param.jsonDecode(json)
@@ -139,6 +139,36 @@ class ParamsSuite extends SparkFunSuite {
139139
}
140140
}
141141

142+
{ // DoubleArrayArrayParam
143+
val param = new DoubleArrayArrayParam(dummy, "name", "doc")
144+
val values: Seq[Array[Array[Double]]] = Seq(
145+
Array(Array()),
146+
Array(Array(1.0)),
147+
Array(Array(1.0), Array(2.0)),
148+
Array(
149+
Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0,
150+
Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity),
151+
Array(Double.MaxValue, Double.PositiveInfinity, Double.MinPositiveValue, 1.0,
152+
Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0)
153+
))
154+
155+
for (value <- values) {
156+
val json = param.jsonEncode(value)
157+
val decoded = param.jsonDecode(json)
158+
assert(decoded.length === value.length)
159+
decoded.zip(value).foreach { case (actualArray, expectedArray) =>
160+
assert(actualArray.length === expectedArray.length)
161+
actualArray.zip(expectedArray).foreach { case (actual, expected) =>
162+
if (expected.isNaN) {
163+
assert(actual.isNaN)
164+
} else {
165+
assert(actual === expected)
166+
}
167+
}
168+
}
169+
}
170+
}
171+
142172
{ // StringArrayParam
143173
val param = new StringArrayParam(dummy, "name", "doc")
144174
val values: Seq[Array[String]] = Seq(

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1882,6 +1882,56 @@ class Dataset[T] private[sql](
18821882
}
18831883
}
18841884

1885+
/**
1886+
* Returns a new Dataset by adding columns or replacing the existing columns that has
1887+
* the same names.
1888+
*
1889+
* @group untypedrel
1890+
* @since 2.3.0
1891+
*/
1892+
def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = {
1893+
assert(colNames.size == cols.size,
1894+
s"The size of column names: ${colNames.size} isn't equal to " +
1895+
s"the size of columns: ${cols.size}")
1896+
1897+
val resolver = sparkSession.sessionState.analyzer.resolver
1898+
val output = queryExecution.analyzed.output
1899+
1900+
val columnMap = colNames.zip(cols).toMap
1901+
1902+
val replacedAndExistingColumns = output.map { field =>
1903+
val dupColumn = columnMap.find { case (colName, col) =>
1904+
resolver(field.name, colName)
1905+
}
1906+
if (dupColumn.isDefined) {
1907+
val colName = dupColumn.get._1
1908+
val col = dupColumn.get._2
1909+
col.as(colName)
1910+
} else {
1911+
Column(field)
1912+
}
1913+
}
1914+
1915+
val newColumns = columnMap.filter { case (colName, col) =>
1916+
!output.exists(f => resolver(f.name, colName))
1917+
}.map { case (colName, col) => col.as(colName) }
1918+
1919+
select(replacedAndExistingColumns ++ newColumns : _*)
1920+
}
1921+
1922+
/**
1923+
* Returns a new Dataset by adding columns with metadata.
1924+
*/
1925+
private[spark] def withColumns(
1926+
colNames: Seq[String],
1927+
cols: Seq[Column],
1928+
metadata: Seq[Metadata]): DataFrame = {
1929+
val newCols = colNames.zip(cols).zip(metadata).map { case ((colName, col), metadata) =>
1930+
col.as(colName, metadata)
1931+
}
1932+
withColumns(colNames, newCols)
1933+
}
1934+
18851935
/**
18861936
* Returns a new Dataset by adding a column with metadata.
18871937
*/

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
555555
assert(df.schema.map(_.name) === Seq("key", "value", "newCol"))
556556
}
557557

558+
test("withColumns") {
559+
val df = testData.toDF().withColumns(Seq("newCol1", "newCol2"),
560+
Seq(col("key") + 1, col("key") + 2))
561+
checkAnswer(
562+
df,
563+
testData.collect().map { case Row(key: Int, value: String) =>
564+
Row(key, value, key + 1, key + 2)
565+
}.toSeq)
566+
assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCol2"))
567+
}
568+
558569
test("replace column using withColumn") {
559570
val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
560571
val df3 = df2.withColumn("x", df2("x") + 1)
@@ -563,6 +574,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
563574
Row(2) :: Row(3) :: Row(4) :: Nil)
564575
}
565576

577+
test("replace column using withColumns") {
578+
val df2 = sparkContext.parallelize(Array((1, 2), (2, 3), (3, 4))).toDF("x", "y")
579+
val df3 = df2.withColumns(Seq("x", "newCol1", "newCol2"),
580+
Seq(df2("x") + 1, df2("y"), df2("y") + 1))
581+
checkAnswer(
582+
df3.select("x", "newCol1", "newCol2"),
583+
Row(2, 2, 3) :: Row(3, 3, 4) :: Row(4, 4, 5) :: Nil)
584+
}
585+
566586
test("drop column using drop") {
567587
val df = testData.drop("key")
568588
checkAnswer(

0 commit comments

Comments
 (0)