Skip to content

Commit d6040ea

Browse files
committed
Add tests, and small clean-up of the NOT IN pathway
1 parent 13bedc0 commit d6040ea

File tree

2 files changed

+137
-5
lines changed

2 files changed

+137
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
116116
// (a1,a2,...) = (b1,b2,...)
117117
// to
118118
// (a1=b1 OR isnull(a1=b1)) AND (a2=b2 OR isnull(a2=b2)) AND ...
119-
val joinConds = splitConjunctivePredicates(joinCond.get)
119+
val baseJoinConds = splitConjunctivePredicates(joinCond.get)
120+
val nullAwareJoinConds = baseJoinConds.map(c => Or(c, IsNull(c)))
120121
// After that, add back the correlated join predicate(s) in the subquery
121122
// Example:
122123
// SELECT ... FROM A WHERE A.A1 NOT IN (SELECT B.B1 FROM B WHERE B.B2 = A.A2 AND B.B3 > 1)
123124
// will have the final conditions in the LEFT ANTI as
124-
// (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2)
125-
val pairs = (joinConds.map(c => Or(c, IsNull(c))) ++ conditions).reduceLeft(And)
125+
// (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1
126+
val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And)
126127
// Deduplicate conflicting attributes if any.
127-
dedupJoin(Join(outerPlan, sub, LeftAnti, Option(pairs)))
128+
dedupJoin(Join(outerPlan, sub, LeftAnti, Option(finalJoinCond)))
128129
case (p, predicate) =>
129130
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
130131
Project(p.output, Filter(newCond.get, inputPlan))

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717

1818
package org.apache.spark.sql
1919

20+
import org.scalatest.GivenWhenThen
21+
2022
import org.apache.spark.sql.catalyst.plans.logical.Join
2123
import org.apache.spark.sql.test.SharedSQLContext
2224

23-
class SubquerySuite extends QueryTest with SharedSQLContext {
25+
class SubquerySuite extends QueryTest with SharedSQLContext with GivenWhenThen {
2426
import testImplicits._
2527

2628
setupTestData()
@@ -275,6 +277,135 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
275277

276278
}
277279

280+
// ``col NOT IN expr'' is quite difficult to reason about. There are many edge cases, some of the
281+
// rules are not very intuitive, and precedence and treatment of null values is somewhat
282+
// unintuitive. To make this simpler to understand, I've come up with a plain English way of
283+
// describing the expected behavior of this query.
284+
//
285+
// - If the subquery is empty (i.e. returns no rows), the row should be returned, regardless of
286+
// whether the filtered columns include nulls.
287+
// - If the subquery contains a result with all nulls, then the row should not be returned.
288+
// - If for all non-null filter columns there exists a row in the subquery in which each column
289+
// either
290+
// 1. is equal to the corresponding filter column or
291+
// 2. is null
292+
// then the row should not be returned. (This includes the case where all filter columns are
293+
// null.)
294+
// - Otherwise, the row should be returned.
295+
//
296+
// Using these rules, we can come up with a set of test cases for single-column and multi-column
297+
// NOT IN test cases.
298+
test("NOT IN single column with nulls predicate subquery") {
299+
// Test cases for single-column ``WHERE a NOT IN (SELECT c FROM r ...)'':
300+
// | # | does subquery include null? | is a null? | a = c? | row with a included in result? |
301+
// | 1 | empty | | | yes |
302+
// | 2 | yes | | | no |
303+
// | 3 | no | yes | | no |
304+
// | 4 | no | no | yes | no |
305+
// | 5 | no | no | no | yes |
306+
Seq(row((null, 5.0)), row((3, 3.0))).toDF("a", "b").createOrReplaceTempView("m")
307+
Seq(row((2, 3.0)), row((2, 3.0)), row((null, 5.0))).toDF("c", "d").createOrReplaceTempView("s")
308+
309+
// Single-column test cases
310+
val subqueryIsEmpty = "d > 6.0"
311+
val cIncludesNull = "d = 5.0"
312+
val cDoesNotMatchA = "d = 3.0"
313+
val cMatchesA = "d = 5.0"
314+
val aIsNull = "b = 5.0"
315+
val aIsNotNull = "b = 3.0"
316+
317+
val includesNullRow = Row(null, 5.0) :: Nil
318+
val includesNotNullRow = Row(3, 3.0) :: Nil
319+
val doesNotIncludeRow = Nil
320+
321+
val singleColumnTestCases = Seq(
322+
("Case 1a (subquery is empty)", subqueryIsEmpty, aIsNull, includesNullRow),
323+
("Case 1b (subquery is empty)", subqueryIsEmpty, aIsNotNull, includesNotNullRow),
324+
("Case 2a (subquery includes null)", cIncludesNull, aIsNull, doesNotIncludeRow),
325+
("Case 2b (subquery includes null)", cIncludesNull, aIsNotNull, doesNotIncludeRow),
326+
("Case 3 (probe column is null)", cDoesNotMatchA, aIsNull, doesNotIncludeRow),
327+
("Case 4 (there is a match)", cMatchesA, aIsNotNull, doesNotIncludeRow),
328+
("Case 5 (there is no match)", cDoesNotMatchA, aIsNotNull, includesNotNullRow))
329+
330+
for ((given, sClause, mClause, expectedOutput) <- singleColumnTestCases) {
331+
Given(given)
332+
val query = s"SELECT * FROM m WHERE $mClause AND a NOT IN (SELECT c FROM s WHERE $sClause)"
333+
checkAnswer(sql(query), expectedOutput)
334+
}
335+
336+
// Correlated subqueries should also be handled properly. The addition of the correlated
337+
// subquery changes the query from case 2/3/4 to case 1. Because of this, the row from l should
338+
// be included in the output.
339+
val correlatedSubqueryTestCases = Seq(
340+
("Case 2a->1 (subquery had nulls)", cIncludesNull, aIsNull, includesNullRow),
341+
("Case 2b->1 (subquery had nulls)", cIncludesNull, aIsNotNull, includesNotNullRow),
342+
("Case 3->1 (probe column was null)", cMatchesA, aIsNull, includesNullRow),
343+
("Case 4->1 (there was a match)", cMatchesA, aIsNotNull, includesNotNullRow))
344+
for ((given, sClause, mClause, expectedOutput) <- correlatedSubqueryTestCases) {
345+
Given(given)
346+
// scalastyle:off
347+
val query =
348+
s"SELECT * FROM m WHERE $mClause AND a NOT IN (SELECT c FROM s WHERE $sClause AND c < b - 10)"
349+
// scalastyle:on
350+
checkAnswer(sql(query), expectedOutput)
351+
}
352+
}
353+
354+
test("NOT IN multi column with nulls predicate subquery") {
355+
// scalastyle:off
356+
// Test cases for multi-column ``WHERE a NOT IN (SELECT c FROM r ...)'':
357+
// | # | does subquery include null? | do filter columns contain null? | a = c? | b = d? | row included in result? |
358+
// | 1 | empty | * | * | * | yes |
359+
// | 2 | 1+ row has null for all columns | * | * | * | no |
360+
// | 3 | no row has null for all columns | (yes, yes) | * | * | no |
361+
// | 4 | no | (no, yes) | yes | * | no |
362+
// | 5 | no row has null for all columns | (no, yes) | no | * | yes |
363+
// | 6 | no | (no, no) | yes | yes | no |
364+
// | 7 | no | (no, no) | _ | _ | yes |
365+
//
366+
// This can clearly be generalized, but only these cases are tested here.
367+
// scalastyle:on
368+
369+
Seq(row((null, null)), row((3, 5.0)), row((2, null)), row((2, 3.0))).toDF("a", "b")
370+
.createOrReplaceTempView("m")
371+
Seq(row((null, null)), row((2, 3.0)), row((3, null))).toDF("c", "d")
372+
.createOrReplaceTempView("s")
373+
374+
val subqueryIsEmpty = "c > 200" // Returns ()
375+
val dIsNull = "c = 3" // Returns (3, null)
376+
val cAndDAreNull = "c IS NULL AND d IS NULL" // Returns (null, null)
377+
val cAndDAreNotNull = "c = 2" // Returns (2, 3.0)
378+
379+
val aAndBAreNull = "a IS NULL AND b IS NULL" // Returns (null, null)
380+
val aAndBAreNotNull = "a = 3" // Returns (3, 5.0)
381+
val aAndBMatch = "a = 2 AND b = 3.0" // Returns (2, 3.0)
382+
val aIsNotNull = "a = 2" // Returns (2, null), (2, 3.0)
383+
384+
val includesNullRow = Row(null, null) :: Nil
385+
val includesSemiNullAndNotNullRow = Row(2, null) :: Row(2, 3.0) :: Nil
386+
val includesPartiallyNullRow = Row(2, null) :: Nil
387+
val includesNotNullRow = Row(3, 5.0) :: Nil
388+
val doesNotIncludeRow = Nil
389+
val multiColumnTestCases = Seq(
390+
("Case 1a (subquery is empty)", subqueryIsEmpty, aAndBAreNull, includesNullRow),
391+
("Case 1b (subquery is empty)", subqueryIsEmpty, aIsNotNull, includesSemiNullAndNotNullRow),
392+
("Case 2a (subquery contains null)", cAndDAreNull, aAndBAreNull, doesNotIncludeRow),
393+
("Case 2b (subquery contains null)", cAndDAreNull, aAndBAreNotNull, doesNotIncludeRow),
394+
("Case 3 (probe columns are all null)", dIsNull, aAndBAreNull, doesNotIncludeRow),
395+
("Case 4 (null column, match)", cAndDAreNotNull, aIsNotNull, doesNotIncludeRow),
396+
("Case 5 (null column, no match)", dIsNull, aIsNotNull, includesSemiNullAndNotNullRow),
397+
("Case 6 (no null column, match)", cAndDAreNotNull, aAndBMatch, doesNotIncludeRow),
398+
("Case 7 (no null column, no match)", cAndDAreNotNull, aAndBAreNotNull, includesNotNullRow))
399+
400+
for ((given, sClause, mClause, expectedOutput) <- multiColumnTestCases) {
401+
Given(given)
402+
val query =
403+
s"SELECT * FROM m WHERE $mClause AND (a, b) NOT IN (SELECT c, d FROM s WHERE $sClause)"
404+
checkAnswer(sql(query), expectedOutput)
405+
}
406+
}
407+
408+
278409
test("IN predicate subquery within OR") {
279410
checkAnswer(
280411
sql("select * from l where l.a in (select c from r)" +

0 commit comments

Comments
 (0)