From 1f817a058887090ebf41c510a3b6086a062433d6 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sun, 29 Jul 2018 12:53:09 +0200 Subject: [PATCH] [SPARK-24957][SQL] Average with decimal followed by aggregation returns wrong result --- .../sql/catalyst/analysis/DecimalPrecision.scala | 2 +- .../catalyst/expressions/aggregate/Average.scala | 9 ++++----- .../sql/hive/execution/AggregationQuerySuite.scala | 13 +++++++++++++ 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index fd2ac78b25dbd..a48801c5ee140 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -85,7 +85,7 @@ object DecimalPrecision extends Rule[LogicalPlan] { } /** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */ - private val decimalAndDecimal: PartialFunction[Expression, Expression] = { + private[catalyst] val decimalAndDecimal: PartialFunction[Expression, Expression] = { // Skip nodes whose children have not been resolved yet case e if !e.childrenResolved => e diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 708bdbfc36058..f80df75ac7f72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, TypeCheckResult} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils @@ -77,10 +77,9 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit // If all input are nulls, count will be 0 and we will get null after the division. override lazy val evaluateExpression = child.dataType match { - case DecimalType.Fixed(p, s) => - // increase the precision and scale to prevent precision loss - val dt = DecimalType.bounded(p + 14, s + 4) - Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded(DecimalType.MAX_PRECISION, 0)), + case _: DecimalType => + Cast( + DecimalPrecision.decimalAndDecimal(sum / Cast(count, DecimalType.LongDecimal)), resultType) case _ => Cast(sum, resultType) / Cast(count, resultType) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 84f915977bd88..3a9e50c7685c0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -1002,6 +1002,19 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te ) ) } + + test("SPARK-24957: average with decimal followed by aggregation returning wrong result") { + val df = Seq(("a", BigDecimal("12.0")), + ("a", BigDecimal("12.0")), + ("a", BigDecimal("11.9999999988")), + ("a", BigDecimal("12.0")), + ("a", BigDecimal("12.0")), + ("a", BigDecimal("11.9999999988")), + ("a", BigDecimal("11.9999999988"))).toDF("text", "number") + val agg1 = df.groupBy($"text").agg(avg($"number").as("avg_res")) + val agg2 = agg1.groupBy($"text").agg(sum($"avg_res")) + checkAnswer(agg2, Row("a", BigDecimal("11.9999999994857142857143"))) + } }