Skip to content

Commit aba523c

Browse files
ueshindongjoon-hyun
authored andcommitted
[SPARK-39419][SQL][3.3] Fix ArraySort to throw an exception when the comparator returns null
### What changes were proposed in this pull request? Backport of #36812. Fixes `ArraySort` to throw an exception when the comparator returns `null`. Also updates the doc to follow the corrected behavior. ### Why are the changes needed? When the comparator of `ArraySort` returns `null`, currently it handles it as `0` (equal). According to the doc, ``` It returns -1, 0, or 1 as the first element is less than, equal to, or greater than the second element. If the comparator function returns other values (including null), the function will fail and raise an error. ``` It's fine to return non -1, 0, 1 integers to follow the Java convention (still need to update the doc, though), but it should throw an exception for `null` result. ### Does this PR introduce _any_ user-facing change? Yes, if a user uses a comparator that returns `null`, it will throw an error after this PR. The legacy flag `spark.sql.legacy.allowNullComparisonResultInArraySort` can be used to restore the legacy behavior that handles `null` as `0` (equal). ### How was this patch tested? Added some tests. Closes #36834 from ueshin/issues/SPARK-39419/3.3/array_sort. Authored-by: Takuya UESHIN <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 0361ee8 commit aba523c

File tree

5 files changed

+60
-6
lines changed

5 files changed

+60
-6
lines changed

core/src/main/resources/error/error-classes.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@
133133
"message" : [ "PARTITION clause cannot contain the non-partition column: <columnName>." ],
134134
"sqlState" : "42000"
135135
},
136+
"NULL_COMPARISON_RESULT" : {
137+
"message" : [ "The comparison result is null. If you want to handle null as 0 (equal), you can set \"spark.sql.legacy.allowNullComparisonResultInArraySort\" to \"true\"." ]
138+
},
136139
"PARSE_CHAR_MISSING_LENGTH" : {
137140
"message" : [ "DataType <type> requires a length parameter, for example <type>(10). Please specify the length." ],
138141
"sqlState" : "42000"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,9 @@ case class ArrayTransform(
357357
Since 3.0.0 this function also sorts and returns the array based on the
358358
given comparator function. The comparator will take two arguments representing
359359
two elements of the array.
360-
It returns -1, 0, or 1 as the first element is less than, equal to, or greater
361-
than the second element. If the comparator function returns other
362-
values (including null), the function will fail and raise an error.
360+
It returns a negative integer, 0, or a positive integer as the first element is less than,
361+
equal to, or greater than the second element. If the comparator function returns null,
362+
the function will fail and raise an error.
363363
""",
364364
examples = """
365365
Examples:
@@ -375,9 +375,17 @@ case class ArrayTransform(
375375
// scalastyle:on line.size.limit
376376
case class ArraySort(
377377
argument: Expression,
378-
function: Expression)
378+
function: Expression,
379+
allowNullComparisonResult: Boolean)
379380
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
380381

382+
def this(argument: Expression, function: Expression) = {
383+
this(
384+
argument,
385+
function,
386+
SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT))
387+
}
388+
381389
def this(argument: Expression) = this(argument, ArraySort.defaultComparator)
382390

383391
@transient lazy val elementType: DataType =
@@ -416,7 +424,11 @@ case class ArraySort(
416424
(o1: Any, o2: Any) => {
417425
firstElemVar.value.set(o1)
418426
secondElemVar.value.set(o2)
419-
f.eval(inputRow).asInstanceOf[Int]
427+
val cmp = f.eval(inputRow)
428+
if (!allowNullComparisonResult && cmp == null) {
429+
throw QueryExecutionErrors.nullComparisonResultError()
430+
}
431+
cmp.asInstanceOf[Int]
420432
}
421433
}
422434

@@ -437,6 +449,10 @@ case class ArraySort(
437449

438450
object ArraySort {
439451

452+
def apply(argument: Expression, function: Expression): ArraySort = {
453+
new ArraySort(argument, function)
454+
}
455+
440456
def comparator(left: Expression, right: Expression): Expression = {
441457
val lit0 = Literal(0)
442458
val lit1 = Literal(1)

sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2065,4 +2065,9 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
20652065
s"add ${toSQLValue(amount, IntegerType)} $unit to " +
20662066
s"${toSQLValue(DateTimeUtils.microsToInstant(micros), TimestampType)}"))
20672067
}
2068+
2069+
def nullComparisonResultError(): Throwable = {
2070+
new SparkException(errorClass = "NULL_COMPARISON_RESULT",
2071+
messageParameters = Array(), cause = null)
2072+
}
20682073
}

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3748,6 +3748,16 @@ object SQLConf {
37483748
.booleanConf
37493749
.createWithDefault(false)
37503750

3751+
val LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT =
3752+
buildConf("spark.sql.legacy.allowNullComparisonResultInArraySort")
3753+
.internal()
3754+
.doc("When set to false, `array_sort` function throws an error " +
3755+
"if the comparator function returns null. " +
3756+
"If set to true, it restores the legacy behavior that handles null as zero (equal).")
3757+
.version("3.2.2")
3758+
.booleanConf
3759+
.createWithDefault(false)
3760+
37513761
/**
37523762
* Holds information about keys that have been deprecated.
37533763
*

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.SparkFunSuite
20+
import org.apache.spark.{SparkException, SparkFunSuite}
2121
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2222
import org.apache.spark.sql.internal.SQLConf
2323
import org.apache.spark.sql.types._
@@ -838,4 +838,24 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
838838
Literal.create(Seq(Double.NaN, 1d, 2d, null), ArrayType(DoubleType))),
839839
Seq(1d, 2d, Double.NaN, null))
840840
}
841+
842+
test("SPARK-39419: ArraySort should throw an exception when the comparator returns null") {
843+
val comparator = {
844+
val comp = ArraySort.comparator _
845+
(left: Expression, right: Expression) =>
846+
If(comp(left, right) === 0, Literal.create(null, IntegerType), comp(left, right))
847+
}
848+
849+
withSQLConf(
850+
SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT.key -> "false") {
851+
checkExceptionInExpression[SparkException](
852+
arraySort(Literal.create(Seq(3, 1, 1, 2)), comparator), "The comparison result is null")
853+
}
854+
855+
withSQLConf(
856+
SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT.key -> "true") {
857+
checkEvaluation(arraySort(Literal.create(Seq(3, 1, 1, 2)), comparator),
858+
Seq(1, 1, 2, 3))
859+
}
860+
}
841861
}

0 commit comments

Comments
 (0)