Skip to content

Commit bb19708

Browse files
committed
Address comments.
1 parent 1889995 commit bb19708

File tree

5 files changed

+146
-59
lines changed

5 files changed

+146
-59
lines changed

examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@ package org.apache.spark.examples.ml
2222
import org.apache.spark.ml.feature.Bucketizer
2323
// $example off$
2424
import org.apache.spark.sql.SparkSession
25-
25+
/**
26+
* An example for Bucketizer.
27+
* Run with
28+
* {{{
29+
* bin/run-example ml.BucketizerExample
30+
* }}}
31+
*/
2632
object BucketizerExample {
2733
def main(args: Array[String]): Unit = {
2834
val spark = SparkSession
@@ -48,6 +54,34 @@ object BucketizerExample {
4854
bucketedData.show()
4955
// $example off$
5056

57+
// $example on$
58+
val splitsArray = Array(
59+
Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity),
60+
Array(Double.NegativeInfinity, -0.3, 0.0, 0.3, Double.PositiveInfinity))
61+
62+
val data2 = Array(
63+
(-999.9, -999.9),
64+
(-0.5, -0.2),
65+
(-0.3, -0.1),
66+
(0.0, 0.0),
67+
(0.2, 0.4),
68+
(999.9, 999.9))
69+
val dataFrame2 = spark.createDataFrame(data2).toDF("features1", "features2")
70+
71+
val bucketizer2 = new Bucketizer()
72+
.setInputCols(Array("features1", "features2"))
73+
.setOutputCols(Array("bucketedFeatures1", "bucketedFeatures2"))
74+
.setSplitsArray(splitsArray)
75+
76+
// Transform original data into its bucket index.
77+
val bucketedData2 = bucketizer2.transform(dataFrame2)
78+
79+
println(s"Bucketizer output with [" +
80+
s"${bucketizer2.getSplitsArray(0).length-1}, " +
81+
s"${bucketizer2.getSplitsArray(1).length-1}] buckets for each input column")
82+
bucketedData2.show()
83+
// $example off$
84+
5185
spark.stop()
5286
}
5387
}

mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.spark.annotation.Since
2424
import org.apache.spark.ml.Model
2525
import org.apache.spark.ml.attribute.NominalAttribute
2626
import org.apache.spark.ml.param._
27-
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol}
27+
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
2828
import org.apache.spark.ml.util._
2929
import org.apache.spark.sql._
3030
import org.apache.spark.sql.expressions.UserDefinedFunction
@@ -33,14 +33,15 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
3333

3434
/**
3535
* `Bucketizer` maps a column of continuous features to a column of feature buckets. Since 2.3.0,
36-
* `Bucketizer` can also map multiple columns at once. Whether it goes to map a column or multiple
37-
* columns, it depends on which parameter of `inputCol` and `inputCols` is set. When both are set,
38-
* a log warning will be printed and by default it chooses `inputCol`.
36+
* `Bucketizer` can map multiple columns at once by setting the `inputCols` parameter. Note that
37+
* when both the `inputCol` and `inputCols` parameters are set, a log warning will be printed and
38+
* only `inputCol` will take effect, while `inputCols` will be ignored. The `splits` parameter is
39+
* only used for single column usage, and `splitsArray` is for multiple columns.
3940
*/
4041
@Since("1.4.0")
4142
final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String)
4243
extends Model[Bucketizer] with HasHandleInvalid with HasInputCol with HasOutputCol
43-
with HasInputCols with DefaultParamsWritable {
44+
with HasInputCols with HasOutputCols with DefaultParamsWritable {
4445

4546
@Since("1.4.0")
4647
def this() = this(Identifiable.randomUID("bucketizer"))
@@ -84,7 +85,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
8485
/**
8586
* Param for how to handle invalid entries. Options are 'skip' (filter out rows with
8687
* invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special
87-
* additional bucket).
88+
* additional bucket). Note that in the multiple column case, the invalid handling is applied
89+
* to all columns. That said for 'error' it will throw an error if any invalids are found in
90+
* any column, for 'skip' it will skip rows with any invalids in any columns, etc.
8891
* Default: "error"
8992
* @group param
9093
*/
@@ -115,22 +118,10 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
115118
"specified will be treated as errors.",
116119
Bucketizer.checkSplitsArray)
117120

118-
/**
119-
* Param for output column names.
120-
* @group param
121-
*/
122-
@Since("2.3.0")
123-
final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols",
124-
"output column names")
125-
126121
/** @group getParam */
127122
@Since("2.3.0")
128123
def getSplitsArray: Array[Array[Double]] = $(splitsArray)
129124

130-
/** @group getParam */
131-
@Since("2.3.0")
132-
final def getOutputCols: Array[String] = $(outputCols)
133-
134125
/** @group setParam */
135126
@Since("2.3.0")
136127
def setSplitsArray(value: Array[Array[Double]]): this.type = set(splitsArray, value)
@@ -148,7 +139,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
148139
* `inputCols` is set, it will map multiple columns. Otherwise, it just maps a column specified
149140
* by `inputCol`. A warning will be printed if both are set.
150141
*/
151-
private[ml] def isBucketizeMultipleColumns(): Boolean = {
142+
private[feature] def isBucketizeMultipleColumns(): Boolean = {
152143
if (isSet(inputCols) && isSet(inputCol)) {
153144
logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " +
154145
"`Bucketizer` only map one column specified by `inputCol`")
@@ -162,7 +153,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
162153

163154
@Since("2.0.0")
164155
override def transform(dataset: Dataset[_]): DataFrame = {
165-
transformSchema(dataset.schema)
156+
val transformedSchema = transformSchema(dataset.schema)
166157

167158
val (filteredDataset, keepInvalid) = {
168159
if (getHandleInvalid == Bucketizer.SKIP_INVALID) {
@@ -193,10 +184,10 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
193184
val newCols = inputColumns.zipWithIndex.map { case (inputCol, idx) =>
194185
bucketizers(idx)(filteredDataset(inputCol).cast(DoubleType))
195186
}
196-
val newFields = outputColumns.zipWithIndex.map { case (outputCol, idx) =>
197-
prepOutputField(seqOfSplits(idx), outputCol)
187+
val metadata = outputColumns.map { col =>
188+
transformedSchema(col).metadata
198189
}
199-
filteredDataset.withColumns(outputColumns, newCols, newFields.map(_.metadata))
190+
filteredDataset.withColumns(outputColumns, newCols, metadata)
200191
}
201192

202193
private def prepOutputField(splits: Array[Double], outputCol: String): StructField = {

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ private[shared] object SharedParamsCodeGen {
6060
ParamDesc[String]("inputCol", "input column name"),
6161
ParamDesc[Array[String]]("inputCols", "input column names"),
6262
ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
63+
ParamDesc[Array[String]]("outputCols", "output column names"),
6364
ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " +
6465
"disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " +
6566
"every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"),

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,21 @@ private[ml] trait HasOutputCol extends Params {
230230
final def getOutputCol: String = $(outputCol)
231231
}
232232

233+
/**
234+
* Trait for shared param outputCols.
235+
*/
236+
private[ml] trait HasOutputCols extends Params {
237+
238+
/**
239+
* Param for output column names.
240+
* @group param
241+
*/
242+
final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", "output column names")
243+
244+
/** @group getParam */
245+
final def getOutputCols: Array[String] = $(outputCols)
246+
}
247+
233248
/**
234249
* Trait for shared param checkpointInterval.
235250
*/

mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala

Lines changed: 81 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
200200
val data = (0 until validData1.length).map { idx =>
201201
(validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx))
202202
}
203-
val dataFrame: DataFrame = data.toSeq.toDF("feature1", "feature2", "expected1", "expected2")
203+
val dataFrame: DataFrame = data.toDF("feature1", "feature2", "expected1", "expected2")
204204

205205
val bucketizer1: Bucketizer = new Bucketizer()
206206
.setInputCols(Array("feature1", "feature2"))
@@ -210,16 +210,12 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
210210
assert(bucketizer1.isBucketizeMultipleColumns())
211211

212212
bucketizer1.transform(dataFrame).select("result1", "expected1", "result2", "expected2")
213-
.collect().foreach {
214-
case Row(r1: Double, e1: Double, r2: Double, e2: Double) =>
215-
assert(r1 === e1,
216-
s"The feature value is not correct after bucketing. Expected $e1 but found $r1")
217-
assert(r2 === e2,
218-
s"The feature value is not correct after bucketing. Expected $e2 but found $r2")
219-
}
213+
BucketizerSuite.checkBucketResults(bucketizer1.transform(dataFrame),
214+
Seq("result1", "result2"),
215+
Seq("expected1", "expected2"))
220216

221217
// Check for exceptions when using a set of invalid feature values.
222-
val invalidData1: Array[Double] = Array(-0.9) ++ validData1
218+
val invalidData1 = Array(-0.9) ++ validData1
223219
val invalidData2 = Array(0.51) ++ validData1
224220
val badDF1 = invalidData1.zipWithIndex.toSeq.toDF("feature", "idx")
225221

@@ -256,7 +252,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
256252
val data = (0 until validData1.length).map { idx =>
257253
(validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx))
258254
}
259-
val dataFrame: DataFrame = data.toSeq.toDF("feature1", "feature2", "expected1", "expected2")
255+
val dataFrame: DataFrame = data.toDF("feature1", "feature2", "expected1", "expected2")
260256

261257
val bucketizer: Bucketizer = new Bucketizer()
262258
.setInputCols(Array("feature1", "feature2"))
@@ -265,14 +261,9 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
265261

266262
assert(bucketizer.isBucketizeMultipleColumns())
267263

268-
bucketizer.transform(dataFrame).select("result1", "expected1", "result2", "expected2")
269-
.collect().foreach {
270-
case Row(r1: Double, e1: Double, r2: Double, e2: Double) =>
271-
assert(r1 === e1,
272-
s"The feature value is not correct after bucketing. Expected $e1 but found $r1")
273-
assert(r2 === e2,
274-
s"The feature value is not correct after bucketing. Expected $e2 but found $r2")
275-
}
264+
BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame),
265+
Seq("result1", "result2"),
266+
Seq("expected1", "expected2"))
276267
}
277268

278269
test("multiple columns: Bucket continuous features, with NaN data but non-NaN splits") {
@@ -288,7 +279,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
288279
val data = (0 until validData1.length).map { idx =>
289280
(validData1(idx), validData2(idx), expectedBuckets1(idx), expectedBuckets2(idx))
290281
}
291-
val dataFrame: DataFrame = data.toSeq.toDF("feature1", "feature2", "expected1", "expected2")
282+
val dataFrame: DataFrame = data.toDF("feature1", "feature2", "expected1", "expected2")
292283

293284
val bucketizer: Bucketizer = new Bucketizer()
294285
.setInputCols(Array("feature1", "feature2"))
@@ -298,14 +289,9 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
298289
assert(bucketizer.isBucketizeMultipleColumns())
299290

300291
bucketizer.setHandleInvalid("keep")
301-
bucketizer.transform(dataFrame).select("result1", "expected1", "result2", "expected2")
302-
.collect().foreach {
303-
case Row(r1: Double, e1: Double, r2: Double, e2: Double) =>
304-
assert(r1 === e1,
305-
s"The feature value is not correct after bucketing. Expected $e1 but found $r1")
306-
assert(r2 === e2,
307-
s"The feature value is not correct after bucketing. Expected $e2 but found $r2")
308-
}
292+
BucketizerSuite.checkBucketResults(bucketizer.transform(dataFrame),
293+
Seq("result1", "result2"),
294+
Seq("expected1", "expected2"))
309295

310296
bucketizer.setHandleInvalid("skip")
311297
val skipResults1: Array[Double] = bucketizer.transform(dataFrame)
@@ -335,7 +321,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
335321
}
336322
}
337323

338-
test("multiple columns:: read/write") {
324+
test("multiple columns: read/write") {
339325
val t = new Bucketizer()
340326
.setInputCols(Array("myInputCol"))
341327
.setOutputCols(Array("myOutputCol"))
@@ -359,13 +345,51 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
359345
.setStages(Array(bucket))
360346
.fit(df)
361347
pl.transform(df).select("result1", "expected1", "result2", "expected2")
362-
.collect().foreach {
363-
case Row(r1: Double, e1: Double, r2: Double, e2: Double) =>
364-
assert(r1 === e1,
365-
s"The feature value is not correct after bucketing. Expected $e1 but found $r1")
366-
assert(r2 === e2,
367-
s"The feature value is not correct after bucketing. Expected $e2 but found $r2")
368-
}
348+
349+
BucketizerSuite.checkBucketResults(pl.transform(df),
350+
Seq("result1", "result2"), Seq("expected1", "expected2"))
351+
}
352+
353+
test("Compare single/multiple column(s) Bucketizer in pipeline") {
354+
val df = Seq((0.5, 0.3, 1.0, 1.0), (0.5, -0.4, 1.0, 0.0))
355+
.toDF("feature1", "feature2", "expected1", "expected2")
356+
357+
val multiColsBucket = new Bucketizer()
358+
.setInputCols(Array("feature1", "feature2"))
359+
.setOutputCols(Array("result1", "result2"))
360+
.setSplitsArray(Array(Array(-0.5, 0.0, 0.5), Array(-0.5, 0.0, 0.5)))
361+
362+
val plForMultiCols = new Pipeline()
363+
.setStages(Array(multiColsBucket))
364+
.fit(df)
365+
366+
val bucketForCol1 = new Bucketizer()
367+
.setInputCol("feature1")
368+
.setOutputCol("result1")
369+
.setSplits(Array(-0.5, 0.0, 0.5))
370+
val bucketForCol2 = new Bucketizer()
371+
.setInputCol("feature2")
372+
.setOutputCol("result2")
373+
.setSplits(Array(-0.5, 0.0, 0.5))
374+
375+
val plForSingleCol = new Pipeline()
376+
.setStages(Array(bucketForCol1, bucketForCol2))
377+
.fit(df)
378+
379+
val resultForSingleCol = plForSingleCol.transform(df)
380+
.select("result1", "expected1", "result2", "expected2")
381+
.collect()
382+
val resultForMultiCols = plForMultiCols.transform(df)
383+
.select("result1", "expected1", "result2", "expected2")
384+
.collect()
385+
386+
resultForSingleCol.zip(resultForMultiCols).foreach {
387+
case (rowForSingle, rowForMultiCols) =>
388+
assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) &&
389+
rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1) &&
390+
rowForSingle.getDouble(2) == rowForMultiCols.getDouble(2) &&
391+
rowForSingle.getDouble(3) == rowForMultiCols.getDouble(3))
392+
}
369393
}
370394

371395
test("Both inputCol and inputCols are set") {
@@ -411,4 +435,26 @@ private object BucketizerSuite extends SparkFunSuite {
411435
i += 1
412436
}
413437
}
438+
439+
/** Checks if bucketized results match expected ones. */
440+
def checkBucketResults(
441+
bucketResult: DataFrame,
442+
resultColumns: Seq[String],
443+
expectedColumns: Seq[String]): Unit = {
444+
assert(resultColumns.length == expectedColumns.length,
445+
s"Given ${resultColumns.length} result columns doesn't match " +
446+
s"${expectedColumns.length} expected columns.")
447+
assert(resultColumns.length > 0, "At least one result and expected columns are needed.")
448+
449+
val allColumns = resultColumns ++ expectedColumns
450+
bucketResult.select(allColumns.head, allColumns.tail: _*).collect().foreach {
451+
case row =>
452+
for (idx <- 0 until row.length / 2) {
453+
val result = row.getDouble(idx)
454+
val expected = row.getDouble(idx + row.length / 2)
455+
assert(result === expected, "The feature value is not correct after bucketing. " +
456+
s"Expected $expected but found $result.")
457+
}
458+
}
459+
}
414460
}

0 commit comments

Comments
 (0)