diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 22b29c3000c1..10c5b5627d6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -35,6 +35,7 @@ import org.apache.spark.unsafe.types.CalendarInterval """) case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { + private val checkOverflow = SQLConf.get.arithmeticOperationsFailOnOverflow override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) @@ -42,10 +43,28 @@ case class UnaryMinus(child: Expression) extends UnaryExpression override def toString: String = s"-$child" - private lazy val numeric = TypeUtils.getNumeric(dataType) + private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") + case ByteType | ShortType if checkOverflow => + nullSafeCodeGen(ctx, ev, eval => { + val javaBoxedType = CodeGenerator.boxedType(dataType) + val javaType = CodeGenerator.javaType(dataType) + val originValue = ctx.freshName("origin") + s""" + |$javaType $originValue = ($javaType)($eval); + |if ($originValue == $javaBoxedType.MIN_VALUE) { + | throw new ArithmeticException("- " + $originValue + " caused overflow."); + |} + |${ev.value} = ($javaType)(-($originValue)); + """.stripMargin + }) + case IntegerType | LongType if checkOverflow => + nullSafeCodeGen(ctx, ev, eval => { + val mathClass = classOf[Math].getName + s"${ev.value} = $mathClass.negateExact($eval);" + }) case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { val originValue = ctx.freshName("origin") // codegen would fail to compile if we just write (-($c)) @@ -117,6 +136,8 @@ case class Abs(child: Expression) abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { + protected val checkOverflow = SQLConf.get.arithmeticOperationsFailOnOverflow + override def dataType: DataType = left.dataType override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess @@ -129,17 +150,57 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { def calendarIntervalMethod: String = sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode") + /** Name of the function for the exact version of this expression in [[Math]]. */ + def exactMathMethod: String = + sys.error("BinaryArithmetics must override either exactMathMethod or genCode") + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case _: DecimalType => + // Overflow is handled in the CheckOverflow operator defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)") // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => - defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") - case _ => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val tmpResult = ctx.freshName("tmpResult") + val overflowCheck = if (checkOverflow) { + val javaType = CodeGenerator.boxedType(dataType) + s""" + |if ($tmpResult < $javaType.MIN_VALUE || $tmpResult > $javaType.MAX_VALUE) { + | throw new ArithmeticException($eval1 + " $symbol " + $eval2 + " caused overflow."); + |} + """.stripMargin + } else { + "" + } + s""" + |${CodeGenerator.JAVA_INT} $tmpResult = $eval1 $symbol $eval2; + |$overflowCheck + |${ev.value} = (${CodeGenerator.javaType(dataType)})($tmpResult); + """.stripMargin + }) + case IntegerType | LongType => + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val operation = if (checkOverflow) { + val mathClass = classOf[Math].getName + s"$mathClass.$exactMathMethod($eval1, $eval2)" + } else { + s"$eval1 $symbol $eval2" + } + s""" + |${ev.value} = $operation; + """.stripMargin + }) + case DoubleType | FloatType => + // When Double/Float overflows, there can be 2 cases: + // - precision loss: according to SQL standard, the number is truncated; + // - returns (+/-)Infinite: same behavior also other DBs have (eg. Postgres) + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + |${ev.value} = $eval1 $symbol $eval2; + """.stripMargin + }) } } @@ -164,7 +225,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def calendarIntervalMethod: String = "add" - private lazy val numeric = TypeUtils.getNumeric(dataType) + private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) protected override def nullSafeEval(input1: Any, input2: Any): Any = { if (dataType.isInstanceOf[CalendarIntervalType]) { @@ -173,6 +234,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { numeric.plus(input1, input2) } } + + override def exactMathMethod: String = "addExact" } @ExpressionDescription( @@ -192,7 +255,7 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti override def calendarIntervalMethod: String = "subtract" - private lazy val numeric = TypeUtils.getNumeric(dataType) + private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) protected override def nullSafeEval(input1: Any, input2: Any): Any = { if (dataType.isInstanceOf[CalendarIntervalType]) { @@ -201,6 +264,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti numeric.minus(input1, input2) } } + + override def exactMathMethod: String = "subtractExact" } @ExpressionDescription( @@ -217,9 +282,11 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti override def symbol: String = "*" override def decimalMethod: String = "$times" - private lazy val numeric = TypeUtils.getNumeric(dataType) + private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) + + override def exactMathMethod: String = "multiplyExact" } // Common base trait for Divide and Remainder, since these two classes are almost identical diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index fed2a1ac4b8b..9680ea3cd206 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -60,8 +60,13 @@ object TypeUtils { } } - def getNumeric(t: DataType): Numeric[Any] = - t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] + def getNumeric(t: DataType, exactNumericRequired: Boolean = false): Numeric[Any] = { + if (exactNumericRequired) { + t.asInstanceOf[NumericType].exactNumeric.asInstanceOf[Numeric[Any]] + } else { + t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] + } + } def getInterpretedOrdering(t: DataType): Ordering[Any] = { t match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fbdb1c5f957d..2fede591fc80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1780,6 +1780,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW = + buildConf("spark.sql.arithmeticOperations.failOnOverFlow") + .doc("If it is set to true, all arithmetic operations on non-decimal fields throw an " + + "exception if an overflow occurs. If it is false (default), in case of overflow a wrong " + + "result is returned.") + .internal() + .booleanConf + .createWithDefault(false) + val LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE = buildConf("spark.sql.legacy.parser.havingWithoutGroupByAsWhere") .internal() @@ -2287,6 +2296,8 @@ class SQLConf extends Serializable with Logging { def decimalOperationsNullOnOverflow: Boolean = getConf(DECIMAL_OPERATIONS_NULL_ON_OVERFLOW) + def arithmeticOperationsFailOnOverflow: Boolean = getConf(ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW) + def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION) def continuousStreamingEpochBacklogQueueSize: Int = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index d2ef08873187..21ac32adca6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -142,6 +142,8 @@ abstract class NumericType extends AtomicType { // desugared by the compiler into an argument to the objects constructor. This means there is no // longer a no argument constructor and thus the JVM cannot serialize the object anymore. private[sql] val numeric: Numeric[InternalType] + + private[sql] def exactNumeric: Numeric[InternalType] = numeric } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala index 9d400eefc0f8..0df9518045f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala @@ -37,6 +37,7 @@ class ByteType private() extends IntegralType { private[sql] val numeric = implicitly[Numeric[Byte]] private[sql] val integral = implicitly[Integral[Byte]] private[sql] val ordering = implicitly[Ordering[InternalType]] + override private[sql] val exactNumeric = ByteExactNumeric /** * The default size of a value of the ByteType is 1 byte. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala index 0755202d20df..c344523bdcb8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala @@ -37,6 +37,7 @@ class IntegerType private() extends IntegralType { private[sql] val numeric = implicitly[Numeric[Int]] private[sql] val integral = implicitly[Integral[Int]] private[sql] val ordering = implicitly[Ordering[InternalType]] + override private[sql] val exactNumeric = IntegerExactNumeric /** * The default size of a value of the IntegerType is 4 bytes. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala index 3c49c721fdc8..f030920db451 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala @@ -37,6 +37,7 @@ class LongType private() extends IntegralType { private[sql] val numeric = implicitly[Numeric[Long]] private[sql] val integral = implicitly[Integral[Long]] private[sql] val ordering = implicitly[Ordering[InternalType]] + override private[sql] val exactNumeric = LongExactNumeric /** * The default size of a value of the LongType is 8 bytes. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala index 9b5ddfef1ccf..825268995853 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala @@ -37,6 +37,7 @@ class ShortType private() extends IntegralType { private[sql] val numeric = implicitly[Numeric[Short]] private[sql] val integral = implicitly[Integral[Short]] private[sql] val ordering = implicitly[Ordering[InternalType]] + override private[sql] val exactNumeric = ShortExactNumeric /** * The default size of a value of the ShortType is 2 bytes. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala new file mode 100644 index 000000000000..a362afa2b31b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Numeric.{ByteIsIntegral, IntIsIntegral, LongIsIntegral, ShortIsIntegral} +import scala.math.Ordering + + +object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOrdering { + private def checkOverflow(res: Int, x: Byte, y: Byte, op: String): Unit = { + if (res > Byte.MaxValue || res < Byte.MinValue) { + throw new ArithmeticException(s"$x $op $y caused overflow.") + } + } + + override def plus(x: Byte, y: Byte): Byte = { + val tmp = x + y + checkOverflow(tmp, x, y, "+") + tmp.toByte + } + + override def minus(x: Byte, y: Byte): Byte = { + val tmp = x - y + checkOverflow(tmp, x, y, "-") + tmp.toByte + } + + override def times(x: Byte, y: Byte): Byte = { + val tmp = x * y + checkOverflow(tmp, x, y, "*") + tmp.toByte + } + + override def negate(x: Byte): Byte = { + if (x == Byte.MinValue) { // if and only if x is Byte.MinValue, overflow can happen + throw new ArithmeticException(s"- $x caused overflow.") + } + (-x).toByte + } +} + + +object ShortExactNumeric extends ShortIsIntegral with Ordering.ShortOrdering { + private def checkOverflow(res: Int, x: Short, y: Short, op: String): Unit = { + if (res > Short.MaxValue || res < Short.MinValue) { + throw new ArithmeticException(s"$x $op $y caused overflow.") + } + } + + override def plus(x: Short, y: Short): Short = { + val tmp = x + y + checkOverflow(tmp, x, y, "+") + tmp.toShort + } + + override def minus(x: Short, y: Short): Short = { + val tmp = x - y + checkOverflow(tmp, x, y, "-") + tmp.toShort + } + + override def times(x: Short, y: Short): Short = { + val tmp = x * y + checkOverflow(tmp, x, y, "*") + tmp.toShort + } + + override def negate(x: Short): Short = { + if (x == Short.MinValue) { // if and only if x is Byte.MinValue, overflow can happen + throw new ArithmeticException(s"- $x caused overflow.") + } + (-x).toShort + } +} + + +object IntegerExactNumeric extends IntIsIntegral with Ordering.IntOrdering { + override def plus(x: Int, y: Int): Int = Math.addExact(x, y) + + override def minus(x: Int, y: Int): Int = Math.subtractExact(x, y) + + override def times(x: Int, y: Int): Int = Math.multiplyExact(x, y) + + override def negate(x: Int): Int = Math.negateExact(x) +} + +object LongExactNumeric extends LongIsIntegral with Ordering.LongOrdering { + override def plus(x: Long, y: Long): Long = Math.addExact(x, y) + + override def minus(x: Long, y: Long): Long = Math.subtractExact(x, y) + + override def times(x: Long, y: Long): Long = Math.multiplyExact(x, y) + + override def negate(x: Long): Long = Math.negateExact(x) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 1318ab185983..d35fff1293ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -59,8 +59,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Add(positiveIntLit, negativeIntLit), -1) checkEvaluation(Add(positiveLongLit, negativeLongLit), -1L) - DataTypeTestUtils.numericAndInterval.foreach { tpe => - checkConsistencyBetweenInterpretedAndCodegen(Add, tpe, tpe) + Seq("true", "false").foreach { checkOverflow => + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> checkOverflow) { + DataTypeTestUtils.numericAndInterval.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegenAllowingException(Add, tpe, tpe) + } + } } } @@ -75,6 +79,22 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(UnaryMinus(Literal(Int.MinValue)), Int.MinValue) checkEvaluation(UnaryMinus(Literal(Short.MinValue)), Short.MinValue) checkEvaluation(UnaryMinus(Literal(Byte.MinValue)), Byte.MinValue) + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true") { + checkExceptionInExpression[ArithmeticException]( + UnaryMinus(Literal(Long.MinValue)), "overflow") + checkExceptionInExpression[ArithmeticException]( + UnaryMinus(Literal(Int.MinValue)), "overflow") + checkExceptionInExpression[ArithmeticException]( + UnaryMinus(Literal(Short.MinValue)), "overflow") + checkExceptionInExpression[ArithmeticException]( + UnaryMinus(Literal(Byte.MinValue)), "overflow") + checkEvaluation(UnaryMinus(positiveShortLit), (- positiveShort).toShort) + checkEvaluation(UnaryMinus(negativeShortLit), (- negativeShort).toShort) + checkEvaluation(UnaryMinus(positiveIntLit), - positiveInt) + checkEvaluation(UnaryMinus(negativeIntLit), - negativeInt) + checkEvaluation(UnaryMinus(positiveLongLit), - positiveLong) + checkEvaluation(UnaryMinus(negativeLongLit), - negativeLong) + } checkEvaluation(UnaryMinus(positiveShortLit), (- positiveShort).toShort) checkEvaluation(UnaryMinus(negativeShortLit), (- negativeShort).toShort) checkEvaluation(UnaryMinus(positiveIntLit), - positiveInt) @@ -100,8 +120,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Subtract(positiveIntLit, negativeIntLit), positiveInt - negativeInt) checkEvaluation(Subtract(positiveLongLit, negativeLongLit), positiveLong - negativeLong) - DataTypeTestUtils.numericAndInterval.foreach { tpe => - checkConsistencyBetweenInterpretedAndCodegen(Subtract, tpe, tpe) + Seq("true", "false").foreach { checkOverflow => + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> checkOverflow) { + DataTypeTestUtils.numericAndInterval.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegenAllowingException(Subtract, tpe, tpe) + } + } } } @@ -118,8 +142,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Multiply(positiveIntLit, negativeIntLit), positiveInt * negativeInt) checkEvaluation(Multiply(positiveLongLit, negativeLongLit), positiveLong * negativeLong) - DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => - checkConsistencyBetweenInterpretedAndCodegen(Multiply, tpe, tpe) + Seq("true", "false").foreach { checkOverflow => + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> checkOverflow) { + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegenAllowingException(Multiply, tpe, tpe) + } + } } } @@ -376,4 +404,100 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2) assert(ctx2.inlinedMutableStates.size == 1) } + + test("SPARK-24598: overflow on long returns wrong result") { + val maxLongLiteral = Literal(Long.MaxValue) + val minLongLiteral = Literal(Long.MinValue) + val e1 = Add(maxLongLiteral, Literal(1L)) + val e2 = Subtract(maxLongLiteral, Literal(-1L)) + val e3 = Multiply(maxLongLiteral, Literal(2L)) + val e4 = Add(minLongLiteral, minLongLiteral) + val e5 = Subtract(minLongLiteral, maxLongLiteral) + val e6 = Multiply(minLongLiteral, minLongLiteral) + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true") { + Seq(e1, e2, e3, e4, e5, e6).foreach { e => + checkExceptionInExpression[ArithmeticException](e, "overflow") + } + } + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "false") { + checkEvaluation(e1, Long.MinValue) + checkEvaluation(e2, Long.MinValue) + checkEvaluation(e3, -2L) + checkEvaluation(e4, 0L) + checkEvaluation(e5, 1L) + checkEvaluation(e6, 0L) + } + } + + test("SPARK-24598: overflow on integer returns wrong result") { + val maxIntLiteral = Literal(Int.MaxValue) + val minIntLiteral = Literal(Int.MinValue) + val e1 = Add(maxIntLiteral, Literal(1)) + val e2 = Subtract(maxIntLiteral, Literal(-1)) + val e3 = Multiply(maxIntLiteral, Literal(2)) + val e4 = Add(minIntLiteral, minIntLiteral) + val e5 = Subtract(minIntLiteral, maxIntLiteral) + val e6 = Multiply(minIntLiteral, minIntLiteral) + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true") { + Seq(e1, e2, e3, e4, e5, e6).foreach { e => + checkExceptionInExpression[ArithmeticException](e, "overflow") + } + } + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "false") { + checkEvaluation(e1, Int.MinValue) + checkEvaluation(e2, Int.MinValue) + checkEvaluation(e3, -2) + checkEvaluation(e4, 0) + checkEvaluation(e5, 1) + checkEvaluation(e6, 0) + } + } + + test("SPARK-24598: overflow on short returns wrong result") { + val maxShortLiteral = Literal(Short.MaxValue) + val minShortLiteral = Literal(Short.MinValue) + val e1 = Add(maxShortLiteral, Literal(1.toShort)) + val e2 = Subtract(maxShortLiteral, Literal((-1).toShort)) + val e3 = Multiply(maxShortLiteral, Literal(2.toShort)) + val e4 = Add(minShortLiteral, minShortLiteral) + val e5 = Subtract(minShortLiteral, maxShortLiteral) + val e6 = Multiply(minShortLiteral, minShortLiteral) + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true") { + Seq(e1, e2, e3, e4, e5, e6).foreach { e => + checkExceptionInExpression[ArithmeticException](e, "overflow") + } + } + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "false") { + checkEvaluation(e1, Short.MinValue) + checkEvaluation(e2, Short.MinValue) + checkEvaluation(e3, (-2).toShort) + checkEvaluation(e4, 0.toShort) + checkEvaluation(e5, 1.toShort) + checkEvaluation(e6, 0.toShort) + } + } + + test("SPARK-24598: overflow on byte returns wrong result") { + val maxByteLiteral = Literal(Byte.MaxValue) + val minByteLiteral = Literal(Byte.MinValue) + val e1 = Add(maxByteLiteral, Literal(1.toByte)) + val e2 = Subtract(maxByteLiteral, Literal((-1).toByte)) + val e3 = Multiply(maxByteLiteral, Literal(2.toByte)) + val e4 = Add(minByteLiteral, minByteLiteral) + val e5 = Subtract(minByteLiteral, maxByteLiteral) + val e6 = Multiply(minByteLiteral, minByteLiteral) + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true") { + Seq(e1, e2, e3, e4, e5, e6).foreach { e => + checkExceptionInExpression[ArithmeticException](e, "overflow") + } + } + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "false") { + checkEvaluation(e1, Byte.MinValue) + checkEvaluation(e2, Byte.MinValue) + checkEvaluation(e3, (-2).toByte) + checkEvaluation(e4, 0.toByte) + checkEvaluation(e5, 1.toByte) + checkEvaluation(e6, 0.toByte) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index a2c0ce35df23..bc1f31b101c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -359,6 +359,26 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa } } + /** + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. If an exception is thrown, + * it checks that both modes throw the same exception. + * + * This method test against binary expressions by feeding them arbitrary literals of `dataType1` + * and `dataType2`. + */ + def checkConsistencyBetweenInterpretedAndCodegenAllowingException( + c: (Expression, Expression) => Expression, + dataType1: DataType, + dataType2: DataType): Unit = { + forAll ( + LiteralGenerator.randomGen(dataType1), + LiteralGenerator.randomGen(dataType2) + ) { (l1: Literal, l2: Literal) => + cmpInterpretWithCodegen(EmptyRow, c(l1, l2), true) + } + } + /** * Test evaluation results between Interpreted mode and Codegen mode, making sure we have * consistent result regardless of the evaluation method we use. @@ -398,23 +418,54 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa } } - def cmpInterpretWithCodegen(inputRow: InternalRow, expr: Expression): Unit = { - val interpret = try { - evaluateWithoutCodegen(expr, inputRow) + def cmpInterpretWithCodegen( + inputRow: InternalRow, + expr: Expression, + exceptionAllowed: Boolean = false): Unit = { + val (interpret, interpretExc) = try { + (Some(evaluateWithoutCodegen(expr, inputRow)), None) } catch { - case e: Exception => fail(s"Exception evaluating $expr", e) + case e: Exception => if (exceptionAllowed) { + (None, Some(e)) + } else { + fail(s"Exception evaluating $expr", e) + } } val plan = generateProject( GenerateMutableProjection.generate(Alias(expr, s"Optimized($expr)")() :: Nil), expr) - val codegen = plan(inputRow).get(0, expr.dataType) + val (codegen, codegenExc) = try { + (Some(plan(inputRow).get(0, expr.dataType)), None) + } catch { + case e: Exception => if (exceptionAllowed) { + (None, Some(e)) + } else { + fail(s"Exception evaluating $expr", e) + } + } - if (!compareResults(interpret, codegen)) { - fail(s"Incorrect evaluation: $expr, interpret: $interpret, codegen: $codegen") + if (interpret.isDefined && codegen.isDefined && !compareResults(interpret.get, codegen.get)) { + fail(s"Incorrect evaluation: $expr, interpret: ${interpret.get}, codegen: ${codegen.get}") + } else if (interpretExc.isDefined && codegenExc.isEmpty) { + fail(s"Incorrect evaluation: $expr, interpet threw exception ${interpretExc.get}") + } else if (interpretExc.isEmpty && codegenExc.isDefined) { + fail(s"Incorrect evaluation: $expr, codegen threw exception ${codegenExc.get}") + } else if (interpretExc.isDefined && codegenExc.isDefined + && !compareExceptions(interpretExc.get, codegenExc.get)) { + fail(s"Different exception evaluating: $expr, " + + s"interpret: ${interpretExc.get}, codegen: ${codegenExc.get}") } } + /** + * Checks the equality between two exceptions. Returns true iff the two exceptions are instances + * of the same class and they have the same message. + */ + private[this] def compareExceptions(e1: Exception, e2: Exception): Boolean = { + e1.getClass == e2.getClass && e1.getMessage == e2.getMessage + } + /** * Check the equality between result of expression and expected value, it will handle * Array[Byte] and Spread[Double]. diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql index 86432a845b6e..1012db72e187 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql @@ -33,6 +33,10 @@ INSERT INTO INT4_TBL VALUES ('-2147483647'); -- INSERT INTO INT4_TBL(f1) VALUES ('123 5'); -- INSERT INTO INT4_TBL(f1) VALUES (''); +-- We cannot test this when failOnOverFlow=true here +-- because exception happens in the executors and the +-- output stacktrace cannot have an exact match +set spark.sql.arithmeticOperations.failOnOverFlow=false; SELECT '' AS five, * FROM INT4_TBL; diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out index 456b1ef962d4..8b9c20e7eb20 100644 --- a/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 53 +-- Number of queries: 54 -- !query 0 @@ -51,30 +51,27 @@ struct<> -- !query 6 -SELECT '' AS five, * FROM INT4_TBL +set spark.sql.arithmeticOperations.failOnOverFlow=false -- !query 6 schema -struct +struct -- !query 6 output --123456 - -2147483647 - 0 - 123456 - 2147483647 +spark.sql.arithmeticOperations.failOnOverFlow false -- !query 7 -SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> smallint('0') +SELECT '' AS five, * FROM INT4_TBL -- !query 7 schema -struct +struct -- !query 7 output -123456 -2147483647 + 0 123456 2147483647 -- !query 8 -SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> int('0') +SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> smallint('0') -- !query 8 schema struct -- !query 8 output @@ -85,15 +82,18 @@ struct -- !query 9 -SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = smallint('0') +SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> int('0') -- !query 9 schema -struct +struct -- !query 9 output -0 +-123456 + -2147483647 + 123456 + 2147483647 -- !query 10 -SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = int('0') +SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = smallint('0') -- !query 10 schema struct -- !query 10 output @@ -101,16 +101,15 @@ struct -- !query 11 -SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < smallint('0') +SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = int('0') -- !query 11 schema -struct +struct -- !query 11 output --123456 - -2147483647 +0 -- !query 12 -SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < int('0') +SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < smallint('0') -- !query 12 schema struct -- !query 12 output @@ -119,17 +118,16 @@ struct -- !query 13 -SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= smallint('0') +SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < int('0') -- !query 13 schema -struct +struct -- !query 13 output -123456 -2147483647 - 0 -- !query 14 -SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= int('0') +SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= smallint('0') -- !query 14 schema struct -- !query 14 output @@ -139,16 +137,17 @@ struct -- !query 15 -SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > smallint('0') +SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= int('0') -- !query 15 schema -struct +struct -- !query 15 output -123456 - 2147483647 +-123456 + -2147483647 + 0 -- !query 16 -SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > int('0') +SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > smallint('0') -- !query 16 schema struct -- !query 16 output @@ -157,17 +156,16 @@ struct -- !query 17 -SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= smallint('0') +SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > int('0') -- !query 17 schema -struct +struct -- !query 17 output -0 - 123456 +123456 2147483647 -- !query 18 -SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= int('0') +SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= smallint('0') -- !query 18 schema struct -- !query 18 output @@ -177,84 +175,81 @@ struct -- !query 19 -SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % smallint('2')) = smallint('1') +SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= int('0') -- !query 19 schema -struct +struct -- !query 19 output -2147483647 +0 + 123456 + 2147483647 -- !query 20 -SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = smallint('0') +SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % smallint('2')) = smallint('1') -- !query 20 schema -struct +struct -- !query 20 output --123456 - 0 - 123456 +2147483647 -- !query 21 -SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i +SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = smallint('0') -- !query 21 schema -struct +struct -- !query 21 output --123456 -246912 - -2147483647 2 - 0 0 - 123456 246912 - 2147483647 -2 +-123456 + 0 + 123456 -- !query 22 SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i -WHERE abs(f1) < 1073741824 -- !query 22 schema struct -- !query 22 output -123456 -246912 + -2147483647 2 0 0 123456 246912 + 2147483647 -2 -- !query 23 -SELECT '' AS five, i.f1, i.f1 * int('2') AS x FROM INT4_TBL i +SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i +WHERE abs(f1) < 1073741824 -- !query 23 schema struct -- !query 23 output -123456 -246912 - -2147483647 2 0 0 123456 246912 - 2147483647 -2 -- !query 24 SELECT '' AS five, i.f1, i.f1 * int('2') AS x FROM INT4_TBL i -WHERE abs(f1) < 1073741824 -- !query 24 schema struct -- !query 24 output -123456 -246912 + -2147483647 2 0 0 123456 246912 + 2147483647 -2 -- !query 25 -SELECT '' AS five, i.f1, i.f1 + smallint('2') AS x FROM INT4_TBL i +SELECT '' AS five, i.f1, i.f1 * int('2') AS x FROM INT4_TBL i +WHERE abs(f1) < 1073741824 -- !query 25 schema struct -- !query 25 output --123456 -123454 - -2147483647 -2147483645 - 0 2 - 123456 123458 - 2147483647 -2147483647 +-123456 -246912 + 0 0 + 123456 246912 -- !query 26 SELECT '' AS five, i.f1, i.f1 + smallint('2') AS x FROM INT4_TBL i -WHERE f1 < 2147483646 -- !query 26 schema struct -- !query 26 output @@ -262,10 +257,12 @@ struct -2147483647 -2147483645 0 2 123456 123458 + 2147483647 -2147483647 -- !query 27 -SELECT '' AS five, i.f1, i.f1 + int('2') AS x FROM INT4_TBL i +SELECT '' AS five, i.f1, i.f1 + smallint('2') AS x FROM INT4_TBL i +WHERE f1 < 2147483646 -- !query 27 schema struct -- !query 27 output @@ -273,12 +270,10 @@ struct -2147483647 -2147483645 0 2 123456 123458 - 2147483647 -2147483647 -- !query 28 SELECT '' AS five, i.f1, i.f1 + int('2') AS x FROM INT4_TBL i -WHERE f1 < 2147483646 -- !query 28 schema struct -- !query 28 output @@ -286,39 +281,40 @@ struct -2147483647 -2147483645 0 2 123456 123458 + 2147483647 -2147483647 -- !query 29 -SELECT '' AS five, i.f1, i.f1 - smallint('2') AS x FROM INT4_TBL i +SELECT '' AS five, i.f1, i.f1 + int('2') AS x FROM INT4_TBL i +WHERE f1 < 2147483646 -- !query 29 schema struct -- !query 29 output --123456 -123458 - -2147483647 2147483647 - 0 -2 - 123456 123454 - 2147483647 2147483645 +-123456 -123454 + -2147483647 -2147483645 + 0 2 + 123456 123458 -- !query 30 SELECT '' AS five, i.f1, i.f1 - smallint('2') AS x FROM INT4_TBL i -WHERE f1 > -2147483647 -- !query 30 schema struct -- !query 30 output -123456 -123458 + -2147483647 2147483647 0 -2 123456 123454 2147483647 2147483645 -- !query 31 -SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i +SELECT '' AS five, i.f1, i.f1 - smallint('2') AS x FROM INT4_TBL i +WHERE f1 > -2147483647 -- !query 31 schema struct -- !query 31 output -123456 -123458 - -2147483647 2147483647 0 -2 123456 123454 2147483647 2147483645 @@ -326,30 +322,30 @@ struct -- !query 32 SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i -WHERE f1 > -2147483647 -- !query 32 schema struct -- !query 32 output -123456 -123458 + -2147483647 2147483647 0 -2 123456 123454 2147483647 2147483645 -- !query 33 -SELECT '' AS five, i.f1, i.f1 / smallint('2') AS x FROM INT4_TBL i +SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i +WHERE f1 > -2147483647 -- !query 33 schema struct -- !query 33 output --123456 -61728 - -2147483647 -1073741823 - 0 0 - 123456 61728 - 2147483647 1073741823 +-123456 -123458 + 0 -2 + 123456 123454 + 2147483647 2147483645 -- !query 34 -SELECT '' AS five, i.f1, i.f1 / int('2') AS x FROM INT4_TBL i +SELECT '' AS five, i.f1, i.f1 / smallint('2') AS x FROM INT4_TBL i -- !query 34 schema struct -- !query 34 output @@ -361,47 +357,51 @@ struct -- !query 35 -SELECT -2+3 AS one +SELECT '' AS five, i.f1, i.f1 / int('2') AS x FROM INT4_TBL i -- !query 35 schema -struct +struct -- !query 35 output -1 +-123456 -61728 + -2147483647 -1073741823 + 0 0 + 123456 61728 + 2147483647 1073741823 -- !query 36 -SELECT 4-2 AS two +SELECT -2+3 AS one -- !query 36 schema -struct +struct -- !query 36 output -2 +1 -- !query 37 -SELECT 2- -1 AS three +SELECT 4-2 AS two -- !query 37 schema -struct +struct -- !query 37 output -3 +2 -- !query 38 -SELECT 2 - -2 AS four +SELECT 2- -1 AS three -- !query 38 schema -struct +struct -- !query 38 output -4 +3 -- !query 39 -SELECT smallint('2') * smallint('2') = smallint('16') / smallint('4') AS true +SELECT 2 - -2 AS four -- !query 39 schema -struct +struct -- !query 39 output -true +4 -- !query 40 -SELECT int('2') * smallint('2') = smallint('16') / int('4') AS true +SELECT smallint('2') * smallint('2') = smallint('16') / smallint('4') AS true -- !query 40 schema struct -- !query 40 output @@ -409,7 +409,7 @@ true -- !query 41 -SELECT smallint('2') * int('2') = int('16') / smallint('4') AS true +SELECT int('2') * smallint('2') = smallint('16') / int('4') AS true -- !query 41 schema struct -- !query 41 output @@ -417,70 +417,78 @@ true -- !query 42 -SELECT int('1000') < int('999') AS `false` +SELECT smallint('2') * int('2') = int('16') / smallint('4') AS true -- !query 42 schema -struct +struct -- !query 42 output -false +true -- !query 43 -SELECT 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 AS ten +SELECT int('1000') < int('999') AS `false` -- !query 43 schema -struct +struct -- !query 43 output -10 +false -- !query 44 -SELECT 2 + 2 / 2 AS three +SELECT 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 AS ten -- !query 44 schema -struct +struct -- !query 44 output -3 +10 -- !query 45 -SELECT (2 + 2) / 2 AS two +SELECT 2 + 2 / 2 AS three -- !query 45 schema -struct +struct -- !query 45 output -2 +3 -- !query 46 -SELECT string(shiftleft(int(-1), 31)) +SELECT (2 + 2) / 2 AS two -- !query 46 schema -struct +struct -- !query 46 output --2147483648 +2 -- !query 47 -SELECT string(int(shiftleft(int(-1), 31))+1) +SELECT string(shiftleft(int(-1), 31)) -- !query 47 schema -struct +struct -- !query 47 output --2147483647 +-2147483648 -- !query 48 -SELECT int(-2147483648) % int(-1) +SELECT string(int(shiftleft(int(-1), 31))+1) -- !query 48 schema -struct<(CAST(-2147483648 AS INT) % CAST(-1 AS INT)):int> +struct -- !query 48 output -0 +-2147483647 -- !query 49 -SELECT int(-2147483648) % smallint(-1) +SELECT int(-2147483648) % int(-1) -- !query 49 schema -struct<(CAST(-2147483648 AS INT) % CAST(CAST(-1 AS SMALLINT) AS INT)):int> +struct<(CAST(-2147483648 AS INT) % CAST(-1 AS INT)):int> -- !query 49 output 0 -- !query 50 +SELECT int(-2147483648) % smallint(-1) +-- !query 50 schema +struct<(CAST(-2147483648 AS INT) % CAST(CAST(-1 AS SMALLINT) AS INT)):int> +-- !query 50 output +0 + + +-- !query 51 SELECT x, int(x) AS int4_value FROM (VALUES double(-2.5), double(-1.5), @@ -489,9 +497,9 @@ FROM (VALUES double(-2.5), double(0.5), double(1.5), double(2.5)) t(x) --- !query 50 schema +-- !query 51 schema struct --- !query 50 output +-- !query 51 output -0.5 0 -1.5 -1 -2.5 -2 @@ -501,7 +509,7 @@ struct 2.5 2 --- !query 51 +-- !query 52 SELECT x, int(x) AS int4_value FROM (VALUES cast(-2.5 as decimal(38, 18)), cast(-1.5 as decimal(38, 18)), @@ -510,9 +518,9 @@ FROM (VALUES cast(-2.5 as decimal(38, 18)), cast(0.5 as decimal(38, 18)), cast(1.5 as decimal(38, 18)), cast(2.5 as decimal(38, 18))) t(x) --- !query 51 schema +-- !query 52 schema struct --- !query 51 output +-- !query 52 output -0.5 0 -1.5 -1 -2.5 -2 @@ -522,9 +530,9 @@ struct 2.5 2 --- !query 52 +-- !query 53 DROP TABLE INT4_TBL --- !query 52 schema +-- !query 53 schema struct<> --- !query 52 output +-- !query 53 output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 8cc702057943..6c1a66cae227 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -113,11 +113,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall val random = new Random(seed) def randomBound(): Long = { - val n = if (random.nextBoolean()) { - random.nextLong() % (Long.MaxValue / (100 * MAX_NUM_STEPS)) - } else { - random.nextLong() / 2 - } + val n = random.nextLong() % (Long.MaxValue / (100 * MAX_NUM_STEPS)) if (random.nextBoolean()) n else -n }