diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala index afe73635a6824..f425dffcc38b9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -497,6 +497,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { def * (that: Decimal): Decimal = Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, MATH_CONTEXT)) + def multiply(that: Decimal, mathContext: MathContext): Decimal = + Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, mathContext)) + def / (that: Decimal): Decimal = if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, DecimalType.MAX_SCALE, MATH_CONTEXT.getRoundingMode)) @@ -505,6 +508,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (that.isZero) null else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT)) + def remainder(that: Decimal, mathContext: MathContext): Decimal = + if (that.isZero) null + else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, mathContext)) + def quot(that: Decimal): Decimal = if (that.isZero) null else Decimal(toJavaBigDecimal.divideToIntegralValue(that.toJavaBigDecimal, MATH_CONTEXT)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 31d4d71cd40ad..01de978baf534 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.catalyst.expressions +import java.math.{MathContext, RoundingMode} + import scala.math.{max, min} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.BinaryArithmetic.getMathContextValue import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLId, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -279,7 +282,7 @@ abstract class BinaryArithmetic extends BinaryOperator } /** Name of the function for this expression on a [[Decimal]] type. */ - def decimalMethod: String = + def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): String = throw QueryExecutionErrors.notOverrideExpectedMethodsError("BinaryArithmetics", "decimalMethod", "genCode") @@ -298,6 +301,7 @@ abstract class BinaryArithmetic extends BinaryOperator override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case DecimalType.Fixed(precision, scale) => val errorContextCode = getContextOrNullCode(ctx, failOnError) + val mathContextValue = getMathContextValue(ctx) val updateIsNull = if (failOnError) { "" } else { @@ -305,7 +309,7 @@ abstract class BinaryArithmetic extends BinaryOperator } nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - |${ev.value} = $eval1.$decimalMethod($eval2).toPrecision( + |${ev.value} = ${decimalMethod(mathContextValue, eval1, eval2)}.toPrecision( | $precision, $scale, Decimal.ROUND_HALF_UP(), ${!failOnError}, $errorContextCode); |$updateIsNull """.stripMargin @@ -405,6 +409,16 @@ abstract class BinaryArithmetic extends BinaryOperator } object BinaryArithmetic { + + // SPARK-40129: We need to use `MathContext` with precision = `Decimal.MAX_PRECISION + 1`, + // to fix decimal multipy bug. See SPARK-40129 for more details. + val mathContext = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.HALF_UP) + + def getMathContextValue(ctx: CodegenContext): GlobalValue = { + JavaCode.global(ctx.addReferenceObj("mathContext", + mathContext, mathContext.getClass.getName), mathContext.getClass) + } + def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right)) } @@ -430,7 +444,8 @@ case class Add( override def symbol: String = "+" - override def decimalMethod: String = "$plus" + override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): + String = s"$value1.$$plus($value2)" // scalastyle:off // The formula follows Hive which is based on the SQL standard and MS SQL: @@ -516,7 +531,8 @@ case class Subtract( override def symbol: String = "-" - override def decimalMethod: String = "$minus" + override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): + String = s"$value1.$$minus($value2)" // scalastyle:off // The formula follows Hive which is based on the SQL standard and MS SQL: @@ -592,7 +608,9 @@ case class Multiply( override def inputType: AbstractDataType = NumericType override def symbol: String = "*" - override def decimalMethod: String = "$times" + + override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): + String = s"$value1.multiply($value2, $mathContextValue)" // scalastyle:off // The formula follows Hive which is based on the SQL standard and MS SQL: @@ -615,7 +633,10 @@ case class Multiply( protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match { case DecimalType.Fixed(precision, scale) => - checkDecimalOverflow(numeric.times(input1, input2).asInstanceOf[Decimal], precision, scale) + // SPARK-40129: We need to use `MathContext` with precision = `Decimal.MAX_PRECISION + 1`, + // to fix decimal multipy bug. See SPARK-40129 for more details. + checkDecimalOverflow(input1.asInstanceOf[Decimal].multiply( + input2.asInstanceOf[Decimal], BinaryArithmetic.mathContext), precision, scale) case _: IntegerType if failOnError => MathUtils.multiplyExact( input1.asInstanceOf[Int], @@ -697,11 +718,13 @@ trait DivModLike extends BinaryArithmetic { } val javaType = CodeGenerator.javaType(dataType) val errorContextCode = getContextOrNullCode(ctx, failOnError) + val mathContextValue = getMathContextValue(ctx) val operation = super.dataType match { case DecimalType.Fixed(precision, scale) => val decimalValue = ctx.freshName("decimalValue") s""" - |Decimal $decimalValue = ${eval1.value}.$decimalMethod(${eval2.value}).toPrecision( + |Decimal $decimalValue = + |${decimalMethod(mathContextValue, eval1.value, eval2.value)}.toPrecision( | $precision, $scale, Decimal.ROUND_HALF_UP(), ${!failOnError}, $errorContextCode); |if ($decimalValue != null) { | ${ev.value} = ${decimalToDataTypeCodeGen(s"$decimalValue")}; @@ -793,7 +816,9 @@ case class Divide( override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) override def symbol: String = "/" - override def decimalMethod: String = "$div" + + override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): + String = s"$value1.$$div($value2)" // scalastyle:off // The formula follows Hive which is based on the SQL standard and MS SQL: @@ -868,7 +893,10 @@ case class IntegralDivide( override def dataType: DataType = LongType override def symbol: String = "/" - override def decimalMethod: String = "quot" + + override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): + String = s"$value1.quot($value2)" + override def decimalToDataTypeCodeGen(decimalResult: String): String = s"$decimalResult.toLong()" override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = { @@ -936,7 +964,9 @@ case class Remainder( override def inputType: AbstractDataType = NumericType override def symbol: String = "%" - override def decimalMethod: String = "remainder" + + override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): + String = s"$value1.remainder($value2, $mathContextValue)" // scalastyle:off // The formula follows Hive which is based on the SQL standard and MS SQL: @@ -980,10 +1010,10 @@ case class Remainder( val integral = PhysicalIntegralType.integral(i) (left, right) => integral.rem(left, right) - case d @ DecimalType.Fixed(precision, scale) => - val integral = PhysicalDecimalType(precision, scale).asIntegral.asInstanceOf[Integral[Any]] + case DecimalType.Fixed(precision, scale) => (left, right) => - checkDecimalOverflow(integral.rem(left, right).asInstanceOf[Decimal], precision, scale) + checkDecimalOverflow(left.asInstanceOf[Decimal].remainder(right.asInstanceOf[Decimal], + BinaryArithmetic.mathContext), precision, scale) } override def evalOperation(left: Any, right: Any): Any = mod(left, right) @@ -1019,7 +1049,8 @@ case class Pmod( override def nullable: Boolean = true - override def decimalMethod: String = "remainder" + override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): + String = s"$value1.remainder($value2)" // This follows Remainder rule override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = { @@ -1077,14 +1108,19 @@ case class Pmod( } val remainder = ctx.freshName("remainder") val javaType = CodeGenerator.javaType(dataType) + val mathContextValue = getMathContextValue(ctx) val errorContext = getContextOrNullCode(ctx) val result = dataType match { case DecimalType.Fixed(precision, scale) => val decimalAdd = "$plus" s""" - |$javaType $remainder = ${eval1.value}.$decimalMethod(${eval2.value}); + |$javaType $remainder = ${decimalMethod(mathContextValue, eval1.value, eval2.value)}; |if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { - | ${ev.value}=($remainder.$decimalAdd(${eval2.value})).$decimalMethod(${eval2.value}); + | ${ev.value}= + |${ + decimalMethod(mathContextValue, s"($remainder.$decimalAdd(${eval2.value}))" + , eval2.value) + }; |} else { | ${ev.value}=$remainder; |} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cfeccbdf648c2..68207afef1d98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -31,7 +31,7 @@ import org.apache.commons.io.FileUtils import org.apache.spark.{AccumulatorSuite, SPARK_DOC_ROOT, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.ExtendedAnalysisException -import org.apache.spark.sql.catalyst.expressions.{GenericRow, Hex} +import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, GenericRow, Hex} import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial} import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, NestedColumnAliasingSuite} @@ -1500,6 +1500,26 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark Seq(Row(d))) } + test("SPARK-40129: Fix Decimal multiply can produce the wrong answer because it rounds twice") { + Seq((false, CodegenObjectFactoryMode.NO_CODEGEN), + (true, CodegenObjectFactoryMode.CODEGEN_ONLY)).foreach(v => { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> v._1.toString) { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> v._2.toString) { + val multiplicandStr = "9173594185998001607642838421.5479932913" + val sparkValue = Seq(multiplicandStr).toDF() + .selectExpr("CAST(value as DECIMAL(38,10)) as a") + .selectExpr("a * CAST(-12 as DECIMAL(38,10))").head().getDecimal(0) + assert(sparkValue.toString == "-110083130231976019291714061058.575919") + + val sparkValue2 = Seq(multiplicandStr).toDF() + .selectExpr("CAST(value as DECIMAL(38,10)) as a") + .selectExpr("a / CAST(-12 as DECIMAL(38,10))").head().getDecimal(0) + assert(sparkValue2.toString == "-764466182166500133970236535.128999") + } + } + }) + } + test("precision smaller than scale") { checkAnswer(sql("select 10.00"), Row(BigDecimal("10.00"))) checkAnswer(sql("select 1.00"), Row(BigDecimal("1.00")))