diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c79fd7a87a83..f4da7870a132 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -647,6 +647,11 @@ object LimitPushDown extends Rule[LogicalPlan] { // There is a Project between LocalLimit and Join if they do not have the same output. case LocalLimit(exp, project @ Project(_, join: Join)) => LocalLimit(exp, project.copy(child = pushLocalLimitThroughJoin(exp, join))) + // Push down limit 1 through Aggregate and turn Aggregate into Project if it is group only. + case Limit(le @ IntegerLiteral(1), a: Aggregate) if a.groupOnly => + Limit(le, Project(a.output, LocalLimit(le, a.child))) + case Limit(le @ IntegerLiteral(1), p @ Project(_, a: Aggregate)) if a.groupOnly => + Limit(le, p.copy(child = Project(a.output, LocalLimit(le, a.child)))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index c2503e362c8c..848416b09813 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -239,4 +239,28 @@ class LimitPushdownSuite extends PlanTest { Limit(5, LocalLimit(5, x).join(y, LeftOuter, joinCondition).select("x.a".attr)).analyze comparePlans(optimized, correctAnswer) } + + test("SPARK-36183: Push down limit 1 through Aggregate if it is group only") { + // Push down when it is group only and limit 1. + comparePlans( + Optimize.execute(x.groupBy("x.a".attr)("x.a".attr).limit(1).analyze), + LocalLimit(1, x).select("x.a".attr).limit(1).analyze) + + comparePlans( + Optimize.execute(x.groupBy("x.a".attr)("x.a".attr).select("x.a".attr).limit(1).analyze), + LocalLimit(1, x).select("x.a".attr).select("x.a".attr).limit(1).analyze) + + comparePlans( + Optimize.execute(x.union(y).groupBy("x.a".attr)("x.a".attr).limit(1).analyze), + LocalLimit(1, LocalLimit(1, x).union(LocalLimit(1, y))).select("x.a".attr).limit(1).analyze) + + // No push down + comparePlans( + Optimize.execute(x.groupBy("x.a".attr)("x.a".attr).limit(2).analyze), + x.groupBy("x.a".attr)("x.a".attr).limit(2).analyze) + + comparePlans( + Optimize.execute(x.groupBy("x.a".attr)("x.a".attr, count("x.a".attr)).limit(1).analyze), + x.groupBy("x.a".attr)("x.a".attr, count("x.a".attr)).limit(1).analyze) + } }