Skip to content

Commit af78b62

Browse files
committed
[SPARK-28610][SQL] Allow having a decimal buffer for long sum
1 parent b3394db commit af78b62

File tree

3 files changed

+41
-2
lines changed

3 files changed

+41
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.util.TypeUtils
24+
import org.apache.spark.sql.internal.SQLConf
2425
import org.apache.spark.sql.types._
2526

2627
@ExpressionDescription(
@@ -56,7 +57,11 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
5657
case _ => DoubleType
5758
}
5859

59-
private lazy val sumDataType = resultType
60+
private lazy val sumDataType = child.dataType match {
61+
case LongType if SQLConf.get.getConf(SQLConf.SUM_DECIMAL_BUFFER_FOR_LONG) =>
62+
DecimalType.BigIntDecimal
63+
case _ => resultType
64+
}
6065

6166
private lazy val sum = AttributeReference("sum", sumDataType)()
6267

@@ -89,5 +94,11 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
8994
)
9095
}
9196

92-
override lazy val evaluateExpression: Expression = sum
97+
override lazy val evaluateExpression: Expression = {
98+
if (sumDataType == resultType) {
99+
sum
100+
} else {
101+
Cast(sum, resultType)
102+
}
103+
}
93104
}

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,6 +1789,16 @@ object SQLConf {
17891789
.booleanConf
17901790
.createWithDefault(false)
17911791

1792+
val SUM_DECIMAL_BUFFER_FOR_LONG =
1793+
buildConf("spark.sql.sum.decimalBufferForLong")
1794+
.doc("If it is set to true, sum of long uses decimal type for the buffer. When false " +
1795+
"(default), long is used to the buffer. If spark.sql.arithmeticOperations.failOnOverFlow" +
1796+
" is turned on, having this config set to true allows operations which have temporary " +
1797+
"overflows to execute properly without the exception thrown when this flag is false.")
1798+
.internal()
1799+
.booleanConf
1800+
.createWithDefault(false)
1801+
17921802
val LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE =
17931803
buildConf("spark.sql.legacy.parser.havingWithoutGroupByAsWhere")
17941804
.internal()

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql
1919

20+
import org.apache.spark.SparkException
21+
2022
import scala.util.Random
2123

2224
import org.scalatest.Matchers.the
@@ -927,4 +929,20 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
927929
assert(error.message.contains("function count_if requires boolean type"))
928930
}
929931
}
932+
933+
test("SPARK-28610: temporary overflow on sum of long should not fail") {
934+
withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true",
935+
SQLConf.SUM_DECIMAL_BUFFER_FOR_LONG.key -> "true") {
936+
val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, -1000L), 1).toDF("a")
937+
checkAnswer(df.select(sum($"a")), Row(Long.MaxValue - 900L))
938+
}
939+
withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true",
940+
SQLConf.SUM_DECIMAL_BUFFER_FOR_LONG.key -> "false") {
941+
val df = sparkContext.parallelize(Seq(100L, Long.MaxValue, -1000L), 1).toDF("a")
942+
val e = intercept[SparkException] {
943+
df.select(sum($"a")).collect()
944+
}
945+
assert(e.getCause.isInstanceOf[ArithmeticException])
946+
}
947+
}
930948
}

0 commit comments

Comments
 (0)