diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala index bc868df3dbb0..afe2cfa81ffe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushDownLeftSemiAntiJoin.scala @@ -82,7 +82,18 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { val newAgg = agg.copy(child = Join(agg.child, rightOp, joinType, Option(replaced), hint)) // If there is no more filter to stay up, just return the Aggregate over Join. // Otherwise, create "Filter(stayUp) <- Aggregate <- Join(pushDownPredicate)". - if (stayUp.isEmpty) newAgg else Filter(stayUp.reduce(And), newAgg) + if (stayUp.isEmpty) { + newAgg + } else { + joinType match { + // In case of Left semi join, the part of the join condition which does not refer to + // to child attributes of the aggregate operator are kept as a Filter over window. + case LeftSemi => Filter(stayUp.reduce(And), newAgg) + // In case of left anti join, the join is pushed down when the entire join condition + // is eligible to be pushed down to preserve the semantics of left anti join. + case _ => join + } + } } else { // The join condition is not a subset of the Aggregate's GROUP BY columns, // no push down. @@ -114,7 +125,18 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { if (pushDown.nonEmpty && rightOpColumns.isEmpty) { val predicate = pushDown.reduce(And) val newPlan = w.copy(child = Join(w.child, rightOp, joinType, Option(predicate), hint)) - if (stayUp.isEmpty) newPlan else Filter(stayUp.reduce(And), newPlan) + if (stayUp.isEmpty) { + newPlan + } else { + joinType match { + // In case of Left semi join, the part of the join condition which does not refer to + // to partition attributes of the window operator are kept as a Filter over window. + case LeftSemi => Filter(stayUp.reduce(And), newPlan) + // In case of left anti join, the join is pushed down when the entire join condition + // is eligible to be pushed down to preserve the semantics of left anti join. + case _ => join + } + } } else { // The join condition is not a subset of the Window's PARTITION BY clause, // no push down. @@ -184,7 +206,14 @@ object PushDownLeftSemiAntiJoin extends Rule[LogicalPlan] with PredicateHelper { if (pushDown.nonEmpty && rightOpColumns.isEmpty) { val newChild = insertJoin(Option(pushDown.reduceLeft(And))) if (stayUp.nonEmpty) { - Filter(stayUp.reduceLeft(And), newChild) + join.joinType match { + // In case of Left semi join, the part of the join condition which does not refer to + // to attributes of the grandchild are kept as a Filter over window. + case LeftSemi => Filter(stayUp.reduce(And), newChild) + // In case of left anti join, the join is pushed down when the entire join condition + // is eligible to be pushed down to preserve the semantics of left anti join. + case _ => join + } } else { newChild } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala index 1a0231ed2d99..185568d334ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala @@ -117,7 +117,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, originalQuery.analyze) } - test("Aggregate: LeftSemiAnti join partial pushdown") { + test("Aggregate: LeftSemi join partial pushdown") { val originalQuery = testRelation .groupBy('b)('b, sum('c).as('sum)) .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 10)) @@ -132,6 +132,15 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Aggregate: LeftAnti join no pushdown") { + val originalQuery = testRelation + .groupBy('b)('b, sum('c).as('sum)) + .join(testRelation1, joinType = LeftAnti, condition = Some('b === 'd && 'sum === 10)) + + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } + test("LeftSemiAnti join over aggregate - no pushdown") { val originalQuery = testRelation .groupBy('b)('b, sum('c).as('sum)) @@ -174,7 +183,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Window: LeftSemiAnti partial pushdown") { + test("Window: LeftSemi partial pushdown") { // Attributes from join condition which does not refer to the window partition spec // are kept up in the plan as a Filter operator above Window. val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) @@ -195,6 +204,25 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Window: LeftAnti no pushdown") { + // Attributes from join condition which does not refer to the window partition spec + // are kept up in the plan as a Filter operator above Window. + val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) + + val originalQuery = testRelation + .select('a, 'b, 'c, winExpr.as('window)) + .join(testRelation1, joinType = LeftAnti, condition = Some('a === 'd && 'b > 5)) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = testRelation + .select('a, 'b, 'c) + .window(winExpr.as('window) :: Nil, 'a :: Nil, 'b.asc :: Nil) + .join(testRelation1, joinType = LeftAnti, condition = Some('a === 'd && 'b > 5)) + .select('a, 'b, 'c, 'window).analyze + comparePlans(optimized, correctAnswer) + } + test("Union: LeftSemiAnti join pushdown") { val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int) @@ -251,7 +279,7 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("Unary: LeftSemiAnti join pushdown - partial pushdown") { + test("Unary: LeftSemi join pushdown - partial pushdown") { val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) val originalQuery = testRelationWithArrayType .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col")) @@ -267,6 +295,16 @@ class LeftSemiPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Unary: LeftAnti join pushdown - no pushdown") { + val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) + val originalQuery = testRelationWithArrayType + .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col")) + .join(testRelation1, joinType = LeftAnti, condition = Some('b === 'd && 'b === 'out_col)) + + val optimized = Optimize.execute(originalQuery.analyze) + comparePlans(optimized, originalQuery.analyze) + } + test("Unary: LeftSemiAnti join pushdown - no pushdown") { val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) val originalQuery = testRelationWithArrayType