|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package org.apache.spark.sql.catalyst.optimizer |
| 19 | + |
| 20 | +import org.apache.spark.sql.catalyst.expressions._ |
| 21 | +import org.apache.spark.sql.catalyst.plans._ |
| 22 | +import org.apache.spark.sql.catalyst.plans.logical._ |
| 23 | +import org.apache.spark.sql.catalyst.rules.Rule |
| 24 | + |
| 25 | +/** |
| 26 | + * This rule is a variant of [[PushDownPredicate]] which can handle |
| 27 | + * pushing down Left semi and Left Anti joins below the following operators. |
| 28 | + * 1) Project |
| 29 | + * 2) Window |
| 30 | + * 3) Union |
| 31 | + * 4) Aggregate |
| 32 | + * 5) Other permissible unary operators. please see [[PushDownPredicate.canPushThrough]]. |
| 33 | + */ |
| 34 | +object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { |
| 35 | + def apply(plan: LogicalPlan): LogicalPlan = plan transform { |
| 36 | + // LeftSemi/LeftAnti over Project |
| 37 | + case Join(p @ Project(pList, gChild), rightOp, LeftSemiOrAnti(joinType), joinCond, hint) |
| 38 | + if pList.forall(_.deterministic) && |
| 39 | + !pList.exists(ScalarSubquery.hasCorrelatedScalarSubquery) && |
| 40 | + canPushThroughCondition(Seq(gChild), joinCond, rightOp) => |
| 41 | + if (joinCond.isEmpty) { |
| 42 | + // No join condition, just push down the Join below Project |
| 43 | + p.copy(child = Join(gChild, rightOp, joinType, joinCond, hint)) |
| 44 | + } else { |
| 45 | + val aliasMap = PushDownPredicate.getAliasMap(p) |
| 46 | + val newJoinCond = if (aliasMap.nonEmpty) { |
| 47 | + Option(replaceAlias(joinCond.get, aliasMap)) |
| 48 | + } else { |
| 49 | + joinCond |
| 50 | + } |
| 51 | + p.copy(child = Join(gChild, rightOp, joinType, newJoinCond, hint)) |
| 52 | + } |
| 53 | + |
| 54 | + // LeftSemi/LeftAnti over Aggregate |
| 55 | + case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) |
| 56 | + if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty && |
| 57 | + !agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) => |
| 58 | + if (joinCond.isEmpty) { |
| 59 | + // No join condition, just push down Join below Aggregate |
| 60 | + agg.copy(child = Join(agg.child, rightOp, joinType, joinCond, hint)) |
| 61 | + } else { |
| 62 | + val aliasMap = PushDownPredicate.getAliasMap(agg) |
| 63 | + |
| 64 | + // For each join condition, expand the alias and check if the condition can be evaluated |
| 65 | + // using attributes produced by the aggregate operator's child operator. |
| 66 | + val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond => |
| 67 | + val replaced = replaceAlias(cond, aliasMap) |
| 68 | + cond.references.nonEmpty && |
| 69 | + replaced.references.subsetOf(agg.child.outputSet ++ rightOp.outputSet) |
| 70 | + } |
| 71 | + |
| 72 | + // Check if the remaining predicates do not contain columns from the right |
| 73 | + // hand side of the join. Since the remaining predicates will be kept |
| 74 | + // as a filter over aggregate, this check is necessary after the left semi |
| 75 | + // or left anti join is moved below aggregate. The reason is, for this kind |
| 76 | + // of join, we only output from the left leg of the join. |
| 77 | + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) |
| 78 | + |
| 79 | + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { |
| 80 | + val pushDownPredicate = pushDown.reduce(And) |
| 81 | + val replaced = replaceAlias(pushDownPredicate, aliasMap) |
| 82 | + val newAgg = agg.copy(child = Join(agg.child, rightOp, joinType, Option(replaced), hint)) |
| 83 | + // If there is no more filter to stay up, just return the Aggregate over Join. |
| 84 | + // Otherwise, create "Filter(stayUp) <- Aggregate <- Join(pushDownPredicate)". |
| 85 | + if (stayUp.isEmpty) newAgg else Filter(stayUp.reduce(And), newAgg) |
| 86 | + } else { |
| 87 | + // The join condition is not a subset of the Aggregate's GROUP BY columns, |
| 88 | + // no push down. |
| 89 | + join |
| 90 | + } |
| 91 | + } |
| 92 | + |
| 93 | + // LeftSemi/LeftAnti over Window |
| 94 | + case join @ Join(w: Window, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) |
| 95 | + if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => |
| 96 | + if (joinCond.isEmpty) { |
| 97 | + // No join condition, just push down Join below Window |
| 98 | + w.copy(child = Join(w.child, rightOp, joinType, joinCond, hint)) |
| 99 | + } else { |
| 100 | + val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++ |
| 101 | + rightOp.outputSet |
| 102 | + |
| 103 | + val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond => |
| 104 | + cond.references.subsetOf(partitionAttrs) |
| 105 | + } |
| 106 | + |
| 107 | + // Check if the remaining predicates do not contain columns from the right |
| 108 | + // hand side of the join. Since the remaining predicates will be kept |
| 109 | + // as a filter over window, this check is necessary after the left semi |
| 110 | + // or left anti join is moved below window. The reason is, for this kind |
| 111 | + // of join, we only output from the left leg of the join. |
| 112 | + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) |
| 113 | + |
| 114 | + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { |
| 115 | + val predicate = pushDown.reduce(And) |
| 116 | + val newPlan = w.copy(child = Join(w.child, rightOp, joinType, Option(predicate), hint)) |
| 117 | + if (stayUp.isEmpty) newPlan else Filter(stayUp.reduce(And), newPlan) |
| 118 | + } else { |
| 119 | + // The join condition is not a subset of the Window's PARTITION BY clause, |
| 120 | + // no push down. |
| 121 | + join |
| 122 | + } |
| 123 | + } |
| 124 | + |
| 125 | + // LeftSemi/LeftAnti over Union |
| 126 | + case join @ Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) |
| 127 | + if canPushThroughCondition(union.children, joinCond, rightOp) => |
| 128 | + if (joinCond.isEmpty) { |
| 129 | + // Push down the Join below Union |
| 130 | + val newGrandChildren = union.children.map { Join(_, rightOp, joinType, joinCond, hint) } |
| 131 | + union.withNewChildren(newGrandChildren) |
| 132 | + } else { |
| 133 | + val output = union.output |
| 134 | + val newGrandChildren = union.children.map { grandchild => |
| 135 | + val newCond = joinCond.get transform { |
| 136 | + case e if output.exists(_.semanticEquals(e)) => |
| 137 | + grandchild.output(output.indexWhere(_.semanticEquals(e))) |
| 138 | + } |
| 139 | + assert(newCond.references.subsetOf(grandchild.outputSet ++ rightOp.outputSet)) |
| 140 | + Join(grandchild, rightOp, joinType, Option(newCond), hint) |
| 141 | + } |
| 142 | + union.withNewChildren(newGrandChildren) |
| 143 | + } |
| 144 | + |
| 145 | + // LeftSemi/LeftAnti over UnaryNode |
| 146 | + case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) |
| 147 | + if PushDownPredicate.canPushThrough(u) && u.expressions.forall(_.deterministic) => |
| 148 | + pushDownJoin(join, u.child) { joinCond => |
| 149 | + u.withNewChildren(Seq(Join(u.child, rightOp, joinType, joinCond, hint))) |
| 150 | + } |
| 151 | + } |
| 152 | + |
| 153 | + /** |
| 154 | + * Check if we can safely push a join through a project or union by making sure that attributes |
| 155 | + * referred in join condition do not contain the same attributes as the plan they are moved |
| 156 | + * into. This can happen when both sides of join refers to the same source (self join). This |
| 157 | + * function makes sure that the join condition refers to attributes that are not ambiguous (i.e |
| 158 | + * present in both the legs of the join) or else the resultant plan will be invalid. |
| 159 | + */ |
| 160 | + private def canPushThroughCondition( |
| 161 | + plans: Seq[LogicalPlan], |
| 162 | + condition: Option[Expression], |
| 163 | + rightOp: LogicalPlan): Boolean = { |
| 164 | + val attributes = AttributeSet(plans.flatMap(_.output)) |
| 165 | + if (condition.isDefined) { |
| 166 | + val matched = condition.get.references.intersect(rightOp.outputSet).intersect(attributes) |
| 167 | + matched.isEmpty |
| 168 | + } else { |
| 169 | + true |
| 170 | + } |
| 171 | + } |
| 172 | + |
| 173 | + |
| 174 | + private def pushDownJoin( |
| 175 | + join: Join, |
| 176 | + grandchild: LogicalPlan)(insertJoin: Option[Expression] => LogicalPlan): LogicalPlan = { |
| 177 | + if (join.condition.isEmpty) { |
| 178 | + insertJoin(None) |
| 179 | + } else { |
| 180 | + val (pushDown, stayUp) = splitConjunctivePredicates(join.condition.get) |
| 181 | + .partition {_.references.subsetOf(grandchild.outputSet ++ join.right.outputSet)} |
| 182 | + |
| 183 | + val rightOpColumns = AttributeSet(stayUp.toSet).intersect(join.right.outputSet) |
| 184 | + if (pushDown.nonEmpty && rightOpColumns.isEmpty) { |
| 185 | + val newChild = insertJoin(Option(pushDown.reduceLeft(And))) |
| 186 | + if (stayUp.nonEmpty) { |
| 187 | + Filter(stayUp.reduceLeft(And), newChild) |
| 188 | + } else { |
| 189 | + newChild |
| 190 | + } |
| 191 | + } else { |
| 192 | + join |
| 193 | + } |
| 194 | + } |
| 195 | + } |
| 196 | +} |
| 197 | + |
0 commit comments