From 5c662f6987b7cdad016e4134f980fd0b8e02d220 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 20 Jun 2018 14:20:04 +0200 Subject: [PATCH 01/23] [SPARK-24598][SQL] Overflow on airthmetic operation returns incorrect result --- .../sql/catalyst/expressions/arithmetic.scala | 72 +++++++++++++++++-- .../ArithmeticExpressionSuite.scala | 24 ++++++- .../expressions/ExpressionEvalHelper.scala | 65 +++++++++++++++-- 3 files changed, 145 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index fe91e520169b..774aa332cfa8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -128,17 +128,31 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { def calendarIntervalMethod: String = sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode") + def checkOverflowCode(result: String, op1: String, op2: String): String = + sys.error("BinaryArithmetics must override either checkOverflowCode or genCode") + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case _: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)") + // In the following cases, overflow can happen, so we need to check the result is valid. + // Otherwise we throw an ArithmeticException // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => - defineCodeGen(ctx, ev, - (eval1, eval2) => s"(${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2)") + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + |${ev.value} = (${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2); + |${checkOverflowCode(ev.value, eval1, eval2)} + """.stripMargin + }) case _ => - defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + s""" + |${ev.value} = $eval1 $symbol $eval2; + |${checkOverflowCode(ev.value, eval1, eval2)} + """.stripMargin + }) } } @@ -169,9 +183,25 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { if (dataType.isInstanceOf[CalendarIntervalType]) { input1.asInstanceOf[CalendarInterval].add(input2.asInstanceOf[CalendarInterval]) } else { - numeric.plus(input1, input2) + val result = numeric.plus(input1, input2) + val resSignum = numeric.signum(result) + val input1Signum = numeric.signum(input1) + val input2Signum = numeric.signum(input2) + if (resSignum != -1 && input1Signum == -1 && input2Signum == -1 + || resSignum != 1 && input1Signum == 1 && input2Signum == 1) { + throw new ArithmeticException(s"$input1 + $input2 caused overflow.") + } + result } } + + override def checkOverflowCode(result: String, op1: String, op2: String): String = { + s""" + |if ($result >= 0 && $op1 < 0 && $op2 < 0 || $result <= 0 && $op1 > 0 && $op2 > 0) { + | throw new ArithmeticException($op1 + " + " + $op2 + " caused overflow."); + |} + """.stripMargin + } } @ExpressionDescription( @@ -197,9 +227,25 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti if (dataType.isInstanceOf[CalendarIntervalType]) { input1.asInstanceOf[CalendarInterval].subtract(input2.asInstanceOf[CalendarInterval]) } else { - numeric.minus(input1, input2) + val result = numeric.minus(input1, input2) + val resSignum = numeric.signum(result) + val input1Signum = numeric.signum(input1) + val input2Signum = numeric.signum(input2) + if (resSignum != 1 && input1Signum == 1 && input2Signum == -1 + || resSignum != -1 && input1Signum == -1 && input2Signum == 1) { + throw new ArithmeticException(s"$input1 - $input2 caused overflow.") + } + result } } + + override def checkOverflowCode(result: String, op1: String, op2: String): String = { + s""" + |if ($result <= 0 && $op1 > 0 && $op2 < 0 || $result >= 0 && $op1 < 0 && $op2 > 0) { + | throw new ArithmeticException($op1 + " - " + $op2 + " caused overflow."); + |} + """.stripMargin + } } @ExpressionDescription( @@ -218,7 +264,21 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti private lazy val numeric = TypeUtils.getNumeric(dataType) - protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val result = numeric.times(input1, input2) + if (numeric.signum(result) != numeric.signum(input1) * numeric.signum(input2)) { + throw new ArithmeticException(s"$input1 * $input2 caused overflow.") + } + result + } + + override def checkOverflowCode(result: String, op1: String, op2: String): String = { + s""" + |if (Math.signum($result) != Math.signum($op1) * Math.signum($op2)) { + | throw new ArithmeticException($op1 + " * " + $op2 + " caused overflow."); + |} + """.stripMargin + } } // Common base trait for Divide and Remainder, since these two classes are almost identical diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 6edb4348f830..3837a4b9cea8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -59,7 +59,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Add(positiveLongLit, negativeLongLit), -1L) DataTypeTestUtils.numericAndInterval.foreach { tpe => - checkConsistencyBetweenInterpretedAndCodegen(Add, tpe, tpe) + checkConsistencyBetweenInterpretedAndCodegenAllowingException(Add, tpe, tpe) } } @@ -100,7 +100,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Subtract(positiveLongLit, negativeLongLit), positiveLong - negativeLong) DataTypeTestUtils.numericAndInterval.foreach { tpe => - checkConsistencyBetweenInterpretedAndCodegen(Subtract, tpe, tpe) + checkConsistencyBetweenInterpretedAndCodegenAllowingException(Subtract, tpe, tpe) } } @@ -118,7 +118,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Multiply(positiveLongLit, negativeLongLit), positiveLong * negativeLong) DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => - checkConsistencyBetweenInterpretedAndCodegen(Multiply, tpe, tpe) + checkConsistencyBetweenInterpretedAndCodegenAllowingException(Multiply, tpe, tpe) } } @@ -354,4 +354,22 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Greatest(Seq(Literal(1), Literal(1))).genCode(ctx2) assert(ctx2.inlinedMutableStates.size == 1) } + + test("SPARK-24598: overflow on BigInt returns wrong result") { + val maxLongLiteral = Literal(Long.MaxValue) + val minLongLiteral = Literal(Long.MinValue) + checkExceptionInExpression[ArithmeticException]( + Add(maxLongLiteral, Literal(1L)), "caused overflow") + checkExceptionInExpression[ArithmeticException]( + Subtract(maxLongLiteral, Literal(-1L)), "caused overflow") + checkExceptionInExpression[ArithmeticException]( + Multiply(maxLongLiteral, Literal(2L)), "caused overflow") + + checkExceptionInExpression[ArithmeticException]( + Add(minLongLiteral, minLongLiteral), "caused overflow") + checkExceptionInExpression[ArithmeticException]( + Subtract(minLongLiteral, maxLongLiteral), "caused overflow") + checkExceptionInExpression[ArithmeticException]( + Multiply(minLongLiteral, minLongLiteral), "caused overflow") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 14bfa212b549..0e4cb01c640b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -315,6 +315,26 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa } } + /** + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. If an exception is thrown, + * it checks that both modes throw the same exception. + * + * This method test against binary expressions by feeding them arbitrary literals of `dataType1` + * and `dataType2`. + */ + def checkConsistencyBetweenInterpretedAndCodegenAllowingException( + c: (Expression, Expression) => Expression, + dataType1: DataType, + dataType2: DataType): Unit = { + forAll ( + LiteralGenerator.randomGen(dataType1), + LiteralGenerator.randomGen(dataType2) + ) { (l1: Literal, l2: Literal) => + cmpInterpretWithCodegen(EmptyRow, c(l1, l2), true) + } + } + /** * Test evaluation results between Interpreted mode and Codegen mode, making sure we have * consistent result regardless of the evaluation method we use. @@ -354,23 +374,54 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa } } - private def cmpInterpretWithCodegen(inputRow: InternalRow, expr: Expression): Unit = { - val interpret = try { - evaluateWithoutCodegen(expr, inputRow) + private def cmpInterpretWithCodegen( + inputRow: InternalRow, + expr: Expression, + exceptionAllowed: Boolean = false): Unit = { + val (interpret, interpretExc) = try { + (Some(evaluateWithoutCodegen(expr, inputRow)), None) } catch { - case e: Exception => fail(s"Exception evaluating $expr", e) + case e: Exception => if (exceptionAllowed) { + (None, Some(e)) + } else { + fail(s"Exception evaluating $expr", e) + } } val plan = generateProject( GenerateMutableProjection.generate(Alias(expr, s"Optimized($expr)")() :: Nil), expr) - val codegen = plan(inputRow).get(0, expr.dataType) + val (codegen, codegenExc) = try { + (Some(plan(inputRow).get(0, expr.dataType)), None) + } catch { + case e: Exception => if (exceptionAllowed) { + (None, Some(e)) + } else { + fail(s"Exception evaluating $expr", e) + } + } - if (!compareResults(interpret, codegen)) { - fail(s"Incorrect evaluation: $expr, interpret: $interpret, codegen: $codegen") + if (interpret.isDefined && codegen.isDefined && !compareResults(interpret.get, codegen.get)) { + fail(s"Incorrect evaluation: $expr, interpret: ${interpret.get}, codegen: ${codegen.get}") + } else if (interpretExc.isDefined && codegenExc.isEmpty) { + fail(s"Incorrect evaluation: $expr, interpet threw exception ${interpretExc.get}") + } else if (interpretExc.isEmpty && codegenExc.isDefined) { + fail(s"Incorrect evaluation: $expr, codegen threw exception ${codegenExc.get}") + } else if (interpretExc.isDefined && codegenExc.isDefined + && !compareExceptions(interpretExc.get, codegenExc.get)) { + fail(s"Different exception evaluating: $expr, " + + s"interpret: ${interpretExc.get}, codegen: ${codegenExc.get}") } } + /** + * Checks the equality between two exceptions. Returns true iff the two exceptions are instances + * of the same class and they have the same message. + */ + private[this] def compareExceptions(e1: Exception, e2: Exception): Boolean = { + e1.getClass == e2.getClass && e1.getMessage == e2.getMessage + } + /** * Check the equality between result of expression and expected value, it will handle * Array[Byte] and Spread[Double]. From fad75fa137778270dc83a83aa34103b77c5f2bf8 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 20 Jun 2018 14:39:51 +0200 Subject: [PATCH 02/23] fix scalastyle --- .../expressions/ExpressionEvalHelper.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 0e4cb01c640b..444ba3a425dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -316,13 +316,13 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa } /** - * Test evaluation results between Interpreted mode and Codegen mode, making sure we have - * consistent result regardless of the evaluation method we use. If an exception is thrown, - * it checks that both modes throw the same exception. - * - * This method test against binary expressions by feeding them arbitrary literals of `dataType1` - * and `dataType2`. - */ + * Test evaluation results between Interpreted mode and Codegen mode, making sure we have + * consistent result regardless of the evaluation method we use. If an exception is thrown, + * it checks that both modes throw the same exception. + * + * This method test against binary expressions by feeding them arbitrary literals of `dataType1` + * and `dataType2`. + */ def checkConsistencyBetweenInterpretedAndCodegenAllowingException( c: (Expression, Expression) => Expression, dataType1: DataType, From 8591417be062240931b15597e788d7aa08e99a43 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 20 Jun 2018 18:05:18 +0200 Subject: [PATCH 03/23] fix ut failures --- .../spark/sql/catalyst/expressions/bitwiseExpressions.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index cc24e397cc14..8889d0438161 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -51,6 +51,8 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme } protected override def nullSafeEval(input1: Any, input2: Any): Any = and(input1, input2) + + override def checkOverflowCode(result: String, op1: String, op2: String): String = "" } /** @@ -83,6 +85,8 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet } protected override def nullSafeEval(input1: Any, input2: Any): Any = or(input1, input2) + + override def checkOverflowCode(result: String, op1: String, op2: String): String = "" } /** @@ -115,6 +119,8 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme } protected override def nullSafeEval(input1: Any, input2: Any): Any = xor(input1, input2) + + override def checkOverflowCode(result: String, op1: String, op2: String): String = "" } /** From 9c3df7d553f523581fe79a64a1bb167062728103 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 21 Jun 2018 16:48:42 +0200 Subject: [PATCH 04/23] use larger intermediate buffer for sum --- .../spark/sql/catalyst/expressions/aggregate/Sum.scala | 10 ++++++++-- .../scala/org/apache/spark/sql/types/Decimal.scala | 3 ++- .../spark/sql/catalyst/expressions/CastSuite.scala | 6 ++++++ .../org/apache/spark/sql/DataFrameRangeSuite.scala | 6 +----- .../org/apache/spark/sql/DatasetAggregatorSuite.scala | 9 +++++++++ 5 files changed, 26 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 86e40a9713b3..ad62c16a273f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -45,7 +45,13 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast case _ => DoubleType } - private lazy val sumDataType = resultType + private lazy val sumDataType = child.dataType match { + case LongType => DecimalType.BigIntDecimal + case _ => resultType + } + + private lazy val castToResultType: (Expression) => Expression = + if (sumDataType == resultType) (e: Expression) => e else (e: Expression) => Cast(e, resultType) private lazy val sum = AttributeReference("sum", sumDataType)() @@ -78,5 +84,5 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast ) } - override lazy val evaluateExpression: Expression = sum + override lazy val evaluateExpression: Expression = castToResultType(sum) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 6da4f28b1296..76e2e719f863 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -215,7 +215,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.eq(null)) { longVal / POW_10(_scale) } else { - decimalVal.longValue() + // This will throw an exception if overflow occurs + decimalVal.toLongExact } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 5b25bdf907c3..28e206ca5679 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -922,4 +922,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val ret6 = cast(Literal.create((1, Map(1 -> "a", 2 -> "b", 3 -> "c"))), StringType) checkEvaluation(ret6, "[1, [1 -> a, 2 -> b, 3 -> c]]") } + + test("SPARK-24598: Cast to long should fail on overflow") { + checkExceptionInExpression[ArithmeticException]( + cast(Literal.create(Decimal(Long.MaxValue) + Decimal(1)), LongType), "Overflow") + checkEvaluation(cast(Literal.create(Decimal(Long.MaxValue)), LongType), Long.MaxValue) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index b0b46640ff31..721846c6b9c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -114,11 +114,7 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall val random = new Random(seed) def randomBound(): Long = { - val n = if (random.nextBoolean()) { - random.nextLong() % (Long.MaxValue / (100 * MAX_NUM_STEPS)) - } else { - random.nextLong() / 2 - } + val n = random.nextLong() % (Long.MaxValue / (100 * MAX_NUM_STEPS)) if (random.nextBoolean()) n else -n } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 0e7eaa9e88d5..5a9f4227694f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scalalang.typed @@ -333,4 +334,12 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { df.groupBy($"i").agg(VeryComplexResultAgg.toColumn), Row(1, Row(Row(1, "a"), Row(1, "a"))) :: Row(2, Row(Row(2, "bc"), Row(2, "bc"))) :: Nil) } + + test("SPARK-24598: sum throws exception instead of silently overflow") { + val df1 = Seq(Long.MinValue, -10, Long.MaxValue).toDF("i") + checkAnswer(df1.agg(sum($"i")), Row(-11)) + val df2 = Seq(Long.MinValue, -10, 8).toDF("i") + val e = intercept[SparkException](df2.agg(sum($"i")).collect()) + assert(e.getCause.isInstanceOf[ArithmeticException]) + } } From ebdaf61cd2d1a663e86fe1905cbb92740ce92908 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 22 Jun 2018 13:38:39 +0200 Subject: [PATCH 05/23] fix UT error --- .../sql-tests/inputs/udaf-regrfunctions.sql | 2 ++ .../results/udaf-regrfunctions.sql.out | 20 +++++++++++++------ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql index 92c7e26e3add..5244545fa6bc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udaf-regrfunctions.sql @@ -47,6 +47,8 @@ CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (800, 7, 1, 1) as t1(id, px, y, x); +set spark.sql.codegen.wholeStage=false; + select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x), regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x), regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x) diff --git a/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out index d7d009a64bf8..50e8840ffef8 100644 --- a/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udaf-regrfunctions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 3 +-- Number of queries: 4 -- !query 0 @@ -41,13 +41,21 @@ struct<> -- !query 1 +set spark.sql.codegen.wholeStage=false +-- !query 1 schema +struct +-- !query 1 output +spark.sql.codegen.wholeStage false + + +-- !query 2 select px, var_pop(x), var_pop(y), corr(y,x), covar_samp(y,x), covar_pop(y,x), regr_count(y,x), regr_slope(y,x), regr_intercept(y,x), regr_r2(y,x), regr_sxx(y,x), regr_syy(y,x), regr_sxy(y,x), regr_avgx(y,x), regr_avgy(y,x), regr_count(y,x) from t1 group by px order by px --- !query 1 schema +-- !query 2 schema struct --- !query 1 output +-- !query 2 output 1 1.25 1.25 1.0 1.6666666666666667 1.25 4 1.0 0.0 1.0 5.0 5.0 5.0 2.5 2.5 4 2 1.25 0.0 NULL 0.0 0.0 4 0.0 1.0 1.0 5.0 0.0 0.0 2.5 1.0 4 3 0.0 1.25 NULL 0.0 0.0 4 NULL NULL NULL 0.0 5.0 0.0 1.0 2.5 4 @@ -57,11 +65,11 @@ struct --- !query 2 output +-- !query 3 output 101 4 102 4 103 4 From a0b862e04b628d957f30b0ffb32d1000c035dd9b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 22 Jun 2018 17:41:42 +0200 Subject: [PATCH 06/23] allow precision loss when converting decimal to long --- .../main/scala/org/apache/spark/sql/types/Decimal.scala | 8 +++++++- .../scala/org/apache/spark/sql/types/DecimalSuite.scala | 4 ++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 76e2e719f863..4a95f13b0e2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -216,7 +216,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { longVal / POW_10(_scale) } else { // This will throw an exception if overflow occurs - decimalVal.toLongExact + if (decimalVal.compare(LONG_MIN_BIG_DEC) < 0 || decimalVal.compare(LONG_MAX_BIG_DEC) > 0) { + throw new ArithmeticException("Overflow") + } + decimalVal.longValue() } } @@ -433,6 +436,9 @@ object Decimal { private val LONG_MAX_BIG_INT = BigInteger.valueOf(JLong.MAX_VALUE) private val LONG_MIN_BIG_INT = BigInteger.valueOf(JLong.MIN_VALUE) + private val LONG_MAX_BIG_DEC = BigDecimal.valueOf(JLong.MAX_VALUE) + private val LONG_MIN_BIG_DEC = BigDecimal.valueOf(JLong.MIN_VALUE) + def apply(value: Double): Decimal = new Decimal().set(value) def apply(value: Long): Decimal = new Decimal().set(value) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 10de90c6a44c..088c7ad4b7d9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -94,8 +94,8 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { checkValues(Decimal(2e18.toLong), 2e18, 2e18.toLong) checkValues(Decimal(Long.MaxValue), Long.MaxValue.toDouble, Long.MaxValue) checkValues(Decimal(Long.MinValue), Long.MinValue.toDouble, Long.MinValue) - checkValues(Decimal(Double.MaxValue), Double.MaxValue, 0L) - checkValues(Decimal(Double.MinValue), Double.MinValue, 0L) + assert(Decimal(Double.MaxValue).toDouble == Double.MaxValue) + assert(Decimal(Double.MinValue).toDouble == Double.MinValue) } // Accessor for the BigDecimal value of a Decimal, which will be null if it's using Longs From 74cd0a4e2f2dd3ee231eb212efbfb552a6d1ca24 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 22 Jun 2019 13:11:02 +0200 Subject: [PATCH 07/23] Handle NaN --- .../spark/sql/catalyst/expressions/arithmetic.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d1b572a427d5..fad0d11e707f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -267,15 +267,21 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti protected override def nullSafeEval(input1: Any, input2: Any): Any = { val result = numeric.times(input1, input2) - if (numeric.signum(result) != numeric.signum(input1) * numeric.signum(input2)) { + if (numeric.signum(result) != numeric.signum(input1) * numeric.signum(input2) && + !(result.isInstanceOf[Double] && !result.asInstanceOf[Double].isNaN) && + !(result.isInstanceOf[Float] && !result.asInstanceOf[Float].isNaN)) { throw new ArithmeticException(s"$input1 * $input2 caused overflow.") } result } override def checkOverflowCode(result: String, op1: String, op2: String): String = { + val isNaNCheck = dataType match { + case DoubleType | FloatType => s" && !java.lang.Double.isNaN($result)" + case _ => "" + } s""" - |if (Math.signum($result) != Math.signum($op1) * Math.signum($op2)) { + |if (Math.signum($result) != Math.signum($op1) * Math.signum($op2)$isNaNCheck) { | throw new ArithmeticException($op1 + " * " + $op2 + " caused overflow."); |} """.stripMargin From 2cfd946d2a568a469c25b204832616ce65af93c4 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 26 Jun 2019 13:39:41 +0200 Subject: [PATCH 08/23] Add conf flag for checking overflow --- .../sql/catalyst/expressions/arithmetic.scala | 46 +++++++++++++------ .../apache/spark/sql/internal/SQLConf.scala | 8 ++++ .../org/apache/spark/sql/types/Decimal.scala | 9 ++-- 3 files changed, 46 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index fad0d11e707f..6c971f46c00d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -117,6 +117,8 @@ case class Abs(child: Expression) abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { + protected val checkOverflow = SQLConf.get.getConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK) + override def dataType: DataType = left.dataType override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess @@ -142,16 +144,26 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val overflowCheck = if (overflowCheck) { + checkOverflowCode(ev.value, eval1, eval2) + } else { + "" + } s""" |${ev.value} = (${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2); - |${checkOverflowCode(ev.value, eval1, eval2)} + |$overflowCheck """.stripMargin }) case _ => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val overflowCheck = if (overflowCheck) { + checkOverflowCode(ev.value, eval1, eval2) + } else { + "" + } s""" |${ev.value} = $eval1 $symbol $eval2; - |${checkOverflowCode(ev.value, eval1, eval2)} + |$overflowCheck """.stripMargin }) } @@ -185,12 +197,14 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { input1.asInstanceOf[CalendarInterval].add(input2.asInstanceOf[CalendarInterval]) } else { val result = numeric.plus(input1, input2) - val resSignum = numeric.signum(result) - val input1Signum = numeric.signum(input1) - val input2Signum = numeric.signum(input2) - if (resSignum != -1 && input1Signum == -1 && input2Signum == -1 + if (checkOverflow) { + val resSignum = numeric.signum(result) + val input1Signum = numeric.signum(input1) + val input2Signum = numeric.signum(input2) + if (resSignum != -1 && input1Signum == -1 && input2Signum == -1 || resSignum != 1 && input1Signum == 1 && input2Signum == 1) { - throw new ArithmeticException(s"$input1 + $input2 caused overflow.") + throw new ArithmeticException(s"$input1 + $input2 caused overflow.") + } } result } @@ -229,12 +243,14 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti input1.asInstanceOf[CalendarInterval].subtract(input2.asInstanceOf[CalendarInterval]) } else { val result = numeric.minus(input1, input2) - val resSignum = numeric.signum(result) - val input1Signum = numeric.signum(input1) - val input2Signum = numeric.signum(input2) - if (resSignum != 1 && input1Signum == 1 && input2Signum == -1 + if (checkOverflow) { + val resSignum = numeric.signum(result) + val input1Signum = numeric.signum(input1) + val input2Signum = numeric.signum(input2) + if (resSignum != 1 && input1Signum == 1 && input2Signum == -1 || resSignum != -1 && input1Signum == -1 && input2Signum == 1) { - throw new ArithmeticException(s"$input1 - $input2 caused overflow.") + throw new ArithmeticException(s"$input1 - $input2 caused overflow.") + } } result } @@ -267,10 +283,12 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti protected override def nullSafeEval(input1: Any, input2: Any): Any = { val result = numeric.times(input1, input2) - if (numeric.signum(result) != numeric.signum(input1) * numeric.signum(input2) && + if (checkOverflow) { + if (numeric.signum(result) != numeric.signum(input1) * numeric.signum(input2) && !(result.isInstanceOf[Double] && !result.asInstanceOf[Double].isNaN) && !(result.isInstanceOf[Float] && !result.asInstanceOf[Float].isNaN)) { - throw new ArithmeticException(s"$input1 * $input2 caused overflow.") + throw new ArithmeticException(s"$input1 * $input2 caused overflow.") + } } result } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d26cd2ca7343..d69fd3f8bdaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1719,6 +1719,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ARITHMETIC_OPERATION_OVERFLOW_CHECK = buildConf("spark.sql.arithmetic.checkOverflow") + .doc("If it is set to true (default), all arithmetic operations on non-decimal fields throw " + + "an exception if an overflow occurs. If it is false, in case of overflow a wrong result " + + "is returned.") + .internal() + .booleanConf + .createWithDefault(true) + val LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE = buildConf("spark.sql.legacy.parser.havingWithoutGroupByAsWhere") .internal() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index a7de07e88763..fd4ba85e5e62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -22,6 +22,7 @@ import java.math.{BigInteger, MathContext, RoundingMode} import org.apache.spark.annotation.Unstable import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.internal.SQLConf /** * A mutable implementation of BigDecimal that can hold a Long if values are small enough. @@ -227,9 +228,11 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.eq(null)) { longVal / POW_10(_scale) } else { - // This will throw an exception if overflow occurs - if (decimalVal.compare(LONG_MIN_BIG_DEC) < 0 || decimalVal.compare(LONG_MAX_BIG_DEC) > 0) { - throw new ArithmeticException("Overflow") + if (SQLConf.get.getConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK)) { + // This will throw an exception if overflow occurs + if (decimalVal.compare(LONG_MIN_BIG_DEC) < 0 || decimalVal.compare(LONG_MAX_BIG_DEC) > 0) { + throw new ArithmeticException("Overflow") + } } decimalVal.longValue() } From 25c853c0dd29abfa4eb70787c1806ca4ebdd143c Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 26 Jun 2019 17:35:31 +0200 Subject: [PATCH 09/23] fix --- .../apache/spark/sql/catalyst/expressions/arithmetic.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 6c971f46c00d..a18f53897efa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -144,7 +144,7 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - val overflowCheck = if (overflowCheck) { + val overflowCheck = if (checkOverflow) { checkOverflowCode(ev.value, eval1, eval2) } else { "" @@ -156,7 +156,7 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { }) case _ => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - val overflowCheck = if (overflowCheck) { + val overflowCheck = if (checkOverflow) { checkOverflowCode(ev.value, eval1, eval2) } else { "" From 00fae1d7e3d4a2db30575726c99db3ad357bbf28 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Thu, 27 Jun 2019 11:50:56 +0200 Subject: [PATCH 10/23] fix tests --- .../resources/sql-tests/inputs/pgSQL/int4.sql | 144 +++ .../sql-tests/results/pgSQL/int4.sql.out | 850 +++++++++++++++--- 2 files changed, 864 insertions(+), 130 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql index 89cac00228f7..3345998da661 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql @@ -33,6 +33,150 @@ INSERT INTO INT4_TBL VALUES ('-2147483647'); -- INSERT INTO INT4_TBL(f1) VALUES ('123 5'); -- INSERT INTO INT4_TBL(f1) VALUES (''); +set spark.sql.arithmetic.checkOverflow=false; + +SELECT '' AS five, * FROM INT4_TBL; + +SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> smallint('0'); + +SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> int('0'); + +SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = smallint('0'); + +SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = int('0'); + +SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < smallint('0'); + +SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < int('0'); + +SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= smallint('0'); + +SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= int('0'); + +SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > smallint('0'); + +SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > int('0'); + +SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= smallint('0'); + +SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= int('0'); + +-- positive odds +SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % smallint('2')) = smallint('1'); + +-- any evens +SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = smallint('0'); + +-- [SPARK-28024] Incorrect value when out of range +SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i; + +SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i +WHERE abs(f1) < 1073741824; + +-- [SPARK-28024] Incorrect value when out of range +SELECT '' AS five, i.f1, i.f1 * int('2') AS x FROM INT4_TBL i; + +SELECT '' AS five, i.f1, i.f1 * int('2') AS x FROM INT4_TBL i +WHERE abs(f1) < 1073741824; + +-- [SPARK-28024] Incorrect value when out of range +SELECT '' AS five, i.f1, i.f1 + smallint('2') AS x FROM INT4_TBL i; + +SELECT '' AS five, i.f1, i.f1 + smallint('2') AS x FROM INT4_TBL i +WHERE f1 < 2147483646; + +-- [SPARK-28024] Incorrect value when out of range +SELECT '' AS five, i.f1, i.f1 + int('2') AS x FROM INT4_TBL i; + +SELECT '' AS five, i.f1, i.f1 + int('2') AS x FROM INT4_TBL i +WHERE f1 < 2147483646; + +-- [SPARK-28024] Incorrect value when out of range +SELECT '' AS five, i.f1, i.f1 - smallint('2') AS x FROM INT4_TBL i; + +SELECT '' AS five, i.f1, i.f1 - smallint('2') AS x FROM INT4_TBL i +WHERE f1 > -2147483647; + +-- [SPARK-28024] Incorrect value when out of range +SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i; + +SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i +WHERE f1 > -2147483647; + +SELECT '' AS five, i.f1, i.f1 / smallint('2') AS x FROM INT4_TBL i; + +SELECT '' AS five, i.f1, i.f1 / int('2') AS x FROM INT4_TBL i; + +-- +-- more complex expressions +-- + +-- variations on unary minus parsing +SELECT -2+3 AS one; + +SELECT 4-2 AS two; + +SELECT 2- -1 AS three; + +SELECT 2 - -2 AS four; + +SELECT smallint('2') * smallint('2') = smallint('16') / smallint('4') AS true; + +SELECT int('2') * smallint('2') = smallint('16') / int('4') AS true; + +SELECT smallint('2') * int('2') = int('16') / smallint('4') AS true; + +SELECT int('1000') < int('999') AS false; + +-- [SPARK-28027] Our ! and !! has different meanings +-- SELECT 4! AS twenty_four; + +-- SELECT !!3 AS six; + +SELECT 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 AS ten; + +-- [SPARK-2659] HiveQL: Division operator should always perform fractional division +SELECT 2 + 2 / 2 AS three; + +SELECT (2 + 2) / 2 AS two; + +-- [SPARK-28027] Add bitwise shift left/right operators +-- corner case +SELECT string(shiftleft(int(-1), 31)); +SELECT string(int(shiftleft(int(-1), 31))+1); + +-- [SPARK-28024] Incorrect numeric values when out of range +-- check sane handling of INT_MIN overflow cases +-- SELECT (-2147483648)::int4 * (-1)::int4; +-- SELECT (-2147483648)::int4 / (-1)::int4; +SELECT int(-2147483648) % int(-1); +-- SELECT (-2147483648)::int4 * (-1)::int2; +-- SELECT (-2147483648)::int4 / (-1)::int2; +SELECT int(-2147483648) % smallint(-1); + +-- [SPARK-28028] Cast numeric to integral type need round +-- check rounding when casting from float +SELECT x, int(x) AS int4_value +FROM (VALUES double(-2.5), + double(-1.5), + double(-0.5), + double(0.0), + double(0.5), + double(1.5), + double(2.5)) t(x); + +-- [SPARK-28028] Cast numeric to integral type need round +-- check rounding when casting from numeric +SELECT x, int(x) AS int4_value +FROM (VALUES cast(-2.5 as decimal(38, 18)), + cast(-1.5 as decimal(38, 18)), + cast(-0.5 as decimal(38, 18)), + cast(-0.0 as decimal(38, 18)), + cast(0.5 as decimal(38, 18)), + cast(1.5 as decimal(38, 18)), + cast(2.5 as decimal(38, 18))) t(x); + +set spark.sql.arithmetic.checkOverflow=true; SELECT '' AS five, * FROM INT4_TBL; diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out index 9c17e9a1a197..058fac18476d 100644 --- a/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 53 +-- Number of queries: 101 -- !query 0 @@ -51,30 +51,27 @@ struct<> -- !query 6 -SELECT '' AS five, * FROM INT4_TBL +set spark.sql.arithmetic.checkOverflow=false -- !query 6 schema -struct +struct -- !query 6 output --123456 - -2147483647 - 0 - 123456 - 2147483647 +spark.sql.arithmetic.checkOverflow false -- !query 7 -SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> smallint('0') +SELECT '' AS five, * FROM INT4_TBL -- !query 7 schema -struct +struct -- !query 7 output -123456 -2147483647 + 0 123456 2147483647 -- !query 8 -SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> int('0') +SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> smallint('0') -- !query 8 schema struct -- !query 8 output @@ -85,15 +82,18 @@ struct -- !query 9 -SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = smallint('0') +SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> int('0') -- !query 9 schema -struct +struct -- !query 9 output -0 +-123456 + -2147483647 + 123456 + 2147483647 -- !query 10 -SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = int('0') +SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = smallint('0') -- !query 10 schema struct -- !query 10 output @@ -101,16 +101,15 @@ struct -- !query 11 -SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < smallint('0') +SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = int('0') -- !query 11 schema -struct +struct -- !query 11 output --123456 - -2147483647 +0 -- !query 12 -SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < int('0') +SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < smallint('0') -- !query 12 schema struct -- !query 12 output @@ -119,17 +118,16 @@ struct -- !query 13 -SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= smallint('0') +SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < int('0') -- !query 13 schema -struct +struct -- !query 13 output -123456 -2147483647 - 0 -- !query 14 -SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= int('0') +SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= smallint('0') -- !query 14 schema struct -- !query 14 output @@ -139,16 +137,17 @@ struct -- !query 15 -SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > smallint('0') +SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= int('0') -- !query 15 schema -struct +struct -- !query 15 output -123456 - 2147483647 +-123456 + -2147483647 + 0 -- !query 16 -SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > int('0') +SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > smallint('0') -- !query 16 schema struct -- !query 16 output @@ -157,17 +156,16 @@ struct -- !query 17 -SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= smallint('0') +SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > int('0') -- !query 17 schema -struct +struct -- !query 17 output -0 - 123456 +123456 2147483647 -- !query 18 -SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= int('0') +SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= smallint('0') -- !query 18 schema struct -- !query 18 output @@ -177,84 +175,81 @@ struct -- !query 19 -SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % smallint('2')) = smallint('1') +SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= int('0') -- !query 19 schema -struct +struct -- !query 19 output -2147483647 +0 + 123456 + 2147483647 -- !query 20 -SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = smallint('0') +SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % smallint('2')) = smallint('1') -- !query 20 schema -struct +struct -- !query 20 output --123456 - 0 - 123456 +2147483647 -- !query 21 -SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i +SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = smallint('0') -- !query 21 schema -struct +struct -- !query 21 output --123456 -246912 - -2147483647 2 - 0 0 - 123456 246912 - 2147483647 -2 +-123456 + 0 + 123456 -- !query 22 SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i -WHERE abs(f1) < 1073741824 -- !query 22 schema struct -- !query 22 output -123456 -246912 + -2147483647 2 0 0 123456 246912 + 2147483647 -2 -- !query 23 -SELECT '' AS five, i.f1, i.f1 * int('2') AS x FROM INT4_TBL i +SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i +WHERE abs(f1) < 1073741824 -- !query 23 schema struct -- !query 23 output -123456 -246912 - -2147483647 2 0 0 123456 246912 - 2147483647 -2 -- !query 24 SELECT '' AS five, i.f1, i.f1 * int('2') AS x FROM INT4_TBL i -WHERE abs(f1) < 1073741824 -- !query 24 schema struct -- !query 24 output -123456 -246912 + -2147483647 2 0 0 123456 246912 + 2147483647 -2 -- !query 25 -SELECT '' AS five, i.f1, i.f1 + smallint('2') AS x FROM INT4_TBL i +SELECT '' AS five, i.f1, i.f1 * int('2') AS x FROM INT4_TBL i +WHERE abs(f1) < 1073741824 -- !query 25 schema struct -- !query 25 output --123456 -123454 - -2147483647 -2147483645 - 0 2 - 123456 123458 - 2147483647 -2147483647 +-123456 -246912 + 0 0 + 123456 246912 -- !query 26 SELECT '' AS five, i.f1, i.f1 + smallint('2') AS x FROM INT4_TBL i -WHERE f1 < 2147483646 -- !query 26 schema struct -- !query 26 output @@ -262,10 +257,12 @@ struct -2147483647 -2147483645 0 2 123456 123458 + 2147483647 -2147483647 -- !query 27 -SELECT '' AS five, i.f1, i.f1 + int('2') AS x FROM INT4_TBL i +SELECT '' AS five, i.f1, i.f1 + smallint('2') AS x FROM INT4_TBL i +WHERE f1 < 2147483646 -- !query 27 schema struct -- !query 27 output @@ -273,12 +270,10 @@ struct -2147483647 -2147483645 0 2 123456 123458 - 2147483647 -2147483647 -- !query 28 SELECT '' AS five, i.f1, i.f1 + int('2') AS x FROM INT4_TBL i -WHERE f1 < 2147483646 -- !query 28 schema struct -- !query 28 output @@ -286,39 +281,40 @@ struct -2147483647 -2147483645 0 2 123456 123458 + 2147483647 -2147483647 -- !query 29 -SELECT '' AS five, i.f1, i.f1 - smallint('2') AS x FROM INT4_TBL i +SELECT '' AS five, i.f1, i.f1 + int('2') AS x FROM INT4_TBL i +WHERE f1 < 2147483646 -- !query 29 schema struct -- !query 29 output --123456 -123458 - -2147483647 2147483647 - 0 -2 - 123456 123454 - 2147483647 2147483645 +-123456 -123454 + -2147483647 -2147483645 + 0 2 + 123456 123458 -- !query 30 SELECT '' AS five, i.f1, i.f1 - smallint('2') AS x FROM INT4_TBL i -WHERE f1 > -2147483647 -- !query 30 schema struct -- !query 30 output -123456 -123458 + -2147483647 2147483647 0 -2 123456 123454 2147483647 2147483645 -- !query 31 -SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i +SELECT '' AS five, i.f1, i.f1 - smallint('2') AS x FROM INT4_TBL i +WHERE f1 > -2147483647 -- !query 31 schema struct -- !query 31 output -123456 -123458 - -2147483647 2147483647 0 -2 123456 123454 2147483647 2147483645 @@ -326,30 +322,30 @@ struct -- !query 32 SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i -WHERE f1 > -2147483647 -- !query 32 schema struct -- !query 32 output -123456 -123458 + -2147483647 2147483647 0 -2 123456 123454 2147483647 2147483645 -- !query 33 -SELECT '' AS five, i.f1, i.f1 / smallint('2') AS x FROM INT4_TBL i +SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i +WHERE f1 > -2147483647 -- !query 33 schema -struct +struct -- !query 33 output --123456 -61728.0 - -2147483647 -1.0737418235E9 - 0 0.0 - 123456 61728.0 - 2147483647 1.0737418235E9 +-123456 -123458 + 0 -2 + 123456 123454 + 2147483647 2147483645 -- !query 34 -SELECT '' AS five, i.f1, i.f1 / int('2') AS x FROM INT4_TBL i +SELECT '' AS five, i.f1, i.f1 / smallint('2') AS x FROM INT4_TBL i -- !query 34 schema struct -- !query 34 output @@ -361,47 +357,51 @@ struct -- !query 35 -SELECT -2+3 AS one +SELECT '' AS five, i.f1, i.f1 / int('2') AS x FROM INT4_TBL i -- !query 35 schema -struct +struct -- !query 35 output -1 +-123456 -61728.0 + -2147483647 -1.0737418235E9 + 0 0.0 + 123456 61728.0 + 2147483647 1.0737418235E9 -- !query 36 -SELECT 4-2 AS two +SELECT -2+3 AS one -- !query 36 schema -struct +struct -- !query 36 output -2 +1 -- !query 37 -SELECT 2- -1 AS three +SELECT 4-2 AS two -- !query 37 schema -struct +struct -- !query 37 output -3 +2 -- !query 38 -SELECT 2 - -2 AS four +SELECT 2- -1 AS three -- !query 38 schema -struct +struct -- !query 38 output -4 +3 -- !query 39 -SELECT smallint('2') * smallint('2') = smallint('16') / smallint('4') AS true +SELECT 2 - -2 AS four -- !query 39 schema -struct +struct -- !query 39 output -true +4 -- !query 40 -SELECT int('2') * smallint('2') = smallint('16') / int('4') AS true +SELECT smallint('2') * smallint('2') = smallint('16') / smallint('4') AS true -- !query 40 schema struct -- !query 40 output @@ -409,7 +409,7 @@ true -- !query 41 -SELECT smallint('2') * int('2') = int('16') / smallint('4') AS true +SELECT int('2') * smallint('2') = smallint('16') / int('4') AS true -- !query 41 schema struct -- !query 41 output @@ -417,70 +417,78 @@ true -- !query 42 -SELECT int('1000') < int('999') AS false +SELECT smallint('2') * int('2') = int('16') / smallint('4') AS true -- !query 42 schema -struct +struct -- !query 42 output -false +true -- !query 43 -SELECT 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 AS ten +SELECT int('1000') < int('999') AS false -- !query 43 schema -struct +struct -- !query 43 output -10 +false -- !query 44 -SELECT 2 + 2 / 2 AS three +SELECT 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 AS ten -- !query 44 schema -struct +struct -- !query 44 output -3.0 +10 -- !query 45 -SELECT (2 + 2) / 2 AS two +SELECT 2 + 2 / 2 AS three -- !query 45 schema -struct +struct -- !query 45 output -2.0 +3.0 -- !query 46 -SELECT string(shiftleft(int(-1), 31)) +SELECT (2 + 2) / 2 AS two -- !query 46 schema -struct +struct -- !query 46 output --2147483648 +2.0 -- !query 47 -SELECT string(int(shiftleft(int(-1), 31))+1) +SELECT string(shiftleft(int(-1), 31)) -- !query 47 schema -struct +struct -- !query 47 output --2147483647 +-2147483648 -- !query 48 -SELECT int(-2147483648) % int(-1) +SELECT string(int(shiftleft(int(-1), 31))+1) -- !query 48 schema -struct<(CAST(-2147483648 AS INT) % CAST(-1 AS INT)):int> +struct -- !query 48 output -0 +-2147483647 -- !query 49 -SELECT int(-2147483648) % smallint(-1) +SELECT int(-2147483648) % int(-1) -- !query 49 schema -struct<(CAST(-2147483648 AS INT) % CAST(CAST(-1 AS SMALLINT) AS INT)):int> +struct<(CAST(-2147483648 AS INT) % CAST(-1 AS INT)):int> -- !query 49 output 0 -- !query 50 +SELECT int(-2147483648) % smallint(-1) +-- !query 50 schema +struct<(CAST(-2147483648 AS INT) % CAST(CAST(-1 AS SMALLINT) AS INT)):int> +-- !query 50 output +0 + + +-- !query 51 SELECT x, int(x) AS int4_value FROM (VALUES double(-2.5), double(-1.5), @@ -489,9 +497,9 @@ FROM (VALUES double(-2.5), double(0.5), double(1.5), double(2.5)) t(x) --- !query 50 schema +-- !query 51 schema struct --- !query 50 output +-- !query 51 output -0.5 0 -1.5 -1 -2.5 -2 @@ -501,7 +509,7 @@ struct 2.5 2 --- !query 51 +-- !query 52 SELECT x, int(x) AS int4_value FROM (VALUES cast(-2.5 as decimal(38, 18)), cast(-1.5 as decimal(38, 18)), @@ -510,9 +518,9 @@ FROM (VALUES cast(-2.5 as decimal(38, 18)), cast(0.5 as decimal(38, 18)), cast(1.5 as decimal(38, 18)), cast(2.5 as decimal(38, 18))) t(x) --- !query 51 schema +-- !query 52 schema struct --- !query 51 output +-- !query 52 output -0.5 0 -1.5 -1 -2.5 -2 @@ -522,9 +530,591 @@ struct 2.5 2 --- !query 52 +-- !query 53 +set spark.sql.arithmetic.checkOverflow=true +-- !query 53 schema +struct +-- !query 53 output +spark.sql.arithmetic.checkOverflow true + + +-- !query 54 +SELECT '' AS five, * FROM INT4_TBL +-- !query 54 schema +struct +-- !query 54 output +-123456 + -2147483647 + 0 + 123456 + 2147483647 + + +-- !query 55 +SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> smallint('0') +-- !query 55 schema +struct +-- !query 55 output +-123456 + -2147483647 + 123456 + 2147483647 + + +-- !query 56 +SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> int('0') +-- !query 56 schema +struct +-- !query 56 output +-123456 + -2147483647 + 123456 + 2147483647 + + +-- !query 57 +SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = smallint('0') +-- !query 57 schema +struct +-- !query 57 output +0 + + +-- !query 58 +SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = int('0') +-- !query 58 schema +struct +-- !query 58 output +0 + + +-- !query 59 +SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < smallint('0') +-- !query 59 schema +struct +-- !query 59 output +-123456 + -2147483647 + + +-- !query 60 +SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < int('0') +-- !query 60 schema +struct +-- !query 60 output +-123456 + -2147483647 + + +-- !query 61 +SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= smallint('0') +-- !query 61 schema +struct +-- !query 61 output +-123456 + -2147483647 + 0 + + +-- !query 62 +SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= int('0') +-- !query 62 schema +struct +-- !query 62 output +-123456 + -2147483647 + 0 + + +-- !query 63 +SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > smallint('0') +-- !query 63 schema +struct +-- !query 63 output +123456 + 2147483647 + + +-- !query 64 +SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > int('0') +-- !query 64 schema +struct +-- !query 64 output +123456 + 2147483647 + + +-- !query 65 +SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= smallint('0') +-- !query 65 schema +struct +-- !query 65 output +0 + 123456 + 2147483647 + + +-- !query 66 +SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= int('0') +-- !query 66 schema +struct +-- !query 66 output +0 + 123456 + 2147483647 + + +-- !query 67 +SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % smallint('2')) = smallint('1') +-- !query 67 schema +struct +-- !query 67 output +2147483647 + + +-- !query 68 +SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = smallint('0') +-- !query 68 schema +struct +-- !query 68 output +-123456 + 0 + 123456 + + +-- !query 69 +SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i +-- !query 69 schema +struct<> +-- !query 69 output +org.apache.spark.SparkException +Job aborted due to stage failure: Task 0 in stage 2069.0 failed 1 times, most recent failure: Lost task 0.0 in stage 2069.0 (TID 91890, localhost, executor driver): java.lang.ArithmeticException: 2147483647 * 2 caused overflow. + at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_doConsume_0$(Unknown Source) + at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) + at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) + at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:701) + at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:292) + at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:852) + at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:852) + at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) + at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:327) + at org.apache.spark.rdd.RDD.iterator(RDD.scala:291) + at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) + at org.apache.spark.scheduler.Task.run(Task.scala:126) + at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:426) + at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1350) + at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:429) + at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) + at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) + at java.lang.Thread.run(Thread.java:748) + +Driver stacktrace: + + +-- !query 70 +SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i +WHERE abs(f1) < 1073741824 +-- !query 70 schema +struct +-- !query 70 output +-123456 -246912 + 0 0 + 123456 246912 + + +-- !query 71 +SELECT '' AS five, i.f1, i.f1 * int('2') AS x FROM INT4_TBL i +-- !query 71 schema +struct<> +-- !query 71 output +org.apache.spark.SparkException +Job aborted due to stage failure: Task 0 in stage 2071.0 failed 1 times, most recent failure: Lost task 0.0 in stage 2071.0 (TID 91894, localhost, executor driver): java.lang.ArithmeticException: 2147483647 * 2 caused overflow. + at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_doConsume_0$(Unknown Source) + at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) + at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) + at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:701) + at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:292) + at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:852) + at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:852) + at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) + at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:327) + at org.apache.spark.rdd.RDD.iterator(RDD.scala:291) + at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) + at org.apache.spark.scheduler.Task.run(Task.scala:126) + at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:426) + at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1350) + at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:429) + at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) + at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) + at java.lang.Thread.run(Thread.java:748) + +Driver stacktrace: + + +-- !query 72 +SELECT '' AS five, i.f1, i.f1 * int('2') AS x FROM INT4_TBL i +WHERE abs(f1) < 1073741824 +-- !query 72 schema +struct +-- !query 72 output +-123456 -246912 + 0 0 + 123456 246912 + + +-- !query 73 +SELECT '' AS five, i.f1, i.f1 + smallint('2') AS x FROM INT4_TBL i +-- !query 73 schema +struct<> +-- !query 73 output +org.apache.spark.SparkException +Job aborted due to stage failure: Task 0 in stage 2073.0 failed 1 times, most recent failure: Lost task 0.0 in stage 2073.0 (TID 91898, localhost, executor driver): java.lang.ArithmeticException: 2147483647 + 2 caused overflow. + at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_doConsume_0$(Unknown Source) + at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) + at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) + at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:701) + at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:292) + at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:852) + at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:852) + at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) + at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:327) + at org.apache.spark.rdd.RDD.iterator(RDD.scala:291) + at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) + at org.apache.spark.scheduler.Task.run(Task.scala:126) + at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:426) + at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1350) + at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:429) + at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) + at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) + at java.lang.Thread.run(Thread.java:748) + +Driver stacktrace: + + +-- !query 74 +SELECT '' AS five, i.f1, i.f1 + smallint('2') AS x FROM INT4_TBL i +WHERE f1 < 2147483646 +-- !query 74 schema +struct +-- !query 74 output +-123456 -123454 + -2147483647 -2147483645 + 0 2 + 123456 123458 + + +-- !query 75 +SELECT '' AS five, i.f1, i.f1 + int('2') AS x FROM INT4_TBL i +-- !query 75 schema +struct<> +-- !query 75 output +org.apache.spark.SparkException +Job aborted due to stage failure: Task 0 in stage 2075.0 failed 1 times, most recent failure: Lost task 0.0 in stage 2075.0 (TID 91902, localhost, executor driver): java.lang.ArithmeticException: 2147483647 + 2 caused overflow. + at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_doConsume_0$(Unknown Source) + at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) + at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) + at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:701) + at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:292) + at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:852) + at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:852) + at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) + at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:327) + at org.apache.spark.rdd.RDD.iterator(RDD.scala:291) + at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) + at org.apache.spark.scheduler.Task.run(Task.scala:126) + at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:426) + at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1350) + at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:429) + at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) + at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) + at java.lang.Thread.run(Thread.java:748) + +Driver stacktrace: + + +-- !query 76 +SELECT '' AS five, i.f1, i.f1 + int('2') AS x FROM INT4_TBL i +WHERE f1 < 2147483646 +-- !query 76 schema +struct +-- !query 76 output +-123456 -123454 + -2147483647 -2147483645 + 0 2 + 123456 123458 + + +-- !query 77 +SELECT '' AS five, i.f1, i.f1 - smallint('2') AS x FROM INT4_TBL i +-- !query 77 schema +struct<> +-- !query 77 output +org.apache.spark.SparkException +Job aborted due to stage failure: Task 1 in stage 2077.0 failed 1 times, most recent failure: Lost task 1.0 in stage 2077.0 (TID 91907, localhost, executor driver): java.lang.ArithmeticException: -2147483647 - 2 caused overflow. + at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_doConsume_0$(Unknown Source) + at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) + at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) + at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:701) + at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:292) + at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:852) + at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:852) + at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) + at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:327) + at org.apache.spark.rdd.RDD.iterator(RDD.scala:291) + at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) + at org.apache.spark.scheduler.Task.run(Task.scala:126) + at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:426) + at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1350) + at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:429) + at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) + at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) + at java.lang.Thread.run(Thread.java:748) + +Driver stacktrace: + + +-- !query 78 +SELECT '' AS five, i.f1, i.f1 - smallint('2') AS x FROM INT4_TBL i +WHERE f1 > -2147483647 +-- !query 78 schema +struct +-- !query 78 output +-123456 -123458 + 0 -2 + 123456 123454 + 2147483647 2147483645 + + +-- !query 79 +SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i +-- !query 79 schema +struct<> +-- !query 79 output +org.apache.spark.SparkException +Job aborted due to stage failure: Task 1 in stage 2079.0 failed 1 times, most recent failure: Lost task 1.0 in stage 2079.0 (TID 91911, localhost, executor driver): java.lang.ArithmeticException: -2147483647 - 2 caused overflow. + at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_doConsume_0$(Unknown Source) + at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) + at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) + at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:701) + at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:292) + at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:852) + at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:852) + at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) + at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:327) + at org.apache.spark.rdd.RDD.iterator(RDD.scala:291) + at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) + at org.apache.spark.scheduler.Task.run(Task.scala:126) + at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:426) + at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1350) + at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:429) + at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) + at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) + at java.lang.Thread.run(Thread.java:748) + +Driver stacktrace: + + +-- !query 80 +SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i +WHERE f1 > -2147483647 +-- !query 80 schema +struct +-- !query 80 output +-123456 -123458 + 0 -2 + 123456 123454 + 2147483647 2147483645 + + +-- !query 81 +SELECT '' AS five, i.f1, i.f1 / smallint('2') AS x FROM INT4_TBL i +-- !query 81 schema +struct +-- !query 81 output +-123456 -61728.0 + -2147483647 -1.0737418235E9 + 0 0.0 + 123456 61728.0 + 2147483647 1.0737418235E9 + + +-- !query 82 +SELECT '' AS five, i.f1, i.f1 / int('2') AS x FROM INT4_TBL i +-- !query 82 schema +struct +-- !query 82 output +-123456 -61728.0 + -2147483647 -1.0737418235E9 + 0 0.0 + 123456 61728.0 + 2147483647 1.0737418235E9 + + +-- !query 83 +SELECT -2+3 AS one +-- !query 83 schema +struct +-- !query 83 output +1 + + +-- !query 84 +SELECT 4-2 AS two +-- !query 84 schema +struct +-- !query 84 output +2 + + +-- !query 85 +SELECT 2- -1 AS three +-- !query 85 schema +struct +-- !query 85 output +3 + + +-- !query 86 +SELECT 2 - -2 AS four +-- !query 86 schema +struct +-- !query 86 output +4 + + +-- !query 87 +SELECT smallint('2') * smallint('2') = smallint('16') / smallint('4') AS true +-- !query 87 schema +struct +-- !query 87 output +true + + +-- !query 88 +SELECT int('2') * smallint('2') = smallint('16') / int('4') AS true +-- !query 88 schema +struct +-- !query 88 output +true + + +-- !query 89 +SELECT smallint('2') * int('2') = int('16') / smallint('4') AS true +-- !query 89 schema +struct +-- !query 89 output +true + + +-- !query 90 +SELECT int('1000') < int('999') AS false +-- !query 90 schema +struct +-- !query 90 output +false + + +-- !query 91 +SELECT 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 AS ten +-- !query 91 schema +struct +-- !query 91 output +10 + + +-- !query 92 +SELECT 2 + 2 / 2 AS three +-- !query 92 schema +struct +-- !query 92 output +3.0 + + +-- !query 93 +SELECT (2 + 2) / 2 AS two +-- !query 93 schema +struct +-- !query 93 output +2.0 + + +-- !query 94 +SELECT string(shiftleft(int(-1), 31)) +-- !query 94 schema +struct +-- !query 94 output +-2147483648 + + +-- !query 95 +SELECT string(int(shiftleft(int(-1), 31))+1) +-- !query 95 schema +struct +-- !query 95 output +-2147483647 + + +-- !query 96 +SELECT int(-2147483648) % int(-1) +-- !query 96 schema +struct<(CAST(-2147483648 AS INT) % CAST(-1 AS INT)):int> +-- !query 96 output +0 + + +-- !query 97 +SELECT int(-2147483648) % smallint(-1) +-- !query 97 schema +struct<(CAST(-2147483648 AS INT) % CAST(CAST(-1 AS SMALLINT) AS INT)):int> +-- !query 97 output +0 + + +-- !query 98 +SELECT x, int(x) AS int4_value +FROM (VALUES double(-2.5), + double(-1.5), + double(-0.5), + double(0.0), + double(0.5), + double(1.5), + double(2.5)) t(x) +-- !query 98 schema +struct +-- !query 98 output +-0.5 0 +-1.5 -1 +-2.5 -2 +0.0 0 +0.5 0 +1.5 1 +2.5 2 + + +-- !query 99 +SELECT x, int(x) AS int4_value +FROM (VALUES cast(-2.5 as decimal(38, 18)), + cast(-1.5 as decimal(38, 18)), + cast(-0.5 as decimal(38, 18)), + cast(-0.0 as decimal(38, 18)), + cast(0.5 as decimal(38, 18)), + cast(1.5 as decimal(38, 18)), + cast(2.5 as decimal(38, 18))) t(x) +-- !query 99 schema +struct +-- !query 99 output +-0.5 0 +-1.5 -1 +-2.5 -2 +0 0 +0.5 0 +1.5 1 +2.5 2 + + +-- !query 100 DROP TABLE INT4_TBL --- !query 52 schema +-- !query 100 schema struct<> --- !query 52 output +-- !query 100 output From 8e9715c35e4d4c80b7a4d5bf412bdc15044763ba Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 28 Jun 2019 22:20:21 +0200 Subject: [PATCH 11/23] change default value and fix tests --- .../apache/spark/sql/internal/SQLConf.scala | 8 +- .../ArithmeticExpressionSuite.scala | 56 +- .../sql/catalyst/expressions/CastSuite.scala | 14 +- .../resources/sql-tests/inputs/pgSQL/int4.sql | 146 +---- .../sql-tests/results/pgSQL/int4.sql.out | 588 +----------------- .../spark/sql/DatasetAggregatorSuite.scala | 13 +- 6 files changed, 68 insertions(+), 757 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d69fd3f8bdaf..a107395701a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1720,12 +1720,12 @@ object SQLConf { .createWithDefault(false) val ARITHMETIC_OPERATION_OVERFLOW_CHECK = buildConf("spark.sql.arithmetic.checkOverflow") - .doc("If it is set to true (default), all arithmetic operations on non-decimal fields throw " + - "an exception if an overflow occurs. If it is false, in case of overflow a wrong result " + - "is returned.") + .doc("If it is set to true, all arithmetic operations on non-decimal fields throw an " + + "exception if an overflow occurs. If it is false (default), in case of overflow a wrong " + + "result is returned.") .internal() .booleanConf - .createWithDefault(true) + .createWithDefault(false) val LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE = buildConf("spark.sql.legacy.parser.havingWithoutGroupByAsWhere") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index c3172249796e..32a97aff9f41 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -59,8 +59,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Add(positiveIntLit, negativeIntLit), -1) checkEvaluation(Add(positiveLongLit, negativeLongLit), -1L) - DataTypeTestUtils.numericAndInterval.foreach { tpe => - checkConsistencyBetweenInterpretedAndCodegenAllowingException(Add, tpe, tpe) + Seq("true", "false").foreach { checkOverflow => + withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> checkOverflow) { + DataTypeTestUtils.numericAndInterval.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegenAllowingException(Add, tpe, tpe) + } + } } } @@ -100,8 +104,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Subtract(positiveIntLit, negativeIntLit), positiveInt - negativeInt) checkEvaluation(Subtract(positiveLongLit, negativeLongLit), positiveLong - negativeLong) - DataTypeTestUtils.numericAndInterval.foreach { tpe => - checkConsistencyBetweenInterpretedAndCodegenAllowingException(Subtract, tpe, tpe) + Seq("true", "false").foreach { checkOverflow => + withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> checkOverflow) { + DataTypeTestUtils.numericAndInterval.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegenAllowingException(Subtract, tpe, tpe) + } + } } } @@ -118,8 +126,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Multiply(positiveIntLit, negativeIntLit), positiveInt * negativeInt) checkEvaluation(Multiply(positiveLongLit, negativeLongLit), positiveLong * negativeLong) - DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => - checkConsistencyBetweenInterpretedAndCodegenAllowingException(Multiply, tpe, tpe) + Seq("true", "false").foreach { checkOverflow => + withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> checkOverflow) { + DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + checkConsistencyBetweenInterpretedAndCodegenAllowingException(Multiply, tpe, tpe) + } + } } } @@ -380,18 +392,24 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper test("SPARK-24598: overflow on BigInt returns wrong result") { val maxLongLiteral = Literal(Long.MaxValue) val minLongLiteral = Literal(Long.MinValue) - checkExceptionInExpression[ArithmeticException]( - Add(maxLongLiteral, Literal(1L)), "caused overflow") - checkExceptionInExpression[ArithmeticException]( - Subtract(maxLongLiteral, Literal(-1L)), "caused overflow") - checkExceptionInExpression[ArithmeticException]( - Multiply(maxLongLiteral, Literal(2L)), "caused overflow") - - checkExceptionInExpression[ArithmeticException]( - Add(minLongLiteral, minLongLiteral), "caused overflow") - checkExceptionInExpression[ArithmeticException]( - Subtract(minLongLiteral, maxLongLiteral), "caused overflow") - checkExceptionInExpression[ArithmeticException]( - Multiply(minLongLiteral, minLongLiteral), "caused overflow") + val e1 = Add(maxLongLiteral, Literal(1L)) + val e2 = Subtract(maxLongLiteral, Literal(-1L)) + val e3 = Multiply(maxLongLiteral, Literal(2L)) + val e4 = Add(minLongLiteral, minLongLiteral) + val e5 = Subtract(minLongLiteral, maxLongLiteral) + val e6 = Multiply(minLongLiteral, minLongLiteral) + withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "true") { + Seq(e1, e2, e3, e4, e5, e6).foreach { e => + checkExceptionInExpression[ArithmeticException](e, "caused overflow") + } + } + withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "false") { + checkEvaluation(e1, Long.MinValue) + checkEvaluation(e2, Long.MinValue) + checkEvaluation(e3, -2L) + checkEvaluation(e4, 0L) + checkEvaluation(e5, 1L) + checkEvaluation(e6, 0L) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 80b349fa8c05..093a9d2cd7ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -957,9 +958,16 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SPARK-24598: Cast to long should fail on overflow") { - checkExceptionInExpression[ArithmeticException]( - cast(Literal.create(Decimal(Long.MaxValue) + Decimal(1)), LongType), "Overflow") - checkEvaluation(cast(Literal.create(Decimal(Long.MaxValue)), LongType), Long.MaxValue) + val overflowCast = cast(Literal.create(Decimal(Long.MaxValue) + Decimal(1)), LongType) + val nonOverflowCast = cast(Literal.create(Decimal(Long.MaxValue)), LongType) + withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "true") { + checkExceptionInExpression[ArithmeticException](overflowCast, "Overflow") + checkEvaluation(nonOverflowCast, Long.MaxValue) + } + withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "false") { + checkEvaluation(overflowCast, Long.MinValue) + checkEvaluation(nonOverflowCast, Long.MaxValue) + } } test("up-cast") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql index 3345998da661..9a0bfd8a2751 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql @@ -33,6 +33,9 @@ INSERT INTO INT4_TBL VALUES ('-2147483647'); -- INSERT INTO INT4_TBL(f1) VALUES ('123 5'); -- INSERT INTO INT4_TBL(f1) VALUES (''); +-- We cannot test this when checkOverflow=false here +-- because exception happens in the executors and the +-- output stacktrace cannot have an exact match set spark.sql.arithmetic.checkOverflow=false; SELECT '' AS five, * FROM INT4_TBL; @@ -154,149 +157,6 @@ SELECT int(-2147483648) % int(-1); -- SELECT (-2147483648)::int4 / (-1)::int2; SELECT int(-2147483648) % smallint(-1); --- [SPARK-28028] Cast numeric to integral type need round --- check rounding when casting from float -SELECT x, int(x) AS int4_value -FROM (VALUES double(-2.5), - double(-1.5), - double(-0.5), - double(0.0), - double(0.5), - double(1.5), - double(2.5)) t(x); - --- [SPARK-28028] Cast numeric to integral type need round --- check rounding when casting from numeric -SELECT x, int(x) AS int4_value -FROM (VALUES cast(-2.5 as decimal(38, 18)), - cast(-1.5 as decimal(38, 18)), - cast(-0.5 as decimal(38, 18)), - cast(-0.0 as decimal(38, 18)), - cast(0.5 as decimal(38, 18)), - cast(1.5 as decimal(38, 18)), - cast(2.5 as decimal(38, 18))) t(x); - -set spark.sql.arithmetic.checkOverflow=true; - -SELECT '' AS five, * FROM INT4_TBL; - -SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> smallint('0'); - -SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> int('0'); - -SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = smallint('0'); - -SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = int('0'); - -SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < smallint('0'); - -SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < int('0'); - -SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= smallint('0'); - -SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= int('0'); - -SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > smallint('0'); - -SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > int('0'); - -SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= smallint('0'); - -SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= int('0'); - --- positive odds -SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % smallint('2')) = smallint('1'); - --- any evens -SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = smallint('0'); - --- [SPARK-28024] Incorrect value when out of range -SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i; - -SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i -WHERE abs(f1) < 1073741824; - --- [SPARK-28024] Incorrect value when out of range -SELECT '' AS five, i.f1, i.f1 * int('2') AS x FROM INT4_TBL i; - -SELECT '' AS five, i.f1, i.f1 * int('2') AS x FROM INT4_TBL i -WHERE abs(f1) < 1073741824; - --- [SPARK-28024] Incorrect value when out of range -SELECT '' AS five, i.f1, i.f1 + smallint('2') AS x FROM INT4_TBL i; - -SELECT '' AS five, i.f1, i.f1 + smallint('2') AS x FROM INT4_TBL i -WHERE f1 < 2147483646; - --- [SPARK-28024] Incorrect value when out of range -SELECT '' AS five, i.f1, i.f1 + int('2') AS x FROM INT4_TBL i; - -SELECT '' AS five, i.f1, i.f1 + int('2') AS x FROM INT4_TBL i -WHERE f1 < 2147483646; - --- [SPARK-28024] Incorrect value when out of range -SELECT '' AS five, i.f1, i.f1 - smallint('2') AS x FROM INT4_TBL i; - -SELECT '' AS five, i.f1, i.f1 - smallint('2') AS x FROM INT4_TBL i -WHERE f1 > -2147483647; - --- [SPARK-28024] Incorrect value when out of range -SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i; - -SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i -WHERE f1 > -2147483647; - -SELECT '' AS five, i.f1, i.f1 / smallint('2') AS x FROM INT4_TBL i; - -SELECT '' AS five, i.f1, i.f1 / int('2') AS x FROM INT4_TBL i; - --- --- more complex expressions --- - --- variations on unary minus parsing -SELECT -2+3 AS one; - -SELECT 4-2 AS two; - -SELECT 2- -1 AS three; - -SELECT 2 - -2 AS four; - -SELECT smallint('2') * smallint('2') = smallint('16') / smallint('4') AS true; - -SELECT int('2') * smallint('2') = smallint('16') / int('4') AS true; - -SELECT smallint('2') * int('2') = int('16') / smallint('4') AS true; - -SELECT int('1000') < int('999') AS false; - --- [SPARK-28027] Our ! and !! has different meanings --- SELECT 4! AS twenty_four; - --- SELECT !!3 AS six; - -SELECT 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 AS ten; - --- [SPARK-2659] HiveQL: Division operator should always perform fractional division -SELECT 2 + 2 / 2 AS three; - -SELECT (2 + 2) / 2 AS two; - --- [SPARK-28027] Add bitwise shift left/right operators --- corner case -SELECT string(shiftleft(int(-1), 31)); -SELECT string(int(shiftleft(int(-1), 31))+1); - --- [SPARK-28024] Incorrect numeric values when out of range --- check sane handling of INT_MIN overflow cases --- SELECT (-2147483648)::int4 * (-1)::int4; --- SELECT (-2147483648)::int4 / (-1)::int4; -SELECT int(-2147483648) % int(-1); --- SELECT (-2147483648)::int4 * (-1)::int2; --- SELECT (-2147483648)::int4 / (-1)::int2; -SELECT int(-2147483648) % smallint(-1); - -- [SPARK-28028] Cast numeric to integral type need round -- check rounding when casting from float SELECT x, int(x) AS int4_value diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out index 058fac18476d..c6826546435f 100644 --- a/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 101 +-- Number of queries: 54 -- !query 0 @@ -531,590 +531,8 @@ struct -- !query 53 -set spark.sql.arithmetic.checkOverflow=true --- !query 53 schema -struct --- !query 53 output -spark.sql.arithmetic.checkOverflow true - - --- !query 54 -SELECT '' AS five, * FROM INT4_TBL --- !query 54 schema -struct --- !query 54 output --123456 - -2147483647 - 0 - 123456 - 2147483647 - - --- !query 55 -SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> smallint('0') --- !query 55 schema -struct --- !query 55 output --123456 - -2147483647 - 123456 - 2147483647 - - --- !query 56 -SELECT '' AS four, i.* FROM INT4_TBL i WHERE i.f1 <> int('0') --- !query 56 schema -struct --- !query 56 output --123456 - -2147483647 - 123456 - 2147483647 - - --- !query 57 -SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = smallint('0') --- !query 57 schema -struct --- !query 57 output -0 - - --- !query 58 -SELECT '' AS one, i.* FROM INT4_TBL i WHERE i.f1 = int('0') --- !query 58 schema -struct --- !query 58 output -0 - - --- !query 59 -SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < smallint('0') --- !query 59 schema -struct --- !query 59 output --123456 - -2147483647 - - --- !query 60 -SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 < int('0') --- !query 60 schema -struct --- !query 60 output --123456 - -2147483647 - - --- !query 61 -SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= smallint('0') --- !query 61 schema -struct --- !query 61 output --123456 - -2147483647 - 0 - - --- !query 62 -SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 <= int('0') --- !query 62 schema -struct --- !query 62 output --123456 - -2147483647 - 0 - - --- !query 63 -SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > smallint('0') --- !query 63 schema -struct --- !query 63 output -123456 - 2147483647 - - --- !query 64 -SELECT '' AS two, i.* FROM INT4_TBL i WHERE i.f1 > int('0') --- !query 64 schema -struct --- !query 64 output -123456 - 2147483647 - - --- !query 65 -SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= smallint('0') --- !query 65 schema -struct --- !query 65 output -0 - 123456 - 2147483647 - - --- !query 66 -SELECT '' AS three, i.* FROM INT4_TBL i WHERE i.f1 >= int('0') --- !query 66 schema -struct --- !query 66 output -0 - 123456 - 2147483647 - - --- !query 67 -SELECT '' AS one, i.* FROM INT4_TBL i WHERE (i.f1 % smallint('2')) = smallint('1') --- !query 67 schema -struct --- !query 67 output -2147483647 - - --- !query 68 -SELECT '' AS three, i.* FROM INT4_TBL i WHERE (i.f1 % int('2')) = smallint('0') --- !query 68 schema -struct --- !query 68 output --123456 - 0 - 123456 - - --- !query 69 -SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i --- !query 69 schema -struct<> --- !query 69 output -org.apache.spark.SparkException -Job aborted due to stage failure: Task 0 in stage 2069.0 failed 1 times, most recent failure: Lost task 0.0 in stage 2069.0 (TID 91890, localhost, executor driver): java.lang.ArithmeticException: 2147483647 * 2 caused overflow. - at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_doConsume_0$(Unknown Source) - at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) - at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) - at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:701) - at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:292) - at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:852) - at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:852) - at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) - at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:327) - at org.apache.spark.rdd.RDD.iterator(RDD.scala:291) - at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) - at org.apache.spark.scheduler.Task.run(Task.scala:126) - at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:426) - at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1350) - at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:429) - at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) - at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) - at java.lang.Thread.run(Thread.java:748) - -Driver stacktrace: - - --- !query 70 -SELECT '' AS five, i.f1, i.f1 * smallint('2') AS x FROM INT4_TBL i -WHERE abs(f1) < 1073741824 --- !query 70 schema -struct --- !query 70 output --123456 -246912 - 0 0 - 123456 246912 - - --- !query 71 -SELECT '' AS five, i.f1, i.f1 * int('2') AS x FROM INT4_TBL i --- !query 71 schema -struct<> --- !query 71 output -org.apache.spark.SparkException -Job aborted due to stage failure: Task 0 in stage 2071.0 failed 1 times, most recent failure: Lost task 0.0 in stage 2071.0 (TID 91894, localhost, executor driver): java.lang.ArithmeticException: 2147483647 * 2 caused overflow. - at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_doConsume_0$(Unknown Source) - at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) - at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) - at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:701) - at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:292) - at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:852) - at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:852) - at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) - at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:327) - at org.apache.spark.rdd.RDD.iterator(RDD.scala:291) - at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) - at org.apache.spark.scheduler.Task.run(Task.scala:126) - at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:426) - at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1350) - at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:429) - at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) - at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) - at java.lang.Thread.run(Thread.java:748) - -Driver stacktrace: - - --- !query 72 -SELECT '' AS five, i.f1, i.f1 * int('2') AS x FROM INT4_TBL i -WHERE abs(f1) < 1073741824 --- !query 72 schema -struct --- !query 72 output --123456 -246912 - 0 0 - 123456 246912 - - --- !query 73 -SELECT '' AS five, i.f1, i.f1 + smallint('2') AS x FROM INT4_TBL i --- !query 73 schema -struct<> --- !query 73 output -org.apache.spark.SparkException -Job aborted due to stage failure: Task 0 in stage 2073.0 failed 1 times, most recent failure: Lost task 0.0 in stage 2073.0 (TID 91898, localhost, executor driver): java.lang.ArithmeticException: 2147483647 + 2 caused overflow. - at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_doConsume_0$(Unknown Source) - at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) - at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) - at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:701) - at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:292) - at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:852) - at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:852) - at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) - at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:327) - at org.apache.spark.rdd.RDD.iterator(RDD.scala:291) - at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) - at org.apache.spark.scheduler.Task.run(Task.scala:126) - at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:426) - at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1350) - at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:429) - at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) - at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) - at java.lang.Thread.run(Thread.java:748) - -Driver stacktrace: - - --- !query 74 -SELECT '' AS five, i.f1, i.f1 + smallint('2') AS x FROM INT4_TBL i -WHERE f1 < 2147483646 --- !query 74 schema -struct --- !query 74 output --123456 -123454 - -2147483647 -2147483645 - 0 2 - 123456 123458 - - --- !query 75 -SELECT '' AS five, i.f1, i.f1 + int('2') AS x FROM INT4_TBL i --- !query 75 schema -struct<> --- !query 75 output -org.apache.spark.SparkException -Job aborted due to stage failure: Task 0 in stage 2075.0 failed 1 times, most recent failure: Lost task 0.0 in stage 2075.0 (TID 91902, localhost, executor driver): java.lang.ArithmeticException: 2147483647 + 2 caused overflow. - at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_doConsume_0$(Unknown Source) - at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) - at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) - at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:701) - at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:292) - at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:852) - at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:852) - at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) - at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:327) - at org.apache.spark.rdd.RDD.iterator(RDD.scala:291) - at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) - at org.apache.spark.scheduler.Task.run(Task.scala:126) - at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:426) - at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1350) - at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:429) - at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) - at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) - at java.lang.Thread.run(Thread.java:748) - -Driver stacktrace: - - --- !query 76 -SELECT '' AS five, i.f1, i.f1 + int('2') AS x FROM INT4_TBL i -WHERE f1 < 2147483646 --- !query 76 schema -struct --- !query 76 output --123456 -123454 - -2147483647 -2147483645 - 0 2 - 123456 123458 - - --- !query 77 -SELECT '' AS five, i.f1, i.f1 - smallint('2') AS x FROM INT4_TBL i --- !query 77 schema -struct<> --- !query 77 output -org.apache.spark.SparkException -Job aborted due to stage failure: Task 1 in stage 2077.0 failed 1 times, most recent failure: Lost task 1.0 in stage 2077.0 (TID 91907, localhost, executor driver): java.lang.ArithmeticException: -2147483647 - 2 caused overflow. - at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_doConsume_0$(Unknown Source) - at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) - at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) - at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:701) - at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:292) - at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:852) - at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:852) - at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) - at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:327) - at org.apache.spark.rdd.RDD.iterator(RDD.scala:291) - at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) - at org.apache.spark.scheduler.Task.run(Task.scala:126) - at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:426) - at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1350) - at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:429) - at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) - at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) - at java.lang.Thread.run(Thread.java:748) - -Driver stacktrace: - - --- !query 78 -SELECT '' AS five, i.f1, i.f1 - smallint('2') AS x FROM INT4_TBL i -WHERE f1 > -2147483647 --- !query 78 schema -struct --- !query 78 output --123456 -123458 - 0 -2 - 123456 123454 - 2147483647 2147483645 - - --- !query 79 -SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i --- !query 79 schema -struct<> --- !query 79 output -org.apache.spark.SparkException -Job aborted due to stage failure: Task 1 in stage 2079.0 failed 1 times, most recent failure: Lost task 1.0 in stage 2079.0 (TID 91911, localhost, executor driver): java.lang.ArithmeticException: -2147483647 - 2 caused overflow. - at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.project_doConsume_0$(Unknown Source) - at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source) - at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43) - at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:701) - at org.apache.spark.sql.execution.SparkPlan.$anonfun$getByteArrayRdd$1(SparkPlan.scala:292) - at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:852) - at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:852) - at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) - at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:327) - at org.apache.spark.rdd.RDD.iterator(RDD.scala:291) - at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) - at org.apache.spark.scheduler.Task.run(Task.scala:126) - at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:426) - at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1350) - at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:429) - at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) - at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) - at java.lang.Thread.run(Thread.java:748) - -Driver stacktrace: - - --- !query 80 -SELECT '' AS five, i.f1, i.f1 - int('2') AS x FROM INT4_TBL i -WHERE f1 > -2147483647 --- !query 80 schema -struct --- !query 80 output --123456 -123458 - 0 -2 - 123456 123454 - 2147483647 2147483645 - - --- !query 81 -SELECT '' AS five, i.f1, i.f1 / smallint('2') AS x FROM INT4_TBL i --- !query 81 schema -struct --- !query 81 output --123456 -61728.0 - -2147483647 -1.0737418235E9 - 0 0.0 - 123456 61728.0 - 2147483647 1.0737418235E9 - - --- !query 82 -SELECT '' AS five, i.f1, i.f1 / int('2') AS x FROM INT4_TBL i --- !query 82 schema -struct --- !query 82 output --123456 -61728.0 - -2147483647 -1.0737418235E9 - 0 0.0 - 123456 61728.0 - 2147483647 1.0737418235E9 - - --- !query 83 -SELECT -2+3 AS one --- !query 83 schema -struct --- !query 83 output -1 - - --- !query 84 -SELECT 4-2 AS two --- !query 84 schema -struct --- !query 84 output -2 - - --- !query 85 -SELECT 2- -1 AS three --- !query 85 schema -struct --- !query 85 output -3 - - --- !query 86 -SELECT 2 - -2 AS four --- !query 86 schema -struct --- !query 86 output -4 - - --- !query 87 -SELECT smallint('2') * smallint('2') = smallint('16') / smallint('4') AS true --- !query 87 schema -struct --- !query 87 output -true - - --- !query 88 -SELECT int('2') * smallint('2') = smallint('16') / int('4') AS true --- !query 88 schema -struct --- !query 88 output -true - - --- !query 89 -SELECT smallint('2') * int('2') = int('16') / smallint('4') AS true --- !query 89 schema -struct --- !query 89 output -true - - --- !query 90 -SELECT int('1000') < int('999') AS false --- !query 90 schema -struct --- !query 90 output -false - - --- !query 91 -SELECT 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1 AS ten --- !query 91 schema -struct --- !query 91 output -10 - - --- !query 92 -SELECT 2 + 2 / 2 AS three --- !query 92 schema -struct --- !query 92 output -3.0 - - --- !query 93 -SELECT (2 + 2) / 2 AS two --- !query 93 schema -struct --- !query 93 output -2.0 - - --- !query 94 -SELECT string(shiftleft(int(-1), 31)) --- !query 94 schema -struct --- !query 94 output --2147483648 - - --- !query 95 -SELECT string(int(shiftleft(int(-1), 31))+1) --- !query 95 schema -struct --- !query 95 output --2147483647 - - --- !query 96 -SELECT int(-2147483648) % int(-1) --- !query 96 schema -struct<(CAST(-2147483648 AS INT) % CAST(-1 AS INT)):int> --- !query 96 output -0 - - --- !query 97 -SELECT int(-2147483648) % smallint(-1) --- !query 97 schema -struct<(CAST(-2147483648 AS INT) % CAST(CAST(-1 AS SMALLINT) AS INT)):int> --- !query 97 output -0 - - --- !query 98 -SELECT x, int(x) AS int4_value -FROM (VALUES double(-2.5), - double(-1.5), - double(-0.5), - double(0.0), - double(0.5), - double(1.5), - double(2.5)) t(x) --- !query 98 schema -struct --- !query 98 output --0.5 0 --1.5 -1 --2.5 -2 -0.0 0 -0.5 0 -1.5 1 -2.5 2 - - --- !query 99 -SELECT x, int(x) AS int4_value -FROM (VALUES cast(-2.5 as decimal(38, 18)), - cast(-1.5 as decimal(38, 18)), - cast(-0.5 as decimal(38, 18)), - cast(-0.0 as decimal(38, 18)), - cast(0.5 as decimal(38, 18)), - cast(1.5 as decimal(38, 18)), - cast(2.5 as decimal(38, 18))) t(x) --- !query 99 schema -struct --- !query 99 output --0.5 0 --1.5 -1 --2.5 -2 -0 0 -0.5 0 -1.5 1 -2.5 2 - - --- !query 100 DROP TABLE INT4_TBL --- !query 100 schema +-- !query 53 schema struct<> --- !query 100 output +-- !query 53 output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 3e9c812bafb0..1de01f114b88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructType} @@ -410,10 +411,16 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { test("SPARK-24598: sum throws exception instead of silently overflow") { val df1 = Seq(Long.MinValue, -10, Long.MaxValue).toDF("i") - checkAnswer(df1.agg(sum($"i")), Row(-11)) val df2 = Seq(Long.MinValue, -10, 8).toDF("i") - val e = intercept[SparkException](df2.agg(sum($"i")).collect()) - assert(e.getCause.isInstanceOf[ArithmeticException]) + withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "true") { + checkAnswer(df1.agg(sum($"i")), Row(-11)) + val e = intercept[SparkException](df2.agg(sum($"i")).collect()) + assert(e.getCause.isInstanceOf[ArithmeticException]) + } + withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "false") { + checkAnswer(df1.agg(sum($"i")), Row(-11)) + checkAnswer(df2.agg(sum($"i")), Row(Long.MaxValue - 1)) + } } test("SPARK-24569: Aggregator with output type Option[Boolean] creates column of type Row") { From 38fc1f4ae65e12b02c42d751af184a5dd88fd69c Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 15 Jul 2019 22:25:13 +0200 Subject: [PATCH 12/23] fix typo --- sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql index 9a0bfd8a2751..218249f4a2aa 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql @@ -33,7 +33,7 @@ INSERT INTO INT4_TBL VALUES ('-2147483647'); -- INSERT INTO INT4_TBL(f1) VALUES ('123 5'); -- INSERT INTO INT4_TBL(f1) VALUES (''); --- We cannot test this when checkOverflow=false here +-- We cannot test this when checkOverflow=true here -- because exception happens in the executors and the -- output stacktrace cannot have an exact match set spark.sql.arithmetic.checkOverflow=false; From 37e19ceae061875bd2d24023ff2f275b358a69f1 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 15 Jul 2019 23:13:20 +0200 Subject: [PATCH 13/23] fix --- .../src/test/resources/sql-tests/results/pgSQL/int4.sql.out | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out index a1e81854d149..1f48205258b3 100644 --- a/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out @@ -417,7 +417,7 @@ true -- !query 42 -SELECT int('1000') < int('999') AS `false` +SELECT smallint('2') * int('2') = int('16') / smallint('4') AS true -- !query 42 schema struct -- !query 42 output @@ -425,7 +425,7 @@ true -- !query 43 -SELECT int('1000') < int('999') AS false +SELECT int('1000') < int('999') AS `false` -- !query 43 schema struct -- !query 43 output From 98bbf8321c679ec1255eac629fb47a6825ca0986 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 20 Jul 2019 10:50:47 +0200 Subject: [PATCH 14/23] address comments --- .../sql/catalyst/expressions/aggregate/Sum.scala | 11 +++++++---- .../org/apache/spark/sql/internal/SQLConf.scala | 15 ++++++++------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 773e722efa31..f9fffb52378b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -61,9 +61,6 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast case _ => resultType } - private lazy val castToResultType: (Expression) => Expression = - if (sumDataType == resultType) (e: Expression) => e else (e: Expression) => Cast(e, resultType) - private lazy val sum = AttributeReference("sum", sumDataType)() private lazy val zero = Cast(Literal(0), sumDataType) @@ -95,5 +92,11 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast ) } - override lazy val evaluateExpression: Expression = castToResultType(sum) + override lazy val evaluateExpression: Expression = { + if (sumDataType == resultType) { + sum + } else { + Cast(sum, resultType) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index cb8d3d50ef31..09376bd0abbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1780,13 +1780,14 @@ object SQLConf { .booleanConf .createWithDefault(false) - val ARITHMETIC_OPERATION_OVERFLOW_CHECK = buildConf("spark.sql.arithmetic.checkOverflow") - .doc("If it is set to true, all arithmetic operations on non-decimal fields throw an " + - "exception if an overflow occurs. If it is false (default), in case of overflow a wrong " + - "result is returned.") - .internal() - .booleanConf - .createWithDefault(false) + val ARITHMETIC_OPERATION_OVERFLOW_CHECK = + buildConf("spark.sql.arithmeticOperations.failOnOverFlow") + .doc("If it is set to true, all arithmetic operations on non-decimal fields throw an " + + "exception if an overflow occurs. If it is false (default), in case of overflow a wrong " + + "result is returned.") + .internal() + .booleanConf + .createWithDefault(false) val LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE = buildConf("spark.sql.legacy.parser.havingWithoutGroupByAsWhere") From 650ea796dafa5115f49115042b409e300ebc79de Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 20 Jul 2019 15:58:32 +0200 Subject: [PATCH 15/23] fix --- .../sql-tests/results/pgSQL/int4.sql.out | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out index 247b7652a298..fd2eb657fec1 100644 --- a/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out @@ -359,13 +359,13 @@ struct -- !query 35 SELECT '' AS five, i.f1, i.f1 / int('2') AS x FROM INT4_TBL i -- !query 35 schema -struct +struct -- !query 35 output --123456 -61728.0 - -2147483647 -1.0737418235E9 - 0 0.0 - 123456 61728.0 - 2147483647 1.0737418235E9 +-123456 -61728 + -2147483647 -1073741823 + 0 0 + 123456 61728 + 2147483647 1073741823 -- !query 36 @@ -443,17 +443,17 @@ struct -- !query 45 SELECT 2 + 2 / 2 AS three -- !query 45 schema -struct +struct -- !query 45 output -3.0 +3 -- !query 46 SELECT (2 + 2) / 2 AS two -- !query 46 schema -struct +struct -- !query 46 output -2.0 +2 -- !query 47 From 1d20f735d911ec3f7c0e0558eaec7e5744693a8e Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 26 Jul 2019 23:55:32 +0200 Subject: [PATCH 16/23] address comments --- .../sql/catalyst/expressions/arithmetic.scala | 111 ++++++------------ .../expressions/bitwiseExpressions.scala | 6 - .../spark/sql/catalyst/util/TypeUtils.scala | 9 +- .../spark/sql/types/AbstractDataType.scala | 2 + .../org/apache/spark/sql/types/ByteType.scala | 1 + .../apache/spark/sql/types/IntegerType.scala | 1 + .../org/apache/spark/sql/types/LongType.scala | 1 + .../apache/spark/sql/types/ShortType.scala | 1 + 8 files changed, 51 insertions(+), 81 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index a18f53897efa..d38fb37cc57d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -131,39 +131,56 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { def calendarIntervalMethod: String = sys.error("BinaryArithmetics must override either calendarIntervalMethod or genCode") - def checkOverflowCode(result: String, op1: String, op2: String): String = - sys.error("BinaryArithmetics must override either checkOverflowCode or genCode") + /** Name of the function for the exact version of this expression in [[Math]]. */ + def exactMathMethod: String = + sys.error("BinaryArithmetics must override either exactMathMethod or genCode") override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case _: DecimalType => + // Overflow is handled in the CheckOverflow operator defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)") - // In the following cases, overflow can happen, so we need to check the result is valid. - // Otherwise we throw an ArithmeticException // byte and short are casted into int when add, minus, times or divide - case ByteType | ShortType => + case dt @ ByteType | ShortType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { + val tmpResult = ctx.freshName("tmpResult") val overflowCheck = if (checkOverflow) { - checkOverflowCode(ev.value, eval1, eval2) + val maxValue = 2 << (dt.defaultSize * 8 - 2) - 1 + val minValue = - 2 << (dt.defaultSize * 8 - 2) + s""" + |if ($tmpResult < $minValue || $tmpResult > $maxValue) { + | throw new ArithmeticException($eval1 + " $symbol " + $eval2 + " caused overflow."); + |} + """.stripMargin } else { "" } s""" - |${ev.value} = (${CodeGenerator.javaType(dataType)})($eval1 $symbol $eval2); + |${CodeGenerator.JAVA_INT} $tmpResult = $eval1 $symbol $eval2; |$overflowCheck + |${ev.value} = (${CodeGenerator.javaType(dataType)})($tmpResult); """.stripMargin }) - case _ => + case IntegerType | LongType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { - val overflowCheck = if (checkOverflow) { - checkOverflowCode(ev.value, eval1, eval2) + val operation = if (checkOverflow) { + val mathClass = classOf[Math].getName + s"$mathClass.$exactMathMethod($eval1, $eval2)" } else { - "" + s"$eval1 $symbol $eval2" } + s""" + |${ev.value} = $operation; + """.stripMargin + }) + case DoubleType | FloatType => + // When Double/Float overflows, there can be 2 cases: + // - precision loss: according to SQL standard, the number is truncated; + // - returns (+/-)Infinite: same behavior also other DBs have (eg. Postgres) + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" |${ev.value} = $eval1 $symbol $eval2; - |$overflowCheck """.stripMargin }) } @@ -190,33 +207,17 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def calendarIntervalMethod: String = "add" - private lazy val numeric = TypeUtils.getNumeric(dataType) + private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) protected override def nullSafeEval(input1: Any, input2: Any): Any = { if (dataType.isInstanceOf[CalendarIntervalType]) { input1.asInstanceOf[CalendarInterval].add(input2.asInstanceOf[CalendarInterval]) } else { - val result = numeric.plus(input1, input2) - if (checkOverflow) { - val resSignum = numeric.signum(result) - val input1Signum = numeric.signum(input1) - val input2Signum = numeric.signum(input2) - if (resSignum != -1 && input1Signum == -1 && input2Signum == -1 - || resSignum != 1 && input1Signum == 1 && input2Signum == 1) { - throw new ArithmeticException(s"$input1 + $input2 caused overflow.") - } - } - result + numeric.plus(input1, input2) } } - override def checkOverflowCode(result: String, op1: String, op2: String): String = { - s""" - |if ($result >= 0 && $op1 < 0 && $op2 < 0 || $result <= 0 && $op1 > 0 && $op2 > 0) { - | throw new ArithmeticException($op1 + " + " + $op2 + " caused overflow."); - |} - """.stripMargin - } + override def exactMathMethod: String = "addExact" } @ExpressionDescription( @@ -236,33 +237,17 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti override def calendarIntervalMethod: String = "subtract" - private lazy val numeric = TypeUtils.getNumeric(dataType) + private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) protected override def nullSafeEval(input1: Any, input2: Any): Any = { if (dataType.isInstanceOf[CalendarIntervalType]) { input1.asInstanceOf[CalendarInterval].subtract(input2.asInstanceOf[CalendarInterval]) } else { - val result = numeric.minus(input1, input2) - if (checkOverflow) { - val resSignum = numeric.signum(result) - val input1Signum = numeric.signum(input1) - val input2Signum = numeric.signum(input2) - if (resSignum != 1 && input1Signum == 1 && input2Signum == -1 - || resSignum != -1 && input1Signum == -1 && input2Signum == 1) { - throw new ArithmeticException(s"$input1 - $input2 caused overflow.") - } - } - result + numeric.minus(input1, input2) } } - override def checkOverflowCode(result: String, op1: String, op2: String): String = { - s""" - |if ($result <= 0 && $op1 > 0 && $op2 < 0 || $result >= 0 && $op1 < 0 && $op2 > 0) { - | throw new ArithmeticException($op1 + " - " + $op2 + " caused overflow."); - |} - """.stripMargin - } + override def exactMathMethod: String = "subtractExact" } @ExpressionDescription( @@ -279,31 +264,11 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti override def symbol: String = "*" override def decimalMethod: String = "$times" - private lazy val numeric = TypeUtils.getNumeric(dataType) + private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - val result = numeric.times(input1, input2) - if (checkOverflow) { - if (numeric.signum(result) != numeric.signum(input1) * numeric.signum(input2) && - !(result.isInstanceOf[Double] && !result.asInstanceOf[Double].isNaN) && - !(result.isInstanceOf[Float] && !result.asInstanceOf[Float].isNaN)) { - throw new ArithmeticException(s"$input1 * $input2 caused overflow.") - } - } - result - } + protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) - override def checkOverflowCode(result: String, op1: String, op2: String): String = { - val isNaNCheck = dataType match { - case DoubleType | FloatType => s" && !java.lang.Double.isNaN($result)" - case _ => "" - } - s""" - |if (Math.signum($result) != Math.signum($op1) * Math.signum($op2)$isNaNCheck) { - | throw new ArithmeticException($op1 + " * " + $op2 + " caused overflow."); - |} - """.stripMargin - } + override def exactMathMethod: String = "multiplyExact" } // Common base trait for Divide and Remainder, since these two classes are almost identical diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 9b2c73891cea..c766bd8e56bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -51,8 +51,6 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme } protected override def nullSafeEval(input1: Any, input2: Any): Any = and(input1, input2) - - override def checkOverflowCode(result: String, op1: String, op2: String): String = "" } /** @@ -85,8 +83,6 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet } protected override def nullSafeEval(input1: Any, input2: Any): Any = or(input1, input2) - - override def checkOverflowCode(result: String, op1: String, op2: String): String = "" } /** @@ -119,8 +115,6 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme } protected override def nullSafeEval(input1: Any, input2: Any): Any = xor(input1, input2) - - override def checkOverflowCode(result: String, op1: String, op2: String): String = "" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index fed2a1ac4b8b..9680ea3cd206 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -60,8 +60,13 @@ object TypeUtils { } } - def getNumeric(t: DataType): Numeric[Any] = - t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] + def getNumeric(t: DataType, exactNumericRequired: Boolean = false): Numeric[Any] = { + if (exactNumericRequired) { + t.asInstanceOf[NumericType].exactNumeric.asInstanceOf[Numeric[Any]] + } else { + t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] + } + } def getInterpretedOrdering(t: DataType): Ordering[Any] = { t match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index d2ef08873187..20caf05bc580 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -142,6 +142,8 @@ abstract class NumericType extends AtomicType { // desugared by the compiler into an argument to the objects constructor. This means there is no // longer a no argument constructor and thus the JVM cannot serialize the object anymore. private[sql] val numeric: Numeric[InternalType] + + private[sql] val exactNumeric: Numeric[InternalType] = numeric } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala index 9d400eefc0f8..0df9518045f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala @@ -37,6 +37,7 @@ class ByteType private() extends IntegralType { private[sql] val numeric = implicitly[Numeric[Byte]] private[sql] val integral = implicitly[Integral[Byte]] private[sql] val ordering = implicitly[Ordering[InternalType]] + override private[sql] val exactNumeric = ByteExactNumeric /** * The default size of a value of the ByteType is 1 byte. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala index 0755202d20df..c344523bdcb8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala @@ -37,6 +37,7 @@ class IntegerType private() extends IntegralType { private[sql] val numeric = implicitly[Numeric[Int]] private[sql] val integral = implicitly[Integral[Int]] private[sql] val ordering = implicitly[Ordering[InternalType]] + override private[sql] val exactNumeric = IntegerExactNumeric /** * The default size of a value of the IntegerType is 4 bytes. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala index 3c49c721fdc8..f030920db451 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala @@ -37,6 +37,7 @@ class LongType private() extends IntegralType { private[sql] val numeric = implicitly[Numeric[Long]] private[sql] val integral = implicitly[Integral[Long]] private[sql] val ordering = implicitly[Ordering[InternalType]] + override private[sql] val exactNumeric = LongExactNumeric /** * The default size of a value of the LongType is 8 bytes. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala index 9b5ddfef1ccf..825268995853 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala @@ -37,6 +37,7 @@ class ShortType private() extends IntegralType { private[sql] val numeric = implicitly[Numeric[Short]] private[sql] val integral = implicitly[Integral[Short]] private[sql] val ordering = implicitly[Ordering[InternalType]] + override private[sql] val exactNumeric = ShortExactNumeric /** * The default size of a value of the ShortType is 2 bytes. From 538e3324d30c9f22c23988b397ccff3f24cba4bb Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Fri, 26 Jul 2019 23:55:56 +0200 Subject: [PATCH 17/23] address comments --- .../org/apache/spark/sql/types/numerics.scala | 92 +++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala new file mode 100644 index 000000000000..dac7c2b2aebc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Numeric.{ByteIsIntegral, IntIsIntegral, LongIsIntegral, ShortIsIntegral} +import scala.math.Ordering + + +object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOrdering { + private def checkOverflow(res: Int, x: Byte, y: Byte, op: String): Unit = { + if (x > Byte.MaxValue || x < Byte.MinValue) { + throw new ArithmeticException(s"$x $op $y caused overflow.") + } + } + + override def plus(x: Byte, y: Byte): Byte = { + val tmp = x + y + checkOverflow(tmp, x, y, "+") + tmp.toByte + } + + override def minus(x: Byte, y: Byte): Byte = { + val tmp = x - y + checkOverflow(tmp, x, y, "-") + tmp.toByte + } + + override def times(x: Byte, y: Byte): Byte = { + val tmp = x * y + checkOverflow(tmp, x, y, "*") + tmp.toByte + } +} + + +object ShortExactNumeric extends ShortIsIntegral with Ordering.ShortOrdering { + private def checkOverflow(res: Int, x: Short, y: Short, op: String): Unit = { + if (x > Short.MaxValue || x < Short.MinValue) { + throw new ArithmeticException(s"$x $op $y caused overflow.") + } + } + + override def plus(x: Short, y: Short): Short = { + val tmp = x + y + checkOverflow(tmp, x, y, "+") + tmp.toShort + } + + override def minus(x: Short, y: Short): Short = { + val tmp = x - y + checkOverflow(tmp, x, y, "-") + tmp.toShort + } + + override def times(x: Short, y: Short): Short = { + val tmp = x * y + checkOverflow(tmp, x, y, "*") + tmp.toShort + } +} + + +object IntegerExactNumeric extends IntIsIntegral with Ordering.IntOrdering { + override def plus(x: Int, y: Int): Int = Math.addExact(x, y) + + override def minus(x: Int, y: Int): Int = Math.subtractExact(x, y) + + override def times(x: Int, y: Int): Int = Math.multiplyExact(x, y) +} + +object LongExactNumeric extends LongIsIntegral with Ordering.LongOrdering { + override def plus(x: Long, y: Long): Long = Math.addExact(x, y) + + override def minus(x: Long, y: Long): Long = Math.subtractExact(x, y) + + override def times(x: Long, y: Long): Long = Math.multiplyExact(x, y) +} From 3de4bfbcd3417f30013a84013cf65f0d152b8948 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 27 Jul 2019 10:03:53 +0200 Subject: [PATCH 18/23] fix --- .../apache/spark/sql/catalyst/expressions/arithmetic.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d38fb37cc57d..85c7366a04da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -142,14 +142,13 @@ abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { case CalendarIntervalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$calendarIntervalMethod($eval2)") // byte and short are casted into int when add, minus, times or divide - case dt @ ByteType | ShortType => + case ByteType | ShortType => nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val tmpResult = ctx.freshName("tmpResult") val overflowCheck = if (checkOverflow) { - val maxValue = 2 << (dt.defaultSize * 8 - 2) - 1 - val minValue = - 2 << (dt.defaultSize * 8 - 2) + val javaType = CodeGenerator.boxedType(dataType) s""" - |if ($tmpResult < $minValue || $tmpResult > $maxValue) { + |if ($tmpResult < $javaType.MIN_VALUE || $tmpResult > $javaType.MAX_VALUE) { | throw new ArithmeticException($eval1 + " $symbol " + $eval2 + " caused overflow."); |} """.stripMargin From 3baecbca9e3c39fb939d64e81ef466aea32d1bd3 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 27 Jul 2019 13:02:53 +0200 Subject: [PATCH 19/23] fixes --- .../scala/org/apache/spark/sql/types/AbstractDataType.scala | 2 +- .../src/main/scala/org/apache/spark/sql/types/numerics.scala | 4 ++-- .../sql/catalyst/expressions/ArithmeticExpressionSuite.scala | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 20caf05bc580..21ac32adca6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -143,7 +143,7 @@ abstract class NumericType extends AtomicType { // longer a no argument constructor and thus the JVM cannot serialize the object anymore. private[sql] val numeric: Numeric[InternalType] - private[sql] val exactNumeric: Numeric[InternalType] = numeric + private[sql] def exactNumeric: Numeric[InternalType] = numeric } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala index dac7c2b2aebc..6fc5f63c5c15 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala @@ -23,7 +23,7 @@ import scala.math.Ordering object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOrdering { private def checkOverflow(res: Int, x: Byte, y: Byte, op: String): Unit = { - if (x > Byte.MaxValue || x < Byte.MinValue) { + if (res > Byte.MaxValue || res < Byte.MinValue) { throw new ArithmeticException(s"$x $op $y caused overflow.") } } @@ -50,7 +50,7 @@ object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOrdering { object ShortExactNumeric extends ShortIsIntegral with Ordering.ShortOrdering { private def checkOverflow(res: Int, x: Short, y: Short, op: String): Unit = { - if (x > Short.MaxValue || x < Short.MinValue) { + if (res > Short.MaxValue || res < Short.MinValue) { throw new ArithmeticException(s"$x $op $y caused overflow.") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 32a97aff9f41..d753a57055cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -400,7 +400,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper val e6 = Multiply(minLongLiteral, minLongLiteral) withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "true") { Seq(e1, e2, e3, e4, e5, e6).foreach { e => - checkExceptionInExpression[ArithmeticException](e, "caused overflow") + checkExceptionInExpression[ArithmeticException](e, "overflow") } } withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "false") { From a247f9fd504412b20b52d5a59b8a0f9040ab8171 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 27 Jul 2019 16:42:09 +0200 Subject: [PATCH 20/23] fix unaryminus --- .../sql/catalyst/expressions/arithmetic.scala | 23 +++++++++++++++++-- .../apache/spark/sql/internal/SQLConf.scala | 2 ++ .../org/apache/spark/sql/types/Decimal.scala | 2 +- .../org/apache/spark/sql/types/numerics.scala | 18 +++++++++++++++ .../ArithmeticExpressionSuite.scala | 10 ++++++++ 5 files changed, 52 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 85c7366a04da..0ed05b9b5911 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -35,6 +35,7 @@ import org.apache.spark.unsafe.types.CalendarInterval """) case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { + private val checkOverflow = SQLConf.get.arithmeticOperationOverflowCheck override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) @@ -42,10 +43,28 @@ case class UnaryMinus(child: Expression) extends UnaryExpression override def toString: String = s"-$child" - private lazy val numeric = TypeUtils.getNumeric(dataType) + private lazy val numeric = TypeUtils.getNumeric(dataType, checkOverflow) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") + case ByteType | ShortType if checkOverflow => + nullSafeCodeGen(ctx, ev, eval => { + val javaBoxedType = CodeGenerator.boxedType(dataType) + val javaType = CodeGenerator.javaType(dataType) + val originValue = ctx.freshName("origin") + s""" + |$javaType $originValue = ($javaType)($eval); + |if ($originValue == $javaBoxedType.MIN_VALUE) { + | throw new ArithmeticException("- " + $originValue + " caused overflow."); + |} + |${ev.value} = ($javaType)(-($originValue)); + """.stripMargin + }) + case IntegerType | LongType if checkOverflow => + nullSafeCodeGen(ctx, ev, eval => { + val mathClass = classOf[Math].getName + s"${ev.value} = $mathClass.negateExact(-($eval));" + }) case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { val originValue = ctx.freshName("origin") // codegen would fail to compile if we just write (-($c)) @@ -117,7 +136,7 @@ case class Abs(child: Expression) abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { - protected val checkOverflow = SQLConf.get.getConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK) + protected val checkOverflow = SQLConf.get.arithmeticOperationOverflowCheck override def dataType: DataType = left.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 09376bd0abbc..7afe704b4ef7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2296,6 +2296,8 @@ class SQLConf extends Serializable with Logging { def decimalOperationsNullOnOverflow: Boolean = getConf(DECIMAL_OPERATIONS_NULL_ON_OVERFLOW) + def arithmeticOperationOverflowCheck: Boolean = getConf(ARITHMETIC_OPERATION_OVERFLOW_CHECK) + def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION) def continuousStreamingEpochBacklogQueueSize: Int = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index fe7691977e01..90318b14149e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -229,7 +229,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.eq(null)) { longVal / POW_10(_scale) } else { - if (SQLConf.get.getConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK)) { + if (SQLConf.get.arithmeticOperationOverflowCheck) { // This will throw an exception if overflow occurs if (decimalVal.compare(LONG_MIN_BIG_DEC) < 0 || decimalVal.compare(LONG_MAX_BIG_DEC) > 0) { throw new ArithmeticException("Overflow") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala index 6fc5f63c5c15..8f2844b557bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala @@ -45,6 +45,13 @@ object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOrdering { checkOverflow(tmp, x, y, "*") tmp.toByte } + + override def negate(x: Byte): Byte = { + if (x == Byte.MinValue) { // if and only if x is Byte.MinValue, overflow can happen + throw new ArithmeticException(s"- $x caused overflow.") + } + (-x).toByte + } } @@ -72,6 +79,13 @@ object ShortExactNumeric extends ShortIsIntegral with Ordering.ShortOrdering { checkOverflow(tmp, x, y, "*") tmp.toShort } + + override def negate(x: Short): Short = { + if (x == Short.MinValue) { // if and only if x is Byte.MinValue, overflow can happen + throw new ArithmeticException(s"- $x caused overflow.") + } + (-x).toByte + } } @@ -81,6 +95,8 @@ object IntegerExactNumeric extends IntIsIntegral with Ordering.IntOrdering { override def minus(x: Int, y: Int): Int = Math.subtractExact(x, y) override def times(x: Int, y: Int): Int = Math.multiplyExact(x, y) + + override def negate(x: Int): Int = Math.negateExact(x) } object LongExactNumeric extends LongIsIntegral with Ordering.LongOrdering { @@ -89,4 +105,6 @@ object LongExactNumeric extends LongIsIntegral with Ordering.LongOrdering { override def minus(x: Long, y: Long): Long = Math.subtractExact(x, y) override def times(x: Long, y: Long): Long = Math.multiplyExact(x, y) + + override def negate(x: Long): Long = Math.negateExact(x) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index d753a57055cd..a0f90dd5a92c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -79,6 +79,16 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(UnaryMinus(Literal(Int.MinValue)), Int.MinValue) checkEvaluation(UnaryMinus(Literal(Short.MinValue)), Short.MinValue) checkEvaluation(UnaryMinus(Literal(Byte.MinValue)), Byte.MinValue) + withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "true") { + checkExceptionInExpression[ArithmeticException]( + UnaryMinus(Literal(Long.MinValue)), "overflow") + checkExceptionInExpression[ArithmeticException]( + UnaryMinus(Literal(Int.MinValue)), "overflow") + checkExceptionInExpression[ArithmeticException]( + UnaryMinus(Literal(Short.MinValue)), "overflow") + checkExceptionInExpression[ArithmeticException]( + UnaryMinus(Literal(Byte.MinValue)), "overflow") + } checkEvaluation(UnaryMinus(positiveShortLit), (- positiveShort).toShort) checkEvaluation(UnaryMinus(negativeShortLit), (- negativeShort).toShort) checkEvaluation(UnaryMinus(positiveIntLit), - positiveInt) From 582d148f11737bb67f6bc126f540bdde452f1d85 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 30 Jul 2019 18:08:24 +0200 Subject: [PATCH 21/23] address comments --- .../catalyst/expressions/aggregate/Sum.scala | 13 ++----------- .../sql/catalyst/expressions/arithmetic.scala | 6 +++--- .../apache/spark/sql/internal/SQLConf.scala | 4 ++-- .../org/apache/spark/sql/types/Decimal.scala | 10 ---------- .../ArithmeticExpressionSuite.scala | 18 ++++++++++++------ .../sql/catalyst/expressions/CastSuite.scala | 14 -------------- .../resources/sql-tests/inputs/pgSQL/int4.sql | 4 ++-- .../sql-tests/results/pgSQL/int4.sql.out | 4 ++-- .../spark/sql/DatasetAggregatorSuite.scala | 16 ---------------- 9 files changed, 23 insertions(+), 66 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index f9fffb52378b..ef204ec82c52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -56,10 +56,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast case _ => DoubleType } - private lazy val sumDataType = child.dataType match { - case LongType => DecimalType.BigIntDecimal - case _ => resultType - } + private lazy val sumDataType = resultType private lazy val sum = AttributeReference("sum", sumDataType)() @@ -92,11 +89,5 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast ) } - override lazy val evaluateExpression: Expression = { - if (sumDataType == resultType) { - sum - } else { - Cast(sum, resultType) - } - } + override lazy val evaluateExpression: Expression = sum } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 0ed05b9b5911..10c5b5627d6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -35,7 +35,7 @@ import org.apache.spark.unsafe.types.CalendarInterval """) case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { - private val checkOverflow = SQLConf.get.arithmeticOperationOverflowCheck + private val checkOverflow = SQLConf.get.arithmeticOperationsFailOnOverflow override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) @@ -63,7 +63,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression case IntegerType | LongType if checkOverflow => nullSafeCodeGen(ctx, ev, eval => { val mathClass = classOf[Math].getName - s"${ev.value} = $mathClass.negateExact(-($eval));" + s"${ev.value} = $mathClass.negateExact($eval);" }) case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { val originValue = ctx.freshName("origin") @@ -136,7 +136,7 @@ case class Abs(child: Expression) abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { - protected val checkOverflow = SQLConf.get.arithmeticOperationOverflowCheck + protected val checkOverflow = SQLConf.get.arithmeticOperationsFailOnOverflow override def dataType: DataType = left.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7afe704b4ef7..2fede591fc80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1780,7 +1780,7 @@ object SQLConf { .booleanConf .createWithDefault(false) - val ARITHMETIC_OPERATION_OVERFLOW_CHECK = + val ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW = buildConf("spark.sql.arithmeticOperations.failOnOverFlow") .doc("If it is set to true, all arithmetic operations on non-decimal fields throw an " + "exception if an overflow occurs. If it is false (default), in case of overflow a wrong " + @@ -2296,7 +2296,7 @@ class SQLConf extends Serializable with Logging { def decimalOperationsNullOnOverflow: Boolean = getConf(DECIMAL_OPERATIONS_NULL_ON_OVERFLOW) - def arithmeticOperationOverflowCheck: Boolean = getConf(ARITHMETIC_OPERATION_OVERFLOW_CHECK) + def arithmeticOperationsFailOnOverflow: Boolean = getConf(ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW) def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 90318b14149e..1bf322af2179 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -22,7 +22,6 @@ import java.math.{BigInteger, MathContext, RoundingMode} import org.apache.spark.annotation.Unstable import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.internal.SQLConf /** * A mutable implementation of BigDecimal that can hold a Long if values are small enough. @@ -229,12 +228,6 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.eq(null)) { longVal / POW_10(_scale) } else { - if (SQLConf.get.arithmeticOperationOverflowCheck) { - // This will throw an exception if overflow occurs - if (decimalVal.compare(LONG_MIN_BIG_DEC) < 0 || decimalVal.compare(LONG_MAX_BIG_DEC) > 0) { - throw new ArithmeticException("Overflow") - } - } decimalVal.longValue() } } @@ -463,9 +456,6 @@ object Decimal { private val LONG_MAX_BIG_INT = BigInteger.valueOf(JLong.MAX_VALUE) private val LONG_MIN_BIG_INT = BigInteger.valueOf(JLong.MIN_VALUE) - private val LONG_MAX_BIG_DEC = BigDecimal.valueOf(JLong.MAX_VALUE) - private val LONG_MIN_BIG_DEC = BigDecimal.valueOf(JLong.MIN_VALUE) - def apply(value: Double): Decimal = new Decimal().set(value) def apply(value: Long): Decimal = new Decimal().set(value) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index a0f90dd5a92c..fd43f53a331c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -60,7 +60,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Add(positiveLongLit, negativeLongLit), -1L) Seq("true", "false").foreach { checkOverflow => - withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> checkOverflow) { + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> checkOverflow) { DataTypeTestUtils.numericAndInterval.foreach { tpe => checkConsistencyBetweenInterpretedAndCodegenAllowingException(Add, tpe, tpe) } @@ -79,7 +79,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(UnaryMinus(Literal(Int.MinValue)), Int.MinValue) checkEvaluation(UnaryMinus(Literal(Short.MinValue)), Short.MinValue) checkEvaluation(UnaryMinus(Literal(Byte.MinValue)), Byte.MinValue) - withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "true") { + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true") { checkExceptionInExpression[ArithmeticException]( UnaryMinus(Literal(Long.MinValue)), "overflow") checkExceptionInExpression[ArithmeticException]( @@ -88,6 +88,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper UnaryMinus(Literal(Short.MinValue)), "overflow") checkExceptionInExpression[ArithmeticException]( UnaryMinus(Literal(Byte.MinValue)), "overflow") + checkEvaluation(UnaryMinus(positiveShortLit), (- positiveShort).toShort) + checkEvaluation(UnaryMinus(negativeShortLit), (- negativeShort).toShort) + checkEvaluation(UnaryMinus(positiveIntLit), - positiveInt) + checkEvaluation(UnaryMinus(negativeIntLit), - negativeInt) + checkEvaluation(UnaryMinus(positiveLongLit), - positiveLong) + checkEvaluation(UnaryMinus(negativeLongLit), - negativeLong) } checkEvaluation(UnaryMinus(positiveShortLit), (- positiveShort).toShort) checkEvaluation(UnaryMinus(negativeShortLit), (- negativeShort).toShort) @@ -115,7 +121,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Subtract(positiveLongLit, negativeLongLit), positiveLong - negativeLong) Seq("true", "false").foreach { checkOverflow => - withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> checkOverflow) { + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> checkOverflow) { DataTypeTestUtils.numericAndInterval.foreach { tpe => checkConsistencyBetweenInterpretedAndCodegenAllowingException(Subtract, tpe, tpe) } @@ -137,7 +143,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Multiply(positiveLongLit, negativeLongLit), positiveLong * negativeLong) Seq("true", "false").foreach { checkOverflow => - withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> checkOverflow) { + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> checkOverflow) { DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => checkConsistencyBetweenInterpretedAndCodegenAllowingException(Multiply, tpe, tpe) } @@ -408,12 +414,12 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper val e4 = Add(minLongLiteral, minLongLiteral) val e5 = Subtract(minLongLiteral, maxLongLiteral) val e6 = Multiply(minLongLiteral, minLongLiteral) - withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "true") { + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true") { Seq(e1, e2, e3, e4, e5, e6).foreach { e => checkExceptionInExpression[ArithmeticException](e, "overflow") } } - withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "false") { + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "false") { checkEvaluation(e1, Long.MinValue) checkEvaluation(e2, Long.MinValue) checkEvaluation(e3, -2L) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 093a9d2cd7ac..4d667fd61ae0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -957,19 +956,6 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ret6, "[1, [1 -> a, 2 -> b, 3 -> c]]") } - test("SPARK-24598: Cast to long should fail on overflow") { - val overflowCast = cast(Literal.create(Decimal(Long.MaxValue) + Decimal(1)), LongType) - val nonOverflowCast = cast(Literal.create(Decimal(Long.MaxValue)), LongType) - withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "true") { - checkExceptionInExpression[ArithmeticException](overflowCast, "Overflow") - checkEvaluation(nonOverflowCast, Long.MaxValue) - } - withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "false") { - checkEvaluation(overflowCast, Long.MinValue) - checkEvaluation(nonOverflowCast, Long.MaxValue) - } - } - test("up-cast") { def isCastSafe(from: NumericType, to: NumericType): Boolean = (from, to) match { case (_, dt: DecimalType) => dt.isWiderThan(from) diff --git a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql index d26bd6a0ac4e..1012db72e187 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pgSQL/int4.sql @@ -33,10 +33,10 @@ INSERT INTO INT4_TBL VALUES ('-2147483647'); -- INSERT INTO INT4_TBL(f1) VALUES ('123 5'); -- INSERT INTO INT4_TBL(f1) VALUES (''); --- We cannot test this when checkOverflow=true here +-- We cannot test this when failOnOverFlow=true here -- because exception happens in the executors and the -- output stacktrace cannot have an exact match -set spark.sql.arithmetic.checkOverflow=false; +set spark.sql.arithmeticOperations.failOnOverFlow=false; SELECT '' AS five, * FROM INT4_TBL; diff --git a/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out b/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out index fd2eb657fec1..8b9c20e7eb20 100644 --- a/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pgSQL/int4.sql.out @@ -51,11 +51,11 @@ struct<> -- !query 6 -set spark.sql.arithmetic.checkOverflow=false +set spark.sql.arithmeticOperations.failOnOverFlow=false -- !query 6 schema struct -- !query 6 output -spark.sql.arithmetic.checkOverflow false +spark.sql.arithmeticOperations.failOnOverFlow false -- !query 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 1de01f114b88..e581211e4e76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructType} @@ -409,20 +407,6 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { Row(1, Row(Row(1, "a"), Row(1, "a"))) :: Row(2, Row(Row(2, "bc"), Row(2, "bc"))) :: Nil) } - test("SPARK-24598: sum throws exception instead of silently overflow") { - val df1 = Seq(Long.MinValue, -10, Long.MaxValue).toDF("i") - val df2 = Seq(Long.MinValue, -10, 8).toDF("i") - withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "true") { - checkAnswer(df1.agg(sum($"i")), Row(-11)) - val e = intercept[SparkException](df2.agg(sum($"i")).collect()) - assert(e.getCause.isInstanceOf[ArithmeticException]) - } - withSQLConf(SQLConf.ARITHMETIC_OPERATION_OVERFLOW_CHECK.key -> "false") { - checkAnswer(df1.agg(sum($"i")), Row(-11)) - checkAnswer(df2.agg(sum($"i")), Row(Long.MaxValue - 1)) - } - } - test("SPARK-24569: Aggregator with output type Option[Boolean] creates column of type Row") { val df = Seq( OptionBooleanData("bob", Some(true)), From b809a3fadd1acbb4d5f061f46a370ffc90348beb Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 30 Jul 2019 21:08:54 +0200 Subject: [PATCH 22/23] fix --- .../src/main/scala/org/apache/spark/sql/types/numerics.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala index 8f2844b557bf..a362afa2b31b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala @@ -84,7 +84,7 @@ object ShortExactNumeric extends ShortIsIntegral with Ordering.ShortOrdering { if (x == Short.MinValue) { // if and only if x is Byte.MinValue, overflow can happen throw new ArithmeticException(s"- $x caused overflow.") } - (-x).toByte + (-x).toShort } } From ce3ed2b6e81b1973868b52aede2cbe6d102d0125 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 31 Jul 2019 16:48:28 +0200 Subject: [PATCH 23/23] address comments --- .../ArithmeticExpressionSuite.scala | 74 ++++++++++++++++++- .../apache/spark/sql/types/DecimalSuite.scala | 4 +- 2 files changed, 75 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index fd43f53a331c..d35fff1293ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -405,7 +405,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper assert(ctx2.inlinedMutableStates.size == 1) } - test("SPARK-24598: overflow on BigInt returns wrong result") { + test("SPARK-24598: overflow on long returns wrong result") { val maxLongLiteral = Literal(Long.MaxValue) val minLongLiteral = Literal(Long.MinValue) val e1 = Add(maxLongLiteral, Literal(1L)) @@ -428,4 +428,76 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(e6, 0L) } } + + test("SPARK-24598: overflow on integer returns wrong result") { + val maxIntLiteral = Literal(Int.MaxValue) + val minIntLiteral = Literal(Int.MinValue) + val e1 = Add(maxIntLiteral, Literal(1)) + val e2 = Subtract(maxIntLiteral, Literal(-1)) + val e3 = Multiply(maxIntLiteral, Literal(2)) + val e4 = Add(minIntLiteral, minIntLiteral) + val e5 = Subtract(minIntLiteral, maxIntLiteral) + val e6 = Multiply(minIntLiteral, minIntLiteral) + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true") { + Seq(e1, e2, e3, e4, e5, e6).foreach { e => + checkExceptionInExpression[ArithmeticException](e, "overflow") + } + } + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "false") { + checkEvaluation(e1, Int.MinValue) + checkEvaluation(e2, Int.MinValue) + checkEvaluation(e3, -2) + checkEvaluation(e4, 0) + checkEvaluation(e5, 1) + checkEvaluation(e6, 0) + } + } + + test("SPARK-24598: overflow on short returns wrong result") { + val maxShortLiteral = Literal(Short.MaxValue) + val minShortLiteral = Literal(Short.MinValue) + val e1 = Add(maxShortLiteral, Literal(1.toShort)) + val e2 = Subtract(maxShortLiteral, Literal((-1).toShort)) + val e3 = Multiply(maxShortLiteral, Literal(2.toShort)) + val e4 = Add(minShortLiteral, minShortLiteral) + val e5 = Subtract(minShortLiteral, maxShortLiteral) + val e6 = Multiply(minShortLiteral, minShortLiteral) + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true") { + Seq(e1, e2, e3, e4, e5, e6).foreach { e => + checkExceptionInExpression[ArithmeticException](e, "overflow") + } + } + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "false") { + checkEvaluation(e1, Short.MinValue) + checkEvaluation(e2, Short.MinValue) + checkEvaluation(e3, (-2).toShort) + checkEvaluation(e4, 0.toShort) + checkEvaluation(e5, 1.toShort) + checkEvaluation(e6, 0.toShort) + } + } + + test("SPARK-24598: overflow on byte returns wrong result") { + val maxByteLiteral = Literal(Byte.MaxValue) + val minByteLiteral = Literal(Byte.MinValue) + val e1 = Add(maxByteLiteral, Literal(1.toByte)) + val e2 = Subtract(maxByteLiteral, Literal((-1).toByte)) + val e3 = Multiply(maxByteLiteral, Literal(2.toByte)) + val e4 = Add(minByteLiteral, minByteLiteral) + val e5 = Subtract(minByteLiteral, maxByteLiteral) + val e6 = Multiply(minByteLiteral, minByteLiteral) + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "true") { + Seq(e1, e2, e3, e4, e5, e6).foreach { e => + checkExceptionInExpression[ArithmeticException](e, "overflow") + } + } + withSQLConf(SQLConf.ARITHMETIC_OPERATIONS_FAIL_ON_OVERFLOW.key -> "false") { + checkEvaluation(e1, Byte.MinValue) + checkEvaluation(e2, Byte.MinValue) + checkEvaluation(e3, (-2).toByte) + checkEvaluation(e4, 0.toByte) + checkEvaluation(e5, 1.toByte) + checkEvaluation(e6, 0.toByte) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 4d1a04151e68..d69bb2f0b6bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -94,8 +94,8 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { checkValues(Decimal(2e18.toLong), 2e18, 2e18.toLong) checkValues(Decimal(Long.MaxValue), Long.MaxValue.toDouble, Long.MaxValue) checkValues(Decimal(Long.MinValue), Long.MinValue.toDouble, Long.MinValue) - assert(Decimal(Double.MaxValue).toDouble == Double.MaxValue) - assert(Decimal(Double.MinValue).toDouble == Double.MinValue) + checkValues(Decimal(Double.MaxValue), Double.MaxValue, 0L) + checkValues(Decimal(Double.MinValue), Double.MinValue, 0L) } // Accessor for the BigDecimal value of a Decimal, which will be null if it's using Longs