diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 95f46044d324..ab9e64fe9bdd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -87,7 +87,10 @@ def _(): 'sum': 'Aggregate function: returns the sum of all values in the expression.', 'avg': 'Aggregate function: returns the average of the values in a group.', 'mean': 'Aggregate function: returns the average of the values in a group.', - 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', + 'stddev': 'Aggregate function: returns the sample standard deviation in a group.', + 'stddevSamp': 'Aggregate function: returns the sample standard deviation in a group.', + 'stddevPop': 'Aggregate function: returns the population standard deviation in a group.', + 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.' } _functions_1_4 = { diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 04594d5a836c..91c377bc962b 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -154,6 +154,48 @@ def min(self, *cols): [Row(min(age)=2, min(height)=80)] """ + @df_varargs_api + @since(1.5) + def stddev(self, *cols): + """Computes the sample standard deviation for each numeric column for each group. + Alias for stddevSamp. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().stddev('age').collect() + [Row(stddev_samp(age)=2.12...)] + >>> df3.groupBy().stddev('age', 'height').collect() + [Row(stddev_samp(age)=2.12..., stddev_samp(height)=3.53...)] + """ + + @df_varargs_api + @since(1.5) + def stddevPop(self, *cols): + """Computes the sample standard deviation for each numeric column for each group. + Alias for stddevSamp. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().stddevPop('age').collect() + [Row(min(age)=1.06...)] + >>> df3.groupBy().stddevPop('age', 'height').collect() + [Row(min(age)=1.06..., min(height)=1.76...)] + """ + + @df_varargs_api + @since(1.5) + def stddevSamp(self, *cols): + """Computes the sample standard deviation for each numeric column for each group. + Alias for stddevSamp. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().stddevSamp('age').collect() + [Row(stddev_samp(age)=2.12...)] + >>> df3.groupBy().stddevSamp('age', 'height').collect() + [Row(stddev_samp(age)=2.12..., stddev_samp(height)=3.53...)] + """ + @df_varargs_api @since(1.3) def sum(self, *cols): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a684dbc3afa4..0c8c51525829 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} +import org.apache.spark.sql.catalyst.expressions.aggregate.{StandardDeviation, Complete, AggregateExpression2, AggregateFunction2} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -524,23 +524,60 @@ class Analyzer( q transformExpressions { case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { - registry.lookupFunction(name, children) match { - // We get an aggregate function built based on AggregateFunction2 interface. - // So, we wrap it in AggregateExpression2. - case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct) - // Currently, our old aggregate function interface supports SUM(DISTINCT ...) - // and COUTN(DISTINCT ...). - case sumDistinct: SumDistinct => sumDistinct - case countDistinct: CountDistinct => countDistinct - // DISTINCT is not meaningful with Max and Min. - case max: Max if isDistinct => max - case min: Min if isDistinct => min - // For other aggregate functions, DISTINCT keyword is not supported for now. - // Once we converted to the new code path, we will allow using DISTINCT keyword. - case other: AggregateExpression1 if isDistinct => - failAnalysis(s"$name does not support DISTINCT keyword.") - // If it does not have DISTINCT keyword, we will return it as is. - case other => other + + // TODO: This is a hack. Hive uses stddev and std as aliases of stddev_pop, which + // is different from other widely used systems (these systems use these two function + // names as aliases of stddev_samp). So, we explicitly rename it to stddev_samp. + // Once we remove AggregateExpression1, we can remove this hack. Also, because + // we do not have stddev in SimpleFunctionRegistry (we want to resolve + // it to the HiveGenericUDAF based on AggregateExpression1 and then do the + // conversion to AggregateExpression2), if it does not exist in function registry, + // we create StandardDeviation directly. + name.toLowerCase match { + case "std" | "stddev" | "stddev_samp" => + if (children.length != 1) { + failAnalysis(s"$name requires exactly one argument.") + } + val funcInRegistry = + registry + .lookupFunction("stddev_samp") + .map(_ => registry.lookupFunction("stddev_samp", children)) + funcInRegistry.getOrElse { + AggregateExpression2( + StandardDeviation(children.head, sample = true), Complete, isDistinct) + } + case "stddev_pop" => + if (children.length != 1) { + failAnalysis(s"$name requires exactly one argument.") + } + val funcInRegistry = + registry + .lookupFunction("stddev_pop") + .map(_ => registry.lookupFunction("stddev_pop", children)) + funcInRegistry.getOrElse { + AggregateExpression2( + StandardDeviation(children.head, sample = true), Complete, isDistinct) + } + case _ => + registry.lookupFunction(name, children) match { + // We get an aggregate function built based on AggregateFunction2 interface. + // So, we wrap it in AggregateExpression2. + case agg2: AggregateFunction2 => + AggregateExpression2(agg2, Complete, isDistinct) + // Currently, our old aggregate function interface supports SUM(DISTINCT ...) + // and COUTN(DISTINCT ...). + case sumDistinct: SumDistinct => sumDistinct + case countDistinct: CountDistinct => countDistinct + // DISTINCT is not meaningful with Max and Min. + case max: Max if isDistinct => max + case min: Min if isDistinct => min + // For other aggregate functions, DISTINCT keyword is not supported for now. + // Once we converted to the new code path, we will allow using DISTINCT keyword. + case other: AggregateExpression1 if isDistinct => + failAnalysis(s"$name does not support DISTINCT keyword.") + // If it does not have DISTINCT keyword, we will return it as is. + case other => other + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index a73024d6adba..8302994cd5cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -302,3 +302,116 @@ case class Sum(child: Expression) extends AlgebraicAggregate { override val evaluateExpression = Cast(currentSum, resultType) } + +/** + * Calculates the Standard Deviation using the online formula here: + * https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + * If sample is true, then we will return the unbiased standard deviation. + */ +case class StandardDeviation(child: Expression, sample: Boolean) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + private lazy val resultType = child.dataType match { + case DecimalType.Fixed(p, s) => + DecimalType.bounded(p + 4, s + 4) + case _ => DoubleType + } + + private lazy val sumDataType = child.dataType match { + case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) + case _ => DoubleType + } + + private lazy val currentCount = AttributeReference("currentCount", LongType)() + private lazy val leftCount = AttributeReference("leftCount", LongType)() + private lazy val rightCount = AttributeReference("rightCount", LongType)() + private lazy val currentDelta = AttributeReference("currentDelta", sumDataType)() + private lazy val currentAvg = AttributeReference("currentAverage", sumDataType)() + private lazy val currentMk = AttributeReference("currentMoment", sumDataType)() + + // the values should be updated in a special order, because they re-use each other + override lazy val bufferAttributes = + leftCount :: rightCount :: currentCount :: currentDelta :: currentAvg :: currentMk :: Nil + + override lazy val initialValues = Seq( + /* leftCount = */ Literal(0L), + /* rightCount = */ Literal(0L), + /* currentCount = */ Literal(0L), + /* currentDelta = */ Cast(Literal(0), sumDataType), + /* currentAvg = */ Cast(Literal(0), sumDataType), + /* currentMk = */ Cast(Literal(0), sumDataType) + ) + + override lazy val updateExpressions = { + val currentValue = Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil) + val deltaX = Subtract(currentValue, currentAvg) + val updatedCount = If(IsNull(child), currentCount, currentCount + 1L) + val updatedAvg = Add(currentAvg, Divide(currentDelta, currentCount)) + Seq( + /* leftCount = */ leftCount, // used only during merging. dummy value + /* rightCount = */ rightCount, // used only during merging. dummy value + /* currentCount = */ updatedCount, + /* currentDelta = */ deltaX, + /* currentAvg = */ updatedAvg, + /* currentMk = */ If(IsNull(child), + currentMk, Add(currentMk, currentDelta * Subtract(currentValue, currentAvg))) + ) + } + + override lazy val mergeExpressions = { + val totalCount = currentCount.left + currentCount.right + val deltaX = currentAvg.left - currentAvg.right + val deltaX2 = deltaX * deltaX + val sumMoments = currentMk.left + currentMk.right + val sumLeft = currentAvg.left * leftCount + val sumRight = currentAvg.right * rightCount + val mergedAvg = (sumLeft + sumRight) / currentCount + val mergedMk = sumMoments + currentDelta * leftCount / currentCount * rightCount + Seq( + /* leftCount = */ currentCount.left, + /* rightCount = */ currentCount.right, + /* currentCount = */ totalCount, + /* currentDelta = */ deltaX2, + /* currentAvg = */ If(EqualTo(leftCount, Cast(Literal(0L), LongType)), currentAvg.right, + If(EqualTo(rightCount, Cast(Literal(0L), LongType)), currentAvg.left, mergedAvg)), + /* currentMk = */ If(EqualTo(leftCount, Cast(Literal(0L), LongType)), currentMk.right, + If(EqualTo(rightCount, Cast(Literal(0L), LongType)), currentMk.left, mergedMk)) + ) + } + + override lazy val evaluateExpression = { + val count = + if (sample) { + If(EqualTo(currentCount, Cast(Literal(0L), LongType)), currentCount, + currentCount - Cast(Literal(1L), LongType)) + } else { + currentCount + } + + child.dataType match { + case DecimalType.Fixed(p, s) => + // increase the precision and scale to prevent precision loss + val dt = DecimalType.bounded(p + 14, s + 4) + Cast(Sqrt(Cast(currentMk, dt) / Cast(count, dt)), resultType) + case _ => + Sqrt(Cast(currentMk, resultType) / Cast(count, resultType)) + } + } + + override def prettyName: String = { + if (sample) { + "stddev_samp" + } else { + "stddev_pop" + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 4a43318a9549..ec6cec8d7f94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -96,6 +96,28 @@ object Utils { aggregateFunction = aggregate.Sum(child), mode = aggregate.Complete, isDistinct = true) + + case hiveUDAF: AggregateExpression1 + if hiveUDAF.getClass.getSimpleName == "HiveGenericUDAF" && + hiveUDAF.toString.contains("GenericUDAFStdSample") => + // We get a STDDEV_SAMP, which is originally resolved as a HiveGenericUDAF. + require(hiveUDAF.children.length == 1, "stddev_samp requires exactly one argument.") + val child = hiveUDAF.children.head + aggregate.AggregateExpression2( + aggregateFunction = aggregate.StandardDeviation(child, sample = true), + mode = aggregate.Complete, + isDistinct = false) + + case hiveUDAF: AggregateExpression1 + if hiveUDAF.getClass.getSimpleName == "HiveGenericUDAF" && + hiveUDAF.toString.contains("GenericUDAFStd") => + // We get a STDDEV_POP, which is originally resolved as a HiveGenericUDAF. + require(hiveUDAF.children.length == 1, "stddev_pop requires exactly one argument.") + val child = hiveUDAF.children.head + aggregate.AggregateExpression2( + aggregateFunction = aggregate.StandardDeviation(child, sample = false), + mode = aggregate.Complete, + isDistinct = false) } // Check if there is any expressions.AggregateExpression1 left. // If so, we cannot convert this plan. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 570b8b2d5928..274d5f25c349 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.StandardDeviation import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} @@ -1268,15 +1269,11 @@ class DataFrame private[sql]( @scala.annotation.varargs def describe(cols: String*): DataFrame = { - // TODO: Add stddev as an expression, and remove it from here. - def stddevExpr(expr: Expression): Expression = - Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr)))) - // The list of summary statistics to compute, in the form of expressions. val statistics = List[(String, Expression => Expression)]( "count" -> Count, "mean" -> Average, - "stddev" -> stddevExpr, + "stddev" -> ((e: Expression) => UnresolvedFunction("stddev_samp", e :: Nil, false)), "min" -> Min, "max" -> Max) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 99d557b03a03..85eab4dc0455 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConversions._ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType @@ -283,6 +283,52 @@ class GroupedData protected[sql]( aggregateNumericColumns(colNames : _*)(Min) } + /** + * Compute the sample standard deviation for each numeric column for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the standard deviation for them. + * + * @since 1.5.0 + */ + @scala.annotation.varargs + def stddev(colNames: String*): DataFrame = { + stddevSamp(colNames : _*) + } + + /** + * Compute the population standard deviation for each numeric column for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the standard deviation for them. + * + * @since 1.5.0 + */ + @scala.annotation.varargs + def stddevPop(colNames: String*): DataFrame = { + def builder(e: Expression): Expression = { + Alias( + UnresolvedFunction("stddev_pop", e :: Nil, false), + s"stddev_pop(${e.prettyString})")() + } + aggregateNumericColumns(colNames : _*)(builder) + } + + /** + * Compute the sample standard deviation for each numeric column for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the standard deviation for them. + * + * @since 1.5.0 + */ + @scala.annotation.varargs + def stddevSamp(colNames: String*): DataFrame = { + def builder(e: Expression): Expression = { + Alias( + UnresolvedFunction("stddev_samp", e :: Nil, false), + s"stddev_samp(${e.prettyString})")() + } + aggregateNumericColumns(colNames : _*)(builder) + } + /** * Compute the sum for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index ad428ad663f3..08e69d566796 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -98,6 +98,10 @@ case class SortBasedAggregate( override def simpleString: String = { val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions - s"""SortBasedAggregate ${groupingExpressions} ${allAggregateExpressions}""" + + val keyString = groupingExpressions.mkString("[", ",", "]") + val valueString = allAggregateExpressions.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + s"SortBasedAggregate(key=$keyString, functions=$valueString, output=$outputString)" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 67ebafde25ad..a9d833222ba3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -107,7 +107,6 @@ class SortBasedAggregationIterator( // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { processRow(sortBasedAggregationBuffer, currentRow) - hasNext = inputKVIterator.next() } else { // We find a new group. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 1694794a53d9..fc9a4b739414 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -96,10 +96,11 @@ case class TungstenAggregate( case None => val keyString = groupingExpressions.mkString("[", ",", "]") val valueString = allAggregateExpressions.mkString("[", ",", "]") - s"TungstenAggregate(key=$keyString, value=$valueString" + val outputString = output.mkString("[", ",", "]") + s"TungstenAggregate(key=$keyString, functions=$valueString, output=$outputString)" case Some(fallbackStartsAt) => s"TungstenAggregateWithControlledFallback $groupingExpressions " + - s"$allAggregateExpressions fallbackStartsAt=$fallbackStartsAt" + s"$allAggregateExpressions $resultExpressions fallbackStartsAt=$fallbackStartsAt" } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 79c5f596661d..e4bbf324a61c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -294,6 +294,64 @@ object functions { */ def min(columnName: String): Column = min(Column(columnName)) + /** + * Aggregate function: returns the sample standard deviation of the values in a group. + * Alias for stddevSamp. + * + * @group agg_funcs + * @since 1.5.0 + */ + def stddev(e: Column): Column = stddevSamp(e) + + /** + * Aggregate function: returns the sample standard deviation of the values in a group. + * Alias for stddevSamp. + * + * @group agg_funcs + * @since 1.5.0 + */ + def stddev(columnName: String): Column = stddevSamp(Column(columnName)) + + /** + * Aggregate function: returns the population standard deviation of the values in a group. + * + * @group agg_funcs + * @since 1.5.0 + */ + def stddevPop(e: Column): Column = { + Alias( + UnresolvedFunction("stddev_pop", e.expr :: Nil, false), + s"stddev_pop(${e.expr.prettyString})")() + } + + /** + * Aggregate function: returns the population standard deviation of the values in a group. + * + * @group agg_funcs + * @since 1.5.0 + */ + def stddevPop(columnName: String): Column = stddevPop(Column(columnName)) + + /** + * Aggregate function: returns the sample standard deviation of the values in a group. + * + * @group agg_funcs + * @since 1.5.0 + */ + def stddevSamp(e: Column): Column = { + Alias( + UnresolvedFunction("stddev_samp", e.expr :: Nil, false), + s"stddev_samp(${e.expr.prettyString})")() + } + + /** + * Aggregate function: returns the sample standard deviation of the values in a group. + * + * @group agg_funcs + * @since 1.5.0 + */ + def stddevSamp(columnName: String): Column = stddevSamp(Column(columnName)) + /** * Aggregate function: returns the sum of all values in the expression. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c49f256be550..f7e9be8c050f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -440,7 +440,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { val describeResult = Seq( Row("count", "4", "4"), Row("mean", "33.0", "178.0"), - Row("stddev", "16.583123951777", "10.0"), + Row("stddev", "19.148542155126762", "11.547005383792516"), Row("min", "16", "164"), Row("max", "60", "192")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 7b5aa4763fd9..b910313e77b9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.sql._ +import org.apache.spark.sql.functions.{stddev, stddevPop} import org.scalatest.BeforeAndAfterAll import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} @@ -284,6 +285,116 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Row(11.125) :: Nil) } + /** For resilience against rounding mismatches. */ + private def about(d: Double): BigDecimal = + BigDecimal(d).setScale(10, BigDecimal.RoundingMode.HALF_UP) + + test("test standard deviation") { + // All results generated in R. Comparisons will be performed up to 10 digits of precision. + val df = Seq.tabulate(10)(i => (i, 1)).toDF("val", "key") + checkAnswer( + df.select(stddev("val").cast("decimal(12, 10)")), + Row(about(3.0276503540974917)) :: Nil) + + checkAnswer( + df.select(stddevPop("val").cast("decimal(12, 10)")), + Row(about(2.8722813232690148)) :: Nil) + + // Make sure we can use stddev functions in SQL. + { + val expectedGroup1 = + Row(1, about(10.0), about(10.0), about(10.0), about(8.16496580927726)) + val expectedGroup2 = + Row( + 2, + about(0.7071067811865476), + about(0.7071067811865476), + about(0.7071067811865476), + about(0.5)) + val expectedGroup3 = Row(3, null, null, null, null) + val expectedGroupNull = + Row( + null, + about(81.8535277187245), + about(81.8535277187245), + about(81.8535277187245), + about(66.83312551921139)) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | cast(std(value) as decimal(12, 10)), + | cast(stDDev(value) as decimal(12, 10)), + | cast(stddEV_samp(value) as decimal(12, 10)), + | cast(stddev_pOP(value) as decimal(12, 10)) + |FROM agg1 GROUP BY key + """.stripMargin), + expectedGroup1 :: + expectedGroup2 :: + expectedGroup3 :: + expectedGroupNull :: Nil) + } + + checkAnswer( + sqlContext.table("agg1").groupBy("key").stddevPop("value") + .select($"key", $"stddev_pop(value)".cast("decimal(12, 10)")), + Row(1, about(8.16496580927726)) :: Row(2, about(0.5)) :: Row(3, null) :: + Row(null, about(66.83312551921139)) :: Nil) + + checkAnswer( + sqlContext.table("agg1").select(stddev("key").cast("decimal(12, 10)"), + stddev("value").cast("decimal(12, 10)")), + Row(about(0.7817359599705717), about(44.898098909801135)) :: Nil) + + checkAnswer( + sqlContext.table("agg1").select(stddevPop("key").cast("decimal(12, 10)"), + stddevPop("value").cast("decimal(12, 10)")), + Row(about(0.7370277311900889), about(41.99832585949111)) :: Nil) + + checkAnswer( + sqlContext.table("agg2").groupBy("key", "value1").stddev("value2") + .select($"key", $"value1", $"stddev_samp(value2)".cast("decimal(12, 10)")), + Row(1, 10, null) :: Row(1, 30, about(42.42640687119285)) :: Row(2, -1, null) :: + Row(2, 1, about(0.0)) :: Row(2, null, null) :: Row(3, null, null) :: Row(null, -10, null) :: + Row(null, -60, null) :: Row(null, 100, null) :: Row(null, null, null) :: Nil) + + checkAnswer( + sqlContext.table("agg2").groupBy("key", "value1").stddevPop("value2") + .select($"key", $"value1", $"stddev_pop(value2)".cast("decimal(12, 10)")), + Row(1, 10, about(0.0)) :: Row(1, 30, about(30.0)) :: Row(2, -1, null) :: + Row(2, 1, about(0.0)) :: Row(2, null, about(0.0)) :: Row(3, null, about(0.0)) :: + Row(null, -10, about(0.0)) :: Row(null, -60, about(0.0)) :: Row(null, 100, about(0.0)) :: + Row(null, null, null) :: Nil) + + checkAnswer( + sqlContext.table("emptyTable").select(stddev("value")), + Row(null) :: Nil) + + checkAnswer( + sqlContext.table("emptyTable").select(stddevPop("value")), + Row(null) :: Nil) + + // stddev_samp returns null when there is a single input. + // While, stddev_pop returns 0.0 when there is a single input. + checkAnswer( + sqlContext.sql("SELECT stddev_samp(1), stddev_pop(1)"), + Row(null, 0.0) :: Nil + ) + + // TODO: Because we will first resolve stddev to Hive's GenericUDAF and it will + // complain about stddev_samp(null) and stddev_pop(null). So, we comment out this + // test. Once we remove AggregateExpression1, we will resolve them directly to + // our native implementation. We should re-enable this test at that time. + /* + checkAnswer( + sqlContext.sql("SELECT stddev_samp(null), stddev_pop(null)"), + Row(null, null) :: Nil + ) + */ + } + test("udaf") { checkAnswer( sqlContext.sql( @@ -505,7 +616,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be |SELECT | key, | mydoublesum(value + 1.5 * key), - | stddev_samp(value) + | variance(value) |FROM agg1 |GROUP BY key """.stripMargin).collect() @@ -518,7 +629,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be |SELECT | key, | sum(value + 1.5 * key), - | stddev_samp(value) + | variance(value) |FROM agg1 |GROUP BY key """.stripMargin).queryExecution.executedPlan.collect {