Skip to content

Commit 2b1b81d

Browse files
author
Vincent Xie
authored
Merge pull request #2 from jkbradley/VinceShieh-spark-17219_followup
Small cleanups
2 parents 2f98d31 + 2644235 commit 2b1b81d

File tree

6 files changed

+97
-57
lines changed

6 files changed

+97
-57
lines changed

docs/ml-features.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,13 +1103,16 @@ for more details on the API.
11031103

11041104
`QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
11051105
categorical features. The number of bins is set by the `numBuckets` parameter. It is possible
1106-
that the number of buckets used will be less than this value, for example, if there are too few
1107-
distinct values of the input to create enough distinct quantiles. Note also that QuantileDiscretizer
1108-
will raise an error when it finds NaN value in the dataset, but user can also choose to either
1109-
keep or remove NaN values within the dataset by setting handleInvalid. If user chooses to keep
1106+
that the number of buckets used will be smaller than this value, for example, if there are too few
1107+
distinct values of the input to create enough distinct quantiles.
1108+
1109+
NaN values: Note also that QuantileDiscretizer
1110+
will raise an error when it finds NaN values in the dataset, but the user can also choose to either
1111+
keep or remove NaN values within the dataset by setting `handleInvalid`. If the user chooses to keep
11101112
NaN values, they will be handled specially and placed into their own bucket, for example, if 4 buckets
11111113
are used, then non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4].
1112-
The bin ranges are chosen using an approximate algorithm (see the documentation for
1114+
1115+
Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation for
11131116
[approxQuantile](api/scala/index.html#org.apache.spark.sql.DataFrameStatFunctions) for a
11141117
detailed description). The precision of the approximation can be controlled with the
11151118
`relativeError` parameter. When set to zero, exact quantiles are calculated

mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
4747
* also includes y. Splits should be of length >= 3 and strictly increasing.
4848
* Values at -inf, inf must be explicitly provided to cover all Double values;
4949
* otherwise, values outside the splits specified will be treated as errors.
50+
*
51+
* See also [[handleInvalid]], which can optionally create an additional bucket for NaN values.
52+
*
5053
* @group param
5154
*/
5255
@Since("1.4.0")
@@ -75,37 +78,36 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
7578
def setOutputCol(value: String): this.type = set(outputCol, value)
7679

7780
/**
78-
* Param for how to handle invalid entries. Options are skip (which will filter out rows with
79-
* invalid values), or error (which will throw an error), or keep (which will keep the invalid
80-
* values in certain way).
81+
* Param for how to handle invalid entries. Options are skip (filter out rows with
82+
* invalid values), error (throw an error), or keep (keep invalid values in a special additional
83+
* bucket).
8184
* Default: "error"
8285
* @group param
8386
*/
8487
@Since("2.1.0")
8588
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" +
86-
"invalid entries. Options are skip (which will filter out rows with invalid values), or" +
87-
"error (which will throw an error), or keep (which will keep the invalid values" +
88-
" in certain way). Default behaviour is to report an error for invalid entries.",
89-
ParamValidators.inArray(Array("skip", "error", "keep")))
89+
"invalid entries. Options are skip (filter out rows with invalid values), " +
90+
"error (throw an error), or keep (keep invalid values in a special additional bucket).",
91+
ParamValidators.inArray(Bucketizer.supportedHandleInvalid))
9092

9193
/** @group getParam */
9294
@Since("2.1.0")
93-
def gethandleInvalid: String = $(handleInvalid)
95+
def getHandleInvalid: String = $(handleInvalid)
9496

9597
/** @group setParam */
9698
@Since("2.1.0")
97-
def sethandleInvalid(value: String): this.type = set(handleInvalid, value)
98-
setDefault(handleInvalid, "error")
99+
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
100+
setDefault(handleInvalid, Bucketizer.ERROR_INVALID)
99101

100102
@Since("2.0.0")
101103
override def transform(dataset: Dataset[_]): DataFrame = {
102104
transformSchema(dataset.schema)
103105
val (filteredDataset, keepInvalid) = {
104-
if ("skip" == gethandleInvalid) {
106+
if (getHandleInvalid == Bucketizer.SKIP_INVALID) {
105107
// "skip" NaN option is set, will filter out NaN values in the dataset
106-
(dataset.na.drop.toDF(), false)
108+
(dataset.na.drop().toDF(), false)
107109
} else {
108-
(dataset.toDF(), "keep" == gethandleInvalid)
110+
(dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID)
109111
}
110112
}
111113

@@ -140,6 +142,12 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
140142
@Since("1.6.0")
141143
object Bucketizer extends DefaultParamsReadable[Bucketizer] {
142144

145+
private[feature] val SKIP_INVALID: String = "skip"
146+
private[feature] val ERROR_INVALID: String = "error"
147+
private[feature] val KEEP_INVALID: String = "keep"
148+
private[feature] val supportedHandleInvalid: Array[String] =
149+
Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
150+
143151
/**
144152
* We require splits to be of length >= 3 and to be in strictly increasing order.
145153
* No NaN split should be accepted.
@@ -173,9 +181,13 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
173181
splits: Array[Double],
174182
feature: Double,
175183
keepInvalid: Boolean): Double = {
176-
if (feature.isNaN && keepInvalid) {
177-
// NaN data point found plus "keep" NaN option is set
178-
splits.length - 1
184+
if (feature.isNaN) {
185+
if (keepInvalid) {
186+
splits.length - 1
187+
} else {
188+
throw new SparkException("Bucketizer encountered NaN value. To handle or skip NaNs," +
189+
" try setting Bucketizer.handleInvalid.")
190+
}
179191
} else if (feature == splits.last) {
180192
splits.length - 2
181193
} else {

mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ private[feature] trait QuantileDiscretizerBase extends Params
3636
/**
3737
* Number of buckets (quantiles, or categories) into which data points are grouped. Must
3838
* be >= 2.
39+
*
40+
* See also [[handleInvalid]], which can optionally create an additional bucket for NaN values.
41+
*
3942
* default: 2
4043
* @group param
4144
*/
@@ -61,19 +64,41 @@ private[feature] trait QuantileDiscretizerBase extends Params
6164

6265
/** @group getParam */
6366
def getRelativeError: Double = getOrDefault(relativeError)
67+
68+
/**
69+
* Param for how to handle invalid entries. Options are skip (filter out rows with
70+
* invalid values), error (throw an error), or keep (keep invalid values in a special additional
71+
* bucket).
72+
* Default: "error"
73+
* @group param
74+
*/
75+
@Since("2.1.0")
76+
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" +
77+
"invalid entries. Options are skip (filter out rows with invalid values), " +
78+
"error (throw an error), or keep (keep invalid values in a special additional bucket).",
79+
ParamValidators.inArray(Bucketizer.supportedHandleInvalid))
80+
setDefault(handleInvalid, Bucketizer.ERROR_INVALID)
81+
82+
/** @group getParam */
83+
@Since("2.1.0")
84+
def getHandleInvalid: String = $(handleInvalid)
85+
6486
}
6587

6688
/**
6789
* `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
6890
* categorical features. The number of bins can be set using the `numBuckets` parameter. It is
69-
* possible that the number of buckets used will be less than this value, for example, if there are
70-
* too few distinct values of the input to create enough distinct quantiles. Note also that
71-
* QuantileDiscretizer will raise an error when it finds NaN value in the dataset, but user can
72-
* also choose to either keep or remove NaN values within the dataset by setting handleInvalid.
73-
* If user chooses to keep NaN values, they will be handled specially and placed into their own
91+
* possible that the number of buckets used will be smaller than this value, for example, if there
92+
* are too few distinct values of the input to create enough distinct quantiles.
93+
*
94+
* NaN handling: Note also that
95+
* QuantileDiscretizer will raise an error when it finds NaN values in the dataset, but the user can
96+
* also choose to either keep or remove NaN values within the dataset by setting `handleInvalid`.
97+
* If the user chooses to keep NaN values, they will be handled specially and placed into their own
7498
* bucket, for example, if 4 buckets are used, then non-NaN data will be put into buckets[0-3],
7599
* but NaNs will be counted in a special bucket[4].
76-
* The bin ranges are chosen using an approximate algorithm (see the documentation for
100+
*
101+
* Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation for
77102
* [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]]
78103
* for a detailed description). The precision of the approximation can be controlled with the
79104
* `relativeError` parameter. The lower and upper bin bounds will be `-Infinity` and `+Infinity`,
@@ -102,28 +127,9 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
102127
@Since("1.6.0")
103128
def setOutputCol(value: String): this.type = set(outputCol, value)
104129

105-
/**
106-
* Param for how to handle invalid entries. Options are skip (which will filter out rows with
107-
* invalid values), or error (which will throw an error), or keep (which will keep the invalid
108-
* values in certain way). Default behaviour is to report an error for invalid entries.
109-
* Default: "error"
110-
* @group param
111-
*/
112-
@Since("2.1.0")
113-
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" +
114-
"invalid entries. Options are skip (which will filter out rows with invalid values), or" +
115-
"error (which will throw an error), or keep (which will keep the invalid values" +
116-
" in certain way). Default behaviour is to report an error for invalid entries.",
117-
ParamValidators.inArray(Array("skip", "error", "keep")))
118-
119-
/** @group getParam */
120-
@Since("2.1.0")
121-
def gethandleInvalid: String = $(handleInvalid)
122-
123130
/** @group setParam */
124131
@Since("2.1.0")
125-
def sethandleInvalid(value: String): this.type = set(handleInvalid, value)
126-
setDefault(handleInvalid, "error")
132+
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
127133

128134
@Since("1.6.0")
129135
override def transformSchema(schema: StructType): StructType = {
@@ -151,7 +157,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
151157
}
152158
val bucketizer = new Bucketizer(uid)
153159
.setSplits(distinctSplits.sorted)
154-
.sethandleInvalid($(handleInvalid))
160+
.setHandleInvalid($(handleInvalid))
155161
copyValues(bucketizer.setParent(this))
156162
}
157163

mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,21 +98,33 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
9898
.setInputCol("feature")
9999
.setOutputCol("result")
100100
.setSplits(splits)
101-
.sethandleInvalid("keep")
102101

102+
bucketizer.setHandleInvalid("keep")
103103
bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
104104
case Row(x: Double, y: Double) =>
105105
assert(x === y,
106106
s"The feature value is not correct after bucketing. Expected $y but found $x")
107107
}
108+
109+
bucketizer.setHandleInvalid("skip")
110+
val skipResults: Array[Double] = bucketizer.transform(dataFrame)
111+
.select("result").as[Double].collect()
112+
assert(skipResults.length === 7)
113+
assert(skipResults.forall(_ !== 4.0))
114+
115+
bucketizer.setHandleInvalid("error")
116+
withClue("Bucketizer should throw error when setHandleInvalid=error and given NaN values") {
117+
intercept[SparkException] {
118+
bucketizer.transform(dataFrame).collect()
119+
}
120+
}
108121
}
109122

110123
test("Bucket continuous features, with NaN splits") {
111124
val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN)
112-
withClue("Invalid NaN split was not caught as an invalid split!") {
125+
withClue("Invalid NaN split was not caught during Bucketizer initialization") {
113126
intercept[IllegalArgumentException] {
114-
val bucketizer: Bucketizer = new Bucketizer()
115-
.setSplits(splits)
127+
new Bucketizer().setSplits(splits)
116128
}
117129
}
118130
}

mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.ml.feature
1919

20-
import org.apache.spark.SparkFunSuite
20+
import org.apache.spark.{SparkException, SparkFunSuite}
2121
import org.apache.spark.ml.util.DefaultReadWriteTest
2222
import org.apache.spark.mllib.util.MLlibTestSparkContext
2323
import org.apache.spark.sql._
@@ -85,9 +85,16 @@ class QuantileDiscretizerSuite
8585
.setOutputCol("result")
8686
.setNumBuckets(numBuckets)
8787

88+
withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") {
89+
val dataFrame: DataFrame = validData.toSeq.toDF("input")
90+
intercept[SparkException] {
91+
discretizer.fit(dataFrame).transform(dataFrame).collect()
92+
}
93+
}
94+
8895
List(("keep", expectedKeep), ("skip", expectedSkip)).foreach{
8996
case(u, v) =>
90-
discretizer.sethandleInvalid(u)
97+
discretizer.setHandleInvalid(u)
9198
val dataFrame: DataFrame = validData.zip(v).toSeq.toDF("input", "expected")
9299
val result = discretizer.fit(dataFrame).transform(dataFrame)
93100
result.select("result", "expected").collect().foreach {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
151151
assert(math.abs(d2 - 2 * q2 * n) < error_double)
152152
}
153153
// test approxQuantile on NaN values
154-
val dfNaN = Array(Double.NaN, 1.0, Double.NaN, Double.NaN).toSeq.toDF("input")
155-
val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), epsilons(0))
156-
assert(resNaN.count(_.isNaN) == 0)
154+
val dfNaN = Seq(Double.NaN, 1.0, Double.NaN, Double.NaN).toDF("input")
155+
val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), epsilons.head)
156+
assert(resNaN.count(_.isNaN) === 0)
157157
}
158158

159159
test("crosstab") {

0 commit comments

Comments
 (0)