From af5f6cfa1042d8aca84a61424b36a04b49bf2807 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 11 Dec 2023 23:39:06 -0800 Subject: [PATCH] Still remove Sort after converting Aggregate to Project --- .../org/apache/spark/sql/catalyst/dsl/package.scala | 2 ++ .../spark/sql/catalyst/optimizer/Optimizer.scala | 6 +++++- .../plans/logical/basicLogicalOperators.scala | 3 +++ .../sql/catalyst/optimizer/EliminateSortsSuite.scala | 12 ++++++++++++ 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 30d4c2dbb409f..eb3047700215d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -395,6 +395,8 @@ package object dsl { def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) + def localLimit(limitExpr: Expression): LogicalPlan = LocalLimit(limitExpr, logicalPlan) + def offset(offsetExpr: Expression): LogicalPlan = Offset(offsetExpr, logicalPlan) def join( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 960f5e532c08c..a4b25cbd1d2ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -769,7 +769,9 @@ object LimitPushDown extends Rule[LogicalPlan] { LocalLimit(exp, project.copy(child = pushLocalLimitThroughJoin(exp, join))) // Push down limit 1 through Aggregate and turn Aggregate into Project if it is group only. case Limit(le @ IntegerLiteral(1), a: Aggregate) if a.groupOnly => - Limit(le, Project(a.aggregateExpressions, LocalLimit(le, a.child))) + val project = Project(a.aggregateExpressions, LocalLimit(le, a.child)) + project.setTagValue(Project.dataOrderIrrelevantTag, ()) + Limit(le, project) case Limit(le @ IntegerLiteral(1), p @ Project(_, a: Aggregate)) if a.groupOnly => Limit(le, p.copy(child = Project(a.aggregateExpressions, LocalLimit(le, a.child)))) // Merge offset value and limit value into LocalLimit and pushes down LocalLimit through Offset. @@ -1583,6 +1585,8 @@ object EliminateSorts extends Rule[LogicalPlan] { right = recursiveRemoveSort(originRight, true)) case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) => g.copy(child = recursiveRemoveSort(originChild, true)) + case p: Project if p.getTagValue(Project.dataOrderIrrelevantTag).isDefined => + p.copy(child = recursiveRemoveSort(p.child, true)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 497f485b67fe2..65f4151c0c963 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -101,6 +101,9 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) object Project { val hiddenOutputTag: TreeNodeTag[Seq[Attribute]] = TreeNodeTag[Seq[Attribute]]("hidden_output") + // Project with this tag means it doesn't care about the data order of its input. We only set + // this tag when the Project was converted from grouping-only Aggregate. + val dataOrderIrrelevantTag: TreeNodeTag[Unit] = TreeNodeTag[Unit]("data_order_irrelevant") def matchSchema(plan: LogicalPlan, schema: StructType, conf: SQLConf): Project = { assert(plan.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index 7cbc308182c61..c6312fa1b1aa1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -478,4 +478,16 @@ class EliminateSortsSuite extends AnalysisTest { comparePlans(Optimize.execute(originalPlan.analyze), correctAnswer.analyze) } + + test("SPARK-46378: Still remove Sort after converting Aggregate to Project") { + val originalPlan = testRelation.orderBy($"a".asc) + .groupBy($"a")($"a") + .limit(1) + + val correctAnswer = testRelation.localLimit(1) + .select($"a") + .limit(1) + + comparePlans(Optimize.execute(originalPlan.analyze), correctAnswer.analyze) + } }