From 0babef485971ca6e2d81acea4c53aeaff9614859 Mon Sep 17 00:00:00 2001 From: Kazuyuki Tanimura Date: Thu, 16 Nov 2023 10:15:08 -0800 Subject: [PATCH 1/2] [SPARK-45786][SQL][FOLLOWUP][TEST] Fix Decimal random number tests with ANSI enabled --- .../ArithmeticExpressionSuite.scala | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 568dcd10d1166..c7dc465723295 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -308,25 +308,38 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper val mulResult = Decimal(mulExact.setScale(mulType.scale, RoundingMode.HALF_UP)) val mulExpected = if (mulResult.precision > DecimalType.MAX_PRECISION) null else mulResult - checkEvaluation(mulActual, mulExpected) + tryCheckEvaluation(mulActual, mulExpected) val divType = Divide(null, null).resultDecimalType(p1, s1, p2, s2) val divResult = Decimal(divExact.setScale(divType.scale, RoundingMode.HALF_UP)) val divExpected = if (divResult.precision > DecimalType.MAX_PRECISION) null else divResult - checkEvaluation(divActual, divExpected) + tryCheckEvaluation(divActual, divExpected) val remType = Remainder(null, null).resultDecimalType(p1, s1, p2, s2) val remResult = Decimal(remExact.setScale(remType.scale, RoundingMode.HALF_UP)) val remExpected = if (remResult.precision > DecimalType.MAX_PRECISION) null else remResult - checkEvaluation(remActual, remExpected) + tryCheckEvaluation(remActual, remExpected) val quotType = IntegralDivide(null, null).resultDecimalType(p1, s1, p2, s2) val quotResult = Decimal(quotExact.setScale(quotType.scale, RoundingMode.HALF_UP)) val quotExpected = if (quotResult.precision > DecimalType.MAX_PRECISION) null else quotResult - checkEvaluation(quotActual, quotExpected.toLong) + tryCheckEvaluation(quotActual, quotExpected.toLong) + } + } + + def tryCheckEvaluation(actual: BinaryArithmetic, expected: Any): Unit = { + try { + checkEvaluation(actual, expected) + } + catch { + // Ignore NUMERIC_VALUE_OUT_OF_RANGE when ANSI is enabled + case e: org.scalatest.exceptions.TestFailedException + if e.cause.exists(c => c.isInstanceOf[SparkArithmeticException] && + c.asInstanceOf[SparkArithmeticException].getErrorClass + == "NUMERIC_VALUE_OUT_OF_RANGE") && SQLConf.get.ansiEnabled => } } } From 9e14251e9b05f543424917d1aa024fb85412dd18 Mon Sep 17 00:00:00 2001 From: Kazuyuki Tanimura Date: Thu, 16 Nov 2023 23:18:28 -0800 Subject: [PATCH 2/2] address review comments --- .../ArithmeticExpressionSuite.scala | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index c7dc465723295..2dc7e82f77226 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -308,40 +308,35 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper val mulResult = Decimal(mulExact.setScale(mulType.scale, RoundingMode.HALF_UP)) val mulExpected = if (mulResult.precision > DecimalType.MAX_PRECISION) null else mulResult - tryCheckEvaluation(mulActual, mulExpected) + checkEvaluationOrException(mulActual, mulExpected) val divType = Divide(null, null).resultDecimalType(p1, s1, p2, s2) val divResult = Decimal(divExact.setScale(divType.scale, RoundingMode.HALF_UP)) val divExpected = if (divResult.precision > DecimalType.MAX_PRECISION) null else divResult - tryCheckEvaluation(divActual, divExpected) + checkEvaluationOrException(divActual, divExpected) val remType = Remainder(null, null).resultDecimalType(p1, s1, p2, s2) val remResult = Decimal(remExact.setScale(remType.scale, RoundingMode.HALF_UP)) val remExpected = if (remResult.precision > DecimalType.MAX_PRECISION) null else remResult - tryCheckEvaluation(remActual, remExpected) + checkEvaluationOrException(remActual, remExpected) val quotType = IntegralDivide(null, null).resultDecimalType(p1, s1, p2, s2) val quotResult = Decimal(quotExact.setScale(quotType.scale, RoundingMode.HALF_UP)) val quotExpected = if (quotResult.precision > DecimalType.MAX_PRECISION) null else quotResult - tryCheckEvaluation(quotActual, quotExpected.toLong) + checkEvaluationOrException(quotActual, quotExpected.toLong) } } - def tryCheckEvaluation(actual: BinaryArithmetic, expected: Any): Unit = { - try { + def checkEvaluationOrException(actual: BinaryArithmetic, expected: Any): Unit = + if (SQLConf.get.ansiEnabled && expected == null) { + checkExceptionInExpression[SparkArithmeticException](actual, + "NUMERIC_VALUE_OUT_OF_RANGE") + } else { checkEvaluation(actual, expected) } - catch { - // Ignore NUMERIC_VALUE_OUT_OF_RANGE when ANSI is enabled - case e: org.scalatest.exceptions.TestFailedException - if e.cause.exists(c => c.isInstanceOf[SparkArithmeticException] && - c.asInstanceOf[SparkArithmeticException].getErrorClass - == "NUMERIC_VALUE_OUT_OF_RANGE") && SQLConf.get.ansiEnabled => - } - } } }