Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,11 @@ object Decimal {

private val BIG_DEC_ZERO = BigDecimal(0)

private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP)
// SPARK-45786 Using RoundingMode.HALF_UP with MathContext may cause inaccurate SQL results
// because TypeCoercion later rounds again. Instead, always round down and use 1 digit longer
// precision than DecimalType.MAX_PRECISION. Then, TypeCoercion will properly round up/down
// the last extra digit.
private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.DOWN)

private[sql] val ZERO = Decimal(0)
private[sql] val ONE = Decimal(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import java.math.RoundingMode
import java.sql.{Date, Timestamp}
import java.time.{Duration, Period}
import java.time.temporal.ChronoUnit
Expand Down Expand Up @@ -226,6 +227,108 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
}
}

test("SPARK-45786: Decimal multiply, divide, remainder, quot") {
// Some known cases
checkEvaluation(
Multiply(
Literal(Decimal(BigDecimal("-14120025096157587712113961295153.858047"), 38, 6)),
Literal(Decimal(BigDecimal("-0.4652"), 4, 4))
),
Decimal(BigDecimal("6568635674732509803675414794505.574763"))
)
checkEvaluation(
Multiply(
Literal(Decimal(BigDecimal("-240810500742726"), 15, 0)),
Literal(Decimal(BigDecimal("-5677.6988688550027099967697071"), 29, 25))
),
Decimal(BigDecimal("1367249507675382200.164877854336665327"))
)
checkEvaluation(
Divide(
Literal(Decimal(BigDecimal("-0.172787979"), 9, 9)),
Literal(Decimal(BigDecimal("533704665545018957788294905796.5"), 31, 1))
),
Decimal(BigDecimal("-3.237520E-31"))
)
checkEvaluation(
Divide(
Literal(Decimal(BigDecimal("-0.574302343618"), 12, 12)),
Literal(Decimal(BigDecimal("-795826820326278835912868.106"), 27, 3))
),
Decimal(BigDecimal("7.21642358550E-25"))
)

// Random tests
val rand = scala.util.Random
def makeNum(p: Int, s: Int): String = {
val int1 = rand.nextLong()
val int2 = rand.nextLong().abs
val frac1 = rand.nextLong().abs
val frac2 = rand.nextLong().abs
s"$int1$int2".take(p - s + (int1 >>> 63).toInt) + "." + s"$frac1$frac2".take(s)
}

(0 until 100).foreach { _ =>
val p1 = rand.nextInt(38) + 1 // 1 <= p1 <= 38
val s1 = rand.nextInt(p1 + 1) // 0 <= s1 <= p1
val p2 = rand.nextInt(38) + 1
val s2 = rand.nextInt(p2 + 1)

val n1 = makeNum(p1, s1)
val n2 = makeNum(p2, s2)

val mulActual = Multiply(
Literal(Decimal(BigDecimal(n1), p1, s1)),
Literal(Decimal(BigDecimal(n2), p2, s2))
)
val mulExact = new java.math.BigDecimal(n1).multiply(new java.math.BigDecimal(n2))

val divActual = Divide(
Literal(Decimal(BigDecimal(n1), p1, s1)),
Literal(Decimal(BigDecimal(n2), p2, s2))
)
val divExact = new java.math.BigDecimal(n1)
.divide(new java.math.BigDecimal(n2), 100, RoundingMode.DOWN)

val remActual = Remainder(
Literal(Decimal(BigDecimal(n1), p1, s1)),
Literal(Decimal(BigDecimal(n2), p2, s2))
)
val remExact = new java.math.BigDecimal(n1).remainder(new java.math.BigDecimal(n2))

val quotActual = IntegralDivide(
Literal(Decimal(BigDecimal(n1), p1, s1)),
Literal(Decimal(BigDecimal(n2), p2, s2))
)
val quotExact =
new java.math.BigDecimal(n1).divideToIntegralValue(new java.math.BigDecimal(n2))

Seq(true, false).foreach { allowPrecLoss =>
withSQLConf(SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key -> allowPrecLoss.toString) {
val mulResult = Decimal(mulExact.setScale(s1, RoundingMode.HALF_UP))
val mulExpected =
if (mulResult.precision > DecimalType.MAX_PRECISION) null else mulResult
checkEvaluation(mulActual, mulExpected)

val divResult = Decimal(divExact.setScale(s1, RoundingMode.HALF_UP))
val divExpected =
if (divResult.precision > DecimalType.MAX_PRECISION) null else divResult
checkEvaluation(divActual, divExpected)

val remResult = Decimal(remExact.setScale(s1, RoundingMode.HALF_UP))
val remExpected =
if (remResult.precision > DecimalType.MAX_PRECISION) null else remResult
checkEvaluation(remActual, remExpected)

val quotResult = Decimal(quotExact.setScale(s1, RoundingMode.HALF_UP))
val quotExpected =
if (quotResult.precision > DecimalType.MAX_PRECISION) null else quotResult
checkEvaluation(quotActual, quotExpected.toLong)
}
}
}
}

private def testDecimalAndDoubleType(testFunc: (Int => Any) => Unit): Unit = {
testFunc(_.toDouble)
testFunc(Decimal(_))
Expand Down