Skip to content

Commit 21c27d5

Browse files
agubichevcloud-fan
authored andcommitted
[SPARK-36191][SQL] Handle limit and order by in correlated scalar (lateral) subqueries
### What changes were proposed in this pull request? Handle LIMIT/ORDER BY in the correlated scalar (lateral) subqueries by rewriting them using ROW_NUMBER() window function. ### Why are the changes needed? Extends our coverage of subqueries ### Does this PR introduce _any_ user-facing change? Users are able to run more subqueries now ### How was this patch tested? Unit tests and query tests. Results of query tests are verified against PostgreSQL. ### Was this patch authored or co-authored using generative AI tooling? No Closes #42705 from agubichev/SPARK-36191-limit. Authored-by: Andrey Gubichev <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 9798244 commit 21c27d5

File tree

13 files changed

+1074
-49
lines changed

13 files changed

+1074
-49
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,6 +1418,11 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
14181418
failOnInvalidOuterReference(g)
14191419
checkPlan(g.child, aggregated, canContainOuter)
14201420

1421+
// Correlated subquery can have a LIMIT clause
1422+
case l @ Limit(_, input) =>
1423+
failOnInvalidOuterReference(l)
1424+
checkPlan(input, aggregated, canContainOuter)
1425+
14211426
// Category 4: Any other operators not in the above 3 categories
14221427
// cannot be on a correlation path, that is they are allowed only
14231428
// under a correlation point but they and their descendant operators

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,39 @@ object DecorrelateInnerQuery extends PredicateHelper {
655655
val newProject = Project(newProjectList ++ referencesToAdd, newChild)
656656
(newProject, joinCond, outerReferenceMap)
657657

658+
case Limit(limit, input) =>
659+
// LIMIT K (with potential ORDER BY) is decorrelated by computing K rows per every
660+
// domain value via a row_number() window function. For example, for a subquery
661+
// (SELECT T2.a FROM T2 WHERE T2.b = OuterReference(x) ORDER BY T2.c LIMIT 3)
662+
// -- we need to get top 3 values of T2.a (ordering by T2.c) for every value of x.
663+
// Following our general decorrelation procedure, 'x' is then replaced by T2.b, so the
664+
// subquery is decorrelated as:
665+
// SELECT * FROM (
666+
// SELECT T2.a, row_number() OVER (PARTITION BY T2.b ORDER BY T2.c) AS rn FROM T2)
667+
// WHERE rn <= 3
668+
val (child, ordering) = input match {
669+
case Sort(order, _, child) => (child, order)
670+
case _ => (input, Seq())
671+
}
672+
val (newChild, joinCond, outerReferenceMap) =
673+
decorrelate(child, parentOuterReferences, aggregated = true, underSetOp)
674+
val collectedChildOuterReferences = collectOuterReferencesInPlanTree(child)
675+
// Add outer references to the PARTITION BY clause
676+
val partitionFields = collectedChildOuterReferences.map(outerReferenceMap(_)).toSeq
677+
val orderByFields = replaceOuterReferences(ordering, outerReferenceMap)
678+
679+
val rowNumber = WindowExpression(RowNumber(),
680+
WindowSpecDefinition(partitionFields, orderByFields,
681+
SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow)))
682+
val rowNumberAlias = Alias(rowNumber, "rn")()
683+
// Window function computes row_number() when partitioning by correlated references,
684+
// and projects all the other fields from the input.
685+
val window = Window(Seq(rowNumberAlias),
686+
partitionFields, orderByFields, newChild)
687+
val filter = Filter(LessThanOrEqual(rowNumberAlias.toAttribute, limit), window)
688+
val project = Project(newChild.output, filter)
689+
(project, joinCond, outerReferenceMap)
690+
658691
case w @ Window(projectList, partitionSpec, orderSpec, child) =>
659692
val outerReferences = collectOuterReferences(w.expressions)
660693
assert(outerReferences.isEmpty, s"Correlated column is not allowed in window " +

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,14 +1027,6 @@ class AnalysisErrorSuite extends AnalysisTest {
10271027
LocalRelation(a))
10281028
assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil)
10291029

1030-
val plan4 = Filter(
1031-
Exists(
1032-
Limit(1,
1033-
Filter(EqualTo(UnresolvedAttribute("a"), b), LocalRelation(b)))
1034-
),
1035-
LocalRelation(a))
1036-
assertAnalysisError(plan4, "Accessing outer query column is not allowed in" :: Nil)
1037-
10381030
val plan5 = Filter(
10391031
Exists(
10401032
Sample(0.0, 0.5, false, 1L,

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,25 @@ class DecorrelateInnerQuerySuite extends PlanTest {
5959
joinCond.zip(conditions).foreach(e => compareExpressions(e._1, e._2))
6060
}
6161

62+
private def check(
63+
outputPlan: LogicalPlan,
64+
joinCond: Seq[Expression],
65+
correctAnswer: LogicalPlan,
66+
conditions: Seq[Expression]): Unit = {
67+
assert(!hasOuterReferences(outputPlan))
68+
comparePlans(outputPlan, correctAnswer)
69+
assert(joinCond.length == conditions.length)
70+
joinCond.zip(conditions).foreach(e => compareExpressions(e._1, e._2))
71+
}
72+
73+
// For tests involving window functions: extract and return the ROW_NUMBER function
74+
// from the 'input' plan.
75+
private def getRowNumberFunc(input: LogicalPlan): Alias = {
76+
val windowFunction = input.collect({ case w: Window => w }).head
77+
windowFunction.expressions.collect(
78+
{ case w: Alias if w.child.isInstanceOf[WindowExpression] => w }).head
79+
}
80+
6281
test("filter with correlated equality predicates only") {
6382
val outerPlan = testRelation2
6483
val innerPlan =
@@ -625,4 +644,96 @@ class DecorrelateInnerQuerySuite extends PlanTest {
625644
}
626645
assert(e.getMessage.contains("Correlated column is not allowed in"))
627646
}
647+
648+
test("SPARK-36191: limit in the correlated subquery") {
649+
val outerPlan = testRelation
650+
val innerPlan =
651+
Project(Seq(x),
652+
Limit(1, Filter(OuterReference(a) === x,
653+
testRelation2)))
654+
val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan.select())
655+
656+
val alias = getRowNumberFunc(outputPlan)
657+
658+
val correctAnswer = Project(Seq(x), Project(Seq(x, y, z),
659+
Filter(GreaterThanOrEqual(1, alias.toAttribute),
660+
Window(Seq(alias), Seq(x), Nil, testRelation2))))
661+
check(outputPlan, joinCond, correctAnswer, Seq(x === a))
662+
}
663+
664+
test("SPARK-36191: limit and order by in the correlated subquery") {
665+
val outerPlan = testRelation
666+
val innerPlan =
667+
Project(Seq(x),
668+
Limit(5, Sort(Seq(SortOrder(x, Ascending)), true,
669+
Filter(OuterReference(a) > x,
670+
testRelation2))))
671+
672+
val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan.select())
673+
674+
val alias = getRowNumberFunc(outputPlan)
675+
val rowNumber = WindowExpression(RowNumber(),
676+
WindowSpecDefinition(Seq(a), Seq(SortOrder(x, Ascending)),
677+
SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow)))
678+
val rowNumberAlias = Alias(rowNumber, alias.name)()
679+
680+
val correctAnswer = Project(Seq(x, a), Project(Seq(a, x, y, z),
681+
Filter(LessThanOrEqual(rowNumberAlias.toAttribute, 5),
682+
Window(Seq(rowNumberAlias), Seq(a), Seq(SortOrder(x, Ascending)),
683+
Filter(GreaterThan(a, x),
684+
DomainJoin(Seq(a), testRelation2))))))
685+
check(outputPlan, joinCond, correctAnswer, Seq(a <=> a))
686+
}
687+
688+
test("SPARK-36191: limit and order by in the correlated subquery with aggregation") {
689+
val outerPlan = testRelation
690+
val minY = Alias(min(y), "min_y")()
691+
692+
val innerPlan =
693+
Project(Seq(x),
694+
Limit(5, Sort(Seq(SortOrder(minY.toAttribute, Ascending)), true,
695+
Aggregate(Seq(x), Seq(minY, x),
696+
Filter(OuterReference(a) > x,
697+
testRelation2)))))
698+
699+
val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan.select())
700+
701+
val alias = getRowNumberFunc(outputPlan)
702+
val rowNumber = WindowExpression(RowNumber(),
703+
WindowSpecDefinition(Seq(a), Seq(SortOrder(minY.toAttribute, Ascending)),
704+
SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow)))
705+
val rowNumberAlias = Alias(rowNumber, alias.name)()
706+
val correctAnswer = Project(Seq(x, a), Project(Seq(minY.toAttribute, x, a),
707+
Filter(LessThanOrEqual(rowNumberAlias.toAttribute, 5),
708+
Window(Seq(rowNumberAlias), Seq(a),
709+
Seq(SortOrder(minY.toAttribute, Ascending)),
710+
Aggregate(Seq(x, a), Seq(minY, x, a),
711+
Filter(GreaterThan(a, x),
712+
DomainJoin(Seq(a), testRelation2)))))))
713+
check(outputPlan, joinCond, correctAnswer, Seq(a <=> a))
714+
715+
}
716+
717+
test("SPARK-36191: order by with correlated attribute") {
718+
val outerPlan = testRelation
719+
val innerPlan =
720+
Project(Seq(x),
721+
Limit(5, Sort(Seq(SortOrder(OuterReference(a), Ascending)), true,
722+
Filter(OuterReference(a) > x,
723+
testRelation2))))
724+
val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan.select())
725+
726+
val alias = getRowNumberFunc(outputPlan)
727+
val rowNumber = WindowExpression(RowNumber(),
728+
WindowSpecDefinition(Seq(a), Seq(SortOrder(a, Ascending)),
729+
SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow)))
730+
val rowNumberAlias = Alias(rowNumber, alias.name)()
731+
732+
val correctAnswer = Project(Seq(x, a), Project(Seq(a, x, y, z),
733+
Filter(LessThanOrEqual(rowNumberAlias.toAttribute, 5),
734+
Window(Seq(rowNumberAlias), Seq(a), Seq(SortOrder(a, Ascending)),
735+
Filter(GreaterThan(a, x),
736+
DomainJoin(Seq(a), testRelation2))))))
737+
check(outputPlan, joinCond, correctAnswer, Seq(a <=> a))
738+
}
628739
}

sql/core/src/test/resources/sql-tests/analyzer-results/join-lateral.sql.out

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2861,6 +2861,161 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
28612861
}
28622862

28632863

2864+
-- !query
2865+
select * from t1 join lateral (select * from t2 where t1.c1 = t2.c1 and t1.c2 < t2.c2 limit 1)
2866+
-- !query analysis
2867+
Project [c1#x, c2#x, c1#x, c2#x]
2868+
+- LateralJoin lateral-subquery#x [c1#x && c2#x], Inner
2869+
: +- SubqueryAlias __auto_generated_subquery_name
2870+
: +- GlobalLimit 1
2871+
: +- LocalLimit 1
2872+
: +- Project [c1#x, c2#x]
2873+
: +- Filter ((outer(c1#x) = c1#x) AND (outer(c2#x) < c2#x))
2874+
: +- SubqueryAlias spark_catalog.default.t2
2875+
: +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
2876+
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
2877+
: +- LocalRelation [col1#x, col2#x]
2878+
+- SubqueryAlias spark_catalog.default.t1
2879+
+- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
2880+
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
2881+
+- LocalRelation [col1#x, col2#x]
2882+
2883+
2884+
-- !query
2885+
select * from t1 join lateral (select * from t4 where t1.c1 <= t4.c1 order by t4.c2 limit 10)
2886+
-- !query analysis
2887+
Project [c1#x, c2#x, c1#x, c2#x]
2888+
+- LateralJoin lateral-subquery#x [c1#x], Inner
2889+
: +- SubqueryAlias __auto_generated_subquery_name
2890+
: +- GlobalLimit 10
2891+
: +- LocalLimit 10
2892+
: +- Sort [c2#x ASC NULLS FIRST], true
2893+
: +- Project [c1#x, c2#x]
2894+
: +- Filter (outer(c1#x) <= c1#x)
2895+
: +- SubqueryAlias spark_catalog.default.t4
2896+
: +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
2897+
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
2898+
: +- LocalRelation [col1#x, col2#x]
2899+
+- SubqueryAlias spark_catalog.default.t1
2900+
+- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
2901+
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
2902+
+- LocalRelation [col1#x, col2#x]
2903+
2904+
2905+
-- !query
2906+
select * from t1 join lateral (select c1, min(c2) as m
2907+
from t2 where t1.c1 = t2.c1 and t1.c2 < t2.c2
2908+
group by t2.c1
2909+
order by m)
2910+
-- !query analysis
2911+
Project [c1#x, c2#x, c1#x, m#x]
2912+
+- LateralJoin lateral-subquery#x [c1#x && c2#x], Inner
2913+
: +- SubqueryAlias __auto_generated_subquery_name
2914+
: +- Sort [m#x ASC NULLS FIRST], true
2915+
: +- Aggregate [c1#x], [c1#x, min(c2#x) AS m#x]
2916+
: +- Filter ((outer(c1#x) = c1#x) AND (outer(c2#x) < c2#x))
2917+
: +- SubqueryAlias spark_catalog.default.t2
2918+
: +- View (`spark_catalog`.`default`.`t2`, [c1#x,c2#x])
2919+
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
2920+
: +- LocalRelation [col1#x, col2#x]
2921+
+- SubqueryAlias spark_catalog.default.t1
2922+
+- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
2923+
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
2924+
+- LocalRelation [col1#x, col2#x]
2925+
2926+
2927+
-- !query
2928+
select * from t1 join lateral (select c1, min(c2) as m
2929+
from t4 where t1.c1 = t4.c1
2930+
group by t4.c1
2931+
limit 1)
2932+
-- !query analysis
2933+
Project [c1#x, c2#x, c1#x, m#x]
2934+
+- LateralJoin lateral-subquery#x [c1#x], Inner
2935+
: +- SubqueryAlias __auto_generated_subquery_name
2936+
: +- GlobalLimit 1
2937+
: +- LocalLimit 1
2938+
: +- Aggregate [c1#x], [c1#x, min(c2#x) AS m#x]
2939+
: +- Filter (outer(c1#x) = c1#x)
2940+
: +- SubqueryAlias spark_catalog.default.t4
2941+
: +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
2942+
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
2943+
: +- LocalRelation [col1#x, col2#x]
2944+
+- SubqueryAlias spark_catalog.default.t1
2945+
+- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
2946+
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
2947+
+- LocalRelation [col1#x, col2#x]
2948+
2949+
2950+
-- !query
2951+
select * from t1 join lateral
2952+
((select t4.c2 from t4 where t1.c1 <= t4.c1 order by t4.c2 limit 1)
2953+
union all
2954+
(select t4.c1 from t4 where t1.c1 = t4.c1 order by t4.c1 limit 3))
2955+
-- !query analysis
2956+
Project [c1#x, c2#x, c2#x]
2957+
+- LateralJoin lateral-subquery#x [c1#x && c1#x], Inner
2958+
: +- SubqueryAlias __auto_generated_subquery_name
2959+
: +- Union false, false
2960+
: :- GlobalLimit 1
2961+
: : +- LocalLimit 1
2962+
: : +- Sort [c2#x ASC NULLS FIRST], true
2963+
: : +- Project [c2#x]
2964+
: : +- Filter (outer(c1#x) <= c1#x)
2965+
: : +- SubqueryAlias spark_catalog.default.t4
2966+
: : +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
2967+
: : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
2968+
: : +- LocalRelation [col1#x, col2#x]
2969+
: +- GlobalLimit 3
2970+
: +- LocalLimit 3
2971+
: +- Sort [c1#x ASC NULLS FIRST], true
2972+
: +- Project [c1#x]
2973+
: +- Filter (outer(c1#x) = c1#x)
2974+
: +- SubqueryAlias spark_catalog.default.t4
2975+
: +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
2976+
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
2977+
: +- LocalRelation [col1#x, col2#x]
2978+
+- SubqueryAlias spark_catalog.default.t1
2979+
+- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
2980+
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
2981+
+- LocalRelation [col1#x, col2#x]
2982+
2983+
2984+
-- !query
2985+
select * from t1 join lateral
2986+
(select * from
2987+
((select t4.c2 as t from t4 where t1.c1 <= t4.c1)
2988+
union all
2989+
(select t4.c1 as t from t4 where t1.c1 = t4.c1)) as foo
2990+
order by foo.t limit 5)
2991+
-- !query analysis
2992+
Project [c1#x, c2#x, t#x]
2993+
+- LateralJoin lateral-subquery#x [c1#x && c1#x], Inner
2994+
: +- SubqueryAlias __auto_generated_subquery_name
2995+
: +- GlobalLimit 5
2996+
: +- LocalLimit 5
2997+
: +- Sort [t#x ASC NULLS FIRST], true
2998+
: +- Project [t#x]
2999+
: +- SubqueryAlias foo
3000+
: +- Union false, false
3001+
: :- Project [c2#x AS t#x]
3002+
: : +- Filter (outer(c1#x) <= c1#x)
3003+
: : +- SubqueryAlias spark_catalog.default.t4
3004+
: : +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
3005+
: : +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
3006+
: : +- LocalRelation [col1#x, col2#x]
3007+
: +- Project [c1#x AS t#x]
3008+
: +- Filter (outer(c1#x) = c1#x)
3009+
: +- SubqueryAlias spark_catalog.default.t4
3010+
: +- View (`spark_catalog`.`default`.`t4`, [c1#x,c2#x])
3011+
: +- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
3012+
: +- LocalRelation [col1#x, col2#x]
3013+
+- SubqueryAlias spark_catalog.default.t1
3014+
+- View (`spark_catalog`.`default`.`t1`, [c1#x,c2#x])
3015+
+- Project [cast(col1#x as int) AS c1#x, cast(col2#x as int) AS c2#x]
3016+
+- LocalRelation [col1#x, col2#x]
3017+
3018+
28643019
-- !query
28653020
DROP VIEW t1
28663021
-- !query analysis

0 commit comments

Comments
 (0)