diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 357d11c39f4e0..b72d85be594d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -721,9 +721,9 @@ object LimitPushDown extends Rule[LogicalPlan] { LocalLimit(exp, project.copy(child = pushLocalLimitThroughJoin(exp, join))) // Push down limit 1 through Aggregate and turn Aggregate into Project if it is group only. case Limit(le @ IntegerLiteral(1), a: Aggregate) if a.groupOnly => - Limit(le, Project(a.output, LocalLimit(le, a.child))) + Limit(le, Project(a.aggregateExpressions, LocalLimit(le, a.child))) case Limit(le @ IntegerLiteral(1), p @ Project(_, a: Aggregate)) if a.groupOnly => - Limit(le, p.copy(child = Project(a.output, LocalLimit(le, a.child)))) + Limit(le, p.copy(child = Project(a.aggregateExpressions, LocalLimit(le, a.child)))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index ee7f872514985..4cfc90a7d32fd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -254,6 +254,13 @@ class LimitPushdownSuite extends PlanTest { Optimize.execute(x.union(y).groupBy("x.a".attr)("x.a".attr).limit(1).analyze), LocalLimit(1, LocalLimit(1, x).union(LocalLimit(1, y))).select("x.a".attr).limit(1).analyze) + comparePlans( + Optimize.execute( + x.groupBy("x.a".attr)("x.a".attr) + .select("x.a".attr.as("a1"), "x.a".attr.as("a2")).limit(1).analyze), + LocalLimit(1, x).select("x.a".attr) + .select("x.a".attr.as("a1"), "x.a".attr.as("a2")).limit(1).analyze) + // No push down comparePlans( Optimize.execute(x.groupBy("x.a".attr)("x.a".attr).limit(2).analyze),