Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@
"message" : [ "PARTITION clause cannot contain the non-partition column: <columnName>." ],
"sqlState" : "42000"
},
"NULL_COMPARISON_RESULT" : {
"message" : [ "The comparison result is null. If you want to handle null as 0 (equal), you can set \"spark.sql.legacy.allowNullComparisonResultInArraySort\" to \"true\"." ]
},
"PARSE_CHAR_MISSING_LENGTH" : {
"message" : [ "DataType <type> requires a length parameter, for example <type>(10). Please specify the length." ],
"sqlState" : "42000"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,9 @@ case class ArrayTransform(
Since 3.0.0 this function also sorts and returns the array based on the
given comparator function. The comparator will take two arguments representing
two elements of the array.
It returns -1, 0, or 1 as the first element is less than, equal to, or greater
than the second element. If the comparator function returns other
values (including null), the function will fail and raise an error.
It returns a negative integer, 0, or a positive integer as the first element is less than,
equal to, or greater than the second element. If the comparator function returns null,
the function will fail and raise an error.
""",
examples = """
Examples:
Expand All @@ -375,9 +375,17 @@ case class ArrayTransform(
// scalastyle:on line.size.limit
case class ArraySort(
argument: Expression,
function: Expression)
function: Expression,
allowNullComparisonResult: Boolean)
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {

def this(argument: Expression, function: Expression) = {
this(
argument,
function,
SQLConf.get.getConf(SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT))
}

def this(argument: Expression) = this(argument, ArraySort.defaultComparator)

@transient lazy val elementType: DataType =
Expand Down Expand Up @@ -416,7 +424,11 @@ case class ArraySort(
(o1: Any, o2: Any) => {
firstElemVar.value.set(o1)
secondElemVar.value.set(o2)
f.eval(inputRow).asInstanceOf[Int]
val cmp = f.eval(inputRow)
if (!allowNullComparisonResult && cmp == null) {
throw QueryExecutionErrors.nullComparisonResultError()
}
cmp.asInstanceOf[Int]
}
}

Expand All @@ -437,6 +449,10 @@ case class ArraySort(

object ArraySort {

def apply(argument: Expression, function: Expression): ArraySort = {
new ArraySort(argument, function)
}

def comparator(left: Expression, right: Expression): Expression = {
val lit0 = Literal(0)
val lit1 = Literal(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2065,4 +2065,9 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
s"add ${toSQLValue(amount, IntegerType)} $unit to " +
s"${toSQLValue(DateTimeUtils.microsToInstant(micros), TimestampType)}"))
}

def nullComparisonResultError(): Throwable = {
new SparkException(errorClass = "NULL_COMPARISON_RESULT",
messageParameters = Array(), cause = null)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3748,6 +3748,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT =
buildConf("spark.sql.legacy.allowNullComparisonResultInArraySort")
.internal()
.doc("When set to false, `array_sort` function throws an error " +
"if the comparator function returns null. " +
"If set to true, it restores the legacy behavior that handles null as zero (equal).")
.version("3.2.2")
.booleanConf
.createWithDefault(false)

/**
* Holds information about keys that have been deprecated.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -838,4 +838,24 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
Literal.create(Seq(Double.NaN, 1d, 2d, null), ArrayType(DoubleType))),
Seq(1d, 2d, Double.NaN, null))
}

test("SPARK-39419: ArraySort should throw an exception when the comparator returns null") {
val comparator = {
val comp = ArraySort.comparator _
(left: Expression, right: Expression) =>
If(comp(left, right) === 0, Literal.create(null, IntegerType), comp(left, right))
}

withSQLConf(
SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT.key -> "false") {
checkExceptionInExpression[SparkException](
arraySort(Literal.create(Seq(3, 1, 1, 2)), comparator), "The comparison result is null")
}

withSQLConf(
SQLConf.LEGACY_ALLOW_NULL_COMPARISON_RESULT_IN_ARRAY_SORT.key -> "true") {
checkEvaluation(arraySort(Literal.create(Seq(3, 1, 1, 2)), comparator),
Seq(1, 1, 2, 3))
}
}
}