Skip to content

Commit ad4823c

Browse files
dilipbiswalcloud-fan
authored andcommitted
[SPARK-19712][SQL] Pushing Left Semi and Left Anti joins through Project, Aggregate, Window, Union etc.
## What changes were proposed in this pull request? This PR adds support for pushing down LeftSemi and LeftAnti joins below operators such as Project, Aggregate, Window, Union etc. This is the initial piece of work that will be needed for the subsequent work of moving the subquery rewrites to the beginning of optimization phase. The larger PR is [here](#23211) . This PR addresses the comment at [link](#23211 (comment)). ## How was this patch tested? Added a new test suite LeftSemiAntiJoinPushDownSuite. Closes #23750 from dilipbiswal/SPARK-19712-pushleftsemi. Authored-by: Dilip Biswal <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 382d5a8 commit ad4823c

File tree

5 files changed

+505
-16
lines changed

5 files changed

+505
-16
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
9595
EliminateOuterJoin,
9696
PushPredicateThroughJoin,
9797
PushDownPredicate,
98+
PushDownLeftSemiAntiJoin,
9899
LimitPushDown,
99100
ColumnPruning,
100101
InferFiltersFromConstraints,
@@ -1012,24 +1013,13 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
10121013
// This also applies to Aggregate.
10131014
case Filter(condition, project @ Project(fields, grandChild))
10141015
if fields.forall(_.deterministic) && canPushThroughCondition(grandChild, condition) =>
1015-
1016-
// Create a map of Aliases to their values from the child projection.
1017-
// e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b).
1018-
val aliasMap = AttributeMap(fields.collect {
1019-
case a: Alias => (a.toAttribute, a.child)
1020-
})
1021-
1016+
val aliasMap = getAliasMap(project)
10221017
project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild))
10231018

10241019
case filter @ Filter(condition, aggregate: Aggregate)
10251020
if aggregate.aggregateExpressions.forall(_.deterministic)
10261021
&& aggregate.groupingExpressions.nonEmpty =>
1027-
// Find all the aliased expressions in the aggregate list that don't include any actual
1028-
// AggregateExpression, and create a map from the alias to the expression
1029-
val aliasMap = AttributeMap(aggregate.aggregateExpressions.collect {
1030-
case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty =>
1031-
(a.toAttribute, a.child)
1032-
})
1022+
val aliasMap = getAliasMap(aggregate)
10331023

10341024
// For each filter, expand the alias and check if the filter can be evaluated using
10351025
// attributes produced by the aggregate operator's child operator.
@@ -1127,7 +1117,23 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
11271117
}
11281118
}
11291119

1130-
private def canPushThrough(p: UnaryNode): Boolean = p match {
1120+
def getAliasMap(plan: Project): AttributeMap[Expression] = {
1121+
// Create a map of Aliases to their values from the child projection.
1122+
// e.g., 'SELECT a + b AS c, d ...' produces Map(c -> a + b).
1123+
AttributeMap(plan.projectList.collect { case a: Alias => (a.toAttribute, a.child) })
1124+
}
1125+
1126+
def getAliasMap(plan: Aggregate): AttributeMap[Expression] = {
1127+
// Find all the aliased expressions in the aggregate list that don't include any actual
1128+
// AggregateExpression, and create a map from the alias to the expression
1129+
val aliasMap = plan.aggregateExpressions.collect {
1130+
case a: Alias if a.child.find(_.isInstanceOf[AggregateExpression]).isEmpty =>
1131+
(a.toAttribute, a.child)
1132+
}
1133+
AttributeMap(aliasMap)
1134+
}
1135+
1136+
def canPushThrough(p: UnaryNode): Boolean = p match {
11311137
// Note that some operators (e.g. project, aggregate, union) are being handled separately
11321138
// (earlier in this rule).
11331139
case _: AppendColumns => true
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.optimizer
19+
20+
import org.apache.spark.sql.catalyst.expressions._
21+
import org.apache.spark.sql.catalyst.plans._
22+
import org.apache.spark.sql.catalyst.plans.logical._
23+
import org.apache.spark.sql.catalyst.rules.Rule
24+
25+
/**
26+
* This rule is a variant of [[PushDownPredicate]] which can handle
27+
* pushing down Left semi and Left Anti joins below the following operators.
28+
* 1) Project
29+
* 2) Window
30+
* 3) Union
31+
* 4) Aggregate
32+
* 5) Other permissible unary operators. please see [[PushDownPredicate.canPushThrough]].
33+
*/
34+
object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
35+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
36+
// LeftSemi/LeftAnti over Project
37+
case Join(p @ Project(pList, gChild), rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
38+
if pList.forall(_.deterministic) &&
39+
!pList.exists(ScalarSubquery.hasCorrelatedScalarSubquery) &&
40+
canPushThroughCondition(Seq(gChild), joinCond, rightOp) =>
41+
if (joinCond.isEmpty) {
42+
// No join condition, just push down the Join below Project
43+
p.copy(child = Join(gChild, rightOp, joinType, joinCond, hint))
44+
} else {
45+
val aliasMap = PushDownPredicate.getAliasMap(p)
46+
val newJoinCond = if (aliasMap.nonEmpty) {
47+
Option(replaceAlias(joinCond.get, aliasMap))
48+
} else {
49+
joinCond
50+
}
51+
p.copy(child = Join(gChild, rightOp, joinType, newJoinCond, hint))
52+
}
53+
54+
// LeftSemi/LeftAnti over Aggregate
55+
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
56+
if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty &&
57+
!agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) =>
58+
if (joinCond.isEmpty) {
59+
// No join condition, just push down Join below Aggregate
60+
agg.copy(child = Join(agg.child, rightOp, joinType, joinCond, hint))
61+
} else {
62+
val aliasMap = PushDownPredicate.getAliasMap(agg)
63+
64+
// For each join condition, expand the alias and check if the condition can be evaluated
65+
// using attributes produced by the aggregate operator's child operator.
66+
val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond =>
67+
val replaced = replaceAlias(cond, aliasMap)
68+
cond.references.nonEmpty &&
69+
replaced.references.subsetOf(agg.child.outputSet ++ rightOp.outputSet)
70+
}
71+
72+
// Check if the remaining predicates do not contain columns from the right
73+
// hand side of the join. Since the remaining predicates will be kept
74+
// as a filter over aggregate, this check is necessary after the left semi
75+
// or left anti join is moved below aggregate. The reason is, for this kind
76+
// of join, we only output from the left leg of the join.
77+
val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet)
78+
79+
if (pushDown.nonEmpty && rightOpColumns.isEmpty) {
80+
val pushDownPredicate = pushDown.reduce(And)
81+
val replaced = replaceAlias(pushDownPredicate, aliasMap)
82+
val newAgg = agg.copy(child = Join(agg.child, rightOp, joinType, Option(replaced), hint))
83+
// If there is no more filter to stay up, just return the Aggregate over Join.
84+
// Otherwise, create "Filter(stayUp) <- Aggregate <- Join(pushDownPredicate)".
85+
if (stayUp.isEmpty) newAgg else Filter(stayUp.reduce(And), newAgg)
86+
} else {
87+
// The join condition is not a subset of the Aggregate's GROUP BY columns,
88+
// no push down.
89+
join
90+
}
91+
}
92+
93+
// LeftSemi/LeftAnti over Window
94+
case join @ Join(w: Window, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
95+
if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) =>
96+
if (joinCond.isEmpty) {
97+
// No join condition, just push down Join below Window
98+
w.copy(child = Join(w.child, rightOp, joinType, joinCond, hint))
99+
} else {
100+
val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++
101+
rightOp.outputSet
102+
103+
val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond =>
104+
cond.references.subsetOf(partitionAttrs)
105+
}
106+
107+
// Check if the remaining predicates do not contain columns from the right
108+
// hand side of the join. Since the remaining predicates will be kept
109+
// as a filter over window, this check is necessary after the left semi
110+
// or left anti join is moved below window. The reason is, for this kind
111+
// of join, we only output from the left leg of the join.
112+
val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet)
113+
114+
if (pushDown.nonEmpty && rightOpColumns.isEmpty) {
115+
val predicate = pushDown.reduce(And)
116+
val newPlan = w.copy(child = Join(w.child, rightOp, joinType, Option(predicate), hint))
117+
if (stayUp.isEmpty) newPlan else Filter(stayUp.reduce(And), newPlan)
118+
} else {
119+
// The join condition is not a subset of the Window's PARTITION BY clause,
120+
// no push down.
121+
join
122+
}
123+
}
124+
125+
// LeftSemi/LeftAnti over Union
126+
case join @ Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
127+
if canPushThroughCondition(union.children, joinCond, rightOp) =>
128+
if (joinCond.isEmpty) {
129+
// Push down the Join below Union
130+
val newGrandChildren = union.children.map { Join(_, rightOp, joinType, joinCond, hint) }
131+
union.withNewChildren(newGrandChildren)
132+
} else {
133+
val output = union.output
134+
val newGrandChildren = union.children.map { grandchild =>
135+
val newCond = joinCond.get transform {
136+
case e if output.exists(_.semanticEquals(e)) =>
137+
grandchild.output(output.indexWhere(_.semanticEquals(e)))
138+
}
139+
assert(newCond.references.subsetOf(grandchild.outputSet ++ rightOp.outputSet))
140+
Join(grandchild, rightOp, joinType, Option(newCond), hint)
141+
}
142+
union.withNewChildren(newGrandChildren)
143+
}
144+
145+
// LeftSemi/LeftAnti over UnaryNode
146+
case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
147+
if PushDownPredicate.canPushThrough(u) && u.expressions.forall(_.deterministic) =>
148+
pushDownJoin(join, u.child) { joinCond =>
149+
u.withNewChildren(Seq(Join(u.child, rightOp, joinType, joinCond, hint)))
150+
}
151+
}
152+
153+
/**
154+
* Check if we can safely push a join through a project or union by making sure that attributes
155+
* referred in join condition do not contain the same attributes as the plan they are moved
156+
* into. This can happen when both sides of join refers to the same source (self join). This
157+
* function makes sure that the join condition refers to attributes that are not ambiguous (i.e
158+
* present in both the legs of the join) or else the resultant plan will be invalid.
159+
*/
160+
private def canPushThroughCondition(
161+
plans: Seq[LogicalPlan],
162+
condition: Option[Expression],
163+
rightOp: LogicalPlan): Boolean = {
164+
val attributes = AttributeSet(plans.flatMap(_.output))
165+
if (condition.isDefined) {
166+
val matched = condition.get.references.intersect(rightOp.outputSet).intersect(attributes)
167+
matched.isEmpty
168+
} else {
169+
true
170+
}
171+
}
172+
173+
174+
private def pushDownJoin(
175+
join: Join,
176+
grandchild: LogicalPlan)(insertJoin: Option[Expression] => LogicalPlan): LogicalPlan = {
177+
if (join.condition.isEmpty) {
178+
insertJoin(None)
179+
} else {
180+
val (pushDown, stayUp) = splitConjunctivePredicates(join.condition.get)
181+
.partition {_.references.subsetOf(grandchild.outputSet ++ join.right.outputSet)}
182+
183+
val rightOpColumns = AttributeSet(stayUp.toSet).intersect(join.right.outputSet)
184+
if (pushDown.nonEmpty && rightOpColumns.isEmpty) {
185+
val newChild = insertJoin(Option(pushDown.reduceLeft(And)))
186+
if (stayUp.nonEmpty) {
187+
Filter(stayUp.reduceLeft(And), newChild)
188+
} else {
189+
newChild
190+
}
191+
} else {
192+
join
193+
}
194+
}
195+
}
196+
}
197+

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,10 @@ object LeftExistence {
114114
case _ => None
115115
}
116116
}
117+
118+
object LeftSemiOrAnti {
119+
def unapply(joinType: JoinType): Option[JoinType] = joinType match {
120+
case LeftSemi | LeftAnti => Some(joinType)
121+
case _ => None
122+
}
123+
}

0 commit comments

Comments
 (0)