Skip to content

Commit 683e270

Browse files
mickjermsurawong-stripecloud-fan
authored andcommitted
[SPARK-28200][SQL] Decimal overflow handling in ExpressionEncoder
## What changes were proposed in this pull request? - Currently, `ExpressionEncoder` does not handle bigdecimal overflow. Round-tripping overflowing java/scala BigDecimal/BigInteger returns null. - The serializer encode java/scala BigDecimal to to sql Decimal, which still has the underlying data to the former. - When writing out to UnsafeRow, `changePrecision` will be false and row has null value. https://github.com/apache/spark/blob/24e1e41648de58d3437e008b187b84828830e238/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java#L202-L206 - In [SPARK-23179](#20350), an option to throw exception on decimal overflow was introduced. - This PR adds the option in `ExpressionEncoder` to throw when detecting overflowing BigDecimal/BigInteger before its corresponding Decimal gets written to Row. This gives a consistent behavior between decimal arithmetic on sql expression (DecimalPrecision), and getting decimal from dataframe (RowEncoder) Thanks to mgaido91 for the very first PR `SPARK-23179` and follow-up discussion on this change. Thanks to JoshRosen for working with me on this. ## How was this patch tested? added unit tests Closes #25016 from mickjermsurawong-stripe/SPARK-28200. Authored-by: Mick Jermsurawong <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent e299f62 commit 683e270

File tree

3 files changed

+107
-5
lines changed

3 files changed

+107
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,17 @@
1717

1818
package org.apache.spark.sql.catalyst
1919

20-
import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, IsNull, UnsafeArrayData}
20+
import org.apache.spark.sql.catalyst.expressions.{CheckOverflow, CreateNamedStruct, Expression, IsNull, UnsafeArrayData}
2121
import org.apache.spark.sql.catalyst.expressions.objects._
2222
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
23+
import org.apache.spark.sql.internal.SQLConf
2324
import org.apache.spark.sql.types._
2425
import org.apache.spark.unsafe.types.UTF8String
2526

2627
object SerializerBuildHelper {
2728

29+
private def nullOnOverflow: Boolean = SQLConf.get.decimalOperationsNullOnOverflow
30+
2831
def createSerializerForBoolean(inputObject: Expression): Expression = {
2932
Invoke(inputObject, "booleanValue", BooleanType)
3033
}
@@ -99,25 +102,25 @@ object SerializerBuildHelper {
99102
}
100103

101104
def createSerializerForJavaBigDecimal(inputObject: Expression): Expression = {
102-
StaticInvoke(
105+
CheckOverflow(StaticInvoke(
103106
Decimal.getClass,
104107
DecimalType.SYSTEM_DEFAULT,
105108
"apply",
106109
inputObject :: Nil,
107-
returnNullable = false)
110+
returnNullable = false), DecimalType.SYSTEM_DEFAULT, nullOnOverflow)
108111
}
109112

110113
def createSerializerForScalaBigDecimal(inputObject: Expression): Expression = {
111114
createSerializerForJavaBigDecimal(inputObject)
112115
}
113116

114117
def createSerializerForJavaBigInteger(inputObject: Expression): Expression = {
115-
StaticInvoke(
118+
CheckOverflow(StaticInvoke(
116119
Decimal.getClass,
117120
DecimalType.BigIntDecimal,
118121
"apply",
119122
inputObject :: Nil,
120-
returnNullable = false)
123+
returnNullable = false), DecimalType.BigIntDecimal, nullOnOverflow)
121124
}
122125

123126
def createSerializerForScalaBigInt(inputObject: Expression): Expression = {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.AttributeReference
3232
import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest
3333
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
3434
import org.apache.spark.sql.catalyst.util.ArrayData
35+
import org.apache.spark.sql.internal.SQLConf
3536
import org.apache.spark.sql.types._
3637
import org.apache.spark.unsafe.types.UTF8String
3738
import org.apache.spark.util.ClosureCleaner
@@ -379,6 +380,78 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
379380
assert(e.getMessage.contains("tuple with more than 22 elements are not supported"))
380381
}
381382

383+
// Scala / Java big decimals ----------------------------------------------------------
384+
385+
encodeDecodeTest(BigDecimal(("9" * 20) + "." + "9" * 18),
386+
"scala decimal within precision/scale limit")
387+
encodeDecodeTest(new java.math.BigDecimal(("9" * 20) + "." + "9" * 18),
388+
"java decimal within precision/scale limit")
389+
390+
encodeDecodeTest(-BigDecimal(("9" * 20) + "." + "9" * 18),
391+
"negative scala decimal within precision/scale limit")
392+
encodeDecodeTest(new java.math.BigDecimal(("9" * 20) + "." + "9" * 18).negate,
393+
"negative java decimal within precision/scale limit")
394+
395+
testOverflowingBigNumeric(BigDecimal("1" * 21), "scala big decimal")
396+
testOverflowingBigNumeric(new java.math.BigDecimal("1" * 21), "java big decimal")
397+
398+
testOverflowingBigNumeric(-BigDecimal("1" * 21), "negative scala big decimal")
399+
testOverflowingBigNumeric(new java.math.BigDecimal("1" * 21).negate, "negative java big decimal")
400+
401+
testOverflowingBigNumeric(BigDecimal(("1" * 21) + ".123"),
402+
"scala big decimal with fractional part")
403+
testOverflowingBigNumeric(new java.math.BigDecimal(("1" * 21) + ".123"),
404+
"java big decimal with fractional part")
405+
406+
testOverflowingBigNumeric(BigDecimal(("1" * 21) + "." + "9999" * 100),
407+
"scala big decimal with long fractional part")
408+
testOverflowingBigNumeric(new java.math.BigDecimal(("1" * 21) + "." + "9999" * 100),
409+
"java big decimal with long fractional part")
410+
411+
// Scala / Java big integers ----------------------------------------------------------
412+
413+
encodeDecodeTest(BigInt("9" * 38), "scala big integer within precision limit")
414+
encodeDecodeTest(new BigInteger("9" * 38), "java big integer within precision limit")
415+
416+
encodeDecodeTest(-BigInt("9" * 38),
417+
"negative scala big integer within precision limit")
418+
encodeDecodeTest(new BigInteger("9" * 38).negate(),
419+
"negative java big integer within precision limit")
420+
421+
testOverflowingBigNumeric(BigInt("1" * 39), "scala big int")
422+
testOverflowingBigNumeric(new BigInteger("1" * 39), "java big integer")
423+
424+
testOverflowingBigNumeric(-BigInt("1" * 39), "negative scala big int")
425+
testOverflowingBigNumeric(new BigInteger("1" * 39).negate, "negative java big integer")
426+
427+
testOverflowingBigNumeric(BigInt("9" * 100), "scala very large big int")
428+
testOverflowingBigNumeric(new BigInteger("9" * 100), "java very big int")
429+
430+
private def testOverflowingBigNumeric[T: TypeTag](bigNumeric: T, testName: String): Unit = {
431+
Seq(true, false).foreach { allowNullOnOverflow =>
432+
testAndVerifyNotLeakingReflectionObjects(
433+
s"overflowing $testName, allowNullOnOverflow=$allowNullOnOverflow") {
434+
withSQLConf(
435+
SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> allowNullOnOverflow.toString
436+
) {
437+
// Need to construct Encoder here rather than implicitly resolving it
438+
// so that SQLConf changes are respected.
439+
val encoder = ExpressionEncoder[T]()
440+
if (allowNullOnOverflow) {
441+
val convertedBack = encoder.resolveAndBind().fromRow(encoder.toRow(bigNumeric))
442+
assert(convertedBack === null)
443+
} else {
444+
val e = intercept[RuntimeException] {
445+
encoder.toRow(bigNumeric)
446+
}
447+
assert(e.getMessage.contains("Error while encoding"))
448+
assert(e.getCause.getClass === classOf[ArithmeticException])
449+
}
450+
}
451+
}
452+
}
453+
}
454+
382455
private def encodeDecodeTest[T : ExpressionEncoder](
383456
input: T,
384457
testName: String): Unit = {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,32 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
162162
assert(row.toSeq(schema).head == decimal)
163163
}
164164

165+
test("SPARK-23179: RowEncoder should respect nullOnOverflow for decimals") {
166+
val schema = new StructType().add("decimal", DecimalType.SYSTEM_DEFAULT)
167+
testDecimalOverflow(schema, Row(BigDecimal("9" * 100)))
168+
testDecimalOverflow(schema, Row(new java.math.BigDecimal("9" * 100)))
169+
}
170+
171+
private def testDecimalOverflow(schema: StructType, row: Row): Unit = {
172+
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "false") {
173+
val encoder = RowEncoder(schema).resolveAndBind()
174+
intercept[Exception] {
175+
encoder.toRow(row)
176+
} match {
177+
case e: ArithmeticException =>
178+
assert(e.getMessage.contains("cannot be represented as Decimal"))
179+
case e: RuntimeException =>
180+
assert(e.getCause.isInstanceOf[ArithmeticException])
181+
assert(e.getCause.getMessage.contains("cannot be represented as Decimal"))
182+
}
183+
}
184+
185+
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "true") {
186+
val encoder = RowEncoder(schema).resolveAndBind()
187+
assert(encoder.fromRow(encoder.toRow(row)).get(0) == null)
188+
}
189+
}
190+
165191
test("RowEncoder should preserve schema nullability") {
166192
val schema = new StructType().add("int", IntegerType, nullable = false)
167193
val encoder = RowEncoder(schema).resolveAndBind()

0 commit comments

Comments
 (0)