diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index ad7f7dd9434a..b5b712cda8ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** @@ -46,19 +47,38 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { */ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { + private val nullOnOverflow = SQLConf.get.decimalOperationsNullOnOverflow + override def dataType: DataType = DecimalType(precision, scale) - override def nullable: Boolean = true + override def nullable: Boolean = child.nullable || nullOnOverflow override def toString: String = s"MakeDecimal($child,$precision,$scale)" - protected override def nullSafeEval(input: Any): Any = - Decimal(input.asInstanceOf[Long], precision, scale) + protected override def nullSafeEval(input: Any): Any = { + val longInput = input.asInstanceOf[Long] + val result = new Decimal() + if (nullOnOverflow) { + result.setOrNull(longInput, precision, scale) + } else { + result.set(longInput, precision, scale) + } + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { + val setMethod = if (nullOnOverflow) { + "setOrNull" + } else { + "set" + } + val setNull = if (nullable) { + s"${ev.isNull} = ${ev.value} == null;" + } else { + "" + } s""" - ${ev.value} = (new Decimal()).setOrNull($eval, $precision, $scale); - ${ev.isNull} = ${ev.value} == null; - """ + |${ev.value} = (new Decimal()).$setMethod($eval, $precision, $scale); + |$setNull + |""".stripMargin }) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index b7b70974f50e..1bf322af2179 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -76,7 +76,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { */ def set(unscaled: Long, precision: Int, scale: Int): Decimal = { if (setOrNull(unscaled, precision, scale) == null) { - throw new IllegalArgumentException("Unscaled value too large for precision") + throw new ArithmeticException("Unscaled value too large for precision") } this } @@ -111,9 +111,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { */ def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { this.decimalVal = decimal.setScale(scale, ROUND_HALF_UP) - require( - decimalVal.precision <= precision, - s"Decimal precision ${decimalVal.precision} exceeds max precision $precision") + if (decimalVal.precision > precision) { + throw new ArithmeticException( + s"Decimal precision ${decimalVal.precision} exceeds max precision $precision") + } this.longVal = 0L this._precision = precision this._scale = scale diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala index d14eceb480f5..fc5e8dc5ee7f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{Decimal, DecimalType, LongType} class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -31,8 +32,23 @@ class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } test("MakeDecimal") { - checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1")) - checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null) + withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "true") { + checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1")) + checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null) + val overflowExpr = MakeDecimal(Literal.create(1000L, LongType), 3, 1) + checkEvaluation(overflowExpr, null) + checkEvaluationWithMutableProjection(overflowExpr, null) + evaluateWithoutCodegen(overflowExpr, null) + checkEvaluationWithUnsafeProjection(overflowExpr, null) + } + withSQLConf(SQLConf.DECIMAL_OPERATIONS_NULL_ON_OVERFLOW.key -> "false") { + checkEvaluation(MakeDecimal(Literal(101L), 3, 1), Decimal("10.1")) + checkEvaluation(MakeDecimal(Literal.create(null, LongType), 3, 1), null) + val overflowExpr = MakeDecimal(Literal.create(1000L, LongType), 3, 1) + intercept[ArithmeticException](checkEvaluationWithMutableProjection(overflowExpr, null)) + intercept[ArithmeticException](evaluateWithoutCodegen(overflowExpr, null)) + intercept[ArithmeticException](checkEvaluationWithUnsafeProjection(overflowExpr, null)) + } } test("PromotePrecision") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 8abd7625c21a..d69bb2f0b6bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -56,11 +56,11 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { checkDecimal(Decimal(1000000000000000000L, 20, 2), "10000000000000000.00", 20, 2) checkDecimal(Decimal(Long.MaxValue), Long.MaxValue.toString, 20, 0) checkDecimal(Decimal(Long.MinValue), Long.MinValue.toString, 20, 0) - intercept[IllegalArgumentException](Decimal(170L, 2, 1)) - intercept[IllegalArgumentException](Decimal(170L, 2, 0)) - intercept[IllegalArgumentException](Decimal(BigDecimal("10.030"), 2, 1)) - intercept[IllegalArgumentException](Decimal(BigDecimal("-9.95"), 2, 1)) - intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0)) + intercept[ArithmeticException](Decimal(170L, 2, 1)) + intercept[ArithmeticException](Decimal(170L, 2, 0)) + intercept[ArithmeticException](Decimal(BigDecimal("10.030"), 2, 1)) + intercept[ArithmeticException](Decimal(BigDecimal("-9.95"), 2, 1)) + intercept[ArithmeticException](Decimal(1e17.toLong, 17, 0)) } test("creating decimals with negative scale") {