From 87d702275c7d3b7a934b99d8096873521af56143 Mon Sep 17 00:00:00 2001 From: Kazuyuki Tanimura Date: Wed, 1 Nov 2023 02:01:10 -0700 Subject: [PATCH 1/7] Fix decimal rounding --- .../org/apache/spark/sql/types/Decimal.scala | 5 +- .../ArithmeticExpressionSuite.scala | 56 +++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 5652e5adda9d4..4fbf4bc4cb846 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -497,9 +497,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { def * (that: Decimal): Decimal = Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, MATH_CONTEXT)) + def / (that: Decimal): Decimal = if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, - DecimalType.MAX_SCALE, MATH_CONTEXT.getRoundingMode)) + DecimalType.MAX_SCALE + 1, MATH_CONTEXT.getRoundingMode)) def % (that: Decimal): Decimal = if (that.isZero) null @@ -547,7 +548,7 @@ object Decimal { val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) - private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP) + private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.DOWN) private[sql] val ZERO = Decimal(0) private[sql] val ONE = Decimal(1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index e21793ab506c4..37ad93d2bbcf0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.math.RoundingMode import java.sql.{Date, Timestamp} import java.time.{Duration, Period} import java.time.temporal.ChronoUnit @@ -225,6 +226,61 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } + test("SPARK-XXXXX: Decimal multiply, divide, modulo squot") { + // Some known cases + checkEvaluation( + Multiply( + Literal(Decimal(BigDecimal("-14120025096157587712113961295153.858047"), 38, 6)), + Literal(Decimal(BigDecimal("-0.4652"), 4, 4)) + ), + Decimal(BigDecimal("6568635674732509803675414794505.574763")) + ) + checkEvaluation( + Multiply( + Literal(Decimal(BigDecimal("-240810500742726"), 15, 0)), + Literal(Decimal(BigDecimal("-5677.6988688550027099967697071"), 29, 25)) + ), + Decimal(BigDecimal("1367249507675382200.164877854336665327")) + ) + + // Random tests + val rand = scala.util.Random + def makeNum(p: Int, s: Int): String = { + val int1 = rand.nextLong() + val int2 = rand.nextLong().abs + val frac1 = rand.nextLong().abs + val frac2 = rand.nextLong().abs + s"$int1$int2".take(p - s + (int1 >>> 63).toInt) + "." + s"$frac1$frac2".take(s) + } + + (0 until 100).foreach { _ => + val p1 = rand.nextInt(38) + 1 // 1 <= p1 <= 38 + val s1 = rand.nextInt(p1 + 1) // 0 <= s1 <= p1 + val p2 = rand.nextInt(38) + 1 + val s2 = rand.nextInt(p2 + 1) + + val n1 = makeNum(p1, s1) + val n2 = makeNum(p2, s2) + + val mulActual = Multiply( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val mulExact = new java.math.BigDecimal(n1).multiply(new java.math.BigDecimal(n2)) + + Seq(true, false).foreach { allowPrecLoss => + withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> allowPrecLoss.toString) { + + val mulType = Multiply(null, null).resultDecimalType(p1, s1, p2, s2) + val mulResult = Decimal(mulExact.setScale(mulType.scale, RoundingMode.HALF_UP)) + val mulExpected = + if (mulResult.precision > DecimalType.MAX_PRECISION) null else mulResult + checkEvaluation(mulActual, mulExpected) + } + } + } + } + private def testDecimalAndDoubleType(testFunc: (Int => Any) => Unit): Unit = { testFunc(_.toDouble) testFunc(Decimal(_)) From c12258a473a9fa012e53546b1ff28a80a965e5cb Mon Sep 17 00:00:00 2001 From: Kazuyuki Tanimura Date: Thu, 2 Nov 2023 01:41:25 -0700 Subject: [PATCH 2/7] Fix decimal rounding --- .../expressions/ArithmeticExpressionSuite.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 37ad93d2bbcf0..e75bd11cc4d3e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -268,14 +268,25 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper ) val mulExact = new java.math.BigDecimal(n1).multiply(new java.math.BigDecimal(n2)) + val divActual = Divide( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val divExact = new java.math.BigDecimal(n1).divide(new java.math.BigDecimal(n2)) + Seq(true, false).foreach { allowPrecLoss => withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> allowPrecLoss.toString) { - val mulType = Multiply(null, null).resultDecimalType(p1, s1, p2, s2) val mulResult = Decimal(mulExact.setScale(mulType.scale, RoundingMode.HALF_UP)) val mulExpected = if (mulResult.precision > DecimalType.MAX_PRECISION) null else mulResult checkEvaluation(mulActual, mulExpected) + + val divType = Divide(null, null).resultDecimalType(p1, s1, p2, s2) + val divResult = Decimal(divExact.setScale(divType.scale, RoundingMode.HALF_UP)) + val divExpected = + if (divResult.precision > DecimalType.MAX_PRECISION) null else divResult + checkEvaluation(divActual, divExpected) } } } From ed8a254fa09e4ee9a36209e97e91bbd8553cde88 Mon Sep 17 00:00:00 2001 From: Kazuyuki Tanimura Date: Thu, 2 Nov 2023 01:57:04 -0700 Subject: [PATCH 3/7] Fix decimal rounding --- .../ArithmeticExpressionSuite.scala | 44 ++++++++++++++++++- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index e75bd11cc4d3e..ee96765563acd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -226,7 +226,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } - test("SPARK-XXXXX: Decimal multiply, divide, modulo squot") { + test("SPARK-XXXXX: Decimal multiply, divide, remainder, quot") { // Some known cases checkEvaluation( Multiply( @@ -242,6 +242,20 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper ), Decimal(BigDecimal("1367249507675382200.164877854336665327")) ) + checkEvaluation( + Divide( + Literal(Decimal(BigDecimal("-0.172787979"), 9, 9)), + Literal(Decimal(BigDecimal("533704665545018957788294905796.5"), 31, 1)) + ), + Decimal(BigDecimal("-3.237520E-31")) + ) + checkEvaluation( + Divide( + Literal(Decimal(BigDecimal("-0.574302343618"), 12, 12)), + Literal(Decimal(BigDecimal("-795826820326278835912868.106"), 27, 3)) + ), + Decimal(BigDecimal("7.21642358550E-25")) + ) // Random tests val rand = scala.util.Random @@ -272,7 +286,21 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Literal(Decimal(BigDecimal(n1), p1, s1)), Literal(Decimal(BigDecimal(n2), p2, s2)) ) - val divExact = new java.math.BigDecimal(n1).divide(new java.math.BigDecimal(n2)) + val divExact = new java.math.BigDecimal(n1) + .divide(new java.math.BigDecimal(n2), 100, RoundingMode.DOWN) + + val remActual = Remainder( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val remExact = new java.math.BigDecimal(n1).remainder(new java.math.BigDecimal(n2)) + + val quotActual = IntegralDivide( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val quotExact = + new java.math.BigDecimal(n1).divideToIntegralValue(new java.math.BigDecimal(n2)) Seq(true, false).foreach { allowPrecLoss => withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> allowPrecLoss.toString) { @@ -287,6 +315,18 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper val divExpected = if (divResult.precision > DecimalType.MAX_PRECISION) null else divResult checkEvaluation(divActual, divExpected) + + val remType = Remainder(null, null).resultDecimalType(p1, s1, p2, s2) + val remResult = Decimal(remExact.setScale(remType.scale, RoundingMode.HALF_UP)) + val remExpected = + if (remResult.precision > DecimalType.MAX_PRECISION) null else remResult + checkEvaluation(remActual, remExpected) + + val quotType = IntegralDivide(null, null).resultDecimalType(p1, s1, p2, s2) + val quotResult = Decimal(quotExact.setScale(quotType.scale, RoundingMode.HALF_UP)) + val quotExpected = + if (quotResult.precision > DecimalType.MAX_PRECISION) null else quotResult + checkEvaluation(quotActual, quotExpected.toLong) } } } From ffda48417f19d348666673b960e00198aa505d24 Mon Sep 17 00:00:00 2001 From: Kazuyuki Tanimura Date: Fri, 3 Nov 2023 17:52:28 -0700 Subject: [PATCH 4/7] [SPARK-45786][SQL] Fix inaccurate Decimal multiplication and division results --- .../sql/catalyst/expressions/ArithmeticExpressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index ee96765563acd..568dcd10d1166 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -226,7 +226,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } - test("SPARK-XXXXX: Decimal multiply, divide, remainder, quot") { + test("SPARK-45786: Decimal multiply, divide, remainder, quot") { // Some known cases checkEvaluation( Multiply( From dc53e9d79bd8e5bf244bff558050880ad999d262 Mon Sep 17 00:00:00 2001 From: Kazuyuki Tanimura Date: Fri, 3 Nov 2023 18:13:52 -0700 Subject: [PATCH 5/7] [SPARK-45786][SQL] Fix inaccurate Decimal multiplication and division results --- sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 4fbf4bc4cb846..3539b44b88b9a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -497,7 +497,6 @@ final class Decimal extends Ordered[Decimal] with Serializable { def * (that: Decimal): Decimal = Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, MATH_CONTEXT)) - def / (that: Decimal): Decimal = if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, DecimalType.MAX_SCALE + 1, MATH_CONTEXT.getRoundingMode)) From 5d0fc4156fdf76de6bf5a9dfb9ffe9b3a02703d4 Mon Sep 17 00:00:00 2001 From: Kazuyuki Tanimura Date: Fri, 3 Nov 2023 23:16:21 -0700 Subject: [PATCH 6/7] [SPARK-45786][SQL] Fix inaccurate Decimal multiplication and division results --- .../src/main/scala/org/apache/spark/sql/types/Decimal.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 3539b44b88b9a..0bcbefaa54828 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -547,6 +547,10 @@ object Decimal { val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) + // SPARK-45786 Using RoundingMode.HALF_UP with MathContext may cause inaccurate SQL results + // because TypeCoercion later rounds again. Instead, always round down and use 1 digit longer + // precision than DecimalType.MAX_PRECISION. Then, TypeCoercion will properly round up/down + // the last extra digit. private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.DOWN) private[sql] val ZERO = Decimal(0) From 2c77cf336f7ee1b2f3c882b5ca1569720d67b017 Mon Sep 17 00:00:00 2001 From: Kazuyuki Tanimura Date: Tue, 7 Nov 2023 01:28:30 -0800 Subject: [PATCH 7/7] [SPARK-45786][SQL] Fix inaccurate Decimal multiplication and division results --- .../ansi/decimalArithmeticOperations.sql.out | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out index 699c916fd8fdb..9593291fae21d 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/decimalArithmeticOperations.sql.out @@ -155,7 +155,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "1000000000000000000000000000000000000.00000000000000000000000000000000000000" + "value" : "1000000000000000000000000000000000000.000000000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -204,7 +204,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "10123456789012345678901234567890123456.00000000000000000000000000000000000000" + "value" : "10123456789012345678901234567890123456.000000000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -229,7 +229,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "101234567890123456789012345678901234.56000000000000000000000000000000000000" + "value" : "101234567890123456789012345678901234.560000000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -254,7 +254,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "10123456789012345678901234567890123.45600000000000000000000000000000000000" + "value" : "10123456789012345678901234567890123.456000000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -279,7 +279,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "1012345678901234567890123456789012.34560000000000000000000000000000000000" + "value" : "1012345678901234567890123456789012.345600000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -304,7 +304,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "101234567890123456789012345678901.23456000000000000000000000000000000000" + "value" : "101234567890123456789012345678901.234560000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "", @@ -337,7 +337,7 @@ org.apache.spark.SparkArithmeticException "config" : "\"spark.sql.ansi.enabled\"", "precision" : "38", "scale" : "6", - "value" : "101234567890123456789012345678901.23456000000000000000000000000000000000" + "value" : "101234567890123456789012345678901.234560000000000000000000000000000000000" }, "queryContext" : [ { "objectType" : "",