diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala index afe2cfa81ffe..d91f262d75e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala @@ -35,7 +35,7 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // LeftSemi/LeftAnti over Project case Join(p @ Project(pList, gChild), rightOp, LeftSemiOrAnti(joinType), joinCond, hint) - if pList.forall(_.deterministic) && + if pList.forall(_.deterministic) && !pList.exists(ScalarSubquery.hasCorrelatedScalarSubquery) && canPushThroughCondition(Seq(gChild), joinCond, rightOp) => if (joinCond.isEmpty) { @@ -52,101 +52,29 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { } // LeftSemi/LeftAnti over Aggregate - case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) - if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty && + case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), _, _) + if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty && !agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) => - if (joinCond.isEmpty) { - // No join condition, just push down Join below Aggregate - agg.copy(child = Join(agg.child, rightOp, joinType, joinCond, hint)) - } else { - val aliasMap = PushDownPredicate.getAliasMap(agg) - - // For each join condition, expand the alias and check if the condition can be evaluated - // using attributes produced by the aggregate operator's child operator. - val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond => - val replaced = replaceAlias(cond, aliasMap) - cond.references.nonEmpty && - replaced.references.subsetOf(agg.child.outputSet ++ rightOp.outputSet) - } - - // Check if the remaining predicates do not contain columns from the right - // hand side of the join. Since the remaining predicates will be kept - // as a filter over aggregate, this check is necessary after the left semi - // or left anti join is moved below aggregate. The reason is, for this kind - // of join, we only output from the left leg of the join. - val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) - - if (pushDown.nonEmpty && rightOpColumns.isEmpty) { - val pushDownPredicate = pushDown.reduce(And) - val replaced = replaceAlias(pushDownPredicate, aliasMap) - val newAgg = agg.copy(child = Join(agg.child, rightOp, joinType, Option(replaced), hint)) - // If there is no more filter to stay up, just return the Aggregate over Join. - // Otherwise, create "Filter(stayUp) <- Aggregate <- Join(pushDownPredicate)". - if (stayUp.isEmpty) { - newAgg - } else { - joinType match { - // In case of Left semi join, the part of the join condition which does not refer to - // to child attributes of the aggregate operator are kept as a Filter over window. - case LeftSemi => Filter(stayUp.reduce(And), newAgg) - // In case of left anti join, the join is pushed down when the entire join condition - // is eligible to be pushed down to preserve the semantics of left anti join. - case _ => join - } - } - } else { - // The join condition is not a subset of the Aggregate's GROUP BY columns, - // no push down. - join - } + val aliasMap = PushDownPredicate.getAliasMap(agg) + val canPushDownPredicate = (predicate: Expression) => { + val replaced = replaceAlias(predicate, aliasMap) + predicate.references.nonEmpty && + replaced.references.subsetOf(agg.child.outputSet ++ rightOp.outputSet) + } + val makeJoinCondition = (predicates: Seq[Expression]) => { + replaceAlias(predicates.reduce(And), aliasMap) } + pushDownJoin(join, canPushDownPredicate, makeJoinCondition) // LeftSemi/LeftAnti over Window - case join @ Join(w: Window, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) - if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => - if (joinCond.isEmpty) { - // No join condition, just push down Join below Window - w.copy(child = Join(w.child, rightOp, joinType, joinCond, hint)) - } else { - val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++ - rightOp.outputSet - - val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond => - cond.references.subsetOf(partitionAttrs) - } - - // Check if the remaining predicates do not contain columns from the right - // hand side of the join. Since the remaining predicates will be kept - // as a filter over window, this check is necessary after the left semi - // or left anti join is moved below window. The reason is, for this kind - // of join, we only output from the left leg of the join. - val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet) - - if (pushDown.nonEmpty && rightOpColumns.isEmpty) { - val predicate = pushDown.reduce(And) - val newPlan = w.copy(child = Join(w.child, rightOp, joinType, Option(predicate), hint)) - if (stayUp.isEmpty) { - newPlan - } else { - joinType match { - // In case of Left semi join, the part of the join condition which does not refer to - // to partition attributes of the window operator are kept as a Filter over window. - case LeftSemi => Filter(stayUp.reduce(And), newPlan) - // In case of left anti join, the join is pushed down when the entire join condition - // is eligible to be pushed down to preserve the semantics of left anti join. - case _ => join - } - } - } else { - // The join condition is not a subset of the Window's PARTITION BY clause, - // no push down. - join - } - } + case join @ Join(w: Window, rightOp, LeftSemiOrAnti(_), _, _) + if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => + val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++ rightOp.outputSet + pushDownJoin(join, _.references.subsetOf(partitionAttrs), _.reduce(And)) // LeftSemi/LeftAnti over Union - case join @ Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) - if canPushThroughCondition(union.children, joinCond, rightOp) => + case Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) + if canPushThroughCondition(union.children, joinCond, rightOp) => if (joinCond.isEmpty) { // Push down the Join below Union val newGrandChildren = union.children.map { Join(_, rightOp, joinType, joinCond, hint) } @@ -165,11 +93,10 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { } // LeftSemi/LeftAnti over UnaryNode - case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(joinType), joinCond, hint) - if PushDownPredicate.canPushThrough(u) && u.expressions.forall(_.deterministic) => - pushDownJoin(join, u.child) { joinCond => - u.withNewChildren(Seq(Join(u.child, rightOp, joinType, joinCond, hint))) - } + case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(_), _, _) + if PushDownPredicate.canPushThrough(u) && u.expressions.forall(_.deterministic) => + val validAttrs = u.child.outputSet ++ rightOp.outputSet + pushDownJoin(join, _.references.subsetOf(validAttrs), _.reduce(And)) } /** @@ -192,35 +119,43 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { } } - private def pushDownJoin( join: Join, - grandchild: LogicalPlan)(insertJoin: Option[Expression] => LogicalPlan): LogicalPlan = { + canPushDownPredicate: Expression => Boolean, + makeJoinCondition: Seq[Expression] => Expression): LogicalPlan = { + assert(join.left.children.length == 1) + if (join.condition.isEmpty) { - insertJoin(None) + join.left.withNewChildren(Seq(join.copy(left = join.left.children.head))) } else { val (pushDown, stayUp) = splitConjunctivePredicates(join.condition.get) - .partition {_.references.subsetOf(grandchild.outputSet ++ join.right.outputSet)} + .partition(canPushDownPredicate) + + // Check if the remaining predicates do not contain columns from the right hand side of the + // join. Since the remaining predicates will be kept as a filter over the operator under join, + // this check is necessary after the left-semi/anti join is pushed down. The reason is, for + // this kind of join, we only output from the left leg of the join. + val referRightSideCols = AttributeSet(stayUp.toSet).intersect(join.right.outputSet).nonEmpty - val rightOpColumns = AttributeSet(stayUp.toSet).intersect(join.right.outputSet) - if (pushDown.nonEmpty && rightOpColumns.isEmpty) { - val newChild = insertJoin(Option(pushDown.reduceLeft(And))) - if (stayUp.nonEmpty) { + if (pushDown.isEmpty || referRightSideCols) { + join + } else { + val newPlan = join.left.withNewChildren(Seq(join.copy( + left = join.left.children.head, condition = Some(makeJoinCondition(pushDown))))) + // If there is no more filter to stay up, return the new plan that has join pushed down. + if (stayUp.isEmpty) { + newPlan + } else { join.joinType match { // In case of Left semi join, the part of the join condition which does not refer to - // to attributes of the grandchild are kept as a Filter over window. - case LeftSemi => Filter(stayUp.reduce(And), newChild) - // In case of left anti join, the join is pushed down when the entire join condition - // is eligible to be pushed down to preserve the semantics of left anti join. + // to attributes of the grandchild are kept as a Filter above. + case LeftSemi => Filter(stayUp.reduce(And), newPlan) + // In case of left-anti join, the join is pushed down only when the entire join + // condition is eligible to be pushed down to preserve the semantics of left-anti join. case _ => join } - } else { - newChild } - } else { - join } } } } -