Skip to content

Commit 9b8521e

Browse files
ueshingatorsmile
authored andcommitted
[SPARK-25068][SQL] Add exists function.
## What changes were proposed in this pull request? This pr adds `exists` function which tests whether a predicate holds for one or more elements in the array. ```sql > SELECT exists(array(1, 2, 3), x -> x % 2 == 0); true ``` ## How was this patch tested? Added tests. Closes #22052 from ueshin/issues/SPARK-25068/exists. Authored-by: Takuya UESHIN <[email protected]> Signed-off-by: Xiao Li <[email protected]>
1 parent fec67ed commit 9b8521e

File tree

6 files changed

+205
-0
lines changed

6 files changed

+205
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ object FunctionRegistry {
444444
expression[ArrayTransform]("transform"),
445445
expression[MapFilter]("map_filter"),
446446
expression[ArrayFilter]("filter"),
447+
expression[ArrayExists]("exists"),
447448
expression[ArrayAggregate]("aggregate"),
448449
CreateStruct.registryEntry,
449450

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,53 @@ case class ArrayFilter(
356356
override def prettyName: String = "filter"
357357
}
358358

359+
/**
360+
* Tests whether a predicate holds for one or more elements in the array.
361+
*/
362+
@ExpressionDescription(usage =
363+
"_FUNC_(expr, pred) - Tests whether a predicate holds for one or more elements in the array.",
364+
examples = """
365+
Examples:
366+
> SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 0);
367+
true
368+
""",
369+
since = "2.4.0")
370+
case class ArrayExists(
371+
input: Expression,
372+
function: Expression)
373+
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
374+
375+
override def nullable: Boolean = input.nullable
376+
377+
override def dataType: DataType = BooleanType
378+
379+
override def expectingFunctionType: AbstractDataType = BooleanType
380+
381+
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayExists = {
382+
val elem = HigherOrderFunction.arrayArgumentType(input.dataType)
383+
copy(function = f(function, elem :: Nil))
384+
}
385+
386+
@transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function
387+
388+
override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
389+
val arr = value.asInstanceOf[ArrayData]
390+
val f = functionForEval
391+
var exists = false
392+
var i = 0
393+
while (i < arr.numElements && !exists) {
394+
elementVar.value.set(arr.get(i, elementVar.dataType))
395+
if (f.eval(inputRow).asInstanceOf[Boolean]) {
396+
exists = true
397+
}
398+
i += 1
399+
}
400+
exists
401+
}
402+
403+
override def prettyName: String = "exists"
404+
}
405+
359406
/**
360407
* Applies a binary operator to a start value and all elements in the array.
361408
*/

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,43 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
202202
Seq(Seq(1, 3), null, Seq(5)))
203203
}
204204

205+
test("ArrayExists") {
206+
def exists(expr: Expression, f: Expression => Expression): Expression = {
207+
val at = expr.dataType.asInstanceOf[ArrayType]
208+
ArrayExists(expr, createLambda(at.elementType, at.containsNull, f))
209+
}
210+
211+
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
212+
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
213+
val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false))
214+
215+
val isEven: Expression => Expression = x => x % 2 === 0
216+
val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1
217+
218+
checkEvaluation(exists(ai0, isEven), true)
219+
checkEvaluation(exists(ai0, isNullOrOdd), true)
220+
checkEvaluation(exists(ai1, isEven), false)
221+
checkEvaluation(exists(ai1, isNullOrOdd), true)
222+
checkEvaluation(exists(ain, isEven), null)
223+
checkEvaluation(exists(ain, isNullOrOdd), null)
224+
225+
val as0 =
226+
Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false))
227+
val as1 = Literal.create(Seq(null, "b", "c"), ArrayType(StringType, containsNull = true))
228+
val asn = Literal.create(null, ArrayType(StringType, containsNull = false))
229+
230+
val startsWithA: Expression => Expression = x => x.startsWith("a")
231+
232+
checkEvaluation(exists(as0, startsWithA), true)
233+
checkEvaluation(exists(as1, startsWithA), false)
234+
checkEvaluation(exists(asn, startsWithA), null)
235+
236+
val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)),
237+
ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true))
238+
checkEvaluation(transform(aai, ix => exists(ix, isNullOrOdd)),
239+
Seq(true, null, true))
240+
}
241+
205242
test("ArrayAggregate") {
206243
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
207244
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))

sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,9 @@ select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as
4545

4646
-- Aggregate a null array
4747
select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) as v;
48+
49+
-- Check for element existence
50+
select exists(ys, y -> y > 30) as v from nested;
51+
52+
-- Check for element existence in a null array
53+
select exists(cast(null as array<int>), y -> y > 30) as v;

sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,21 @@ select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) a
145145
struct<v:int>
146146
-- !query 14 output
147147
NULL
148+
149+
150+
-- !query 15
151+
select exists(ys, y -> y > 30) as v from nested
152+
-- !query 15 schema
153+
struct<v:boolean>
154+
-- !query 15 output
155+
false
156+
true
157+
true
158+
159+
160+
-- !query 16
161+
select exists(cast(null as array<int>), y -> y > 30) as v
162+
-- !query 16 schema
163+
struct<v:boolean>
164+
-- !query 16 output
165+
NULL

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1996,6 +1996,102 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
19961996
assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type"))
19971997
}
19981998

1999+
test("exists function - array for primitive type not containing null") {
2000+
val df = Seq(
2001+
Seq(1, 9, 8, 7),
2002+
Seq(5, 9, 7),
2003+
Seq.empty,
2004+
null
2005+
).toDF("i")
2006+
2007+
def testArrayOfPrimitiveTypeNotContainsNull(): Unit = {
2008+
checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"),
2009+
Seq(
2010+
Row(true),
2011+
Row(false),
2012+
Row(false),
2013+
Row(null)))
2014+
}
2015+
2016+
// Test with local relation, the Project will be evaluated without codegen
2017+
testArrayOfPrimitiveTypeNotContainsNull()
2018+
// Test with cached relation, the Project will be evaluated with codegen
2019+
df.cache()
2020+
testArrayOfPrimitiveTypeNotContainsNull()
2021+
}
2022+
2023+
test("exists function - array for primitive type containing null") {
2024+
val df = Seq[Seq[Integer]](
2025+
Seq(1, 9, 8, null, 7),
2026+
Seq(5, null, null, 9, 7, null),
2027+
Seq.empty,
2028+
null
2029+
).toDF("i")
2030+
2031+
def testArrayOfPrimitiveTypeContainsNull(): Unit = {
2032+
checkAnswer(df.selectExpr("exists(i, x -> x % 2 == 0)"),
2033+
Seq(
2034+
Row(true),
2035+
Row(false),
2036+
Row(false),
2037+
Row(null)))
2038+
}
2039+
2040+
// Test with local relation, the Project will be evaluated without codegen
2041+
testArrayOfPrimitiveTypeContainsNull()
2042+
// Test with cached relation, the Project will be evaluated with codegen
2043+
df.cache()
2044+
testArrayOfPrimitiveTypeContainsNull()
2045+
}
2046+
2047+
test("exists function - array for non-primitive type") {
2048+
val df = Seq(
2049+
Seq("c", "a", "b"),
2050+
Seq("b", null, "c", null),
2051+
Seq.empty,
2052+
null
2053+
).toDF("s")
2054+
2055+
def testNonPrimitiveType(): Unit = {
2056+
checkAnswer(df.selectExpr("exists(s, x -> x is null)"),
2057+
Seq(
2058+
Row(false),
2059+
Row(true),
2060+
Row(false),
2061+
Row(null)))
2062+
}
2063+
2064+
// Test with local relation, the Project will be evaluated without codegen
2065+
testNonPrimitiveType()
2066+
// Test with cached relation, the Project will be evaluated with codegen
2067+
df.cache()
2068+
testNonPrimitiveType()
2069+
}
2070+
2071+
test("exists function - invalid") {
2072+
val df = Seq(
2073+
(Seq("c", "a", "b"), 1),
2074+
(Seq("b", null, "c", null), 2),
2075+
(Seq.empty, 3),
2076+
(null, 4)
2077+
).toDF("s", "i")
2078+
2079+
val ex1 = intercept[AnalysisException] {
2080+
df.selectExpr("exists(s, (x, y) -> x + y)")
2081+
}
2082+
assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match"))
2083+
2084+
val ex2 = intercept[AnalysisException] {
2085+
df.selectExpr("exists(i, x -> x)")
2086+
}
2087+
assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type"))
2088+
2089+
val ex3 = intercept[AnalysisException] {
2090+
df.selectExpr("exists(s, x -> x)")
2091+
}
2092+
assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type"))
2093+
}
2094+
19992095
test("aggregate function - array for primitive type not containing null") {
20002096
val df = Seq(
20012097
Seq(1, 9, 8, 7),

0 commit comments

Comments
 (0)