Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// LeftSemi/LeftAnti over Project
case Join(p @ Project(pList, gChild), rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
if pList.forall(_.deterministic) &&
if pList.forall(_.deterministic) &&
!pList.exists(ScalarSubquery.hasCorrelatedScalarSubquery) &&
canPushThroughCondition(Seq(gChild), joinCond, rightOp) =>
if (joinCond.isEmpty) {
Expand All @@ -52,101 +52,29 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
}

// LeftSemi/LeftAnti over Aggregate
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty &&
case join @ Join(agg: Aggregate, rightOp, LeftSemiOrAnti(_), _, _)
if agg.aggregateExpressions.forall(_.deterministic) && agg.groupingExpressions.nonEmpty &&
!agg.aggregateExpressions.exists(ScalarSubquery.hasCorrelatedScalarSubquery) =>
if (joinCond.isEmpty) {
// No join condition, just push down Join below Aggregate
agg.copy(child = Join(agg.child, rightOp, joinType, joinCond, hint))
} else {
val aliasMap = PushDownPredicate.getAliasMap(agg)

// For each join condition, expand the alias and check if the condition can be evaluated
// using attributes produced by the aggregate operator's child operator.
val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond =>
val replaced = replaceAlias(cond, aliasMap)
cond.references.nonEmpty &&
replaced.references.subsetOf(agg.child.outputSet ++ rightOp.outputSet)
}

// Check if the remaining predicates do not contain columns from the right
// hand side of the join. Since the remaining predicates will be kept
// as a filter over aggregate, this check is necessary after the left semi
// or left anti join is moved below aggregate. The reason is, for this kind
// of join, we only output from the left leg of the join.
val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet)

if (pushDown.nonEmpty && rightOpColumns.isEmpty) {
val pushDownPredicate = pushDown.reduce(And)
val replaced = replaceAlias(pushDownPredicate, aliasMap)
val newAgg = agg.copy(child = Join(agg.child, rightOp, joinType, Option(replaced), hint))
// If there is no more filter to stay up, just return the Aggregate over Join.
// Otherwise, create "Filter(stayUp) <- Aggregate <- Join(pushDownPredicate)".
if (stayUp.isEmpty) {
newAgg
} else {
joinType match {
// In case of Left semi join, the part of the join condition which does not refer to
// to child attributes of the aggregate operator are kept as a Filter over window.
case LeftSemi => Filter(stayUp.reduce(And), newAgg)
// In case of left anti join, the join is pushed down when the entire join condition
// is eligible to be pushed down to preserve the semantics of left anti join.
case _ => join
}
}
} else {
// The join condition is not a subset of the Aggregate's GROUP BY columns,
// no push down.
join
}
val aliasMap = PushDownPredicate.getAliasMap(agg)
val canPushDownPredicate = (predicate: Expression) => {
val replaced = replaceAlias(predicate, aliasMap)
predicate.references.nonEmpty &&
replaced.references.subsetOf(agg.child.outputSet ++ rightOp.outputSet)
}
val makeJoinCondition = (predicates: Seq[Expression]) => {
replaceAlias(predicates.reduce(And), aliasMap)
}
pushDownJoin(join, canPushDownPredicate, makeJoinCondition)

// LeftSemi/LeftAnti over Window
case join @ Join(w: Window, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) =>
if (joinCond.isEmpty) {
// No join condition, just push down Join below Window
w.copy(child = Join(w.child, rightOp, joinType, joinCond, hint))
} else {
val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++
rightOp.outputSet

val (pushDown, stayUp) = splitConjunctivePredicates(joinCond.get).partition { cond =>
cond.references.subsetOf(partitionAttrs)
}

// Check if the remaining predicates do not contain columns from the right
// hand side of the join. Since the remaining predicates will be kept
// as a filter over window, this check is necessary after the left semi
// or left anti join is moved below window. The reason is, for this kind
// of join, we only output from the left leg of the join.
val rightOpColumns = AttributeSet(stayUp.toSet).intersect(rightOp.outputSet)

if (pushDown.nonEmpty && rightOpColumns.isEmpty) {
val predicate = pushDown.reduce(And)
val newPlan = w.copy(child = Join(w.child, rightOp, joinType, Option(predicate), hint))
if (stayUp.isEmpty) {
newPlan
} else {
joinType match {
// In case of Left semi join, the part of the join condition which does not refer to
// to partition attributes of the window operator are kept as a Filter over window.
case LeftSemi => Filter(stayUp.reduce(And), newPlan)
// In case of left anti join, the join is pushed down when the entire join condition
// is eligible to be pushed down to preserve the semantics of left anti join.
case _ => join
}
}
} else {
// The join condition is not a subset of the Window's PARTITION BY clause,
// no push down.
join
}
}
case join @ Join(w: Window, rightOp, LeftSemiOrAnti(_), _, _)
if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) =>
val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) ++ rightOp.outputSet
pushDownJoin(join, _.references.subsetOf(partitionAttrs), _.reduce(And))

// LeftSemi/LeftAnti over Union
case join @ Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
if canPushThroughCondition(union.children, joinCond, rightOp) =>
case Join(union: Union, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
if canPushThroughCondition(union.children, joinCond, rightOp) =>
if (joinCond.isEmpty) {
// Push down the Join below Union
val newGrandChildren = union.children.map { Join(_, rightOp, joinType, joinCond, hint) }
Expand All @@ -165,11 +93,10 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
}

// LeftSemi/LeftAnti over UnaryNode
case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(joinType), joinCond, hint)
if PushDownPredicate.canPushThrough(u) && u.expressions.forall(_.deterministic) =>
pushDownJoin(join, u.child) { joinCond =>
u.withNewChildren(Seq(Join(u.child, rightOp, joinType, joinCond, hint)))
}
case join @ Join(u: UnaryNode, rightOp, LeftSemiOrAnti(_), _, _)
if PushDownPredicate.canPushThrough(u) && u.expressions.forall(_.deterministic) =>
val validAttrs = u.child.outputSet ++ rightOp.outputSet
pushDownJoin(join, _.references.subsetOf(validAttrs), _.reduce(And))
}

/**
Expand All @@ -192,35 +119,43 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper {
}
}


private def pushDownJoin(
join: Join,
grandchild: LogicalPlan)(insertJoin: Option[Expression] => LogicalPlan): LogicalPlan = {
canPushDownPredicate: Expression => Boolean,
makeJoinCondition: Seq[Expression] => Expression): LogicalPlan = {
assert(join.left.children.length == 1)

if (join.condition.isEmpty) {
insertJoin(None)
join.left.withNewChildren(Seq(join.copy(left = join.left.children.head)))
} else {
val (pushDown, stayUp) = splitConjunctivePredicates(join.condition.get)
.partition {_.references.subsetOf(grandchild.outputSet ++ join.right.outputSet)}
.partition(canPushDownPredicate)

// Check if the remaining predicates do not contain columns from the right hand side of the
// join. Since the remaining predicates will be kept as a filter over the operator under join,
// this check is necessary after the left-semi/anti join is pushed down. The reason is, for
// this kind of join, we only output from the left leg of the join.
val referRightSideCols = AttributeSet(stayUp.toSet).intersect(join.right.outputSet).nonEmpty

val rightOpColumns = AttributeSet(stayUp.toSet).intersect(join.right.outputSet)
if (pushDown.nonEmpty && rightOpColumns.isEmpty) {
val newChild = insertJoin(Option(pushDown.reduceLeft(And)))
if (stayUp.nonEmpty) {
if (pushDown.isEmpty || referRightSideCols) {
join
} else {
val newPlan = join.left.withNewChildren(Seq(join.copy(
left = join.left.children.head, condition = Some(makeJoinCondition(pushDown)))))
// If there is no more filter to stay up, return the new plan that has join pushed down.
if (stayUp.isEmpty) {
newPlan
} else {
join.joinType match {
// In case of Left semi join, the part of the join condition which does not refer to
// to attributes of the grandchild are kept as a Filter over window.
case LeftSemi => Filter(stayUp.reduce(And), newChild)
// In case of left anti join, the join is pushed down when the entire join condition
// is eligible to be pushed down to preserve the semantics of left anti join.
// to attributes of the grandchild are kept as a Filter above.
case LeftSemi => Filter(stayUp.reduce(And), newPlan)
// In case of left-anti join, the join is pushed down only when the entire join
// condition is eligible to be pushed down to preserve the semantics of left-anti join.
case _ => join
}
} else {
newChild
}
} else {
join
}
}
}
}