From 1a1252d1db0af0485b98c9bca0d442c9235bd2a0 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 29 Jul 2018 12:53:09 +0200 Subject: [PATCH 1/2] [SPARK-24957][SQL] Average with decimal followed by aggregation returns wrong result --- .../sql/catalyst/analysis/DecimalPrecision.scala | 2 +- .../catalyst/expressions/aggregate/Average.scala | 9 ++++----- .../sql/hive/execution/AggregationQuerySuite.scala | 13 +++++++++++++ 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index 65a5888222f2e..23d146e71ed19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -89,7 +89,7 @@ object DecimalPrecision extends TypeCoercionRule { } /** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */ - private val decimalAndDecimal: PartialFunction[Expression, Expression] = { + private[catalyst] val decimalAndDecimal: PartialFunction[Expression, Expression] = { // Skip nodes whose children have not been resolved yet case e if !e.childrenResolved => e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index a133bc2361eb5..4180294ae7f15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils @@ -57,10 +57,9 @@ abstract class AverageLike(child: Expression) extends DeclarativeAggregate { // If all input are nulls, count will be 0 and we will get null after the division. override lazy val evaluateExpression = child.dataType match { - case DecimalType.Fixed(p, s) => - // increase the precision and scale to prevent precision loss - val dt = DecimalType.bounded(p + 14, s + 4) - Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded(DecimalType.MAX_PRECISION, 0)), + case _: DecimalType => + Cast( + DecimalPrecision.decimalAndDecimal.lift(sum / Cast(count, DecimalType.LongDecimal)).get, resultType) case _ => Cast(sum, resultType) / Cast(count, resultType) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index ae675149df5e2..c65bf7c14c7a5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -1005,6 +1005,19 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te ) ) } + + test("SPARK-24957: average with decimal followed by aggregation returning wrong result") { + val df = Seq(("a", BigDecimal("12.0")), + ("a", BigDecimal("12.0")), + ("a", BigDecimal("11.9999999988")), + ("a", BigDecimal("12.0")), + ("a", BigDecimal("12.0")), + ("a", BigDecimal("11.9999999988")), + ("a", BigDecimal("11.9999999988"))).toDF("text", "number") + val agg1 = df.groupBy($"text").agg(avg($"number").as("avg_res")) + val agg2 = agg1.groupBy($"text").agg(sum($"avg_res")) + checkAnswer(agg2, Row("a", BigDecimal("11.9999999994857142860000"))) + } } From 1938be02edce7c34c7d1d3686a5b0c15a158f0e4 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 30 Jul 2018 09:40:23 +0200 Subject: [PATCH 2/2] address comment --- .../spark/sql/catalyst/expressions/aggregate/Average.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 4180294ae7f15..9ccf5aa092d11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -59,7 +59,7 @@ abstract class AverageLike(child: Expression) extends DeclarativeAggregate { override lazy val evaluateExpression = child.dataType match { case _: DecimalType => Cast( - DecimalPrecision.decimalAndDecimal.lift(sum / Cast(count, DecimalType.LongDecimal)).get, + DecimalPrecision.decimalAndDecimal(sum / Cast(count, DecimalType.LongDecimal)), resultType) case _ => Cast(sum, resultType) / Cast(count, resultType)