From 9ddb7e8ce5fad91f012f2b81b4d6107d848d3c43 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 26 Apr 2024 15:44:18 -0700 Subject: [PATCH 1/2] fix --- .../sql/catalyst/expressions/arithmetic.scala | 4 +++- .../expressions/bitwiseExpressions.scala | 6 +++--- .../ArithmeticExpressionSuite.scala | 19 +++++++++++++++++++ .../ansi/try_arithmetic.sql.out | 7 +++++++ .../analyzer-results/try_arithmetic.sql.out | 7 +++++++ .../sql-tests/inputs/try_arithmetic.sql | 1 + .../results/ansi/try_arithmetic.sql.out | 8 ++++++++ .../sql-tests/results/try_arithmetic.sql.out | 8 ++++++++ .../functions/V2FunctionBenchmark.scala | 2 +- 9 files changed, 57 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 9eecf81684ce..b818be556819 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -190,7 +190,7 @@ case class Abs(child: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant with SupportQueryContext { - protected val evalMode: EvalMode.Value + protected[sql] val evalMode: EvalMode.Value private lazy val internalDataType: DataType = (left.dataType, right.dataType) match { case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => @@ -237,6 +237,8 @@ abstract class BinaryArithmetic extends BinaryOperator } } + override def otherCopyArgs: Seq[AnyRef] = evalMode :: Nil + final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_ARITHMETIC) override def initQueryContext(): Option[QueryContext] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 89890ea08641..dc6c596da685 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.types._ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic with CommutativeExpression { - protected override val evalMode: EvalMode.Value = EvalMode.LEGACY + protected[sql] override val evalMode: EvalMode.Value = EvalMode.LEGACY override def inputType: AbstractDataType = IntegralType @@ -86,7 +86,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic with CommutativeExpression { - protected override val evalMode: EvalMode.Value = EvalMode.LEGACY + protected[sql] override val evalMode: EvalMode.Value = EvalMode.LEGACY override def inputType: AbstractDataType = IntegralType @@ -133,7 +133,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic with CommutativeExpression { - protected override val evalMode: EvalMode.Value = EvalMode.LEGACY + protected[sql] override val evalMode: EvalMode.Value = EvalMode.LEGACY override def inputType: AbstractDataType = IntegralType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 89f0b95f5c18..946a65273276 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -1086,4 +1086,23 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(IntegralDivide(Literal(Duration.ofDays(1)), Literal(Duration.ofHours(-5))), -4L) } + + test("BinaryArithmetic: makeCopy should include the evalMode") { + val originalLeft = Literal(1) + val originalRight = Literal(0) + val newLeft = Literal(1.0) + val newRight = Literal(0.0) + Seq( + (left: Expression, right: Expression) => Divide(left, right), + (left: Expression, right: Expression) => Add(left, right), + (left: Expression, right: Expression) => Subtract(left, right), + (left: Expression, right: Expression) => Multiply(left, right) + ).foreach { binaryOp => + val original = binaryOp(originalLeft, originalRight) + val copy = original.makeCopy(Array(newLeft, newRight)).asInstanceOf[BinaryArithmetic] + assert(copy.evalMode === original.evalMode) + assert(copy.left == newLeft) + assert(copy.right == newRight) + } + } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/try_arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/try_arithmetic.sql.out index ef17f6b50b90..9b8306dfe077 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/try_arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/try_arithmetic.sql.out @@ -211,6 +211,13 @@ Project [try_divide(1, (1.0 / 0.0)) AS try_divide(1, (1.0 / 0.0))#x] +- OneRowRelation +-- !query +SELECT try_divide(1, decimal(0)) +-- !query analysis +Project [try_divide(1, cast(0 as decimal(10,0))) AS try_divide(1, 0)#x] ++- OneRowRelation + + -- !query SELECT try_divide(interval 2 year, 2) -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/try_arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/try_arithmetic.sql.out index ef17f6b50b90..9b8306dfe077 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/try_arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/try_arithmetic.sql.out @@ -211,6 +211,13 @@ Project [try_divide(1, (1.0 / 0.0)) AS try_divide(1, (1.0 / 0.0))#x] +- OneRowRelation +-- !query +SELECT try_divide(1, decimal(0)) +-- !query analysis +Project [try_divide(1, cast(0 as decimal(10,0))) AS try_divide(1, 0)#x] ++- OneRowRelation + + -- !query SELECT try_divide(interval 2 year, 2) -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql b/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql index 55907b6701e5..82e9f9d09c3e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql @@ -38,6 +38,7 @@ SELECT try_divide(0, 0); SELECT try_divide(1, (2147483647 + 1)); SELECT try_divide(1L, (9223372036854775807L + 1L)); SELECT try_divide(1, 1.0 / 0.0); +SELECT try_divide(1, decimal(0)); -- Interval / Numeric SELECT try_divide(interval 2 year, 2); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out index adb6550e8083..6d9a2f2df1a8 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out @@ -341,6 +341,14 @@ org.apache.spark.SparkArithmeticException } +-- !query +SELECT try_divide(1, decimal(0)) +-- !query schema +struct +-- !query output +NULL + + -- !query SELECT try_divide(interval 2 year, 2) -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out index fa83652da0ed..316ef473ca44 100644 --- a/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out @@ -249,6 +249,14 @@ struct NULL +-- !query +SELECT try_divide(1, decimal(0)) +-- !query schema +struct +-- !query output +NULL + + -- !query SELECT try_divide(interval 2 year, 2) -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala index d9c3848d3b6b..abb4c2ae948d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala @@ -104,7 +104,7 @@ object V2FunctionBenchmark extends SqlBasedBenchmark { left: Expression, right: Expression, override val nullable: Boolean) extends BinaryArithmetic { - protected override val evalMode: EvalMode.Value = EvalMode.LEGACY + protected[sql] override val evalMode: EvalMode.Value = EvalMode.LEGACY override def inputType: AbstractDataType = NumericType override def symbol: String = "+" From 2d6b0465d12861b49bd1c165700278bf797b0d5c Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 26 Apr 2024 15:51:15 -0700 Subject: [PATCH 2/2] add jira --- .../sql/catalyst/expressions/ArithmeticExpressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 946a65273276..7edbac472eb7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -1087,7 +1087,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper Literal(Duration.ofHours(-5))), -4L) } - test("BinaryArithmetic: makeCopy should include the evalMode") { + test("SPARK-48016: makeCopy should include the evalMode") { val originalLeft = Literal(1) val originalRight = Literal(0) val newLeft = Literal(1.0)