Skip to content

Commit 25ea27b

Browse files
mgaido91cloud-fan
authored andcommitted
[SPARK-24957][SQL] Average with decimal followed by aggregation returns wrong result
## What changes were proposed in this pull request? When we do an average, the result is computed dividing the sum of the values by their count. In the case the result is a DecimalType, the way we are casting/managing the precision and scale is not really optimized and it is not coherent with what we do normally. In particular, a problem can happen when the `Divide` operand returns a result which contains a precision and scale different by the ones which are expected as output of the `Divide` operand. In the case reported in the JIRA, for instance, the result of the `Divide` operand is a `Decimal(38, 36)`, while the output data type for `Divide` is 38, 22. This is not an issue when the `Divide` is followed by a `CheckOverflow` or a `Cast` to the right data type, as these operations return a decimal with the defined precision and scale. Despite in the `Average` operator we do have a `Cast`, this may be bypassed if the result of `Divide` is the same type which it is casted to, hence the issue reported in the JIRA may arise. The PR proposes to use the normal rules/handling of the arithmetic operators with Decimal data type, so we both reuse the existing code (having a single logic for operations between decimals) and we fix this problem as the result is always guarded by `CheckOverflow`. ## How was this patch tested? added UT Author: Marco Gaido <[email protected]> Closes #21910 from mgaido91/SPARK-24957. (cherry picked from commit 85505fc) Signed-off-by: Wenchen Fan <[email protected]>
1 parent aa51c07 commit 25ea27b

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ object DecimalPrecision extends TypeCoercionRule {
8989
}
9090

9191
/** Decimal precision promotion for +, -, *, /, %, pmod, and binary comparison. */
92-
private val decimalAndDecimal: PartialFunction[Expression, Expression] = {
92+
private[catalyst] val decimalAndDecimal: PartialFunction[Expression, Expression] = {
9393
// Skip nodes whose children have not been resolved yet
9494
case e if !e.childrenResolved => e
9595

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

20-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
20+
import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, TypeCheckResult}
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.util.TypeUtils
@@ -77,10 +77,9 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
7777

7878
// If all input are nulls, count will be 0 and we will get null after the division.
7979
override lazy val evaluateExpression = child.dataType match {
80-
case DecimalType.Fixed(p, s) =>
81-
// increase the precision and scale to prevent precision loss
82-
val dt = DecimalType.bounded(p + 14, s + 4)
83-
Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded(DecimalType.MAX_PRECISION, 0)),
80+
case _: DecimalType =>
81+
Cast(
82+
DecimalPrecision.decimalAndDecimal(sum / Cast(count, DecimalType.LongDecimal)),
8483
resultType)
8584
case _ =>
8685
Cast(sum, resultType) / Cast(count, resultType)

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,19 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
10051005
)
10061006
)
10071007
}
1008+
1009+
test("SPARK-24957: average with decimal followed by aggregation returning wrong result") {
1010+
val df = Seq(("a", BigDecimal("12.0")),
1011+
("a", BigDecimal("12.0")),
1012+
("a", BigDecimal("11.9999999988")),
1013+
("a", BigDecimal("12.0")),
1014+
("a", BigDecimal("12.0")),
1015+
("a", BigDecimal("11.9999999988")),
1016+
("a", BigDecimal("11.9999999988"))).toDF("text", "number")
1017+
val agg1 = df.groupBy($"text").agg(avg($"number").as("avg_res"))
1018+
val agg2 = agg1.groupBy($"text").agg(sum($"avg_res"))
1019+
checkAnswer(agg2, Row("a", BigDecimal("11.9999999994857142860000")))
1020+
}
10081021
}
10091022

10101023

0 commit comments

Comments
 (0)