Skip to content

Commit 10691d3

Browse files
zhengruifenggatorsmile
authored andcommitted
[SPARK-19573][SQL] Make NaN/null handling consistent in approxQuantile
## What changes were proposed in this pull request? update `StatFunctions.multipleApproxQuantiles` to handle NaN/null ## How was this patch tested? existing tests and added tests Author: Zheng RuiFeng <[email protected]> Closes #16971 from zhengruifeng/quantiles_nan.
1 parent c2d1761 commit 10691d3

File tree

6 files changed

+95
-54
lines changed

6 files changed

+95
-54
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,8 @@ object ApproximatePercentile {
245245
val result = new Array[Double](percentages.length)
246246
var i = 0
247247
while (i < percentages.length) {
248-
result(i) = summaries.query(percentages(i))
248+
// Since summaries.count != 0, the query here never return None.
249+
result(i) = summaries.query(percentages(i)).get
249250
i += 1
250251
}
251252
result

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,17 +176,19 @@ class QuantileSummaries(
176176
* @param quantile the target quantile
177177
* @return
178178
*/
179-
def query(quantile: Double): Double = {
179+
def query(quantile: Double): Option[Double] = {
180180
require(quantile >= 0 && quantile <= 1.0, "quantile should be in the range [0.0, 1.0]")
181181
require(headSampled.isEmpty,
182182
"Cannot operate on an uncompressed summary, call compress() first")
183183

184+
if (sampled.isEmpty) return None
185+
184186
if (quantile <= relativeError) {
185-
return sampled.head.value
187+
return Some(sampled.head.value)
186188
}
187189

188190
if (quantile >= 1 - relativeError) {
189-
return sampled.last.value
191+
return Some(sampled.last.value)
190192
}
191193

192194
// Target rank
@@ -200,11 +202,11 @@ class QuantileSummaries(
200202
minRank += curSample.g
201203
val maxRank = minRank + curSample.delta
202204
if (maxRank - targetError <= rank && rank <= minRank + targetError) {
203-
return curSample.value
205+
return Some(curSample.value)
204206
}
205207
i += 1
206208
}
207-
sampled.last.value
209+
Some(sampled.last.value)
208210
}
209211
}
210212

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,19 @@ class QuantileSummariesSuite extends SparkFunSuite {
5555
}
5656

5757
private def checkQuantile(quant: Double, data: Seq[Double], summary: QuantileSummaries): Unit = {
58-
val approx = summary.query(quant)
59-
// The rank of the approximation.
60-
val rank = data.count(_ < approx) // has to be <, not <= to be exact
61-
val lower = math.floor((quant - summary.relativeError) * data.size)
62-
val upper = math.ceil((quant + summary.relativeError) * data.size)
63-
val msg =
64-
s"$rank not in [$lower $upper], requested quantile: $quant, approx returned: $approx"
65-
assert(rank >= lower, msg)
66-
assert(rank <= upper, msg)
58+
if (data.nonEmpty) {
59+
val approx = summary.query(quant).get
60+
// The rank of the approximation.
61+
val rank = data.count(_ < approx) // has to be <, not <= to be exact
62+
val lower = math.floor((quant - summary.relativeError) * data.size)
63+
val upper = math.ceil((quant + summary.relativeError) * data.size)
64+
val msg =
65+
s"$rank not in [$lower $upper], requested quantile: $quant, approx returned: $approx"
66+
assert(rank >= lower, msg)
67+
assert(rank <= upper, msg)
68+
} else {
69+
assert(summary.query(quant).isEmpty)
70+
}
6771
}
6872

6973
for {
@@ -74,9 +78,9 @@ class QuantileSummariesSuite extends SparkFunSuite {
7478

7579
test(s"Extremas with epsi=$epsi and seq=$seq_name, compression=$compression") {
7680
val s = buildSummary(data, epsi, compression)
77-
val min_approx = s.query(0.0)
81+
val min_approx = s.query(0.0).get
7882
assert(min_approx == data.min, s"Did not return the min: min=${data.min}, got $min_approx")
79-
val max_approx = s.query(1.0)
83+
val max_approx = s.query(1.0).get
8084
assert(max_approx == data.max, s"Did not return the max: max=${data.max}, got $max_approx")
8185
}
8286

@@ -100,6 +104,18 @@ class QuantileSummariesSuite extends SparkFunSuite {
100104
checkQuantile(0.1, data, s)
101105
checkQuantile(0.001, data, s)
102106
}
107+
108+
test(s"Tests on empty data with epsi=$epsi and seq=$seq_name, compression=$compression") {
109+
val emptyData = Seq.empty[Double]
110+
val s = buildSummary(emptyData, epsi, compression)
111+
assert(s.count == 0, s"Found count=${s.count} but data size=0")
112+
assert(s.sampled.isEmpty, s"if QuantileSummaries is empty, sampled should be empty")
113+
checkQuantile(0.9999, emptyData, s)
114+
checkQuantile(0.9, emptyData, s)
115+
checkQuantile(0.5, emptyData, s)
116+
checkQuantile(0.1, emptyData, s)
117+
checkQuantile(0.001, emptyData, s)
118+
}
103119
}
104120

105121
// Tests for merging procedure
@@ -118,9 +134,9 @@ class QuantileSummariesSuite extends SparkFunSuite {
118134
val s1 = buildSummary(data1, epsi, compression)
119135
val s2 = buildSummary(data2, epsi, compression)
120136
val s = s1.merge(s2)
121-
val min_approx = s.query(0.0)
137+
val min_approx = s.query(0.0).get
122138
assert(min_approx == data.min, s"Did not return the min: min=${data.min}, got $min_approx")
123-
val max_approx = s.query(1.0)
139+
val max_approx = s.query(1.0).get
124140
assert(max_approx == data.max, s"Did not return the max: max=${data.max}, got $max_approx")
125141
checkQuantile(0.9999, data, s)
126142
checkQuantile(0.9, data, s)
@@ -137,9 +153,9 @@ class QuantileSummariesSuite extends SparkFunSuite {
137153
val s1 = buildSummary(data11, epsi, compression)
138154
val s2 = buildSummary(data12, epsi, compression)
139155
val s = s1.merge(s2)
140-
val min_approx = s.query(0.0)
156+
val min_approx = s.query(0.0).get
141157
assert(min_approx == data.min, s"Did not return the min: min=${data.min}, got $min_approx")
142-
val max_approx = s.query(1.0)
158+
val max_approx = s.query(1.0).get
143159
assert(max_approx == data.max, s"Did not return the max: max=${data.max}, got $max_approx")
144160
checkQuantile(0.9999, data, s)
145161
checkQuantile(0.9, data, s)

sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,15 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
6464
* @return the approximate quantiles at the given probabilities
6565
*
6666
* @note null and NaN values will be removed from the numerical column before calculation. If
67-
* the dataframe is empty or all rows contain null or NaN, null is returned.
67+
* the dataframe is empty or the column only contains null or NaN, an empty array is returned.
6868
*
6969
* @since 2.0.0
7070
*/
7171
def approxQuantile(
7272
col: String,
7373
probabilities: Array[Double],
7474
relativeError: Double): Array[Double] = {
75-
val res = approxQuantile(Array(col), probabilities, relativeError)
76-
Option(res).map(_.head).orNull
75+
approxQuantile(Array(col), probabilities, relativeError).head
7776
}
7877

7978
/**
@@ -89,22 +88,20 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
8988
* Note that values greater than 1 are accepted but give the same result as 1.
9089
* @return the approximate quantiles at the given probabilities of each column
9190
*
92-
* @note Rows containing any null or NaN values will be removed before calculation. If
93-
* the dataframe is empty or all rows contain null or NaN, null is returned.
91+
* @note null and NaN values will be ignored in numerical columns before calculation. For
92+
* columns only containing null or NaN values, an empty array is returned.
9493
*
9594
* @since 2.2.0
9695
*/
9796
def approxQuantile(
9897
cols: Array[String],
9998
probabilities: Array[Double],
10099
relativeError: Double): Array[Array[Double]] = {
101-
// TODO: Update NaN/null handling to keep consistent with the single-column version
102-
try {
103-
StatFunctions.multipleApproxQuantiles(df.select(cols.map(col): _*).na.drop(), cols,
104-
probabilities, relativeError).map(_.toArray).toArray
105-
} catch {
106-
case e: NoSuchElementException => null
107-
}
100+
StatFunctions.multipleApproxQuantiles(
101+
df.select(cols.map(col): _*),
102+
cols,
103+
probabilities,
104+
relativeError).map(_.toArray).toArray
108105
}
109106

110107

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ object StatFunctions extends Logging {
5454
* Note that values greater than 1 are accepted but give the same result as 1.
5555
*
5656
* @return for each column, returns the requested approximations
57+
*
58+
* @note null and NaN values will be ignored in numerical columns before calculation. For
59+
* a column only containing null or NaN values, an empty array is returned.
5760
*/
5861
def multipleApproxQuantiles(
5962
df: DataFrame,
@@ -78,7 +81,10 @@ object StatFunctions extends Logging {
7881
def apply(summaries: Array[QuantileSummaries], row: Row): Array[QuantileSummaries] = {
7982
var i = 0
8083
while (i < summaries.length) {
81-
summaries(i) = summaries(i).insert(row.getDouble(i))
84+
if (!row.isNullAt(i)) {
85+
val v = row.getDouble(i)
86+
if (!v.isNaN) summaries(i) = summaries(i).insert(v)
87+
}
8288
i += 1
8389
}
8490
summaries
@@ -91,7 +97,7 @@ object StatFunctions extends Logging {
9197
}
9298
val summaries = df.select(columns: _*).rdd.aggregate(emptySummaries)(apply, merge)
9399

94-
summaries.map { summary => probabilities.map(summary.query) }
100+
summaries.map { summary => probabilities.flatMap(summary.query) }
95101
}
96102

97103
/** Calculate the Pearson Correlation Coefficient for the given columns */

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -171,15 +171,6 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
171171
df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), -1.0)
172172
}
173173
assert(e2.getMessage.contains("Relative Error must be non-negative"))
174-
175-
// return null if the dataset is empty
176-
val res1 = df.selectExpr("*").limit(0)
177-
.stat.approxQuantile("singles", Array(q1, q2), epsilons.head)
178-
assert(res1 === null)
179-
180-
val res2 = df.selectExpr("*").limit(0)
181-
.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilons.head)
182-
assert(res2 === null)
183174
}
184175

185176
test("approximate quantile 2: test relativeError greater than 1 return the same result as 1") {
@@ -214,20 +205,48 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
214205
val q1 = 0.5
215206
val q2 = 0.8
216207
val epsilon = 0.1
217-
val rows = spark.sparkContext.parallelize(Seq(Row(Double.NaN, 1.0), Row(1.0, 1.0),
218-
Row(-1.0, Double.NaN), Row(Double.NaN, Double.NaN), Row(null, null), Row(null, 1.0),
219-
Row(-1.0, null), Row(Double.NaN, null)))
208+
val rows = spark.sparkContext.parallelize(Seq(Row(Double.NaN, 1.0, Double.NaN),
209+
Row(1.0, -1.0, null), Row(-1.0, Double.NaN, null), Row(Double.NaN, Double.NaN, null),
210+
Row(null, null, Double.NaN), Row(null, 1.0, null), Row(-1.0, null, Double.NaN),
211+
Row(Double.NaN, null, null)))
220212
val schema = StructType(Seq(StructField("input1", DoubleType, nullable = true),
221-
StructField("input2", DoubleType, nullable = true)))
213+
StructField("input2", DoubleType, nullable = true),
214+
StructField("input3", DoubleType, nullable = true)))
222215
val dfNaN = spark.createDataFrame(rows, schema)
223-
val resNaN = dfNaN.stat.approxQuantile("input1", Array(q1, q2), epsilon)
224-
assert(resNaN.count(_.isNaN) === 0)
225-
assert(resNaN.count(_ == null) === 0)
226216

227-
val resNaN2 = dfNaN.stat.approxQuantile(Array("input1", "input2"),
217+
val resNaN1 = dfNaN.stat.approxQuantile("input1", Array(q1, q2), epsilon)
218+
assert(resNaN1.count(_.isNaN) === 0)
219+
assert(resNaN1.count(_ == null) === 0)
220+
221+
val resNaN2 = dfNaN.stat.approxQuantile("input2", Array(q1, q2), epsilon)
222+
assert(resNaN2.count(_.isNaN) === 0)
223+
assert(resNaN2.count(_ == null) === 0)
224+
225+
val resNaN3 = dfNaN.stat.approxQuantile("input3", Array(q1, q2), epsilon)
226+
assert(resNaN3.isEmpty)
227+
228+
val resNaNAll = dfNaN.stat.approxQuantile(Array("input1", "input2", "input3"),
228229
Array(q1, q2), epsilon)
229-
assert(resNaN2.flatten.count(_.isNaN) === 0)
230-
assert(resNaN2.flatten.count(_ == null) === 0)
230+
assert(resNaNAll.flatten.count(_.isNaN) === 0)
231+
assert(resNaNAll.flatten.count(_ == null) === 0)
232+
233+
assert(resNaN1(0) === resNaNAll(0)(0))
234+
assert(resNaN1(1) === resNaNAll(0)(1))
235+
assert(resNaN2(0) === resNaNAll(1)(0))
236+
assert(resNaN2(1) === resNaNAll(1)(1))
237+
238+
// return empty array for columns only containing null or NaN values
239+
assert(resNaNAll(2).isEmpty)
240+
241+
// return empty array if the dataset is empty
242+
val res1 = dfNaN.selectExpr("*").limit(0)
243+
.stat.approxQuantile("input1", Array(q1, q2), epsilon)
244+
assert(res1.isEmpty)
245+
246+
val res2 = dfNaN.selectExpr("*").limit(0)
247+
.stat.approxQuantile(Array("input1", "input2"), Array(q1, q2), epsilon)
248+
assert(res2(0).isEmpty)
249+
assert(res2(1).isEmpty)
231250
}
232251

233252
test("crosstab") {

0 commit comments

Comments
 (0)