diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 07a2c47cff082..d5816f4a9ed01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -564,7 +564,11 @@ object Decimal { private val BIG_DEC_ZERO = BigDecimal(0) - private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP) + // SPARK-45786 Using RoundingMode.HALF_UP with MathContext may cause inaccurate SQL results + // because TypeCoercion later rounds again. Instead, always round down and use 1 digit longer + // precision than DecimalType.MAX_PRECISION. Then, TypeCoercion will properly round up/down + // the last extra digit. + private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.DOWN) private[sql] val ZERO = Decimal(0) private[sql] val ONE = Decimal(1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index e76ff0b439007..5b9074650b9bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.math.RoundingMode import java.sql.{Date, Timestamp} import java.time.{Duration, Period} import java.time.temporal.ChronoUnit @@ -226,6 +227,108 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } + test("SPARK-45786: Decimal multiply, divide, remainder, quot") { + // Some known cases + checkEvaluation( + Multiply( + Literal(Decimal(BigDecimal("-14120025096157587712113961295153.858047"), 38, 6)), + Literal(Decimal(BigDecimal("-0.4652"), 4, 4)) + ), + Decimal(BigDecimal("6568635674732509803675414794505.574763")) + ) + checkEvaluation( + Multiply( + Literal(Decimal(BigDecimal("-240810500742726"), 15, 0)), + Literal(Decimal(BigDecimal("-5677.6988688550027099967697071"), 29, 25)) + ), + Decimal(BigDecimal("1367249507675382200.164877854336665327")) + ) + checkEvaluation( + Divide( + Literal(Decimal(BigDecimal("-0.172787979"), 9, 9)), + Literal(Decimal(BigDecimal("533704665545018957788294905796.5"), 31, 1)) + ), + Decimal(BigDecimal("-3.237520E-31")) + ) + checkEvaluation( + Divide( + Literal(Decimal(BigDecimal("-0.574302343618"), 12, 12)), + Literal(Decimal(BigDecimal("-795826820326278835912868.106"), 27, 3)) + ), + Decimal(BigDecimal("7.21642358550E-25")) + ) + + // Random tests + val rand = scala.util.Random + def makeNum(p: Int, s: Int): String = { + val int1 = rand.nextLong() + val int2 = rand.nextLong().abs + val frac1 = rand.nextLong().abs + val frac2 = rand.nextLong().abs + s"$int1$int2".take(p - s + (int1 >>> 63).toInt) + "." + s"$frac1$frac2".take(s) + } + + (0 until 100).foreach { _ => + val p1 = rand.nextInt(38) + 1 // 1 <= p1 <= 38 + val s1 = rand.nextInt(p1 + 1) // 0 <= s1 <= p1 + val p2 = rand.nextInt(38) + 1 + val s2 = rand.nextInt(p2 + 1) + + val n1 = makeNum(p1, s1) + val n2 = makeNum(p2, s2) + + val mulActual = Multiply( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val mulExact = new java.math.BigDecimal(n1).multiply(new java.math.BigDecimal(n2)) + + val divActual = Divide( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val divExact = new java.math.BigDecimal(n1) + .divide(new java.math.BigDecimal(n2), 100, RoundingMode.DOWN) + + val remActual = Remainder( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val remExact = new java.math.BigDecimal(n1).remainder(new java.math.BigDecimal(n2)) + + val quotActual = IntegralDivide( + Literal(Decimal(BigDecimal(n1), p1, s1)), + Literal(Decimal(BigDecimal(n2), p2, s2)) + ) + val quotExact = + new java.math.BigDecimal(n1).divideToIntegralValue(new java.math.BigDecimal(n2)) + + Seq(true, false).foreach { allowPrecLoss => + withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> allowPrecLoss.toString) { + val mulResult = Decimal(mulExact.setScale(s1, RoundingMode.HALF_UP)) + val mulExpected = + if (mulResult.precision > DecimalType.MAX_PRECISION) null else mulResult + checkEvaluation(mulActual, mulExpected) + + val divResult = Decimal(divExact.setScale(s1, RoundingMode.HALF_UP)) + val divExpected = + if (divResult.precision > DecimalType.MAX_PRECISION) null else divResult + checkEvaluation(divActual, divExpected) + + val remResult = Decimal(remExact.setScale(s1, RoundingMode.HALF_UP)) + val remExpected = + if (remResult.precision > DecimalType.MAX_PRECISION) null else remResult + checkEvaluation(remActual, remExpected) + + val quotResult = Decimal(quotExact.setScale(s1, RoundingMode.HALF_UP)) + val quotExpected = + if (quotResult.precision > DecimalType.MAX_PRECISION) null else quotResult + checkEvaluation(quotActual, quotExpected.toLong) + } + } + } + } + private def testDecimalAndDoubleType(testFunc: (Int => Any) => Unit): Unit = { testFunc(_.toDouble) testFunc(Decimal(_))