Skip to content

Commit 6888089

Browse files
cloud-fangengliangwang
authored andcommitted
[SPARK-36926][3.2][SQL] Decimal average mistakenly overflow
backport #34180 ### What changes were proposed in this pull request? This bug was introduced by #33177 When checking overflow of the sum value in the average function, we should use the `sumDataType` instead of the input decimal type. ### Why are the changes needed? fix a regression ### Does this PR introduce _any_ user-facing change? Yes, the result was wrong before this PR. ### How was this patch tested? a new test Closes #34193 from cloud-fan/bug. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Gengliang Wang <[email protected]>
1 parent c542297 commit 6888089

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,10 @@ case class Average(
9494
// If all input are nulls, count will be 0 and we will get null after the division.
9595
// We can't directly use `/` as it throws an exception under ansi mode.
9696
override lazy val evaluateExpression = child.dataType match {
97-
case d: DecimalType =>
97+
case _: DecimalType =>
9898
DecimalPrecision.decimalAndDecimal()(
9999
Divide(
100-
CheckOverflowInSum(sum, d, !failOnError),
100+
CheckOverflowInSum(sum, sumDataType.asInstanceOf[DecimalType], !failOnError),
101101
count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
102102
case _: YearMonthIntervalType =>
103103
If(EqualTo(count, Literal(0L)),

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,6 +1416,12 @@ class DataFrameAggregateSuite extends QueryTest
14161416
val df2 = Seq(Period.ofYears(1)).toDF("a").groupBy("a").count()
14171417
checkAnswer(df2, Row(Period.ofYears(1), 1))
14181418
}
1419+
1420+
test("SPARK-36926: decimal average mistakenly overflow") {
1421+
val df = (1 to 10).map(_ => "9999999999.99").toDF("d")
1422+
val res = df.select($"d".cast("decimal(12, 2)").as("d")).agg(avg($"d").cast("string"))
1423+
checkAnswer(res, Row("9999999999.990000"))
1424+
}
14191425
}
14201426

14211427
case class B(c: Option[Double])

0 commit comments

Comments
 (0)