diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 9eecf81684ce..b818be556819 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -190,7 +190,7 @@ case class Abs(child: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant with SupportQueryContext { - protected val evalMode: EvalMode.Value + protected[sql] val evalMode: EvalMode.Value private lazy val internalDataType: DataType = (left.dataType, right.dataType) match { case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) => @@ -237,6 +237,8 @@ abstract class BinaryArithmetic extends BinaryOperator } } + override def otherCopyArgs: Seq[AnyRef] = evalMode :: Nil + final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_ARITHMETIC) override def initQueryContext(): Option[QueryContext] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 89890ea08641..dc6c596da685 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.types._ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic with CommutativeExpression { - protected override val evalMode: EvalMode.Value = EvalMode.LEGACY + protected[sql] override val evalMode: EvalMode.Value = EvalMode.LEGACY override def inputType: AbstractDataType = IntegralType @@ -86,7 +86,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic with CommutativeExpression { - protected override val evalMode: EvalMode.Value = EvalMode.LEGACY + protected[sql] override val evalMode: EvalMode.Value = EvalMode.LEGACY override def inputType: AbstractDataType = IntegralType @@ -133,7 +133,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic with CommutativeExpression { - protected override val evalMode: EvalMode.Value = EvalMode.LEGACY + protected[sql] override val evalMode: EvalMode.Value = EvalMode.LEGACY override def inputType: AbstractDataType = IntegralType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 89f0b95f5c18..7edbac472eb7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -1086,4 +1086,23 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(IntegralDivide(Literal(Duration.ofDays(1)), Literal(Duration.ofHours(-5))), -4L) } + + test("SPARK-48016: makeCopy should include the evalMode") { + val originalLeft = Literal(1) + val originalRight = Literal(0) + val newLeft = Literal(1.0) + val newRight = Literal(0.0) + Seq( + (left: Expression, right: Expression) => Divide(left, right), + (left: Expression, right: Expression) => Add(left, right), + (left: Expression, right: Expression) => Subtract(left, right), + (left: Expression, right: Expression) => Multiply(left, right) + ).foreach { binaryOp => + val original = binaryOp(originalLeft, originalRight) + val copy = original.makeCopy(Array(newLeft, newRight)).asInstanceOf[BinaryArithmetic] + assert(copy.evalMode === original.evalMode) + assert(copy.left == newLeft) + assert(copy.right == newRight) + } + } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/try_arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/try_arithmetic.sql.out index ef17f6b50b90..9b8306dfe077 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/try_arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/try_arithmetic.sql.out @@ -211,6 +211,13 @@ Project [try_divide(1, (1.0 / 0.0)) AS try_divide(1, (1.0 / 0.0))#x] +- OneRowRelation +-- !query +SELECT try_divide(1, decimal(0)) +-- !query analysis +Project [try_divide(1, cast(0 as decimal(10,0))) AS try_divide(1, 0)#x] ++- OneRowRelation + + -- !query SELECT try_divide(interval 2 year, 2) -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/try_arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/try_arithmetic.sql.out index ef17f6b50b90..9b8306dfe077 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/try_arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/try_arithmetic.sql.out @@ -211,6 +211,13 @@ Project [try_divide(1, (1.0 / 0.0)) AS try_divide(1, (1.0 / 0.0))#x] +- OneRowRelation +-- !query +SELECT try_divide(1, decimal(0)) +-- !query analysis +Project [try_divide(1, cast(0 as decimal(10,0))) AS try_divide(1, 0)#x] ++- OneRowRelation + + -- !query SELECT try_divide(interval 2 year, 2) -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql b/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql index 55907b6701e5..82e9f9d09c3e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/try_arithmetic.sql @@ -38,6 +38,7 @@ SELECT try_divide(0, 0); SELECT try_divide(1, (2147483647 + 1)); SELECT try_divide(1L, (9223372036854775807L + 1L)); SELECT try_divide(1, 1.0 / 0.0); +SELECT try_divide(1, decimal(0)); -- Interval / Numeric SELECT try_divide(interval 2 year, 2); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out index adb6550e8083..6d9a2f2df1a8 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/try_arithmetic.sql.out @@ -341,6 +341,14 @@ org.apache.spark.SparkArithmeticException } +-- !query +SELECT try_divide(1, decimal(0)) +-- !query schema +struct +-- !query output +NULL + + -- !query SELECT try_divide(interval 2 year, 2) -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out b/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out index fa83652da0ed..316ef473ca44 100644 --- a/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/try_arithmetic.sql.out @@ -249,6 +249,14 @@ struct NULL +-- !query +SELECT try_divide(1, decimal(0)) +-- !query schema +struct +-- !query output +NULL + + -- !query SELECT try_divide(interval 2 year, 2) -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala index d9c3848d3b6b..abb4c2ae948d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/functions/V2FunctionBenchmark.scala @@ -104,7 +104,7 @@ object V2FunctionBenchmark extends SqlBasedBenchmark { left: Expression, right: Expression, override val nullable: Boolean) extends BinaryArithmetic { - protected override val evalMode: EvalMode.Value = EvalMode.LEGACY + protected[sql] override val evalMode: EvalMode.Value = EvalMode.LEGACY override def inputType: AbstractDataType = NumericType override def symbol: String = "+"