From 11c266b4db7c515c24aa9112a07655cb0a09f850 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 16 Jul 2021 23:42:18 +0800 Subject: [PATCH 1/4] Push down limit 1 through Aggregate if it is group only. --- .../sql/catalyst/optimizer/Optimizer.scala | 5 ++++ .../optimizer/LimitPushdownSuite.scala | 25 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c79fd7a87a83..981c1bc3a2c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -647,6 +647,11 @@ object LimitPushDown extends Rule[LogicalPlan] { // There is a Project between LocalLimit and Join if they do not have the same output. case LocalLimit(exp, project @ Project(_, join: Join)) => LocalLimit(exp, project.copy(child = pushLocalLimitThroughJoin(exp, join))) + // Push down limit 1 through Aggregate if it is group only. + case Limit(le @ IntegerLiteral(1), a: Aggregate) if a.groupOnly => + Limit(le, a.copy(child = LocalLimit(le, a.child))) + case Limit(le @ IntegerLiteral(1), p @ Project(_, a: Aggregate)) if a.groupOnly => + Limit(le, p.copy(child = a.copy(child = LocalLimit(le, a.child)))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index c2503e362c8c..47af5725fd4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -239,4 +239,29 @@ class LimitPushdownSuite extends PlanTest { Limit(5, LocalLimit(5, x).join(y, LeftOuter, joinCondition).select("x.a".attr)).analyze comparePlans(optimized, correctAnswer) } + + test("SPARK-36183: Push down limit 1 through Aggregate if it is group only") { + // Push down when it is group only and limit 1. + comparePlans( + Optimize.execute(x.groupBy("x.a".attr)("x.a".attr).limit(1).analyze), + LocalLimit(1, x).groupBy("x.a".attr)("x.a".attr).limit(1).analyze) + + comparePlans( + Optimize.execute(x.groupBy("x.a".attr)("x.a".attr).select("x.a".attr).limit(1).analyze), + LocalLimit(1, x).groupBy("x.a".attr)("x.a".attr).select("x.a".attr).limit(1).analyze) + + comparePlans( + Optimize.execute(x.union(y).groupBy("x.a".attr)("x.a".attr).limit(1).analyze), + LocalLimit(1, LocalLimit(1, x).union(LocalLimit(1, y))) + .groupBy("x.a".attr)("x.a".attr).limit(1).analyze) + + // No push down + comparePlans( + Optimize.execute(x.groupBy("x.a".attr)("x.a".attr).limit(2).analyze), + x.groupBy("x.a".attr)("x.a".attr).limit(2).analyze) + + comparePlans( + Optimize.execute(x.groupBy("x.a".attr)("x.a".attr, count("x.a".attr)).limit(1).analyze), + x.groupBy("x.a".attr)("x.a".attr, count("x.a".attr)).limit(1).analyze) + } } From c630a6344f8df95cea739f71ae81821555bde8d1 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 19 Jul 2021 21:46:23 +0800 Subject: [PATCH 2/4] fix --- .../apache/spark/sql/catalyst/optimizer/Optimizer.scala | 4 ++-- .../spark/sql/catalyst/optimizer/LimitPushdownSuite.scala | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 981c1bc3a2c9..049820b54449 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -649,9 +649,9 @@ object LimitPushDown extends Rule[LogicalPlan] { LocalLimit(exp, project.copy(child = pushLocalLimitThroughJoin(exp, join))) // Push down limit 1 through Aggregate if it is group only. case Limit(le @ IntegerLiteral(1), a: Aggregate) if a.groupOnly => - Limit(le, a.copy(child = LocalLimit(le, a.child))) + Limit(le, Project(a.output, LocalLimit(le, a.child))) case Limit(le @ IntegerLiteral(1), p @ Project(_, a: Aggregate)) if a.groupOnly => - Limit(le, p.copy(child = a.copy(child = LocalLimit(le, a.child)))) + Limit(le, p.copy(child = Project(a.output, LocalLimit(le, a.child)))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index 47af5725fd4d..848416b09813 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -244,16 +244,15 @@ class LimitPushdownSuite extends PlanTest { // Push down when it is group only and limit 1. comparePlans( Optimize.execute(x.groupBy("x.a".attr)("x.a".attr).limit(1).analyze), - LocalLimit(1, x).groupBy("x.a".attr)("x.a".attr).limit(1).analyze) + LocalLimit(1, x).select("x.a".attr).limit(1).analyze) comparePlans( Optimize.execute(x.groupBy("x.a".attr)("x.a".attr).select("x.a".attr).limit(1).analyze), - LocalLimit(1, x).groupBy("x.a".attr)("x.a".attr).select("x.a".attr).limit(1).analyze) + LocalLimit(1, x).select("x.a".attr).select("x.a".attr).limit(1).analyze) comparePlans( Optimize.execute(x.union(y).groupBy("x.a".attr)("x.a".attr).limit(1).analyze), - LocalLimit(1, LocalLimit(1, x).union(LocalLimit(1, y))) - .groupBy("x.a".attr)("x.a".attr).limit(1).analyze) + LocalLimit(1, LocalLimit(1, x).union(LocalLimit(1, y))).select("x.a".attr).limit(1).analyze) // No push down comparePlans( From ec3df298047861d8b4afd3fb764a4ec853d0d3ea Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 19 Jul 2021 23:36:24 +0800 Subject: [PATCH 3/4] Update comment --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 049820b54449..db901a6f7d1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -647,7 +647,7 @@ object LimitPushDown extends Rule[LogicalPlan] { // There is a Project between LocalLimit and Join if they do not have the same output. case LocalLimit(exp, project @ Project(_, join: Join)) => LocalLimit(exp, project.copy(child = pushLocalLimitThroughJoin(exp, join))) - // Push down limit 1 through Aggregate if it is group only. + // Push down limit 1 and turn Aggregate into Project through Aggregate if it is group only. case Limit(le @ IntegerLiteral(1), a: Aggregate) if a.groupOnly => Limit(le, Project(a.output, LocalLimit(le, a.child))) case Limit(le @ IntegerLiteral(1), p @ Project(_, a: Aggregate)) if a.groupOnly => From 1d60d648aea63000366aa1a51bb5d62dc29e27dd Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Tue, 20 Jul 2021 15:49:42 +0800 Subject: [PATCH 4/4] Update sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala Co-authored-by: Wenchen Fan --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index db901a6f7d1f..f4da7870a132 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -647,7 +647,7 @@ object LimitPushDown extends Rule[LogicalPlan] { // There is a Project between LocalLimit and Join if they do not have the same output. case LocalLimit(exp, project @ Project(_, join: Join)) => LocalLimit(exp, project.copy(child = pushLocalLimitThroughJoin(exp, join))) - // Push down limit 1 and turn Aggregate into Project through Aggregate if it is group only. + // Push down limit 1 through Aggregate and turn Aggregate into Project if it is group only. case Limit(le @ IntegerLiteral(1), a: Aggregate) if a.groupOnly => Limit(le, Project(a.output, LocalLimit(le, a.child))) case Limit(le @ IntegerLiteral(1), p @ Project(_, a: Aggregate)) if a.groupOnly =>