Skip to content

Commit 8617bf6

Browse files
mgaido91maropu
authored andcommitted
[SPARK-28470][SQL] Cast to decimal throws ArithmeticException on overflow
## What changes were proposed in this pull request? The flag `spark.sql.decimalOperations.nullOnOverflow` is not honored by the `Cast` operator. This means that a casting which causes an overflow currently returns `null`. The PR makes `Cast` respecting that flag, ie. when it is turned to false and a decimal overflow occurs, an exception id thrown. ## How was this patch tested? Added UT Closes #25253 from mgaido91/SPARK-28470. Authored-by: Marco Gaido <[email protected]> Signed-off-by: Takeshi Yamamuro <[email protected]>
1 parent 325bc8e commit 8617bf6

File tree

2 files changed

+52
-6
lines changed
  • sql/catalyst/src
    • main/scala/org/apache/spark/sql/catalyst/expressions
    • test/scala/org/apache/spark/sql/catalyst/expressions

2 files changed

+52
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
2828
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
2929
import org.apache.spark.sql.catalyst.util._
3030
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
31+
import org.apache.spark.sql.internal.SQLConf
3132
import org.apache.spark.sql.types._
3233
import org.apache.spark.unsafe.UTF8StringBuilder
3334
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -499,22 +500,37 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
499500
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
500501
}
501502

503+
private val nullOnOverflow = SQLConf.get.decimalOperationsNullOnOverflow
504+
502505
/**
503506
* Change the precision / scale in a given decimal to those set in `decimalType` (if any),
504-
* returning null if it overflows or modifying `value` in-place and returning it if successful.
507+
* modifying `value` in-place and returning it if successful. If an overflow occurs, it
508+
* either returns null or throws an exception according to the value set for
509+
* `spark.sql.decimalOperations.nullOnOverflow`.
505510
*
506511
* NOTE: this modifies `value` in-place, so don't call it on external data.
507512
*/
508513
private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = {
509-
if (value.changePrecision(decimalType.precision, decimalType.scale)) value else null
514+
if (value.changePrecision(decimalType.precision, decimalType.scale)) {
515+
value
516+
} else {
517+
if (nullOnOverflow) {
518+
null
519+
} else {
520+
throw new ArithmeticException(s"${value.toDebugString} cannot be represented as " +
521+
s"Decimal(${decimalType.precision}, ${decimalType.scale}).")
522+
}
523+
}
510524
}
511525

512526
/**
513-
* Create new `Decimal` with precision and scale given in `decimalType` (if any),
514-
* returning null if it overflows or creating a new `value` and returning it if successful.
527+
* Create new `Decimal` with precision and scale given in `decimalType` (if any).
528+
* If overflow occurs, if `spark.sql.decimalOperations.nullOnOverflow` is true, null is returned;
529+
* otherwise, an `ArithmeticException` is thrown.
515530
*/
516531
private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal =
517-
value.toPrecision(decimalType.precision, decimalType.scale)
532+
value.toPrecision(
533+
decimalType.precision, decimalType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow)
518534

519535

520536
private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
@@ -964,11 +980,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
964980
|$evPrim = $d;
965981
""".stripMargin
966982
} else {
983+
val overflowCode = if (nullOnOverflow) {
984+
s"$evNull = true;"
985+
} else {
986+
s"""
987+
|throw new ArithmeticException($d.toDebugString() + " cannot be represented as " +
988+
| "Decimal(${decimalType.precision}, ${decimalType.scale}).");
989+
""".stripMargin
990+
}
967991
code"""
968992
|if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) {
969993
| $evPrim = $d;
970994
|} else {
971-
| $evNull = true;
995+
| $overflowCode
972996
|}
973997
""".stripMargin
974998
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
2929
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._
3030
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3131
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
32+
import org.apache.spark.sql.internal.SQLConf
3233
import org.apache.spark.sql.types._
3334
import org.apache.spark.unsafe.types.UTF8String
3435

@@ -1023,4 +1024,25 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
10231024
checkEvaluation(ret, InternalRow(null))
10241025
}
10251026
}
1027+
1028+
test("SPARK-28470: Cast should honor nullOnOverflow property") {
1029+
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "true") {
1030+
checkEvaluation(Cast(Literal("134.12"), DecimalType(3, 2)), null)
1031+
checkEvaluation(
1032+
Cast(Literal(Timestamp.valueOf("2019-07-25 22:04:36")), DecimalType(3, 2)), null)
1033+
checkEvaluation(Cast(Literal(BigDecimal(134.12)), DecimalType(3, 2)), null)
1034+
checkEvaluation(Cast(Literal(134.12), DecimalType(3, 2)), null)
1035+
}
1036+
withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "false") {
1037+
checkExceptionInExpression[ArithmeticException](
1038+
Cast(Literal("134.12"), DecimalType(3, 2)), "cannot be represented")
1039+
checkExceptionInExpression[ArithmeticException](
1040+
Cast(Literal(Timestamp.valueOf("2019-07-25 22:04:36")), DecimalType(3, 2)),
1041+
"cannot be represented")
1042+
checkExceptionInExpression[ArithmeticException](
1043+
Cast(Literal(BigDecimal(134.12)), DecimalType(3, 2)), "cannot be represented")
1044+
checkExceptionInExpression[ArithmeticException](
1045+
Cast(Literal(134.12), DecimalType(3, 2)), "cannot be represented")
1046+
}
1047+
}
10261048
}

0 commit comments

Comments
 (0)