Skip to content

Commit 75edcb1

Browse files
committed
Address comments and change Python API too.
1 parent 47d52b9 commit 75edcb1

File tree

4 files changed

+33
-19
lines changed

4 files changed

+33
-19
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,7 +1162,7 @@ def replace(self, to_replace, value, subset=None):
11621162
self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx)
11631163

11641164
@since(2.0)
1165-
def approxQuantile(self, col, probabilities, relativeError):
1165+
def approxQuantile(self, cols, probabilities, relativeError):
11661166
"""
11671167
Calculates the approximate quantiles of a numerical column of a
11681168
DataFrame.
@@ -1181,7 +1181,7 @@ def approxQuantile(self, col, probabilities, relativeError):
11811181
Space-efficient Online Computation of Quantile Summaries]]
11821182
by Greenwald and Khanna.
11831183
1184-
:param col: the name of the numerical column
1184+
:param cols: the name(s) of the numerical column(s)
11851185
:param probabilities: a list of quantile probabilities
11861186
Each number must belong to [0, 1].
11871187
For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
@@ -1191,8 +1191,13 @@ def approxQuantile(self, col, probabilities, relativeError):
11911191
accepted but give the same result as 1.
11921192
:return: the approximate quantiles at the given probabilities
11931193
"""
1194-
if not isinstance(col, str):
1195-
raise ValueError("col should be a string.")
1194+
if not isinstance(cols, (str, list, tuple)):
1195+
raise ValueError("col should be a string, list or tuple.")
1196+
1197+
if isinstance(cols, tuple):
1198+
cols = list(cols)
1199+
if isinstance(cols, list):
1200+
cols = _to_list(self._sc, cols)
11961201

11971202
if not isinstance(probabilities, (list, tuple)):
11981203
raise ValueError("probabilities should be a list or tuple")
@@ -1207,8 +1212,12 @@ def approxQuantile(self, col, probabilities, relativeError):
12071212
raise ValueError("relativeError should be numerical (float, int, long) >= 0.")
12081213
relativeError = float(relativeError)
12091214

1210-
jaq = self._jdf.stat().approxQuantile(col, probabilities, relativeError)
1211-
return list(jaq)
1215+
jaq = self._jdf.stat().approxQuantile(cols, probabilities, relativeError)
1216+
jaq = list(jaq)
1217+
for idx, a in enumerate(jaq):
1218+
if not isinstance(a, (list, float)):
1219+
jaq[idx] = list(a)
1220+
return jaq
12121221

12131222
@since(1.4)
12141223
def corr(self, col1, col2, method=None):
@@ -1440,8 +1449,8 @@ class DataFrameStatFunctions(object):
14401449
def __init__(self, df):
14411450
self.df = df
14421451

1443-
def approxQuantile(self, col, probabilities, relativeError):
1444-
return self.df.approxQuantile(col, probabilities, relativeError)
1452+
def approxQuantile(self, cols, probabilities, relativeError):
1453+
return self.df.approxQuantile(cols, probabilities, relativeError)
14451454

14461455
approxQuantile.__doc__ = DataFrame.approxQuantile.__doc__
14471456

python/pyspark/sql/tests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,14 @@ def test_approxQuantile(self):
702702
self.assertEqual(len(aq), 3)
703703
self.assertTrue(all(isinstance(q, float) for q in aq))
704704

705+
aqs = df.stat.approxQuantile(["a", "a"], [0.1, 0.5, 0.9], 0.1)
706+
self.assertTrue(isinstance(aqs[0], list))
707+
self.assertEqual(len(aqs[0]), 3)
708+
self.assertTrue(all(isinstance(q, float) for q in aqs[0]))
709+
self.assertTrue(isinstance(aqs[1], list))
710+
self.assertEqual(len(aqs[1]), 3)
711+
self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
712+
705713
def test_corr(self):
706714
import math
707715
df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()

sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
7272

7373
/**
7474
* Calculates the approximate quantiles of numerical columns of a DataFrame.
75+
* @see approxQuantile for detailed description.
7576
*
7677
* @param cols the names of the numerical columns.
7778
* @param probabilities a list of quantile probabilities

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -149,19 +149,15 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
149149
assert(math.abs(s2 - q2 * n) < error_single)
150150
assert(math.abs(d1 - 2 * q1 * n) < error_double)
151151
assert(math.abs(d2 - 2 * q2 * n) < error_double)
152-
}
153-
154-
for (epsilon <- epsilons) {
155-
val Array(Array(s1, s2), Array(d1, d2)) = df.stat.approxQuantile(Array("singles", "doubles"),
156-
Array(q1, q2), epsilon)
157152

158-
val error_single = 2 * 1000 * epsilon
159-
val error_double = 2 * 2000 * epsilon
153+
// Multiple columns
154+
val Array(Array(ms1, ms2), Array(md1, md2)) =
155+
df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilon)
160156

161-
assert(math.abs(s1 - q1 * n) < error_single)
162-
assert(math.abs(s2 - q2 * n) < error_single)
163-
assert(math.abs(d1 - 2 * q1 * n) < error_double)
164-
assert(math.abs(d2 - 2 * q2 * n) < error_double)
157+
assert(math.abs(ms1 - q1 * n) < error_single)
158+
assert(math.abs(ms2 - q2 * n) < error_single)
159+
assert(math.abs(md1 - 2 * q1 * n) < error_double)
160+
assert(math.abs(md2 - 2 * q2 * n) < error_double)
165161
}
166162
}
167163

0 commit comments

Comments
 (0)