Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,9 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def * (that: Decimal): Decimal =
Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, MATH_CONTEXT))

def multiply(that: Decimal, mathContext: MathContext): Decimal =
Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, mathContext))

def / (that: Decimal): Decimal =
if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal,
DecimalType.MAX_SCALE, MATH_CONTEXT.getRoundingMode))
Expand All @@ -505,6 +508,10 @@ final class Decimal extends Ordered[Decimal] with Serializable {
if (that.isZero) null
else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT))

def remainder(that: Decimal, mathContext: MathContext): Decimal =
if (that.isZero) null
else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, mathContext))

def quot(that: Decimal): Decimal =
if (that.isZero) null
else Decimal(toJavaBigDecimal.divideToIntegralValue(that.toJavaBigDecimal, MATH_CONTEXT))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@

package org.apache.spark.sql.catalyst.expressions

import java.math.{MathContext, RoundingMode}

import scala.math.{max, min}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.BinaryArithmetic.getMathContextValue
import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLId, toSQLType}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
Expand Down Expand Up @@ -279,7 +282,7 @@ abstract class BinaryArithmetic extends BinaryOperator
}

/** Name of the function for this expression on a [[Decimal]] type. */
def decimalMethod: String =
def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): String =
throw QueryExecutionErrors.notOverrideExpectedMethodsError("BinaryArithmetics",
"decimalMethod", "genCode")

Expand All @@ -298,14 +301,15 @@ abstract class BinaryArithmetic extends BinaryOperator
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case DecimalType.Fixed(precision, scale) =>
val errorContextCode = getContextOrNullCode(ctx, failOnError)
val mathContextValue = getMathContextValue(ctx)
val updateIsNull = if (failOnError) {
""
} else {
s"${ev.isNull} = ${ev.value} == null;"
}
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
|${ev.value} = $eval1.$decimalMethod($eval2).toPrecision(
|${ev.value} = ${decimalMethod(mathContextValue, eval1, eval2)}.toPrecision(
| $precision, $scale, Decimal.ROUND_HALF_UP(), ${!failOnError}, $errorContextCode);
|$updateIsNull
""".stripMargin
Expand Down Expand Up @@ -405,6 +409,16 @@ abstract class BinaryArithmetic extends BinaryOperator
}

object BinaryArithmetic {

// SPARK-40129: We need to use `MathContext` with precision = `Decimal.MAX_PRECISION + 1`,
// to fix decimal multipy bug. See SPARK-40129 for more details.
val mathContext = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.HALF_UP)

def getMathContextValue(ctx: CodegenContext): GlobalValue = {
JavaCode.global(ctx.addReferenceObj("mathContext",
mathContext, mathContext.getClass.getName), mathContext.getClass)
}

def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right))
}

Expand All @@ -430,7 +444,8 @@ case class Add(

override def symbol: String = "+"

override def decimalMethod: String = "$plus"
override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String):
String = s"$value1.$$plus($value2)"

// scalastyle:off
// The formula follows Hive which is based on the SQL standard and MS SQL:
Expand Down Expand Up @@ -516,7 +531,8 @@ case class Subtract(

override def symbol: String = "-"

override def decimalMethod: String = "$minus"
override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String):
String = s"$value1.$$minus($value2)"

// scalastyle:off
// The formula follows Hive which is based on the SQL standard and MS SQL:
Expand Down Expand Up @@ -592,7 +608,9 @@ case class Multiply(
override def inputType: AbstractDataType = NumericType

override def symbol: String = "*"
override def decimalMethod: String = "$times"

override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String):
String = s"$value1.multiply($value2, $mathContextValue)"

// scalastyle:off
// The formula follows Hive which is based on the SQL standard and MS SQL:
Expand All @@ -615,7 +633,10 @@ case class Multiply(

protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match {
case DecimalType.Fixed(precision, scale) =>
checkDecimalOverflow(numeric.times(input1, input2).asInstanceOf[Decimal], precision, scale)
// SPARK-40129: We need to use `MathContext` with precision = `Decimal.MAX_PRECISION + 1`,
// to fix decimal multipy bug. See SPARK-40129 for more details.
checkDecimalOverflow(input1.asInstanceOf[Decimal].multiply(
input2.asInstanceOf[Decimal], BinaryArithmetic.mathContext), precision, scale)
case _: IntegerType if failOnError =>
MathUtils.multiplyExact(
input1.asInstanceOf[Int],
Expand Down Expand Up @@ -697,11 +718,13 @@ trait DivModLike extends BinaryArithmetic {
}
val javaType = CodeGenerator.javaType(dataType)
val errorContextCode = getContextOrNullCode(ctx, failOnError)
val mathContextValue = getMathContextValue(ctx)
val operation = super.dataType match {
case DecimalType.Fixed(precision, scale) =>
val decimalValue = ctx.freshName("decimalValue")
s"""
|Decimal $decimalValue = ${eval1.value}.$decimalMethod(${eval2.value}).toPrecision(
|Decimal $decimalValue =
|${decimalMethod(mathContextValue, eval1.value, eval2.value)}.toPrecision(
| $precision, $scale, Decimal.ROUND_HALF_UP(), ${!failOnError}, $errorContextCode);
|if ($decimalValue != null) {
| ${ev.value} = ${decimalToDataTypeCodeGen(s"$decimalValue")};
Expand Down Expand Up @@ -793,7 +816,9 @@ case class Divide(
override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType)

override def symbol: String = "/"
override def decimalMethod: String = "$div"

override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String):
String = s"$value1.$$div($value2)"

// scalastyle:off
// The formula follows Hive which is based on the SQL standard and MS SQL:
Expand Down Expand Up @@ -868,7 +893,10 @@ case class IntegralDivide(
override def dataType: DataType = LongType

override def symbol: String = "/"
override def decimalMethod: String = "quot"

override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String):
String = s"$value1.quot($value2)"

override def decimalToDataTypeCodeGen(decimalResult: String): String = s"$decimalResult.toLong()"

override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
Expand Down Expand Up @@ -936,7 +964,9 @@ case class Remainder(
override def inputType: AbstractDataType = NumericType

override def symbol: String = "%"
override def decimalMethod: String = "remainder"

override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String):
String = s"$value1.remainder($value2, $mathContextValue)"

// scalastyle:off
// The formula follows Hive which is based on the SQL standard and MS SQL:
Expand Down Expand Up @@ -980,10 +1010,10 @@ case class Remainder(
val integral = PhysicalIntegralType.integral(i)
(left, right) => integral.rem(left, right)

case d @ DecimalType.Fixed(precision, scale) =>
val integral = PhysicalDecimalType(precision, scale).asIntegral.asInstanceOf[Integral[Any]]
case DecimalType.Fixed(precision, scale) =>
(left, right) =>
checkDecimalOverflow(integral.rem(left, right).asInstanceOf[Decimal], precision, scale)
checkDecimalOverflow(left.asInstanceOf[Decimal].remainder(right.asInstanceOf[Decimal],
BinaryArithmetic.mathContext), precision, scale)
}

override def evalOperation(left: Any, right: Any): Any = mod(left, right)
Expand Down Expand Up @@ -1019,7 +1049,8 @@ case class Pmod(

override def nullable: Boolean = true

override def decimalMethod: String = "remainder"
override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String):
String = s"$value1.remainder($value2)"

// This follows Remainder rule
override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
Expand Down Expand Up @@ -1077,14 +1108,19 @@ case class Pmod(
}
val remainder = ctx.freshName("remainder")
val javaType = CodeGenerator.javaType(dataType)
val mathContextValue = getMathContextValue(ctx)
val errorContext = getContextOrNullCode(ctx)
val result = dataType match {
case DecimalType.Fixed(precision, scale) =>
val decimalAdd = "$plus"
s"""
|$javaType $remainder = ${eval1.value}.$decimalMethod(${eval2.value});
|$javaType $remainder = ${decimalMethod(mathContextValue, eval1.value, eval2.value)};
|if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
| ${ev.value}=($remainder.$decimalAdd(${eval2.value})).$decimalMethod(${eval2.value});
| ${ev.value}=
|${
decimalMethod(mathContextValue, s"($remainder.$decimalAdd(${eval2.value}))"
, eval2.value)
};
|} else {
| ${ev.value}=$remainder;
|}
Expand Down
22 changes: 21 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.commons.io.FileUtils
import org.apache.spark.{AccumulatorSuite, SPARK_DOC_ROOT, SparkException}
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Hex}
import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, GenericRow, Hex}
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial}
import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, NestedColumnAliasingSuite}
Expand Down Expand Up @@ -1500,6 +1500,26 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
Seq(Row(d)))
}

test("SPARK-40129: Fix Decimal multiply can produce the wrong answer because it rounds twice") {
Seq((false, CodegenObjectFactoryMode.NO_CODEGEN),
(true, CodegenObjectFactoryMode.CODEGEN_ONLY)).foreach(v => {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> v._1.toString) {
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> v._2.toString) {
val multiplicandStr = "9173594185998001607642838421.5479932913"
val sparkValue = Seq(multiplicandStr).toDF()
.selectExpr("CAST(value as DECIMAL(38,10)) as a")
.selectExpr("a * CAST(-12 as DECIMAL(38,10))").head().getDecimal(0)
assert(sparkValue.toString == "-110083130231976019291714061058.575919")

val sparkValue2 = Seq(multiplicandStr).toDF()
.selectExpr("CAST(value as DECIMAL(38,10)) as a")
.selectExpr("a / CAST(-12 as DECIMAL(38,10))").head().getDecimal(0)
assert(sparkValue2.toString == "-764466182166500133970236535.128999")
}
}
})
}

test("precision smaller than scale") {
checkAnswer(sql("select 10.00"), Row(BigDecimal("10.00")))
checkAnswer(sql("select 1.00"), Row(BigDecimal("1.00")))
Expand Down