From 4f41230740aae3c55e1cbc4ff429d022b4df1174 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 10 Jun 2022 16:57:23 +0800 Subject: [PATCH] [SPARK-39419][SQL] Fix ArraySort to throw an exception when the comparator returns null MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Fixes `ArraySort` to throw an exception when the comparator returns `null`. Also updates the doc to follow the corrected behavior. ### Why are the changes needed? When the comparator of `ArraySort` returns `null`, currently it handles it as `0` (equal). According to the doc, ``` It returns -1, 0, or 1 as the first element is less than, equal to, or greater than the second element. If the comparator function returns other values (including null), the function will fail and raise an error. ``` It's fine to return non -1, 0, 1 integers to follow the Java convention (still need to update the doc, though), but it should throw an exception for `null` result. ### Does this PR introduce _any_ user-facing change? Yes, if a user uses a comparator that returns `null`, it will throw an error after this PR. The legacy flag `spark.sql.legacy.allowNullComparisonResultInArraySort` can be used to restore the legacy behavior that handles `null` as `0` (equal). ### How was this patch tested? Added some tests. --- .../main/resources/error/error-classes.json | 3 +++ .../expressions/higherOrderFunctions.scala | 26 +++++++++++++++---- .../sql/errors/QueryExecutionErrors.scala | 5 ++++ .../apache/spark/sql/internal/SQLConf.scala | 10 +++++++ .../HigherOrderFunctionsSuite.scala | 22 +++++++++++++++- 5 files changed, 60 insertions(+), 6 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 34588fae5a45..2e32482328a7 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -133,6 +133,9 @@ "message" : [ "PARTITION clause cannot contain the non-partition column: ." ], "sqlState" : "42000" }, + "NULL_COMPARISON_RESULT" : { + "message" : [ "The comparison result is null. If you want to handle null as 0 (equal), you can set \"spark.sql.legacy.allowNullComparisonResultInArraySort\" to \"true\"." ] + }, "PARSE_CHAR_MISSING_LENGTH" : { "message" : [ "DataType requires a length parameter, for example (10). Please specify the length." ], "sqlState" : "42000" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index fa444a670f28..d56e761bd2f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -357,9 +357,9 @@ case class ArrayTransform( Since 3.0.0 this function also sorts and returns the array based on the given comparator function. The comparator will take two arguments representing two elements of the array. - It returns -1, 0, or 1 as the first element is less than, equal to, or greater - than the second element. If the comparator function returns other - values (including null), the function will fail and raise an error. + It returns a negative integer, 0, or a positive integer as the first element is less than, + equal to, or greater than the second element. If the comparator function returns null, + the function will fail and raise an error. """, examples = """ Examples: @@ -375,9 +375,17 @@ case class ArrayTransform( // scalastyle:on line.size.limit case class ArraySort( argument: Expression, - function: Expression) + function: Expression, + allowNullComparisonResult: Boolean) extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback { + def this(argument: Expression, function: Expression) = { + this( + argument, + function, + SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT)) + } + def this(argument: Expression) = this(argument, ArraySort.defaultComparator) @transient lazy val elementType: DataType = @@ -416,7 +424,11 @@ case class ArraySort( (o1: Any, o2: Any) => { firstElemVar.value.set(o1) secondElemVar.value.set(o2) - f.eval(inputRow).asInstanceOf[Int] + val cmp = f.eval(inputRow) + if (!allowNullComparisonResult && cmp == null) { + throw QueryExecutionErrors.nullComparisonResultError() + } + cmp.asInstanceOf[Int] } } @@ -437,6 +449,10 @@ case class ArraySort( object ArraySort { + def apply(argument: Expression, function: Expression): ArraySort = { + new ArraySort(argument, function) + } + def comparator(left: Expression, right: Expression): Expression = { val lit0 = Literal(0) val lit1 = Literal(1) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 21fe0b926701..9e29acf04d2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2065,4 +2065,9 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase { s"add ${toSQLValue(amount, IntegerType)} $unit to " + s"${toSQLValue(DateTimeUtils.microsToInstant(micros), TimestampType)}")) } + + def nullComparisonResultError(): Throwable = { + new SparkException(errorClass = "NULL_COMPARISON_RESULT", + messageParameters = Array(), cause = null) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b6230f713838..7f41e463d89a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3748,6 +3748,16 @@ object SQLConf { .booleanConf .createWithDefault(false) + val LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT = + buildConf("spark.sql.legacy.allowNullComparisonResultInArraySort") + .internal() + .doc("When set to false, `array_sort` function throws an error " + + "if the comparator function returns null. " + + "If set to true, it restores the legacy behavior that handles null as zero (equal).") + .version("3.2.2") + .booleanConf + .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index c0db6d8dc29f..b1c4c4414274 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -838,4 +838,24 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper Literal.create(Seq(Double.NaN, 1d, 2d, null), ArrayType(DoubleType))), Seq(1d, 2d, Double.NaN, null)) } + + test("SPARK-39419: ArraySort should throw an exception when the comparator returns null") { + val comparator = { + val comp = ArraySort.comparator _ + (left: Expression, right: Expression) => + If(comp(left, right) === 0, Literal.create(null, IntegerType), comp(left, right)) + } + + withSQLConf( + SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT.key -> "false") { + checkExceptionInExpression[SparkException]( + arraySort(Literal.create(Seq(3, 1, 1, 2)), comparator), "The comparison result is null") + } + + withSQLConf( + SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT.key -> "true") { + checkEvaluation(arraySort(Literal.create(Seq(3, 1, 1, 2)), comparator), + Seq(1, 1, 2, 3)) + } + } }