diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 833ecc0a3c09..275566589f54 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -295,6 +295,11 @@ "UDF class doesn't implement any UDF interface" ] }, + "NULL_COMPARISON_RESULT" : { + "message" : [ + "The comparison result is null. If you want to handle null as 0 (equal), you can set \"spark.sql.legacy.allowNullComparisonResultInArraySort\" to \"true\"." + ] + }, "PARSE_CHAR_MISSING_LENGTH" : { "message" : [ "DataType requires a length parameter, for example (10). Please specify the length." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 79b76f799d94..135a423b38a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -357,9 +357,9 @@ case class ArrayTransform( Since 3.0.0 this function also sorts and returns the array based on the given comparator function. The comparator will take two arguments representing two elements of the array. - It returns -1, 0, or 1 as the first element is less than, equal to, or greater - than the second element. If the comparator function returns other - values (including null), the function will fail and raise an error. + It returns a negative integer, 0, or a positive integer as the first element is less than, + equal to, or greater than the second element. If the comparator function returns null, + the function will fail and raise an error. """, examples = """ Examples: @@ -375,9 +375,17 @@ case class ArrayTransform( // scalastyle:on line.size.limit case class ArraySort( argument: Expression, - function: Expression) + function: Expression, + allowNullComparisonResult: Boolean) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + def this(argument: Expression, function: Expression) = { + this( + argument, + function, + SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT)) + } + def this(argument: Expression) = this(argument, ArraySort.defaultComparator) @transient lazy val elementType: DataType = @@ -416,7 +424,11 @@ case class ArraySort( (o1: Any, o2: Any) => { firstElemVar.value.set(o1) secondElemVar.value.set(o2) - f.eval(inputRow).asInstanceOf[Int] + val cmp = f.eval(inputRow) + if (!allowNullComparisonResult && cmp == null) { + throw QueryExecutionErrors.nullComparisonResultError() + } + cmp.asInstanceOf[Int] } } @@ -437,6 +449,10 @@ case class ArraySort( object ArraySort { + def apply(argument: Expression, function: Expression): ArraySort = { + new ArraySort(argument, function) + } + def comparator(left: Expression, right: Expression): Expression = { val lit0 = Literal(0) val lit1 = Literal(1) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index cd258e3649a7..2b573b2385cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2021,4 +2021,9 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { new SparkException( errorClass = "MULTI_VALUE_SUBQUERY_ERROR", messageParameters = Array(plan), cause = null) } + + def nullComparisonResultError(): Throwable = { + new SparkException(errorClass = "NULL_COMPARISON_RESULT", + messageParameters = Array(), cause = null) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8c7702efd47f..5e1f39561599 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3818,6 +3818,16 @@ object SQLConf { .booleanConf .createWithDefault(false) + val LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT = + buildConf("spark.sql.legacy.allowNullComparisonResultInArraySort") + .internal() + .doc("When set to false, `array_sort` function throws an error " + + "if the comparator function returns null. " + + "If set to true, it restores the legacy behavior that handles null as zero (equal).") + .version("3.2.2") + .booleanConf + .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index c0db6d8dc29f..b1c4c4414274 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -838,4 +838,24 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper Literal.create(Seq(Double.NaN, 1d, 2d, null), ArrayType(DoubleType))), Seq(1d, 2d, Double.NaN, null)) } + + test("SPARK-39419: ArraySort should throw an exception when the comparator returns null") { + val comparator = { + val comp = ArraySort.comparator _ + (left: Expression, right: Expression) => + If(comp(left, right) === 0, Literal.create(null, IntegerType), comp(left, right)) + } + + withSQLConf( + SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT.key -> "false") { + checkExceptionInExpression[SparkException]( + arraySort(Literal.create(Seq(3, 1, 1, 2)), comparator), "The comparison result is null") + } + + withSQLConf( + SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT.key -> "true") { + checkEvaluation(arraySort(Literal.create(Seq(3, 1, 1, 2)), comparator), + Seq(1, 1, 2, 3)) + } + } }