Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {

def / (that: Decimal): Decimal =
if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal,
DecimalType.MAX_SCALE, MATH_CONTEXT.getRoundingMode))
DecimalType.MAX_SCALE + 1, MATH_CONTEXT.getRoundingMode))

def % (that: Decimal): Decimal =
if (that.isZero) null
Expand Down Expand Up @@ -547,7 +547,11 @@ object Decimal {

val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong)

private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP)
// SPARK-45786 Using RoundingMode.HALF_UP with MathContext may cause inaccurate SQL results
// because TypeCoercion later rounds again. Instead, always round down and use 1 digit longer
// precision than DecimalType.MAX_PRECISION. Then, TypeCoercion will properly round up/down
// the last extra digit.
private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.DOWN)

private[sql] val ZERO = Decimal(0)
private[sql] val ONE = Decimal(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import java.math.RoundingMode
import java.sql.{Date, Timestamp}
import java.time.{Duration, Period}
import java.time.temporal.ChronoUnit
Expand Down Expand Up @@ -225,6 +226,112 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
}
}

test("SPARK-45786: Decimal multiply, divide, remainder, quot") {
Copy link
Contributor

@LuciferYang LuciferYang Nov 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test will failed when spark.sql.ansi.enabled

https://github.com/apache/spark/actions/runs/6885072758/job/18728675619

image

You can reproduce the issue locally by executing SPARK_ANSI_SQL_MODE=true build/sbt clean "catalyst/testOnly org.apache.spark.sql.catalyst.expressions.ArithmeticExpressionSuite"

@kazuyukitanimura Can you take a look at this issue?

also cc @dongjoon-hyun Since this patch has been backported to branch-3.4, I'm not sure if this will affect the version release of Spark 3.4.2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @LuciferYang
Yes, this test is assuming the default spark.sql.ansi.enabled=false. The default behavior does not throw the exception for overflows, but Ansi mode does. Since this is a random value test, we may have combinations that overflows.

Cause: org.apache.spark.SparkArithmeticException: [NUMERIC_VALUE_OUT_OF_RANGE] 431393072276642444045219979063553045.571 cannot be represented as Decimal(38, 4). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error, and return NULL instead. SQLSTATE: 22003

Sorry that I wasn't aware that there is a GHA for spark.sql.ansi.enabled=true. I can modify the test to ignore those cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed #43853

// Some known cases
checkEvaluation(
Multiply(
Literal(Decimal(BigDecimal("-14120025096157587712113961295153.858047"), 38, 6)),
Literal(Decimal(BigDecimal("-0.4652"), 4, 4))
),
Decimal(BigDecimal("6568635674732509803675414794505.574763"))
)
checkEvaluation(
Multiply(
Literal(Decimal(BigDecimal("-240810500742726"), 15, 0)),
Literal(Decimal(BigDecimal("-5677.6988688550027099967697071"), 29, 25))
),
Decimal(BigDecimal("1367249507675382200.164877854336665327"))
)
checkEvaluation(
Divide(
Literal(Decimal(BigDecimal("-0.172787979"), 9, 9)),
Literal(Decimal(BigDecimal("533704665545018957788294905796.5"), 31, 1))
),
Decimal(BigDecimal("-3.237520E-31"))
)
checkEvaluation(
Divide(
Literal(Decimal(BigDecimal("-0.574302343618"), 12, 12)),
Literal(Decimal(BigDecimal("-795826820326278835912868.106"), 27, 3))
),
Decimal(BigDecimal("7.21642358550E-25"))
)

// Random tests
val rand = scala.util.Random
def makeNum(p: Int, s: Int): String = {
val int1 = rand.nextLong()
val int2 = rand.nextLong().abs
val frac1 = rand.nextLong().abs
val frac2 = rand.nextLong().abs
s"$int1$int2".take(p - s + (int1 >>> 63).toInt) + "." + s"$frac1$frac2".take(s)
}

(0 until 100).foreach { _ =>
val p1 = rand.nextInt(38) + 1 // 1 <= p1 <= 38
val s1 = rand.nextInt(p1 + 1) // 0 <= s1 <= p1
val p2 = rand.nextInt(38) + 1
val s2 = rand.nextInt(p2 + 1)

val n1 = makeNum(p1, s1)
val n2 = makeNum(p2, s2)

val mulActual = Multiply(
Literal(Decimal(BigDecimal(n1), p1, s1)),
Literal(Decimal(BigDecimal(n2), p2, s2))
)
val mulExact = new java.math.BigDecimal(n1).multiply(new java.math.BigDecimal(n2))

val divActual = Divide(
Literal(Decimal(BigDecimal(n1), p1, s1)),
Literal(Decimal(BigDecimal(n2), p2, s2))
)
val divExact = new java.math.BigDecimal(n1)
.divide(new java.math.BigDecimal(n2), 100, RoundingMode.DOWN)

val remActual = Remainder(
Literal(Decimal(BigDecimal(n1), p1, s1)),
Literal(Decimal(BigDecimal(n2), p2, s2))
)
val remExact = new java.math.BigDecimal(n1).remainder(new java.math.BigDecimal(n2))

val quotActual = IntegralDivide(
Literal(Decimal(BigDecimal(n1), p1, s1)),
Literal(Decimal(BigDecimal(n2), p2, s2))
)
val quotExact =
new java.math.BigDecimal(n1).divideToIntegralValue(new java.math.BigDecimal(n2))

Seq(true, false).foreach { allowPrecLoss =>
withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> allowPrecLoss.toString) {
val mulType = Multiply(null, null).resultDecimalType(p1, s1, p2, s2)
val mulResult = Decimal(mulExact.setScale(mulType.scale, RoundingMode.HALF_UP))
val mulExpected =
if (mulResult.precision > DecimalType.MAX_PRECISION) null else mulResult
checkEvaluation(mulActual, mulExpected)

val divType = Divide(null, null).resultDecimalType(p1, s1, p2, s2)
val divResult = Decimal(divExact.setScale(divType.scale, RoundingMode.HALF_UP))
val divExpected =
if (divResult.precision > DecimalType.MAX_PRECISION) null else divResult
checkEvaluation(divActual, divExpected)

val remType = Remainder(null, null).resultDecimalType(p1, s1, p2, s2)
val remResult = Decimal(remExact.setScale(remType.scale, RoundingMode.HALF_UP))
val remExpected =
if (remResult.precision > DecimalType.MAX_PRECISION) null else remResult
checkEvaluation(remActual, remExpected)

val quotType = IntegralDivide(null, null).resultDecimalType(p1, s1, p2, s2)
val quotResult = Decimal(quotExact.setScale(quotType.scale, RoundingMode.HALF_UP))
val quotExpected =
if (quotResult.precision > DecimalType.MAX_PRECISION) null else quotResult
checkEvaluation(quotActual, quotExpected.toLong)
}
}
}
}

private def testDecimalAndDoubleType(testFunc: (Int => Any) => Unit): Unit = {
testFunc(_.toDouble)
testFunc(Decimal(_))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ org.apache.spark.SparkArithmeticException
"config" : "\"spark.sql.ansi.enabled\"",
"precision" : "38",
"scale" : "6",
"value" : "1000000000000000000000000000000000000.00000000000000000000000000000000000000"
"value" : "1000000000000000000000000000000000000.000000000000000000000000000000000000000"
},
"queryContext" : [ {
"objectType" : "",
Expand Down Expand Up @@ -204,7 +204,7 @@ org.apache.spark.SparkArithmeticException
"config" : "\"spark.sql.ansi.enabled\"",
"precision" : "38",
"scale" : "6",
"value" : "10123456789012345678901234567890123456.00000000000000000000000000000000000000"
"value" : "10123456789012345678901234567890123456.000000000000000000000000000000000000000"
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -229,7 +229,7 @@ org.apache.spark.SparkArithmeticException
"config" : "\"spark.sql.ansi.enabled\"",
"precision" : "38",
"scale" : "6",
"value" : "101234567890123456789012345678901234.56000000000000000000000000000000000000"
"value" : "101234567890123456789012345678901234.560000000000000000000000000000000000000"
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -254,7 +254,7 @@ org.apache.spark.SparkArithmeticException
"config" : "\"spark.sql.ansi.enabled\"",
"precision" : "38",
"scale" : "6",
"value" : "10123456789012345678901234567890123.45600000000000000000000000000000000000"
"value" : "10123456789012345678901234567890123.456000000000000000000000000000000000000"
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -279,7 +279,7 @@ org.apache.spark.SparkArithmeticException
"config" : "\"spark.sql.ansi.enabled\"",
"precision" : "38",
"scale" : "6",
"value" : "1012345678901234567890123456789012.34560000000000000000000000000000000000"
"value" : "1012345678901234567890123456789012.345600000000000000000000000000000000000"
},
"queryContext" : [ {
"objectType" : "",
Expand All @@ -304,7 +304,7 @@ org.apache.spark.SparkArithmeticException
"config" : "\"spark.sql.ansi.enabled\"",
"precision" : "38",
"scale" : "6",
"value" : "101234567890123456789012345678901.23456000000000000000000000000000000000"
"value" : "101234567890123456789012345678901.234560000000000000000000000000000000000"
},
"queryContext" : [ {
"objectType" : "",
Expand Down Expand Up @@ -337,7 +337,7 @@ org.apache.spark.SparkArithmeticException
"config" : "\"spark.sql.ansi.enabled\"",
"precision" : "38",
"scale" : "6",
"value" : "101234567890123456789012345678901.23456000000000000000000000000000000000"
"value" : "101234567890123456789012345678901.234560000000000000000000000000000000000"
},
"queryContext" : [ {
"objectType" : "",
Expand Down