diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d473d6b53464..db8c6cd41a8f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1162,7 +1162,7 @@ def replace(self, to_replace, value, subset=None): self._jdf.na().replace(self._jseq(subset), self._jmap(rep_dict)), self.sql_ctx) @since(2.0) - def approxQuantile(self, col, probabilities, relativeError): + def approxQuantile(self, cols, probabilities, relativeError): """ Calculates the approximate quantiles of a numerical column of a DataFrame. @@ -1181,18 +1181,28 @@ def approxQuantile(self, col, probabilities, relativeError): Space-efficient Online Computation of Quantile Summaries]] by Greenwald and Khanna. - :param col: the name of the numerical column + :param cols: str, list. + Can be a single column name, or a list of names for multiple columns. :param probabilities: a list of quantile probabilities - Each number must belong to [0, 1]. - For example 0 is the minimum, 0.5 is the median, 1 is the maximum. - :param relativeError: The relative target precision to achieve - (>= 0). If set to zero, the exact quantiles are computed, which - could be very expensive. Note that values greater than 1 are - accepted but give the same result as 1. - :return: the approximate quantiles at the given probabilities - """ - if not isinstance(col, str): - raise ValueError("col should be a string.") + Each number must belong to [0, 1]. + For example 0 is the minimum, 0.5 is the median, 1 is the maximum. + :param relativeError: The relative target precision to achieve + (>= 0). If set to zero, the exact quantiles are computed, which + could be very expensive. Note that values greater than 1 are + accepted but give the same result as 1. + :return: the approximate quantiles at the given probabilities for + the given column or columns. + """ + if not isinstance(cols, (str, list, tuple)): + raise ValueError("col should be a string, list or tuple.") + + if isinstance(cols, tuple): + cols = list(cols) + if isinstance(cols, list): + for c in cols: + if not isinstance(c, str): + raise ValueError("column name should be string.") + cols = _to_list(self._sc, cols) if not isinstance(probabilities, (list, tuple)): raise ValueError("probabilities should be a list or tuple") @@ -1207,8 +1217,12 @@ def approxQuantile(self, col, probabilities, relativeError): raise ValueError("relativeError should be numerical (float, int, long) >= 0.") relativeError = float(relativeError) - jaq = self._jdf.stat().approxQuantile(col, probabilities, relativeError) - return list(jaq) + jaq = self._jdf.stat().approxQuantile(cols, probabilities, relativeError) + jaq = list(jaq) + for idx, a in enumerate(jaq): + if not isinstance(a, (list, float)): + jaq[idx] = list(a) + return jaq @since(1.4) def corr(self, col1, col2, method=None): @@ -1440,8 +1454,8 @@ class DataFrameStatFunctions(object): def __init__(self, df): self.df = df - def approxQuantile(self, col, probabilities, relativeError): - return self.df.approxQuantile(col, probabilities, relativeError) + def approxQuantile(self, cols, probabilities, relativeError): + return self.df.approxQuantile(cols, probabilities, relativeError) approxQuantile.__doc__ = DataFrame.approxQuantile.__doc__ diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e4f79c911c0d..5f8fa42deb34 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -702,6 +702,15 @@ def test_approxQuantile(self): self.assertEqual(len(aq), 3) self.assertTrue(all(isinstance(q, float) for q in aq)) + aqs = df.stat.approxQuantile(["a", "a"], [0.1, 0.5, 0.9], 0.1) + self.assertEqual(len(aqs), 2) + self.assertTrue(isinstance(aqs[0], list)) + self.assertEqual(len(aqs[0]), 3) + self.assertTrue(all(isinstance(q, float) for q in aqs[0])) + self.assertTrue(isinstance(aqs[1], list)) + self.assertEqual(len(aqs[1]), 3) + self.assertTrue(all(isinstance(q, float) for q in aqs[1])) + def test_corr(self): import math df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 3eb1f0f0d58f..0c2ecb2bb2dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -52,14 +52,14 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 Space-efficient * Online Computation of Quantile Summaries]] by Greenwald and Khanna. * - * @param col the name of the numerical column + * @param col the name of the numerical column. * @param probabilities a list of quantile probabilities * Each number must belong to [0, 1]. * For example 0 is the minimum, 0.5 is the median, 1 is the maximum. * @param relativeError The relative target precision to achieve (>= 0). * If set to zero, the exact quantiles are computed, which could be very expensive. * Note that values greater than 1 are accepted but give the same result as 1. - * @return the approximate quantiles at the given probabilities + * @return the approximate quantiles at the given probabilities. * * @since 2.0.0 */ @@ -70,6 +70,29 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { StatFunctions.multipleApproxQuantiles(df, Seq(col), probabilities, relativeError).head.toArray } + /** + * Calculates the approximate quantiles of numerical columns of a DataFrame. + * @see #approxQuantile(String, Array[Double], Double) for detailed description. + * + * @param cols the names of the numerical columns. + * @param probabilities a list of quantile probabilities + * Each number must belong to [0, 1]. + * For example 0 is the minimum, 0.5 is the median, 1 is the maximum. + * @param relativeError The relative target precision to achieve (>= 0). + * If set to zero, the exact quantiles are computed, which could be very expensive. + * Note that values greater than 1 are accepted but give the same result as 1. + * @return the approximate quantiles at the given probabilities for given columns. + * + * @since 2.0.0 + */ + def approxQuantile( + cols: Array[String], + probabilities: Array[Double], + relativeError: Double): Array[Array[Double]] = { + StatFunctions.multipleApproxQuantiles(df, cols, probabilities, relativeError) + .map(_.toArray).toArray + } + /** * Python-friendly version of [[approxQuantile()]] */ @@ -80,6 +103,18 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { approxQuantile(col, probabilities.toArray, relativeError).toList.asJava } + /** + * Python-friendly version of [[approxQuantile()]] that computes approximate quantiles + * for multiple columns. + */ + private[spark] def approxQuantile( + cols: List[String], + probabilities: List[Double], + relativeError: Double): java.util.List[java.util.List[Double]] = { + approxQuantile(cols.toArray, probabilities.toArray, relativeError) + .map(_.toList.asJava).toList.asJava + } + /** * Calculate the sample covariance of two numerical columns of a DataFrame. * @param col1 the name of the first column diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 0ea7727e4502..657de10c7b18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -149,6 +149,15 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(math.abs(s2 - q2 * n) < error_single) assert(math.abs(d1 - 2 * q1 * n) < error_double) assert(math.abs(d2 - 2 * q2 * n) < error_double) + + // Multiple columns + val Array(Array(ms1, ms2), Array(md1, md2)) = + df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilon) + + assert(math.abs(ms1 - q1 * n) < error_single) + assert(math.abs(ms2 - q2 * n) < error_single) + assert(math.abs(md1 - 2 * q1 * n) < error_double) + assert(math.abs(md2 - 2 * q2 * n) < error_double) } }