diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8a18d55cb6437..44ffc4f749502 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -719,9 +719,11 @@ object LimitPushDown extends Rule[LogicalPlan] { private def pushLocalLimitThroughJoin(limitExpr: Expression, join: Join): Join = { join.joinType match { - case RightOuter => join.copy(right = maybePushLocalLimit(limitExpr, join.right)) - case LeftOuter => join.copy(left = maybePushLocalLimit(limitExpr, join.left)) - case _: InnerLike if join.condition.isEmpty => + case RightOuter if join.condition.nonEmpty => + join.copy(right = maybePushLocalLimit(limitExpr, join.right)) + case LeftOuter if join.condition.nonEmpty => + join.copy(left = maybePushLocalLimit(limitExpr, join.left)) + case _: InnerLike | RightOuter | LeftOuter | FullOuter if join.condition.isEmpty => join.copy( left = maybePushLocalLimit(limitExpr, join.left), right = maybePushLocalLimit(limitExpr, join.right)) @@ -743,15 +745,15 @@ object LimitPushDown extends Rule[LogicalPlan] { LocalLimit(exp, u.copy(children = u.children.map(maybePushLocalLimit(exp, _)))) // Add extra limits below JOIN: - // 1. For LEFT OUTER and RIGHT OUTER JOIN, we push limits to the left and right sides, - // respectively. - // 2. For INNER and CROSS JOIN, we push limits to both the left and right sides if join - // condition is empty. + // 1. For LEFT OUTER and RIGHT OUTER JOIN, we push limits to the left and right sides + // respectively if join condition is not empty. + // 2. For INNER, CROSS JOIN and OUTER JOIN, we push limits to both the left and right sides if + // join condition is empty. // 3. For LEFT SEMI and LEFT ANTI JOIN, we push limits to the left side if join condition // is empty. - // It's not safe to push limits below FULL OUTER JOIN in the general case without a more - // invasive rewrite. We also need to ensure that this limit pushdown rule will not eventually - // introduce limits on both sides if it is applied multiple times. Therefore: + // It's not safe to push limits below FULL OUTER JOIN with join condition in the general case + // without a more invasive rewrite. We also need to ensure that this limit pushdown rule will + // not eventually introduce limits on both sides if it is applied multiple times. Therefore: // - If one side is already limited, stack another limit on top if the new limit is smaller. // The redundant limit will be collapsed by the CombineLimits rule. case LocalLimit(exp, join: Join) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index f309da6b4f5df..7cbc308182c61 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -372,7 +372,7 @@ class EliminateSortsSuite extends AnalysisTest { .limit(10) val optimized = Optimize.execute(joinPlan.analyze) val correctAnswer = LocalLimit(10, projectPlan) - .join(projectPlanB, LeftOuter) + .join(LocalLimit(10, projectPlanB), LeftOuter) .limit(10).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index 9c093bda26366..02631c4cf61c9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -96,45 +96,75 @@ class LimitPushdownSuite extends PlanTest { // Outer join ---------------------------------------------------------------------------------- test("left outer join") { - val originalQuery = x.join(y, LeftOuter).limit(1) - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, LocalLimit(1, x).join(y, LeftOuter)).analyze - comparePlans(optimized, correctAnswer) + Seq(Some("x.a".attr === "y.b".attr), None).foreach { condition => + val originalQuery = x.join(y, LeftOuter, condition).limit(1).analyze + val optimized = if (condition.isEmpty) { + LocalLimit(1, x).join(LocalLimit(1, y), LeftOuter, condition).limit(1).analyze + } else { + LocalLimit(1, x).join(y, LeftOuter, condition).limit(1).analyze + } + comparePlans(Optimize.execute(originalQuery), optimized) + } } test("left outer join and left sides are limited") { - val originalQuery = x.limit(2).join(y, LeftOuter).limit(1) - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, LocalLimit(1, x).join(y, LeftOuter)).analyze - comparePlans(optimized, correctAnswer) + Seq(Some("x.a".attr === "y.b".attr), None).foreach { condition => + val originalQuery = x.limit(2).join(y, LeftOuter, condition).limit(1).analyze + val optimized = if (condition.isEmpty) { + LocalLimit(1, x).join(LocalLimit(1, y), LeftOuter, condition).limit(1).analyze + } else { + LocalLimit(1, x).join(y, LeftOuter, condition).limit(1).analyze + } + comparePlans(Optimize.execute(originalQuery), optimized) + } } test("left outer join and right sides are limited") { - val originalQuery = x.join(y.limit(2), LeftOuter).limit(1) - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, LocalLimit(1, x).join(Limit(2, y), LeftOuter)).analyze - comparePlans(optimized, correctAnswer) + Seq(Some("x.a".attr === "y.b".attr), None).foreach { condition => + val originalQuery = x.join(y.limit(2), LeftOuter, condition).limit(1).analyze + val optimized = if (condition.isEmpty) { + LocalLimit(1, x).join(LocalLimit(1, y), LeftOuter, condition).limit(1).analyze + } else { + LocalLimit(1, x).join(Limit(2, y), LeftOuter, condition).limit(1).analyze + } + comparePlans( Optimize.execute(originalQuery), optimized) + } } test("right outer join") { - val originalQuery = x.join(y, RightOuter).limit(1) - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, x.join(LocalLimit(1, y), RightOuter)).analyze - comparePlans(optimized, correctAnswer) + Seq(Some("x.a".attr === "y.b".attr), None).foreach { condition => + val originalQuery = x.join(y, RightOuter, condition).limit(1).analyze + val optimized = if (condition.isEmpty) { + LocalLimit(1, x).join(LocalLimit(1, y), RightOuter, condition).limit(1).analyze + } else { + x.join(LocalLimit(1, y), RightOuter, condition).limit(1).analyze + } + comparePlans(Optimize.execute(originalQuery), optimized) + } } test("right outer join and right sides are limited") { - val originalQuery = x.join(y.limit(2), RightOuter).limit(1) - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, x.join(LocalLimit(1, y), RightOuter)).analyze - comparePlans(optimized, correctAnswer) + Seq(Some("x.a".attr === "y.b".attr), None).foreach { condition => + val originalQuery = x.join(y.limit(2), RightOuter, condition).limit(1).analyze + val optimized = if (condition.isEmpty) { + LocalLimit(1, x).join(LocalLimit(1, y), RightOuter, condition).limit(1).analyze + } else { + x.join(LocalLimit(1, y), RightOuter, condition).limit(1).analyze + } + comparePlans(Optimize.execute(originalQuery), optimized) + } } test("right outer join and left sides are limited") { - val originalQuery = x.limit(2).join(y, RightOuter).limit(1) - val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Limit(1, Limit(2, x).join(LocalLimit(1, y), RightOuter)).analyze - comparePlans(optimized, correctAnswer) + Seq(Some("x.a".attr === "y.b".attr), None).foreach { condition => + val originalQuery = x.limit(2).join(y, RightOuter, condition).limit(1).analyze + val optimized = if (condition.isEmpty) { + LocalLimit(1, x).join(LocalLimit(1, y), RightOuter, condition).limit(1).analyze + } else { + Limit(2, x).join(LocalLimit(1, y), RightOuter, condition).limit(1).analyze + } + comparePlans(Optimize.execute(originalQuery), optimized) + } } test("larger limits are not pushed on top of smaller ones in right outer join") { @@ -146,35 +176,59 @@ class LimitPushdownSuite extends PlanTest { test("full outer join where neither side is limited and both sides have same statistics") { assert(x.stats.sizeInBytes === y.stats.sizeInBytes) - val originalQuery = x.join(y, FullOuter).limit(1).analyze - val optimized = Optimize.execute(originalQuery) - // No pushdown for FULL OUTER JOINS. - comparePlans(optimized, originalQuery) + Seq(Some("x.a".attr === "y.b".attr), None).foreach { condition => + val originalQuery = x.join(y, FullOuter, condition).limit(1).analyze + val optimized = if (condition.isEmpty) { + LocalLimit(1, x).join(LocalLimit(1, y), FullOuter, condition).limit(1).analyze + } else { + // No pushdown for FULL OUTER JOINS. + originalQuery + } + comparePlans(Optimize.execute(originalQuery), optimized) + } } test("full outer join where neither side is limited and left side has larger statistics") { val xBig = testRelation.copy(data = Seq.fill(10)(null)).subquery("x") assert(xBig.stats.sizeInBytes > y.stats.sizeInBytes) - val originalQuery = xBig.join(y, FullOuter).limit(1).analyze - val optimized = Optimize.execute(originalQuery) - // No pushdown for FULL OUTER JOINS. - comparePlans(optimized, originalQuery) + Seq(Some("x.a".attr === "y.b".attr), None).foreach { condition => + val originalQuery = xBig.join(y, FullOuter, condition).limit(1).analyze + val optimized = if (condition.isEmpty) { + LocalLimit(1, xBig).join(LocalLimit(1, y), FullOuter, condition).limit(1).analyze + } else { + // No pushdown for FULL OUTER JOINS. + originalQuery + } + comparePlans(Optimize.execute(originalQuery), optimized) + } } test("full outer join where neither side is limited and right side has larger statistics") { val yBig = testRelation.copy(data = Seq.fill(10)(null)).subquery("y") assert(x.stats.sizeInBytes < yBig.stats.sizeInBytes) - val originalQuery = x.join(yBig, FullOuter).limit(1).analyze - val optimized = Optimize.execute(originalQuery) - // No pushdown for FULL OUTER JOINS. - comparePlans(optimized, originalQuery) + Seq(Some("x.a".attr === "y.b".attr), None).foreach { condition => + val originalQuery = x.join(yBig, FullOuter, condition).limit(1).analyze + val optimized = if (condition.isEmpty) { + LocalLimit(1, x).join(LocalLimit(1, yBig), FullOuter, condition).limit(1).analyze + } else { + // No pushdown for FULL OUTER JOINS. + originalQuery + } + comparePlans(Optimize.execute(originalQuery), optimized) + } } test("full outer join where both sides are limited") { - val originalQuery = x.limit(2).join(y.limit(2), FullOuter).limit(1).analyze - val optimized = Optimize.execute(originalQuery) - // No pushdown for FULL OUTER JOINS. - comparePlans(optimized, originalQuery) + Seq(Some("x.a".attr === "y.b".attr), None).foreach { condition => + val originalQuery = x.limit(2).join(y.limit(2), FullOuter, condition).limit(1).analyze + val optimized = if (condition.isEmpty) { + LocalLimit(1, x).join(LocalLimit(1, y), FullOuter, condition).limit(1).analyze + } else { + // No pushdown for FULL OUTER JOINS. + originalQuery + } + comparePlans(Optimize.execute(originalQuery), optimized) + } } test("SPARK-33433: Change Aggregate max rows to 1 if grouping is empty") {