Skip to content

Commit 1fda011

Browse files
karenfenggengliangwang
authored andcommitted
[SPARK-35955][SQL] Check for overflow in Average in ANSI mode
### What changes were proposed in this pull request? Fixes decimal overflow issues for decimal average in ANSI mode, so that overflows throw an exception rather than returning null. ### Why are the changes needed? Query: ``` scala> import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._ scala> spark.conf.set("spark.sql.ansi.enabled", true) scala> val df = Seq( | (BigDecimal("10000000000000000000"), 1), | (BigDecimal("10000000000000000000"), 1), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2), | (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") df: org.apache.spark.sql.DataFrame = [decNum: decimal(38,18), intNum: int] scala> val df2 = df.withColumnRenamed("decNum", "decNum2").join(df, "intNum").agg(mean("decNum")) df2: org.apache.spark.sql.DataFrame = [avg(decNum): decimal(38,22)] scala> df2.show(40,false) ``` Before: ``` +-----------+ |avg(decNum)| +-----------+ |null | +-----------+ ``` After: ``` 21/07/01 19:48:31 ERROR Executor: Exception in task 0.0 in stage 3.0 (TID 24) java.lang.ArithmeticException: Overflow in sum of decimals. at org.apache.spark.sql.errors.QueryExecutionErrors$.overflowInSumOfDecimalError(QueryExecutionErrors.scala:162) at org.apache.spark.sql.errors.QueryExecutionErrors.overflowInSumOfDecimalError(QueryExecutionErrors.scala) at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source) at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:759) at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:349) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:898) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:898) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:373) at org.apache.spark.rdd.RDD.iterator(RDD.scala:337) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) at org.apache.spark.scheduler.Task.run(Task.scala:131) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:499) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1462) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:502) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test Closes #33177 from karenfeng/SPARK-35955. Authored-by: Karen Feng <[email protected]> Signed-off-by: Gengliang Wang <[email protected]>
1 parent 47485a3 commit 1fda011

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.trees.TreePattern.{AVERAGE, TreePattern}
2424
import org.apache.spark.sql.catalyst.trees.UnaryLike
2525
import org.apache.spark.sql.catalyst.util.TypeUtils
26+
import org.apache.spark.sql.internal.SQLConf
2627
import org.apache.spark.sql.types._
2728

2829
@ExpressionDescription(
@@ -87,9 +88,11 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
8788
// If all input are nulls, count will be 0 and we will get null after the division.
8889
// We can't directly use `/` as it throws an exception under ansi mode.
8990
override lazy val evaluateExpression = child.dataType match {
90-
case _: DecimalType =>
91+
case d: DecimalType =>
9192
DecimalPrecision.decimalAndDecimal()(
92-
Divide(sum, count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
93+
Divide(
94+
CheckOverflowInSum(sum, d, !SQLConf.get.ansiEnabled),
95+
count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
9396
case _: YearMonthIntervalType =>
9497
If(EqualTo(count, Literal(0L)),
9598
Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count))

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ class DataFrameSuite extends QueryTest
235235
}
236236
}
237237

238-
test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") {
238+
def checkAggResultsForDecimalOverflow(aggFn: Column => Column): Unit = {
239239
Seq("true", "false").foreach { wholeStageEnabled =>
240240
withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) {
241241
Seq(true, false).foreach { ansiEnabled =>
@@ -256,30 +256,30 @@ class DataFrameSuite extends QueryTest
256256
(BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum")
257257
val df = df0.union(df1)
258258
val df2 = df.withColumnRenamed("decNum", "decNum2").
259-
join(df, "intNum").agg(sum("decNum"))
259+
join(df, "intNum").agg(aggFn($"decNum"))
260260

261261
val expectedAnswer = Row(null)
262262
assertDecimalSumOverflow(df2, ansiEnabled, expectedAnswer)
263263

264264
val decStr = "1" + "0" * 19
265265
val d1 = spark.range(0, 12, 1, 1)
266-
val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d"))
266+
val d2 = d1.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(aggFn($"d"))
267267
assertDecimalSumOverflow(d2, ansiEnabled, expectedAnswer)
268268

269269
val d3 = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
270-
val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(sum($"d"))
270+
val d4 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d")).agg(aggFn($"d"))
271271
assertDecimalSumOverflow(d4, ansiEnabled, expectedAnswer)
272272

273273
val d5 = d3.select(expr(s"cast('$decStr' as decimal (38, 18)) as d"),
274-
lit(1).as("key")).groupBy("key").agg(sum($"d").alias("sumd")).select($"sumd")
274+
lit(1).as("key")).groupBy("key").agg(aggFn($"d").alias("aggd")).select($"aggd")
275275
assertDecimalSumOverflow(d5, ansiEnabled, expectedAnswer)
276276

277277
val nullsDf = spark.range(1, 4, 1).select(expr(s"cast(null as decimal(38,18)) as d"))
278278

279279
val largeDecimals = Seq(BigDecimal("1"* 20 + ".123"), BigDecimal("9"* 20 + ".123")).
280280
toDF("d")
281281
assertDecimalSumOverflow(
282-
nullsDf.union(largeDecimals).agg(sum($"d")), ansiEnabled, expectedAnswer)
282+
nullsDf.union(largeDecimals).agg(aggFn($"d")), ansiEnabled, expectedAnswer)
283283

284284
val df3 = Seq(
285285
(BigDecimal("10000000000000000000"), 1),
@@ -306,6 +306,14 @@ class DataFrameSuite extends QueryTest
306306
}
307307
}
308308

309+
test("SPARK-28067: Aggregate sum should not return wrong results for decimal overflow") {
310+
checkAggResultsForDecimalOverflow(c => sum(c))
311+
}
312+
313+
test("SPARK-35955: Aggregate avg should not return wrong results for decimal overflow") {
314+
checkAggResultsForDecimalOverflow(c => avg(c))
315+
}
316+
309317
test("Star Expansion - ds.explode should fail with a meaningful message if it takes a star") {
310318
val df = Seq(("1", "1,2"), ("2", "4"), ("3", "7,8,9")).toDF("prefix", "csv")
311319
val e = intercept[AnalysisException] {

0 commit comments

Comments
 (0)