From fca71f04d8599ab868bb57bcf83afa0b75d004da Mon Sep 17 00:00:00 2001 From: Hisoka Date: Thu, 11 May 2023 09:27:09 +0800 Subject: [PATCH 1/5] [SPARK-40129][SQL] Fix Decimal multiply can produce the wrong answer --- .../sql/catalyst/expressions/arithmetic.scala | 53 ++++++++++++++----- .../org/apache/spark/sql/SQLQuerySuite.scala | 14 +++++ 2 files changed, 55 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 31d4d71cd40ad..a217113f9bee3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.math.{MathContext, RoundingMode} + import scala.math.{max, min} import org.apache.spark.sql.catalyst.InternalRow @@ -279,7 +281,7 @@ abstract class BinaryArithmetic extends BinaryOperator } /** Name of the function for this expression on a [[Decimal]] type. */ - def decimalMethod: String = + def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): String = throw QueryExecutionErrors.notOverrideExpectedMethodsError("BinaryArithmetics", "decimalMethod", "genCode") @@ -298,6 +300,9 @@ abstract class BinaryArithmetic extends BinaryOperator override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case DecimalType.Fixed(precision, scale) => val errorContextCode = getContextOrNullCode(ctx, failOnError) + val mathContext = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.HALF_UP) + val mathContextValue = JavaCode.global(ctx.addReferenceObj("mathContext", + mathContext, mathContext.getClass.getName), mathContext.getClass) val updateIsNull = if (failOnError) { "" } else { @@ -305,7 +310,7 @@ abstract class BinaryArithmetic extends BinaryOperator } nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - |${ev.value} = $eval1.$decimalMethod($eval2).toPrecision( + |${ev.value} = ${decimalMethod(mathContextValue, eval1, eval2)}.toPrecision( | $precision, $scale, Decimal.ROUND_HALF_UP(), ${!failOnError}, $errorContextCode); |$updateIsNull """.stripMargin @@ -430,7 +435,8 @@ case class Add( override def symbol: String = "+" - override def decimalMethod: String = "$plus" + override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): + String = s"$value1.$$plus($value2)" // scalastyle:off // The formula follows Hive which is based on the SQL standard and MS SQL: @@ -516,7 +522,8 @@ case class Subtract( override def symbol: String = "-" - override def decimalMethod: String = "$minus" + override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): + String = s"$value1.$$minus($value2)" // scalastyle:off // The formula follows Hive which is based on the SQL standard and MS SQL: @@ -592,7 +599,10 @@ case class Multiply( override def inputType: AbstractDataType = NumericType override def symbol: String = "*" - override def decimalMethod: String = "$times" + + override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): + String = s"Decimal.apply($value1.toJavaBigDecimal()" + + s".multiply($value2.toJavaBigDecimal(), $mathContextValue))" // scalastyle:off // The formula follows Hive which is based on the SQL standard and MS SQL: @@ -697,11 +707,15 @@ trait DivModLike extends BinaryArithmetic { } val javaType = CodeGenerator.javaType(dataType) val errorContextCode = getContextOrNullCode(ctx, failOnError) + val mathContext = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.HALF_UP) + val mathContextValue = JavaCode.global(ctx.addReferenceObj("mathContext", + mathContext, mathContext.getClass.getName), mathContext.getClass) val operation = super.dataType match { case DecimalType.Fixed(precision, scale) => val decimalValue = ctx.freshName("decimalValue") s""" - |Decimal $decimalValue = ${eval1.value}.$decimalMethod(${eval2.value}).toPrecision( + |Decimal $decimalValue = + |${decimalMethod(mathContextValue, eval1.value, eval2.value)}.toPrecision( | $precision, $scale, Decimal.ROUND_HALF_UP(), ${!failOnError}, $errorContextCode); |if ($decimalValue != null) { | ${ev.value} = ${decimalToDataTypeCodeGen(s"$decimalValue")}; @@ -793,7 +807,9 @@ case class Divide( override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) override def symbol: String = "/" - override def decimalMethod: String = "$div" + + override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): + String = s"$value1.$$div($value2)" // scalastyle:off // The formula follows Hive which is based on the SQL standard and MS SQL: @@ -868,7 +884,10 @@ case class IntegralDivide( override def dataType: DataType = LongType override def symbol: String = "/" - override def decimalMethod: String = "quot" + + override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): + String = s"$value1.quot($value2)" + override def decimalToDataTypeCodeGen(decimalResult: String): String = s"$decimalResult.toLong()" override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = { @@ -936,7 +955,9 @@ case class Remainder( override def inputType: AbstractDataType = NumericType override def symbol: String = "%" - override def decimalMethod: String = "remainder" + + override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): + String = s"$value1.remainder($value2)" // scalastyle:off // The formula follows Hive which is based on the SQL standard and MS SQL: @@ -1019,7 +1040,8 @@ case class Pmod( override def nullable: Boolean = true - override def decimalMethod: String = "remainder" + override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): + String = s"$value1.remainder($value2)" // This follows Remainder rule override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = { @@ -1077,14 +1099,21 @@ case class Pmod( } val remainder = ctx.freshName("remainder") val javaType = CodeGenerator.javaType(dataType) + val mathContext = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.HALF_UP) + val mathContextValue = JavaCode.global(ctx.addReferenceObj("mathContext", + mathContext, mathContext.getClass.getName), mathContext.getClass) val errorContext = getContextOrNullCode(ctx) val result = dataType match { case DecimalType.Fixed(precision, scale) => val decimalAdd = "$plus" s""" - |$javaType $remainder = ${eval1.value}.$decimalMethod(${eval2.value}); + |$javaType $remainder = ${decimalMethod(mathContextValue, eval1.value, eval2.value)}; |if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) { - | ${ev.value}=($remainder.$decimalAdd(${eval2.value})).$decimalMethod(${eval2.value}); + | ${ev.value}= + |${ + decimalMethod(mathContextValue, s"($remainder.$decimalAdd(${eval2.value}))" + , eval2.value) + }; |} else { | ${ev.value}=$remainder; |} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 089464dd5696e..f7d68b18143c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.io.File +import java.math.RoundingMode import java.net.{MalformedURLException, URL} import java.sql.{Date, Timestamp} import java.time.{Duration, Period} @@ -1498,6 +1499,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark Seq(Row(d))) } + test("SPARK-40129: Fix Decimal multiply can produce the wrong answer because it rounds twice") { + val sparkValue = Seq("9173594185998001607642838421.5479932913").toDF() + .selectExpr("CAST(value as DECIMAL(38,10)) as a") + .selectExpr("a * CAST(-12 as DECIMAL(38,10))").head().getDecimal(0) + + val l = new java.math.BigDecimal("9173594185998001607642838421.5479932913") + val r = new java.math.BigDecimal("-12.0000000000") + val prod = l.multiply(r) + val javaValue = prod.setScale(6, RoundingMode.HALF_UP) + + assert(sparkValue == javaValue) + } + test("precision smaller than scale") { checkAnswer(sql("select 10.00"), Row(BigDecimal("10.00"))) checkAnswer(sql("select 1.00"), Row(BigDecimal("1.00"))) From aeaf76b2dc1682c8571b503c3835df9f801967c4 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Sat, 13 May 2023 23:44:13 +0800 Subject: [PATCH 2/5] [SPARK-40129][SQL] add non codegen logic --- .../sql/catalyst/expressions/arithmetic.scala | 29 +++++++++++------ .../org/apache/spark/sql/SQLQuerySuite.scala | 32 ++++++++++++------- 2 files changed, 39 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index a217113f9bee3..fa44fd798fc74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -24,6 +24,7 @@ import scala.math.{max, min} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.BinaryArithmetic.getMathContextValue import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLId, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -300,9 +301,7 @@ abstract class BinaryArithmetic extends BinaryOperator override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case DecimalType.Fixed(precision, scale) => val errorContextCode = getContextOrNullCode(ctx, failOnError) - val mathContext = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.HALF_UP) - val mathContextValue = JavaCode.global(ctx.addReferenceObj("mathContext", - mathContext, mathContext.getClass.getName), mathContext.getClass) + val mathContextValue = getMathContextValue(ctx) val updateIsNull = if (failOnError) { "" } else { @@ -410,6 +409,15 @@ abstract class BinaryArithmetic extends BinaryOperator } object BinaryArithmetic { + + def getMathContextValue(ctx: CodegenContext): GlobalValue = { + // SPARK-40129: We need to use `MathContext` with precision = `Decimal.MAX_PRECISION + 1`, + // to fix decimal multipy bug. See SPARK-40129 for more details. + val mathContext = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.HALF_UP) + JavaCode.global(ctx.addReferenceObj("mathContext", + mathContext, mathContext.getClass.getName), mathContext.getClass) + } + def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right)) } @@ -625,7 +633,12 @@ case class Multiply( protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match { case DecimalType.Fixed(precision, scale) => - checkDecimalOverflow(numeric.times(input1, input2).asInstanceOf[Decimal], precision, scale) + // SPARK-40129: We need to use `MathContext` with precision = `Decimal.MAX_PRECISION + 1`, + // to fix decimal multipy bug. See SPARK-40129 for more details. + checkDecimalOverflow(Decimal(input1.asInstanceOf[Decimal].toJavaBigDecimal + .multiply(input2.asInstanceOf[Decimal].toJavaBigDecimal, + new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.HALF_UP))), + precision, scale) case _: IntegerType if failOnError => MathUtils.multiplyExact( input1.asInstanceOf[Int], @@ -707,9 +720,7 @@ trait DivModLike extends BinaryArithmetic { } val javaType = CodeGenerator.javaType(dataType) val errorContextCode = getContextOrNullCode(ctx, failOnError) - val mathContext = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.HALF_UP) - val mathContextValue = JavaCode.global(ctx.addReferenceObj("mathContext", - mathContext, mathContext.getClass.getName), mathContext.getClass) + val mathContextValue = getMathContextValue(ctx) val operation = super.dataType match { case DecimalType.Fixed(precision, scale) => val decimalValue = ctx.freshName("decimalValue") @@ -1099,9 +1110,7 @@ case class Pmod( } val remainder = ctx.freshName("remainder") val javaType = CodeGenerator.javaType(dataType) - val mathContext = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.HALF_UP) - val mathContextValue = JavaCode.global(ctx.addReferenceObj("mathContext", - mathContext, mathContext.getClass.getName), mathContext.getClass) + val mathContextValue = getMathContextValue(ctx) val errorContext = getContextOrNullCode(ctx) val result = dataType match { case DecimalType.Fixed(precision, scale) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f7d68b18143c0..7690d49275b81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -28,10 +28,10 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable import org.apache.commons.io.FileUtils - import org.apache.spark.{AccumulatorSuite, SPARK_DOC_ROOT, SparkException} + import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} -import org.apache.spark.sql.catalyst.expressions.{GenericRow, Hex} +import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, GenericRow, Hex} import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial} import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, NestedColumnAliasingSuite} @@ -1500,16 +1500,24 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } test("SPARK-40129: Fix Decimal multiply can produce the wrong answer because it rounds twice") { - val sparkValue = Seq("9173594185998001607642838421.5479932913").toDF() - .selectExpr("CAST(value as DECIMAL(38,10)) as a") - .selectExpr("a * CAST(-12 as DECIMAL(38,10))").head().getDecimal(0) - - val l = new java.math.BigDecimal("9173594185998001607642838421.5479932913") - val r = new java.math.BigDecimal("-12.0000000000") - val prod = l.multiply(r) - val javaValue = prod.setScale(6, RoundingMode.HALF_UP) - - assert(sparkValue == javaValue) + Seq((false, CodegenObjectFactoryMode.NO_CODEGEN), + (true, CodegenObjectFactoryMode.CODEGEN_ONLY)).foreach(v => { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> v._1.toString) { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> v._2.toString) { + val multiplicandStr = "9173594185998001607642838421.5479932913" + val sparkValue = Seq(multiplicandStr).toDF() + .selectExpr("CAST(value as DECIMAL(38,10)) as a") + .selectExpr("a * CAST(-12 as DECIMAL(38,10))").head().getDecimal(0) + + val l = new java.math.BigDecimal(multiplicandStr) + val r = new java.math.BigDecimal("-12.0000000000") + val prod = l.multiply(r) + val javaValue = prod.setScale(6, RoundingMode.HALF_UP) + + assert(sparkValue == javaValue) + } + } + }) } test("precision smaller than scale") { From 9fdd97b7f6e806b3033856685c46f5b4ee3d3b3c Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Sun, 14 May 2023 08:47:30 +0800 Subject: [PATCH 3/5] [SPARK-40129][SQL] fix code style --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 7690d49275b81..55b15da5020b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -28,8 +28,8 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable import org.apache.commons.io.FileUtils -import org.apache.spark.{AccumulatorSuite, SPARK_DOC_ROOT, SparkException} +import org.apache.spark.{AccumulatorSuite, SPARK_DOC_ROOT, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, GenericRow, Hex} import org.apache.spark.sql.catalyst.expressions.Cast._ From d7c34a0d6e9c45d969318ef0148ae7b9266e7af6 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Wed, 5 Jul 2023 09:37:42 +0800 Subject: [PATCH 4/5] fix code style --- .../sql/catalyst/expressions/arithmetic.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index fa44fd798fc74..fc721319dde2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -444,7 +444,7 @@ case class Add( override def symbol: String = "+" override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): - String = s"$value1.$$plus($value2)" + String = s"$value1.$$plus($value2)" // scalastyle:off // The formula follows Hive which is based on the SQL standard and MS SQL: @@ -531,7 +531,7 @@ case class Subtract( override def symbol: String = "-" override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): - String = s"$value1.$$minus($value2)" + String = s"$value1.$$minus($value2)" // scalastyle:off // The formula follows Hive which is based on the SQL standard and MS SQL: @@ -609,8 +609,8 @@ case class Multiply( override def symbol: String = "*" override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): - String = s"Decimal.apply($value1.toJavaBigDecimal()" + - s".multiply($value2.toJavaBigDecimal(), $mathContextValue))" + String = s"Decimal.apply($value1.toJavaBigDecimal()" + + s".multiply($value2.toJavaBigDecimal(), $mathContextValue))" // scalastyle:off // The formula follows Hive which is based on the SQL standard and MS SQL: @@ -820,7 +820,7 @@ case class Divide( override def symbol: String = "/" override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): - String = s"$value1.$$div($value2)" + String = s"$value1.$$div($value2)" // scalastyle:off // The formula follows Hive which is based on the SQL standard and MS SQL: @@ -897,7 +897,7 @@ case class IntegralDivide( override def symbol: String = "/" override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): - String = s"$value1.quot($value2)" + String = s"$value1.quot($value2)" override def decimalToDataTypeCodeGen(decimalResult: String): String = s"$decimalResult.toLong()" @@ -968,7 +968,7 @@ case class Remainder( override def symbol: String = "%" override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): - String = s"$value1.remainder($value2)" + String = s"$value1.remainder($value2)" // scalastyle:off // The formula follows Hive which is based on the SQL standard and MS SQL: @@ -1052,7 +1052,7 @@ case class Pmod( override def nullable: Boolean = true override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): - String = s"$value1.remainder($value2)" + String = s"$value1.remainder($value2)" // This follows Remainder rule override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = { From 6db0ec5fe5f036feae8e0a3a8d89d10690b679c8 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Wed, 2 Aug 2023 10:57:42 +0800 Subject: [PATCH 5/5] update --- .../org/apache/spark/sql/types/Decimal.scala | 7 ++++++ .../sql/catalyst/expressions/arithmetic.scala | 24 +++++++++---------- .../org/apache/spark/sql/SQLQuerySuite.scala | 12 ++++------ 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala index afe73635a6824..f425dffcc38b9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -497,6 +497,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { def * (that: Decimal): Decimal = Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, MATH_CONTEXT)) + def multiply(that: Decimal, mathContext: MathContext): Decimal = + Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, mathContext)) + def / (that: Decimal): Decimal = if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, DecimalType.MAX_SCALE, MATH_CONTEXT.getRoundingMode)) @@ -505,6 +508,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (that.isZero) null else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT)) + def remainder(that: Decimal, mathContext: MathContext): Decimal = + if (that.isZero) null + else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, mathContext)) + def quot(that: Decimal): Decimal = if (that.isZero) null else Decimal(toJavaBigDecimal.divideToIntegralValue(that.toJavaBigDecimal, MATH_CONTEXT)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index fc721319dde2a..01de978baf534 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -410,10 +410,11 @@ abstract class BinaryArithmetic extends BinaryOperator object BinaryArithmetic { + // SPARK-40129: We need to use `MathContext` with precision = `Decimal.MAX_PRECISION + 1`, + // to fix decimal multipy bug. See SPARK-40129 for more details. + val mathContext = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.HALF_UP) + def getMathContextValue(ctx: CodegenContext): GlobalValue = { - // SPARK-40129: We need to use `MathContext` with precision = `Decimal.MAX_PRECISION + 1`, - // to fix decimal multipy bug. See SPARK-40129 for more details. - val mathContext = new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.HALF_UP) JavaCode.global(ctx.addReferenceObj("mathContext", mathContext, mathContext.getClass.getName), mathContext.getClass) } @@ -609,8 +610,7 @@ case class Multiply( override def symbol: String = "*" override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): - String = s"Decimal.apply($value1.toJavaBigDecimal()" + - s".multiply($value2.toJavaBigDecimal(), $mathContextValue))" + String = s"$value1.multiply($value2, $mathContextValue)" // scalastyle:off // The formula follows Hive which is based on the SQL standard and MS SQL: @@ -635,10 +635,8 @@ case class Multiply( case DecimalType.Fixed(precision, scale) => // SPARK-40129: We need to use `MathContext` with precision = `Decimal.MAX_PRECISION + 1`, // to fix decimal multipy bug. See SPARK-40129 for more details. - checkDecimalOverflow(Decimal(input1.asInstanceOf[Decimal].toJavaBigDecimal - .multiply(input2.asInstanceOf[Decimal].toJavaBigDecimal, - new MathContext(DecimalType.MAX_PRECISION + 1, RoundingMode.HALF_UP))), - precision, scale) + checkDecimalOverflow(input1.asInstanceOf[Decimal].multiply( + input2.asInstanceOf[Decimal], BinaryArithmetic.mathContext), precision, scale) case _: IntegerType if failOnError => MathUtils.multiplyExact( input1.asInstanceOf[Int], @@ -968,7 +966,7 @@ case class Remainder( override def symbol: String = "%" override def decimalMethod(mathContextValue: GlobalValue, value1: String, value2: String): - String = s"$value1.remainder($value2)" + String = s"$value1.remainder($value2, $mathContextValue)" // scalastyle:off // The formula follows Hive which is based on the SQL standard and MS SQL: @@ -1012,10 +1010,10 @@ case class Remainder( val integral = PhysicalIntegralType.integral(i) (left, right) => integral.rem(left, right) - case d @ DecimalType.Fixed(precision, scale) => - val integral = PhysicalDecimalType(precision, scale).asIntegral.asInstanceOf[Integral[Any]] + case DecimalType.Fixed(precision, scale) => (left, right) => - checkDecimalOverflow(integral.rem(left, right).asInstanceOf[Decimal], precision, scale) + checkDecimalOverflow(left.asInstanceOf[Decimal].remainder(right.asInstanceOf[Decimal], + BinaryArithmetic.mathContext), precision, scale) } override def evalOperation(left: Any, right: Any): Any = mod(left, right) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a707da4aaccfb..68207afef1d98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import java.io.File -import java.math.RoundingMode import java.net.{MalformedURLException, URL} import java.sql.{Date, Timestamp} import java.time.{Duration, Period} @@ -1510,13 +1509,12 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark val sparkValue = Seq(multiplicandStr).toDF() .selectExpr("CAST(value as DECIMAL(38,10)) as a") .selectExpr("a * CAST(-12 as DECIMAL(38,10))").head().getDecimal(0) + assert(sparkValue.toString == "-110083130231976019291714061058.575919") - val l = new java.math.BigDecimal(multiplicandStr) - val r = new java.math.BigDecimal("-12.0000000000") - val prod = l.multiply(r) - val javaValue = prod.setScale(6, RoundingMode.HALF_UP) - - assert(sparkValue == javaValue) + val sparkValue2 = Seq(multiplicandStr).toDF() + .selectExpr("CAST(value as DECIMAL(38,10)) as a") + .selectExpr("a / CAST(-12 as DECIMAL(38,10))").head().getDecimal(0) + assert(sparkValue2.toString == "-764466182166500133970236535.128999") } } })