Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/ml-features.md
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,11 @@ for more details on the API.
## QuantileDiscretizer

`QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
categorical features. The number of bins is set by the `numBuckets` parameter.
categorical features. The number of bins is set by the `numBuckets` parameter. It is possible
that the number of buckets used will be less than this value, for example, if there are too few
distinct values of the input to create enough distinct quantiles. Note also that NaN values are
handled specially and placed into their own bucket. For example, if 4 buckets are used, then
non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4].
The bin ranges are chosen using an approximate algorithm (see the documentation for
[approxQuantile](api/scala/index.html#org.apache.spark.sql.DataFrameStatFunctions) for a
detailed description). The precision of the approximation can be controlled with the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,21 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
@Since("1.6.0")
object Bucketizer extends DefaultParamsReadable[Bucketizer] {

/** We require splits to be of length >= 3 and to be in strictly increasing order. */
/**
* We require splits to be of length >= 3 and to be in strictly increasing order.
* No NaN split should be accepted.
*/
private[feature] def checkSplits(splits: Array[Double]): Boolean = {
if (splits.length < 3) {
false
} else {
var i = 0
val n = splits.length - 1
while (i < n) {
if (splits(i) >= splits(i + 1)) return false
if (splits(i) >= splits(i + 1) || splits(i).isNaN) return false
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a safe checker for NaN split here

i += 1
}
true
!splits(n).isNaN
}
}

Expand All @@ -126,7 +129,9 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
* @throws SparkException if a feature is < splits.head or > splits.last
*/
private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
if (feature == splits.last) {
if (feature.isNaN) {
splits.length - 1
} else if (feature == splits.last) {
splits.length - 2
} else {
val idx = ju.Arrays.binarySearch(splits, feature)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
* default: 2
* @group param
*/
val numBuckets = new IntParam(this, "numBuckets", "Maximum number of buckets (quantiles, or " +
val numBuckets = new IntParam(this, "numBuckets", "Number of buckets (quantiles, or " +
"categories) into which data points are grouped. Must be >= 2.",
ParamValidators.gtEq(2))
setDefault(numBuckets -> 2)
Expand All @@ -65,7 +65,12 @@ private[feature] trait QuantileDiscretizerBase extends Params

/**
* `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
* categorical features. The number of bins can be set using the `numBuckets` parameter.
* categorical features. The number of bins can be set using the `numBuckets` parameter. It is
* possible that the number of buckets used will be less than this value, for example, if there
* are too few distinct values of the input to create enough distinct quantiles. Note also that
* NaN values are handled specially and placed into their own bucket. For example, if 4 buckets
* are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in a special
* bucket(4).
* The bin ranges are chosen using an approximate algorithm (see the documentation for
* [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]]
* for a detailed description). The precision of the approximation can be controlled with the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,37 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
}
}

test("Bucket continuous features, with NaN data but non-NaN splits") {
val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity)
val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN)
val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 4.0)
val dataFrame: DataFrame =
spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected")

val bucketizer: Bucketizer = new Bucketizer()
.setInputCol("feature")
.setOutputCol("result")
.setSplits(splits)

bucketizer.transform(dataFrame).select("result", "expected").collect().foreach {
case Row(x: Double, y: Double) =>
assert(x === y,
s"The feature value is not correct after bucketing. Expected $y but found $x")
}
}

test("Bucket continuous features, with NaN splits") {
val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN)
withClue("Invalid NaN split was not caught as an invalid split!") {
intercept[IllegalArgumentException] {
val bucketizer: Bucketizer = new Bucketizer()
.setInputCol("feature")
.setOutputCol("result")
.setSplits(splits)
}
}
}

test("Binary search correctness on hand-picked examples") {
import BucketizerSuite.checkBinarySearch
// length 3, with -inf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,44 @@ class QuantileDiscretizerSuite
"Bucket sizes are not within expected relative error tolerance.")
}

test("Test Bucketizer on duplicated splits") {
test("Test on data with high proportion of duplicated values") {
val spark = this.spark
import spark.implicits._

val datasetSize = 12
val numBuckets = 5
val expectedNumBuckets = 3
val df = sc.parallelize(Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 1.0, 3.0))
.map(Tuple1.apply).toDF("input")
val discretizer = new QuantileDiscretizer()
.setInputCol("input")
.setOutputCol("result")
.setNumBuckets(numBuckets)
val result = discretizer.fit(df).transform(df)
val observedNumBuckets = result.select("result").distinct.count
assert(observedNumBuckets == expectedNumBuckets,
s"Observed number of buckets are not correct." +
s" Expected $expectedNumBuckets but found $observedNumBuckets")
}

test("Test transform on data with NaN value") {
val spark = this.spark
import spark.implicits._

val numBuckets = 3
val df = sc.parallelize(Array(1.0, 1.0, 1.0, Double.NaN))
.map(Tuple1.apply).toDF("input")
val discretizer = new QuantileDiscretizer()
.setInputCol("input")
.setOutputCol("result")
.setNumBuckets(numBuckets)

// Reserve extra one bucket for NaN
val expectedNumBuckets = discretizer.fit(df).getSplits.length - 1
val result = discretizer.fit(df).transform(df)
val observedNumBuckets = result.select("result").distinct.count
assert(2 <= observedNumBuckets && observedNumBuckets <= numBuckets,
"Observed number of buckets are not within expected range.")
assert(observedNumBuckets == expectedNumBuckets,
s"Observed number of buckets are not correct." +
s" Expected $expectedNumBuckets but found $observedNumBuckets")
}

test("Test transform method on unseen data") {
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,11 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab

`QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
categorical features. The number of bins can be set using the :py:attr:`numBuckets` parameter.
It is possible that the number of buckets used will be less than this value, for example, if
there are too few distinct values of the input to create enough distinct quantiles. Note also
that NaN values are handled specially and placed into their own bucket. For example, if 4
buckets are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in
a special bucket(4).
The bin ranges are chosen using an approximate algorithm (see the documentation for
:py:meth:`~.DataFrameStatFunctions.approxQuantile` for a detailed description).
The precision of the approximation can be controlled with the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 Space-efficient
* Online Computation of Quantile Summaries]] by Greenwald and Khanna.
*
* Note that NaN values will be removed from the numerical column before calculation
* @param col the name of the numerical column
* @param probabilities a list of quantile probabilities
* Each number must belong to [0, 1].
Expand All @@ -67,7 +68,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
col: String,
probabilities: Array[Double],
relativeError: Double): Array[Double] = {
StatFunctions.multipleApproxQuantiles(df, Seq(col), probabilities, relativeError).head.toArray
StatFunctions.multipleApproxQuantiles(df.select(col).na.drop(),
Seq(col), probabilities, relativeError).head.toArray
}

/**
Expand Down