From 27ae62566e1f34a407ecde7fb8fef0415073c466 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Sat, 8 Aug 2015 23:10:26 -0700 Subject: [PATCH 01/15] unbiased standard deviation aggregation function --- .../expressions/aggregate/functions.scala | 84 +++++++++++++++++++ .../expressions/aggregate/utils.scala | 8 ++ .../org/apache/spark/sql/DataFrame.scala | 7 +- .../org/apache/spark/sql/GroupedData.scala | 13 +++ .../SortBasedAggregationIterator.scala | 7 +- .../org/apache/spark/sql/functions.scala | 17 ++++ .../org/apache/spark/sql/DataFrameSuite.scala | 2 +- .../execution/AggregationQuerySuite.scala | 27 ++++++ 8 files changed, 158 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index a73024d6adba..bc14c1d1fd4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -302,3 +302,87 @@ case class Sum(child: Expression) extends AlgebraicAggregate { override val evaluateExpression = Cast(currentSum, resultType) } + +/** + * Calculates the unbiased Standard Deviation using the online formula here: + * https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + */ +case class StandardDeviation(child: Expression) extends AlgebraicAggregate { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = resultType + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + + private lazy val resultType = child.dataType match { + case DecimalType.Fixed(p, s) => + DecimalType.bounded(p + 4, s + 4) + case _ => DoubleType + } + + private lazy val sumDataType = child.dataType match { + case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) + case _ => DoubleType + } + + private lazy val currentCount = AttributeReference("currentCount", LongType)() + private lazy val currentAvg = AttributeReference("currentAverage", sumDataType)() + private lazy val currentMk = AttributeReference("currentMoment", sumDataType)() + + // the values should be updated in a special order, because they re-use each other + override lazy val bufferAttributes = currentCount :: currentAvg :: currentMk :: Nil + + override lazy val initialValues = Seq( + /* currentCount = */ Literal(0L), + /* currentAvg = */ Cast(Literal(0), sumDataType), + /* currentMk = */ Cast(Literal(0), sumDataType) + ) + + override lazy val updateExpressions = { + val currentValue = Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil) + val deltaX = Subtract(currentValue, currentAvg) + val updatedCount = If(IsNull(child), currentCount, currentCount + 1L) + val updatedAvg = Add(currentAvg, Divide(deltaX, updatedCount)) + Seq( + /* currentCount = */ updatedCount, + /* currentAvg = */ If(IsNull(child), currentAvg, updatedAvg), + /* currentMk = */ If(IsNull(child), + currentMk, Add(currentMk, deltaX * Subtract(currentValue, updatedAvg))) + ) + } + + override lazy val mergeExpressions = { + val totalCount = currentCount.left + currentCount.right + val deltaX = currentAvg.left - currentAvg.right + val deltaX2 = deltaX * deltaX + val sumMoments = currentMk.left + currentMk.right + val sumLeft = currentAvg.left * currentCount.left + val sumRight = currentAvg.right * currentCount.right + Seq( + /* currentCount = */ totalCount, + /* currentAvg = */ If(EqualTo(totalCount, Cast(Literal(0L), LongType)), + Cast(Literal(0), sumDataType), (sumLeft + sumRight) / totalCount), + /* currentMk = */ If(EqualTo(totalCount, Cast(Literal(0L), LongType)), + Cast(Literal(0), sumDataType), + sumMoments + deltaX2 * currentCount.left / totalCount * currentCount.right) + ) + } + + override lazy val evaluateExpression = { + val count = If(EqualTo(currentCount, Cast(Literal(0L), LongType)), + currentCount, currentCount - Cast(Literal(1L), LongType)) + child.dataType match { + case DecimalType.Fixed(p, s) => + // increase the precision and scale to prevent precision loss + val dt = DecimalType.bounded(p + 14, s + 4) + Cast(Sqrt(Cast(currentMk, dt) / Cast(count, dt)), resultType) + case _ => + Sqrt(Cast(currentMk, resultType) / Cast(count, resultType)) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 4a43318a9549..41cfdd2d5e61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -164,4 +164,12 @@ object Utils { } case other => None } + + def standardDeviation(e: Expression): Expression = { + val std = aggregate.AggregateExpression2( + aggregateFunction = aggregate.StandardDeviation(e), + mode = aggregate.Complete, + isDistinct = false) + Alias(std, s"std(${e.prettyString})")() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 405b5a4a9a7f..2c1151996872 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.StandardDeviation import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} @@ -1268,15 +1269,11 @@ class DataFrame private[sql]( @scala.annotation.varargs def describe(cols: String*): DataFrame = { - // TODO: Add stddev as an expression, and remove it from here. - def stddevExpr(expr: Expression): Expression = - Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr)))) - // The list of summary statistics to compute, in the form of expressions. val statistics = List[(String, Expression => Expression)]( "count" -> Count, "mean" -> Average, - "stddev" -> stddevExpr, + "stddev" -> aggregate.Utils.standardDeviation, "min" -> Min, "max" -> Max) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 99d557b03a03..0dfadcf63cb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -23,6 +23,7 @@ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.StandardDeviation import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType @@ -283,6 +284,18 @@ class GroupedData protected[sql]( aggregateNumericColumns(colNames : _*)(Min) } + /** + * Compute the sample standard deviation for each numeric column for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the standard deviation for them. + * + * @since 1.5.0 + */ + @scala.annotation.varargs + def std(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(aggregate.Utils.standardDeviation) + } + /** * Compute the sum for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 67ebafde25ad..7b1941068a65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -87,6 +87,8 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer + private val dataTypes = allAggregateFunctions.flatMap(_.bufferAttributes).map(_.dataType) + /** Processes rows in the current group. It will stop when it find a new group. */ protected def processCurrentSortedGroup(): Unit = { currentGroupingKey = nextGroupingKey @@ -95,6 +97,7 @@ class SortBasedAggregationIterator( var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + println(dataTypes.zipWithIndex.map(d => sortBasedAggregationBuffer.get(d._2, d._1)).mkString("[", ",", "]")) // The search will stop when we see the next group or there is no // input row left in the iter. @@ -107,7 +110,9 @@ class SortBasedAggregationIterator( // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { processRow(sortBasedAggregationBuffer, currentRow) - + println("Second") + println(currentRow) + println(dataTypes.zipWithIndex.map(d => sortBasedAggregationBuffer.get(d._2, d._1)).mkString("[", ",", "]")) hasNext = inputKVIterator.next() } else { // We find a new group. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 79c5f596661d..bc2620fd90ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag} import scala.util.Try import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.expressions.aggregate.StandardDeviation import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ @@ -294,6 +295,22 @@ object functions { */ def min(columnName: String): Column = min(Column(columnName)) + /** + * Aggregate function: returns the sample standard deviation of the values in a group. + * + * @group agg_funcs + * @since 1.5.0 + */ + def std(e: Column): Column = aggregate.Utils.standardDeviation(e.expr) + + /** + * Aggregate function: returns the sample standard deviation of the values in a group. + * + * @group agg_funcs + * @since 1.5.0 + */ + def std(columnName: String): Column = std(Column(columnName)) + /** * Aggregate function: returns the sum of all values in the expression. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f9cc6d1f3c25..d3d84a283cc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -442,7 +442,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { val describeResult = Seq( Row("count", "4", "4"), Row("mean", "33.0", "178.0"), - Row("stddev", "16.583123951777", "10.0"), + Row("stddev", "19.148542155126762", "11.547005383792516"), Row("min", "16", "164"), Row("max", "60", "192")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 7b5aa4763fd9..d7d4321277e9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.sql._ +import org.apache.spark.sql.functions.std import org.scalatest.BeforeAndAfterAll import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} @@ -84,6 +85,32 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString) } + test("test standard deviation") { + val df = Seq.tabulate(10)(i => (i, 1)).toDF("val", "key") + checkAnswer( + df.select(std("val")), + Row(3.0276503540974917) :: Nil) + + checkAnswer( + sqlContext.table("agg1").groupBy("key").std("value"), + Row(1, 10.0) :: Row(2, 0.7071067811865476) :: Row(3, null) :: + Row(null, 81.8535277187245) :: Nil) + + checkAnswer( + sqlContext.table("agg1").select(std("key"), std("value")), + Row(0.7817359599705717, 44.898098909801135) :: Nil) + + checkAnswer( + sqlContext.table("agg2").groupBy("key", "value1").std("value2"), + Row(1, 10, null) :: Row(1, 30, 42.42640687119285) :: Row(2, -1, null) :: + Row(2, 1, 0.0) :: Row(2, null, null) :: Row(3, null, null) :: Row(null, -10, null) :: + Row(null, -60, null) :: Row(null, 100, null) :: Row(null, null, null) :: Nil) + + checkAnswer( + sqlContext.table("emptyTable").select(std("value")), + Row(null) :: Nil) + } + test("empty table") { // If there is no GROUP BY clause and the table is empty, we will generate a single row. checkAnswer( From 0a6d2c04c3b54c92330b54cfa75bc9ea014d6d69 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 10 Aug 2015 10:11:11 -0700 Subject: [PATCH 02/15] save changes --- .../spark/sql/catalyst/expressions/aggregate/functions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index bc14c1d1fd4c..37fe2f4a8576 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -350,7 +350,7 @@ case class StandardDeviation(child: Expression) extends AlgebraicAggregate { val updatedAvg = Add(currentAvg, Divide(deltaX, updatedCount)) Seq( /* currentCount = */ updatedCount, - /* currentAvg = */ If(IsNull(child), currentAvg, updatedAvg), + /* currentAvg = */ currentAvg, /* currentMk = */ If(IsNull(child), currentMk, Add(currentMk, deltaX * Subtract(currentValue, updatedAvg))) ) From 3c117342b40dd1fdd9de98c6a5cec94f5accb376 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 10 Aug 2015 12:09:40 -0700 Subject: [PATCH 03/15] fixed test --- .../expressions/aggregate/functions.scala | 36 ++++++++---- .../expressions/aggregate/utils.scala | 5 ++ .../org/apache/spark/sql/GroupedData.scala | 1 - .../SortBasedAggregationIterator.scala | 6 -- .../org/apache/spark/sql/QueryTest.scala | 3 + .../execution/AggregationQuerySuite.scala | 55 ++++++++++--------- 6 files changed, 62 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 37fe2f4a8576..657a2c5d2c59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -331,14 +331,21 @@ case class StandardDeviation(child: Expression) extends AlgebraicAggregate { } private lazy val currentCount = AttributeReference("currentCount", LongType)() + private lazy val leftCount = AttributeReference("leftCount", LongType)() + private lazy val rightCount = AttributeReference("rightCount", LongType)() + private lazy val currentDelta = AttributeReference("currentDelta", sumDataType)() private lazy val currentAvg = AttributeReference("currentAverage", sumDataType)() private lazy val currentMk = AttributeReference("currentMoment", sumDataType)() // the values should be updated in a special order, because they re-use each other - override lazy val bufferAttributes = currentCount :: currentAvg :: currentMk :: Nil + override lazy val bufferAttributes = + leftCount :: rightCount :: currentCount :: currentDelta :: currentAvg :: currentMk :: Nil override lazy val initialValues = Seq( + /* leftCount = */ Literal(0L), + /* rightCount = */ Literal(0L), /* currentCount = */ Literal(0L), + /* currentDelta = */ Cast(Literal(0), sumDataType), /* currentAvg = */ Cast(Literal(0), sumDataType), /* currentMk = */ Cast(Literal(0), sumDataType) ) @@ -347,12 +354,15 @@ case class StandardDeviation(child: Expression) extends AlgebraicAggregate { val currentValue = Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil) val deltaX = Subtract(currentValue, currentAvg) val updatedCount = If(IsNull(child), currentCount, currentCount + 1L) - val updatedAvg = Add(currentAvg, Divide(deltaX, updatedCount)) + val updatedAvg = Add(currentAvg, Divide(currentDelta, currentCount)) Seq( + /* leftCount = */ leftCount, // used only during merging. dummy value + /* rightCount = */ rightCount, // used only during merging. dummy value /* currentCount = */ updatedCount, - /* currentAvg = */ currentAvg, + /* currentDelta = */ deltaX, + /* currentAvg = */ updatedAvg, /* currentMk = */ If(IsNull(child), - currentMk, Add(currentMk, deltaX * Subtract(currentValue, updatedAvg))) + currentMk, Add(currentMk, currentDelta * Subtract(currentValue, currentAvg))) ) } @@ -361,15 +371,19 @@ case class StandardDeviation(child: Expression) extends AlgebraicAggregate { val deltaX = currentAvg.left - currentAvg.right val deltaX2 = deltaX * deltaX val sumMoments = currentMk.left + currentMk.right - val sumLeft = currentAvg.left * currentCount.left - val sumRight = currentAvg.right * currentCount.right + val sumLeft = currentAvg.left * leftCount + val sumRight = currentAvg.right * rightCount + val mergedAvg = (sumLeft + sumRight) / currentCount + val mergedMk = sumMoments + currentDelta * leftCount / currentCount * rightCount Seq( + /* leftCount = */ currentCount.left, + /* rightCount = */ currentCount.right, /* currentCount = */ totalCount, - /* currentAvg = */ If(EqualTo(totalCount, Cast(Literal(0L), LongType)), - Cast(Literal(0), sumDataType), (sumLeft + sumRight) / totalCount), - /* currentMk = */ If(EqualTo(totalCount, Cast(Literal(0L), LongType)), - Cast(Literal(0), sumDataType), - sumMoments + deltaX2 * currentCount.left / totalCount * currentCount.right) + /* currentDelta = */ deltaX2, + /* currentAvg = */ If(EqualTo(leftCount, Cast(Literal(0L), LongType)), currentAvg.right, + If(EqualTo(rightCount, Cast(Literal(0L), LongType)), currentAvg.left, mergedAvg)), + /* currentMk = */ If(EqualTo(leftCount, Cast(Literal(0L), LongType)), currentMk.right, + If(EqualTo(rightCount, Cast(Literal(0L), LongType)), currentMk.left, mergedMk)) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 41cfdd2d5e61..2b2e57631d5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -165,6 +165,11 @@ object Utils { case other => None } + /** + * All Aggregate functions have previous versions as expressions in the old AggregateExpression + * format, but standard deviation uses the new format directly. We wrap it here in one place, + * and use an Alias so that the column name looks pretty as well instead of a long identifier. + */ def standardDeviation(e: Expression): Expression = { val std = aggregate.AggregateExpression2( aggregateFunction = aggregate.StandardDeviation(e), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 0dfadcf63cb2..42e77fdb32b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -23,7 +23,6 @@ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.StandardDeviation import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 7b1941068a65..a9d833222ba3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -87,8 +87,6 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer - private val dataTypes = allAggregateFunctions.flatMap(_.bufferAttributes).map(_.dataType) - /** Processes rows in the current group. It will stop when it find a new group. */ protected def processCurrentSortedGroup(): Unit = { currentGroupingKey = nextGroupingKey @@ -97,7 +95,6 @@ class SortBasedAggregationIterator( var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. processRow(sortBasedAggregationBuffer, firstRowInNextGroup) - println(dataTypes.zipWithIndex.map(d => sortBasedAggregationBuffer.get(d._2, d._1)).mkString("[", ",", "]")) // The search will stop when we see the next group or there is no // input row left in the iter. @@ -110,9 +107,6 @@ class SortBasedAggregationIterator( // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { processRow(sortBasedAggregationBuffer, currentRow) - println("Second") - println(currentRow) - println(dataTypes.zipWithIndex.map(d => sortBasedAggregationBuffer.get(d._2, d._1)).mkString("[", ",", "]")) hasNext = inputKVIterator.next() } else { // We find a new group. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 98ba3c99283a..1dc462130874 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -21,6 +21,8 @@ import java.util.{Locale, TimeZone} import scala.collection.JavaConversions._ +import org.scalactic.TripleEqualsSupport.Spread + import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation @@ -114,6 +116,7 @@ object QueryTest { Row.fromSeq(s.toSeq.map { case d: java.math.BigDecimal => BigDecimal(d) case b: Array[Byte] => b.toSeq + case d: Double => BigDecimal(d).setScale(10, BigDecimal.RoundingMode.HALF_UP) case o => o }) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index d7d4321277e9..acb48507a9db 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import org.scalactic.TripleEqualsSupport.Spread + import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils @@ -85,32 +87,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString) } - test("test standard deviation") { - val df = Seq.tabulate(10)(i => (i, 1)).toDF("val", "key") - checkAnswer( - df.select(std("val")), - Row(3.0276503540974917) :: Nil) - - checkAnswer( - sqlContext.table("agg1").groupBy("key").std("value"), - Row(1, 10.0) :: Row(2, 0.7071067811865476) :: Row(3, null) :: - Row(null, 81.8535277187245) :: Nil) - - checkAnswer( - sqlContext.table("agg1").select(std("key"), std("value")), - Row(0.7817359599705717, 44.898098909801135) :: Nil) - - checkAnswer( - sqlContext.table("agg2").groupBy("key", "value1").std("value2"), - Row(1, 10, null) :: Row(1, 30, 42.42640687119285) :: Row(2, -1, null) :: - Row(2, 1, 0.0) :: Row(2, null, null) :: Row(3, null, null) :: Row(null, -10, null) :: - Row(null, -60, null) :: Row(null, 100, null) :: Row(null, null, null) :: Nil) - - checkAnswer( - sqlContext.table("emptyTable").select(std("value")), - Row(null) :: Nil) - } - test("empty table") { // If there is no GROUP BY clause and the table is empty, we will generate a single row. checkAnswer( @@ -311,6 +287,33 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Row(11.125) :: Nil) } + test("test standard deviation") { + // All results generated in R. Comparisons will be performed up to 10 digits of precision. + val df = Seq.tabulate(10)(i => (i, 1)).toDF("val", "key") + checkAnswer( + df.select(std("val")), + Row(3.0276503540974917) :: Nil) + + checkAnswer( + sqlContext.table("agg1").groupBy("key").std("value"), + Row(1, 10.0) :: Row(2, 0.7071067811865476) :: Row(3, null) :: + Row(null, 81.8535277187245) :: Nil) + + checkAnswer( + sqlContext.table("agg1").select(std("key"), std("value")), + Row(0.7817359599705717, 44.898098909801135) :: Nil) + + checkAnswer( + sqlContext.table("agg2").groupBy("key", "value1").std("value2"), + Row(1, 10, null) :: Row(1, 30, 42.42640687119285) :: Row(2, -1, null) :: + Row(2, 1, 0.0) :: Row(2, null, null) :: Row(3, null, null) :: Row(null, -10, null) :: + Row(null, -60, null) :: Row(null, 100, null) :: Row(null, null, null) :: Nil) + + checkAnswer( + sqlContext.table("emptyTable").select(std("value")), + Row(null) :: Nil) + } + test("udaf") { checkAnswer( sqlContext.sql( From 1175ace107a13d0efdc928f0d45eb2bd89de9c48 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 10 Aug 2015 12:10:47 -0700 Subject: [PATCH 04/15] remove unnecessary import --- .../apache/spark/sql/hive/execution/AggregationQuerySuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index acb48507a9db..7f4d11de235e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.hive.execution -import org.scalactic.TripleEqualsSupport.Spread - import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils From 941bb9e3564dc6709b8f3bf5d4de79e95ca5982d Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 10 Aug 2015 12:34:57 -0700 Subject: [PATCH 05/15] remove nan and infinity for checkAnswer --- sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 1dc462130874..48d76986664b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -116,7 +116,8 @@ object QueryTest { Row.fromSeq(s.toSeq.map { case d: java.math.BigDecimal => BigDecimal(d) case b: Array[Byte] => b.toSeq - case d: Double => BigDecimal(d).setScale(10, BigDecimal.RoundingMode.HALF_UP) + case d: Double if !d.isNaN && !d.isInfinity => + BigDecimal(d).setScale(10, BigDecimal.RoundingMode.HALF_UP) case o => o }) } From 34b22e8ef94a56e6ee06488a6d9d24330eb69a42 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 10 Aug 2015 14:27:39 -0700 Subject: [PATCH 06/15] addressed comments --- python/pyspark/sql/functions.py | 5 +- python/pyspark/sql/group.py | 42 ++++++++++++++ .../expressions/aggregate/functions.scala | 15 +++-- .../expressions/aggregate/utils.scala | 8 ++- .../org/apache/spark/sql/DataFrame.scala | 2 +- .../org/apache/spark/sql/GroupedData.scala | 29 +++++++++- .../org/apache/spark/sql/functions.scala | 39 ++++++++++++- .../org/apache/spark/sql/QueryTest.scala | 2 - .../execution/AggregationQuerySuite.scala | 57 +++++++++++++++---- 9 files changed, 171 insertions(+), 28 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 95f46044d324..ab9e64fe9bdd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -87,7 +87,10 @@ def _(): 'sum': 'Aggregate function: returns the sum of all values in the expression.', 'avg': 'Aggregate function: returns the average of the values in a group.', 'mean': 'Aggregate function: returns the average of the values in a group.', - 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.', + 'stddev': 'Aggregate function: returns the sample standard deviation in a group.', + 'stddevSamp': 'Aggregate function: returns the sample standard deviation in a group.', + 'stddevPop': 'Aggregate function: returns the population standard deviation in a group.', + 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.' } _functions_1_4 = { diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 04594d5a836c..91c377bc962b 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -154,6 +154,48 @@ def min(self, *cols): [Row(min(age)=2, min(height)=80)] """ + @df_varargs_api + @since(1.5) + def stddev(self, *cols): + """Computes the sample standard deviation for each numeric column for each group. + Alias for stddevSamp. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().stddev('age').collect() + [Row(stddev_samp(age)=2.12...)] + >>> df3.groupBy().stddev('age', 'height').collect() + [Row(stddev_samp(age)=2.12..., stddev_samp(height)=3.53...)] + """ + + @df_varargs_api + @since(1.5) + def stddevPop(self, *cols): + """Computes the sample standard deviation for each numeric column for each group. + Alias for stddevSamp. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().stddevPop('age').collect() + [Row(min(age)=1.06...)] + >>> df3.groupBy().stddevPop('age', 'height').collect() + [Row(min(age)=1.06..., min(height)=1.76...)] + """ + + @df_varargs_api + @since(1.5) + def stddevSamp(self, *cols): + """Computes the sample standard deviation for each numeric column for each group. + Alias for stddevSamp. + + :param cols: list of column names (string). Non-numeric columns are ignored. + + >>> df.groupBy().stddevSamp('age').collect() + [Row(stddev_samp(age)=2.12...)] + >>> df3.groupBy().stddevSamp('age', 'height').collect() + [Row(stddev_samp(age)=2.12..., stddev_samp(height)=3.53...)] + """ + @df_varargs_api @since(1.3) def sum(self, *cols): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 657a2c5d2c59..7ada3d043da4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -304,10 +304,11 @@ case class Sum(child: Expression) extends AlgebraicAggregate { } /** - * Calculates the unbiased Standard Deviation using the online formula here: + * Calculates the Standard Deviation using the online formula here: * https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + * If sample is true, then we will return the unbiased standard deviation. */ -case class StandardDeviation(child: Expression) extends AlgebraicAggregate { +case class StandardDeviation(child: Expression, sample: Boolean) extends AlgebraicAggregate { override def children: Seq[Expression] = child :: Nil @@ -388,8 +389,14 @@ case class StandardDeviation(child: Expression) extends AlgebraicAggregate { } override lazy val evaluateExpression = { - val count = If(EqualTo(currentCount, Cast(Literal(0L), LongType)), - currentCount, currentCount - Cast(Literal(1L), LongType)) + val count = + if (sample) { + If(EqualTo(currentCount, Cast(Literal(0L), LongType)), currentCount, + currentCount - Cast(Literal(1L), LongType)) + } else { + currentCount + } + child.dataType match { case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index 2b2e57631d5a..aefe736bd122 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -170,11 +170,13 @@ object Utils { * format, but standard deviation uses the new format directly. We wrap it here in one place, * and use an Alias so that the column name looks pretty as well instead of a long identifier. */ - def standardDeviation(e: Expression): Expression = { + def standardDeviation(e: Expression, sample: Boolean, name: String): Expression = { val std = aggregate.AggregateExpression2( - aggregateFunction = aggregate.StandardDeviation(e), + aggregateFunction = aggregate.StandardDeviation(e, sample), mode = aggregate.Complete, isDistinct = false) - Alias(std, s"std(${e.prettyString})")() + Alias(std, s"$name(${e.prettyString})")() } + + def sampleStandardDeviation(e: Expression): Expression = standardDeviation(e, true, "stddev_samp") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index e04ffafb9279..fb6d05288e48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1273,7 +1273,7 @@ class DataFrame private[sql]( val statistics = List[(String, Expression => Expression)]( "count" -> Count, "mean" -> Average, - "stddev" -> aggregate.Utils.standardDeviation, + "stddev" -> aggregate.Utils.sampleStandardDeviation, "min" -> Min, "max" -> Max) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 42e77fdb32b4..bc0f3b514c07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -291,8 +291,33 @@ class GroupedData protected[sql]( * @since 1.5.0 */ @scala.annotation.varargs - def std(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(aggregate.Utils.standardDeviation) + def stddev(colNames: String*): DataFrame = { + stddevSamp(colNames : _*) + } + + /** + * Compute the population standard deviation for each numeric column for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the standard deviation for them. + * + * @since 1.5.0 + */ + @scala.annotation.varargs + def stddevPop(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(aggregate.Utils.standardDeviation(_, sample = false, + "stddev_pop")) + } + + /** + * Compute the sample standard deviation for each numeric column for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the standard deviation for them. + * + * @since 1.5.0 + */ + @scala.annotation.varargs + def stddevSamp(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames : _*)(aggregate.Utils.sampleStandardDeviation) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index bc2620fd90ad..848c8bd6586b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -22,7 +22,6 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag} import scala.util.Try import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.expressions.aggregate.StandardDeviation import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ @@ -297,19 +296,53 @@ object functions { /** * Aggregate function: returns the sample standard deviation of the values in a group. + * Alias for stddevSamp. * * @group agg_funcs * @since 1.5.0 */ - def std(e: Column): Column = aggregate.Utils.standardDeviation(e.expr) + def stddev(e: Column): Column = stddevSamp(e) /** * Aggregate function: returns the sample standard deviation of the values in a group. + * Alias for stddevSamp. * * @group agg_funcs * @since 1.5.0 */ - def std(columnName: String): Column = std(Column(columnName)) + def stddev(columnName: String): Column = stddev(Column(columnName)) + + /** + * Aggregate function: returns the population standard deviation of the values in a group. + * + * @group agg_funcs + * @since 1.5.0 + */ + def stddevPop(e: Column): Column = aggregate.Utils.standardDeviation(e.expr, false, "stddev_pomp") + + /** + * Aggregate function: returns the population standard deviation of the values in a group. + * + * @group agg_funcs + * @since 1.5.0 + */ + def stddevPop(columnName: String): Column = stddevPop(Column(columnName)) + + /** + * Aggregate function: returns the sample standard deviation of the values in a group. + * + * @group agg_funcs + * @since 1.5.0 + */ + def stddevSamp(e: Column): Column = aggregate.Utils.sampleStandardDeviation(e.expr) + + /** + * Aggregate function: returns the sample standard deviation of the values in a group. + * + * @group agg_funcs + * @since 1.5.0 + */ + def stddevSamp(columnName: String): Column = stddev(Column(columnName)) /** * Aggregate function: returns the sum of all values in the expression. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 48d76986664b..ba37947952b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -116,8 +116,6 @@ object QueryTest { Row.fromSeq(s.toSeq.map { case d: java.math.BigDecimal => BigDecimal(d) case b: Array[Byte] => b.toSeq - case d: Double if !d.isNaN && !d.isInfinity => - BigDecimal(d).setScale(10, BigDecimal.RoundingMode.HALF_UP) case o => o }) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 7f4d11de235e..34fb60266ce3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.sql._ -import org.apache.spark.sql.functions.std +import org.apache.spark.sql.functions.{stddev, stddevPop} import org.scalatest.BeforeAndAfterAll import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} @@ -285,30 +285,63 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be Row(11.125) :: Nil) } + /** For resilience against rounding mismatches. */ + private def about(d: Double): BigDecimal = BigDecimal(d).setScale(10, BigDecimal.RoundingMode.HALF_UP) + test("test standard deviation") { // All results generated in R. Comparisons will be performed up to 10 digits of precision. val df = Seq.tabulate(10)(i => (i, 1)).toDF("val", "key") checkAnswer( - df.select(std("val")), - Row(3.0276503540974917) :: Nil) + df.select(stddev("val").cast("decimal(12, 10)")), + Row(about(3.0276503540974917)) :: Nil) + + checkAnswer( + df.select(stddevPop("val").cast("decimal(12, 10)")), + Row(about(2.8722813232690148)) :: Nil) + + checkAnswer( + sqlContext.table("agg1").groupBy("key").stddev("value") + .select($"key", $"stddev_samp(value)".cast("decimal(12, 10)")), + Row(1, about(10.0)) :: Row(2, about(0.7071067811865476)) :: Row(3, null) :: + Row(null, about(81.8535277187245)) :: Nil) + + checkAnswer( + sqlContext.table("agg1").groupBy("key").stddevPop("value") + .select($"key", $"stddev_pop(value)".cast("decimal(12, 10)")), + Row(1, about(8.16496580927726)) :: Row(2, about(0.5)) :: Row(3, null) :: + Row(null, about(66.83312551921139)) :: Nil) checkAnswer( - sqlContext.table("agg1").groupBy("key").std("value"), - Row(1, 10.0) :: Row(2, 0.7071067811865476) :: Row(3, null) :: - Row(null, 81.8535277187245) :: Nil) + sqlContext.table("agg1").select(stddev("key").cast("decimal(12, 10)"), + stddev("value").cast("decimal(12, 10)")), + Row(about(0.7817359599705717), about(44.898098909801135)) :: Nil) checkAnswer( - sqlContext.table("agg1").select(std("key"), std("value")), - Row(0.7817359599705717, 44.898098909801135) :: Nil) + sqlContext.table("agg1").select(stddevPop("key").cast("decimal(12, 10)"), + stddevPop("value").cast("decimal(12, 10)")), + Row(about(0.7370277311900889), about(41.99832585949111)) :: Nil) checkAnswer( - sqlContext.table("agg2").groupBy("key", "value1").std("value2"), - Row(1, 10, null) :: Row(1, 30, 42.42640687119285) :: Row(2, -1, null) :: - Row(2, 1, 0.0) :: Row(2, null, null) :: Row(3, null, null) :: Row(null, -10, null) :: + sqlContext.table("agg2").groupBy("key", "value1").stddev("value2") + .select($"key", $"value1", $"stddev_samp(value2)".cast("decimal(12, 10)")), + Row(1, 10, null) :: Row(1, 30, about(42.42640687119285)) :: Row(2, -1, null) :: + Row(2, 1, about(0.0)) :: Row(2, null, null) :: Row(3, null, null) :: Row(null, -10, null) :: Row(null, -60, null) :: Row(null, 100, null) :: Row(null, null, null) :: Nil) checkAnswer( - sqlContext.table("emptyTable").select(std("value")), + sqlContext.table("agg2").groupBy("key", "value1").stddevPop("value2") + .select($"key", $"value1", $"stddev_pop(value2)".cast("decimal(12, 10)")), + Row(1, 10, about(0.0)) :: Row(1, 30, about(30.0)) :: Row(2, -1, null) :: + Row(2, 1, about(0.0)) :: Row(2, null, about(0.0)) :: Row(3, null, about(0.0)) :: + Row(null, -10, about(0.0)) :: Row(null, -60, about(0.0)) :: Row(null, 100, about(0.0)) :: + Row(null, null, null) :: Nil) + + checkAnswer( + sqlContext.table("emptyTable").select(stddev("value")), + Row(null) :: Nil) + + checkAnswer( + sqlContext.table("emptyTable").select(stddevPop("value")), Row(null) :: Nil) } From 9e6ac9da988642ea74a026dd31b22a3cd7dd9c69 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 10 Aug 2015 20:14:49 -0700 Subject: [PATCH 07/15] First resolve stddev to Hive's UDAF and replace it to our native implementation based on AggregateFunction2 if possible. --- .../sql/catalyst/analysis/Analyzer.scala | 13 +++- .../expressions/aggregate/functions.scala | 8 +++ .../expressions/aggregate/utils.scala | 37 ++++++----- .../org/apache/spark/sql/DataFrame.scala | 2 +- .../org/apache/spark/sql/GroupedData.scala | 17 +++-- .../org/apache/spark/sql/functions.scala | 16 +++-- .../org/apache/spark/sql/QueryTest.scala | 2 - .../execution/AggregationQuerySuite.scala | 62 ++++++++++++++++--- 8 files changed, 122 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a684dbc3afa4..2d5be382f9a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -524,7 +524,16 @@ class Analyzer( q transformExpressions { case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { - registry.lookupFunction(name, children) match { + // TODO: This is a hack. Hive uses stddev and std as aliases of stddev_pop, which + // is different from other widely used systems (these systems use these two function + // names as aliases of stddev_samp). So, we explicitly rename it to stddev_samp. + // Once we remove AggregateExpression1, we can remove this hack. + val rewrittenName = name match { + case "stddev" | "std" => "stddev_samp" + case other => other + } + + registry.lookupFunction(rewrittenName, children) match { // We get an aggregate function built based on AggregateFunction2 interface. // So, we wrap it in AggregateExpression2. case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct) @@ -538,7 +547,7 @@ class Analyzer( // For other aggregate functions, DISTINCT keyword is not supported for now. // Once we converted to the new code path, we will allow using DISTINCT keyword. case other: AggregateExpression1 if isDistinct => - failAnalysis(s"$name does not support DISTINCT keyword.") + failAnalysis(s"$rewrittenName does not support DISTINCT keyword.") // If it does not have DISTINCT keyword, we will return it as is. case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 7ada3d043da4..8302994cd5cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -406,4 +406,12 @@ case class StandardDeviation(child: Expression, sample: Boolean) extends Algebra Sqrt(Cast(currentMk, resultType) / Cast(count, resultType)) } } + + override def prettyName: String = { + if (sample) { + "stddev_samp" + } else { + "stddev_pop" + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala index aefe736bd122..ec6cec8d7f94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala @@ -96,6 +96,28 @@ object Utils { aggregateFunction = aggregate.Sum(child), mode = aggregate.Complete, isDistinct = true) + + case hiveUDAF: AggregateExpression1 + if hiveUDAF.getClass.getSimpleName == "HiveGenericUDAF" && + hiveUDAF.toString.contains("GenericUDAFStdSample") => + // We get a STDDEV_SAMP, which is originally resolved as a HiveGenericUDAF. + require(hiveUDAF.children.length == 1, "stddev_samp requires exactly one argument.") + val child = hiveUDAF.children.head + aggregate.AggregateExpression2( + aggregateFunction = aggregate.StandardDeviation(child, sample = true), + mode = aggregate.Complete, + isDistinct = false) + + case hiveUDAF: AggregateExpression1 + if hiveUDAF.getClass.getSimpleName == "HiveGenericUDAF" && + hiveUDAF.toString.contains("GenericUDAFStd") => + // We get a STDDEV_POP, which is originally resolved as a HiveGenericUDAF. + require(hiveUDAF.children.length == 1, "stddev_pop requires exactly one argument.") + val child = hiveUDAF.children.head + aggregate.AggregateExpression2( + aggregateFunction = aggregate.StandardDeviation(child, sample = false), + mode = aggregate.Complete, + isDistinct = false) } // Check if there is any expressions.AggregateExpression1 left. // If so, we cannot convert this plan. @@ -164,19 +186,4 @@ object Utils { } case other => None } - - /** - * All Aggregate functions have previous versions as expressions in the old AggregateExpression - * format, but standard deviation uses the new format directly. We wrap it here in one place, - * and use an Alias so that the column name looks pretty as well instead of a long identifier. - */ - def standardDeviation(e: Expression, sample: Boolean, name: String): Expression = { - val std = aggregate.AggregateExpression2( - aggregateFunction = aggregate.StandardDeviation(e, sample), - mode = aggregate.Complete, - isDistinct = false) - Alias(std, s"$name(${e.prettyString})")() - } - - def sampleStandardDeviation(e: Expression): Expression = standardDeviation(e, true, "stddev_samp") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index fb6d05288e48..274d5f25c349 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1273,7 +1273,7 @@ class DataFrame private[sql]( val statistics = List[(String, Expression => Expression)]( "count" -> Count, "mean" -> Average, - "stddev" -> aggregate.Utils.sampleStandardDeviation, + "stddev" -> ((e: Expression) => UnresolvedFunction("stddev_samp", e :: Nil, false)), "min" -> Min, "max" -> Max) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index bc0f3b514c07..b601a5522e21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConversions._ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType @@ -304,8 +304,12 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def stddevPop(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(aggregate.Utils.standardDeviation(_, sample = false, - "stddev_pop")) + val builder = (e: Expression) => { + Alias( + UnresolvedFunction("stddev_pop", e :: Nil, false), + s"stddev_pop(${e.prettyString})")() + } + aggregateNumericColumns(colNames : _*)(builder) } /** @@ -317,7 +321,12 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def stddevSamp(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames : _*)(aggregate.Utils.sampleStandardDeviation) + val builder = (e: Expression) => { + Alias( + UnresolvedFunction("stddev_samp", e :: Nil, false), + s"stddev_samp(${e.prettyString})")() + } + aggregateNumericColumns(colNames : _*)(builder) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 848c8bd6586b..e4bbf324a61c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -310,7 +310,7 @@ object functions { * @group agg_funcs * @since 1.5.0 */ - def stddev(columnName: String): Column = stddev(Column(columnName)) + def stddev(columnName: String): Column = stddevSamp(Column(columnName)) /** * Aggregate function: returns the population standard deviation of the values in a group. @@ -318,7 +318,11 @@ object functions { * @group agg_funcs * @since 1.5.0 */ - def stddevPop(e: Column): Column = aggregate.Utils.standardDeviation(e.expr, false, "stddev_pomp") + def stddevPop(e: Column): Column = { + Alias( + UnresolvedFunction("stddev_pop", e.expr :: Nil, false), + s"stddev_pop(${e.expr.prettyString})")() + } /** * Aggregate function: returns the population standard deviation of the values in a group. @@ -334,7 +338,11 @@ object functions { * @group agg_funcs * @since 1.5.0 */ - def stddevSamp(e: Column): Column = aggregate.Utils.sampleStandardDeviation(e.expr) + def stddevSamp(e: Column): Column = { + Alias( + UnresolvedFunction("stddev_samp", e.expr :: Nil, false), + s"stddev_samp(${e.expr.prettyString})")() + } /** * Aggregate function: returns the sample standard deviation of the values in a group. @@ -342,7 +350,7 @@ object functions { * @group agg_funcs * @since 1.5.0 */ - def stddevSamp(columnName: String): Column = stddev(Column(columnName)) + def stddevSamp(columnName: String): Column = stddevSamp(Column(columnName)) /** * Aggregate function: returns the sum of all values in the expression. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index ba37947952b7..98ba3c99283a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -21,8 +21,6 @@ import java.util.{Locale, TimeZone} import scala.collection.JavaConversions._ -import org.scalactic.TripleEqualsSupport.Spread - import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 34fb60266ce3..8a2c5cd0094d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -299,11 +299,42 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be df.select(stddevPop("val").cast("decimal(12, 10)")), Row(about(2.8722813232690148)) :: Nil) - checkAnswer( - sqlContext.table("agg1").groupBy("key").stddev("value") - .select($"key", $"stddev_samp(value)".cast("decimal(12, 10)")), - Row(1, about(10.0)) :: Row(2, about(0.7071067811865476)) :: Row(3, null) :: - Row(null, about(81.8535277187245)) :: Nil) + // Make sure we can use stddev functions in SQL. + { + val expectedGroup1 = + Row(1, about(10.0), about(10.0), about(10.0), about(8.16496580927726)) + val expectedGroup2 = + Row( + 2, + about(0.7071067811865476), + about(0.7071067811865476), + about(0.7071067811865476), + about(0.5)) + val expectedGroup3 = Row(3, null, null, null, null) + val expectedGroupNull = + Row( + null, + about(81.8535277187245), + about(81.8535277187245), + about(81.8535277187245), + about(66.83312551921139)) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | cast(std(value) as decimal(12, 10)), + | cast(stddev(value) as decimal(12, 10)), + | cast(stddev_samp(value) as decimal(12, 10)), + | cast(stddev_pop(value) as decimal(12, 10)) + |FROM agg1 GROUP BY key + """.stripMargin), + expectedGroup1 :: + expectedGroup2 :: + expectedGroup3 :: + expectedGroupNull :: Nil) + } checkAnswer( sqlContext.table("agg1").groupBy("key").stddevPop("value") @@ -343,6 +374,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be checkAnswer( sqlContext.table("emptyTable").select(stddevPop("value")), Row(null) :: Nil) + + // stddev_samp returns null when there is a single input. + // While, stddev_pop returns 0.0 when there is a single input. + checkAnswer( + sqlContext.sql("SELECT stddev_samp(1), stddev_pop(1)"), + Row(null, 0.0) :: Nil + ) + + // TODO: Because we will first resolve stddev to Hive's GenericUDAF and it will + // complain about stddev_samp(null) and stddev_pop(null). So, we comment out this + // test. Once we remove AggregateExpression1, we will resolve them directly to + // out native implementation. We should re-enable this test at that time. + /* + checkAnswer( + sqlContext.sql("SELECT stddev_samp(null), stddev_pop(null)"), + Row(null, null) :: Nil + )*/ } test("udaf") { @@ -566,7 +614,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be |SELECT | key, | mydoublesum(value + 1.5 * key), - | stddev_samp(value) + | variance(value) |FROM agg1 |GROUP BY key """.stripMargin).collect() @@ -579,7 +627,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be |SELECT | key, | sum(value + 1.5 * key), - | stddev_samp(value) + | variance(value) |FROM agg1 |GROUP BY key """.stripMargin).queryExecution.executedPlan.collect { From 9f24d5e3748f2cd12cb44c28f9e27e4338d03fe3 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 10 Aug 2015 20:27:57 -0700 Subject: [PATCH 08/15] Also update the simpleString of aggregate operators. --- .../spark/sql/execution/aggregate/SortBasedAggregate.scala | 6 +++++- .../spark/sql/execution/aggregate/TungstenAggregate.scala | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index ad428ad663f3..08e69d566796 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -98,6 +98,10 @@ case class SortBasedAggregate( override def simpleString: String = { val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions - s"""SortBasedAggregate ${groupingExpressions} ${allAggregateExpressions}""" + + val keyString = groupingExpressions.mkString("[", ",", "]") + val valueString = allAggregateExpressions.mkString("[", ",", "]") + val outputString = output.mkString("[", ",", "]") + s"SortBasedAggregate(key=$keyString, functions=$valueString, output=$outputString)" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 1694794a53d9..fc9a4b739414 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -96,10 +96,11 @@ case class TungstenAggregate( case None => val keyString = groupingExpressions.mkString("[", ",", "]") val valueString = allAggregateExpressions.mkString("[", ",", "]") - s"TungstenAggregate(key=$keyString, value=$valueString" + val outputString = output.mkString("[", ",", "]") + s"TungstenAggregate(key=$keyString, functions=$valueString, output=$outputString)" case Some(fallbackStartsAt) => s"TungstenAggregateWithControlledFallback $groupingExpressions " + - s"$allAggregateExpressions fallbackStartsAt=$fallbackStartsAt" + s"$allAggregateExpressions $resultExpressions fallbackStartsAt=$fallbackStartsAt" } } } From 91dc106ee9d6912fa49cc91635329438751b8cc7 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 10 Aug 2015 22:11:54 -0700 Subject: [PATCH 09/15] delete spaces --- .../spark/sql/hive/execution/AggregationQuerySuite.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 8a2c5cd0094d..1e804b295b60 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -286,7 +286,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be } /** For resilience against rounding mismatches. */ - private def about(d: Double): BigDecimal = BigDecimal(d).setScale(10, BigDecimal.RoundingMode.HALF_UP) + private def about(d: Double): BigDecimal = + BigDecimal(d).setScale(10, BigDecimal.RoundingMode.HALF_UP) test("test standard deviation") { // All results generated in R. Comparisons will be performed up to 10 digits of precision. @@ -364,7 +365,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be .select($"key", $"value1", $"stddev_pop(value2)".cast("decimal(12, 10)")), Row(1, 10, about(0.0)) :: Row(1, 30, about(30.0)) :: Row(2, -1, null) :: Row(2, 1, about(0.0)) :: Row(2, null, about(0.0)) :: Row(3, null, about(0.0)) :: - Row(null, -10, about(0.0)) :: Row(null, -60, about(0.0)) :: Row(null, 100, about(0.0)) :: + Row(null, -10, about(0.0)) :: Row(null, -60, about(0.0)) :: Row(null, 100, about(0.0)) :: Row(null, null, null) :: Nil) checkAnswer( @@ -385,7 +386,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be // TODO: Because we will first resolve stddev to Hive's GenericUDAF and it will // complain about stddev_samp(null) and stddev_pop(null). So, we comment out this // test. Once we remove AggregateExpression1, we will resolve them directly to - // out native implementation. We should re-enable this test at that time. + // our native implementation. We should re-enable this test at that time. /* checkAnswer( sqlContext.sql("SELECT stddev_samp(null), stddev_pop(null)"), From 89946055b522190a2fcb38f19c918f93a83a8f19 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 10 Aug 2015 22:13:58 -0700 Subject: [PATCH 10/15] change to defs --- .../src/main/scala/org/apache/spark/sql/GroupedData.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index b601a5522e21..85eab4dc0455 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -304,7 +304,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def stddevPop(colNames: String*): DataFrame = { - val builder = (e: Expression) => { + def builder(e: Expression): Expression = { Alias( UnresolvedFunction("stddev_pop", e :: Nil, false), s"stddev_pop(${e.prettyString})")() @@ -321,11 +321,11 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def stddevSamp(colNames: String*): DataFrame = { - val builder = (e: Expression) => { + def builder(e: Expression): Expression = { Alias( UnresolvedFunction("stddev_samp", e :: Nil, false), s"stddev_samp(${e.prettyString})")() - } + } aggregateNumericColumns(colNames : _*)(builder) } From 4a83f75e29514e2fcaa424d8891342d0f012a1c7 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 10 Aug 2015 23:28:58 -0700 Subject: [PATCH 11/15] tried to fix scalastyle --- .../apache/spark/sql/hive/execution/AggregationQuerySuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 1e804b295b60..aab5060a32a0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -387,6 +387,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be // complain about stddev_samp(null) and stddev_pop(null). So, we comment out this // test. Once we remove AggregateExpression1, we will resolve them directly to // our native implementation. We should re-enable this test at that time. + /* checkAnswer( sqlContext.sql("SELECT stddev_samp(null), stddev_pop(null)"), From 3e8c46296e538a2b3471223d8e2c6e812f52df81 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 10 Aug 2015 23:42:10 -0700 Subject: [PATCH 12/15] added space --- .../apache/spark/sql/hive/execution/AggregationQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index aab5060a32a0..a97ef6fdb903 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -388,7 +388,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be // test. Once we remove AggregateExpression1, we will resolve them directly to // our native implementation. We should re-enable this test at that time. - /* + /* checkAnswer( sqlContext.sql("SELECT stddev_samp(null), stddev_pop(null)"), Row(null, null) :: Nil From 48fa619350d7f1e9f107126b960b289df30b042d Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 11 Aug 2015 09:27:40 -0700 Subject: [PATCH 13/15] locally scalastyle passes --- .../spark/sql/hive/execution/AggregationQuerySuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index a97ef6fdb903..42cf882bb159 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -387,12 +387,12 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be // complain about stddev_samp(null) and stddev_pop(null). So, we comment out this // test. Once we remove AggregateExpression1, we will resolve them directly to // our native implementation. We should re-enable this test at that time. - - /* + /* checkAnswer( sqlContext.sql("SELECT stddev_samp(null), stddev_pop(null)"), Row(null, null) :: Nil - )*/ + ) + */ } test("udaf") { From f221d2a99199d33438d3e5a2e3b8ac86f530dc10 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 11 Aug 2015 11:03:37 -0700 Subject: [PATCH 14/15] Make describe work in SQLContext. --- .../sql/catalyst/analysis/Analyzer.scala | 75 +++++++++++++------ .../execution/AggregationQuerySuite.scala | 6 +- 2 files changed, 54 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2d5be382f9a3..bc040f7d2cd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} +import org.apache.spark.sql.catalyst.expressions.aggregate.{StandardDeviation, Complete, AggregateExpression2, AggregateFunction2} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -524,32 +524,59 @@ class Analyzer( q transformExpressions { case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { + // TODO: This is a hack. Hive uses stddev and std as aliases of stddev_pop, which // is different from other widely used systems (these systems use these two function // names as aliases of stddev_samp). So, we explicitly rename it to stddev_samp. - // Once we remove AggregateExpression1, we can remove this hack. - val rewrittenName = name match { - case "stddev" | "std" => "stddev_samp" - case other => other - } - - registry.lookupFunction(rewrittenName, children) match { - // We get an aggregate function built based on AggregateFunction2 interface. - // So, we wrap it in AggregateExpression2. - case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct) - // Currently, our old aggregate function interface supports SUM(DISTINCT ...) - // and COUTN(DISTINCT ...). - case sumDistinct: SumDistinct => sumDistinct - case countDistinct: CountDistinct => countDistinct - // DISTINCT is not meaningful with Max and Min. - case max: Max if isDistinct => max - case min: Min if isDistinct => min - // For other aggregate functions, DISTINCT keyword is not supported for now. - // Once we converted to the new code path, we will allow using DISTINCT keyword. - case other: AggregateExpression1 if isDistinct => - failAnalysis(s"$rewrittenName does not support DISTINCT keyword.") - // If it does not have DISTINCT keyword, we will return it as is. - case other => other + // Once we remove AggregateExpression1, we can remove this hack. Also, because + // we do not have stddev in SimpleFunctionRegistry (we want to resolve + // it to the HiveGenericUDAF based on AggregateExpression1 and then do the + // conversion to AggregateExpression2), if it does not exist in function registry, + // we create StandardDeviation directly. + name.toLowerCase match { + case "std" | "stddev" | "stddev_samp" => + if (children.length != 1) { + failAnalysis(s"$name requires exactly one argument.") + } + val funcInRegistry = + registry + .lookupFunction("stddev_samp") + .map(_ => registry.lookupFunction("stddev_samp", children)) + funcInRegistry.getOrElse { + AggregateExpression2( + StandardDeviation(children.head, sample = true), Complete, isDistinct) + } + case "stddev_pop" => + if (children.length != 1) { + failAnalysis(s"$name requires exactly one argument.") + } + val funcInRegistry = + registry + .lookupFunction("stddev_pop") + .map(_ => registry.lookupFunction("stddev_pop", children)) + funcInRegistry.getOrElse { + AggregateExpression2( + StandardDeviation(children.head, sample = true), Complete, isDistinct) + } + case _ => + registry.lookupFunction(name, children) match { + // We get an aggregate function built based on AggregateFunction2 interface. + // So, we wrap it in AggregateExpression2. + case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct) + // Currently, our old aggregate function interface supports SUM(DISTINCT ...) + // and COUTN(DISTINCT ...). + case sumDistinct: SumDistinct => sumDistinct + case countDistinct: CountDistinct => countDistinct + // DISTINCT is not meaningful with Max and Min. + case max: Max if isDistinct => max + case min: Min if isDistinct => min + // For other aggregate functions, DISTINCT keyword is not supported for now. + // Once we converted to the new code path, we will allow using DISTINCT keyword. + case other: AggregateExpression1 if isDistinct => + failAnalysis(s"$name does not support DISTINCT keyword.") + // If it does not have DISTINCT keyword, we will return it as is. + case other => other + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 42cf882bb159..b910313e77b9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -326,9 +326,9 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be |SELECT | key, | cast(std(value) as decimal(12, 10)), - | cast(stddev(value) as decimal(12, 10)), - | cast(stddev_samp(value) as decimal(12, 10)), - | cast(stddev_pop(value) as decimal(12, 10)) + | cast(stDDev(value) as decimal(12, 10)), + | cast(stddEV_samp(value) as decimal(12, 10)), + | cast(stddev_pOP(value) as decimal(12, 10)) |FROM agg1 GROUP BY key """.stripMargin), expectedGroup1 :: From a170f433cd9d8304212b5d3aae66cc8aff7df12d Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 11 Aug 2015 11:20:14 -0700 Subject: [PATCH 15/15] fix long line --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bc040f7d2cd3..0c8c51525829 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -562,7 +562,8 @@ class Analyzer( registry.lookupFunction(name, children) match { // We get an aggregate function built based on AggregateFunction2 interface. // So, we wrap it in AggregateExpression2. - case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct) + case agg2: AggregateFunction2 => + AggregateExpression2(agg2, Complete, isDistinct) // Currently, our old aggregate function interface supports SUM(DISTINCT ...) // and COUTN(DISTINCT ...). case sumDistinct: SumDistinct => sumDistinct