diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a4b25cbd1d2ee..ced4f6c7c6fca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -211,7 +211,8 @@ abstract class Optimizer(catalogManager: CatalogManager) Batch("Join Reorder", FixedPoint(1), CostBasedJoinReorder) :+ Batch("Eliminate Sorts", Once, - EliminateSorts) :+ + EliminateSorts, + RemoveRedundantSorts) :+ Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :+ // This batch must run after "Decimal Optimizations", as that one may change the @@ -769,11 +770,11 @@ object LimitPushDown extends Rule[LogicalPlan] { LocalLimit(exp, project.copy(child = pushLocalLimitThroughJoin(exp, join))) // Push down limit 1 through Aggregate and turn Aggregate into Project if it is group only. case Limit(le @ IntegerLiteral(1), a: Aggregate) if a.groupOnly => - val project = Project(a.aggregateExpressions, LocalLimit(le, a.child)) - project.setTagValue(Project.dataOrderIrrelevantTag, ()) - Limit(le, project) + val newAgg = EliminateSorts(a.copy(child = LocalLimit(le, a.child))).asInstanceOf[Aggregate] + Limit(le, Project(newAgg.aggregateExpressions, newAgg.child)) case Limit(le @ IntegerLiteral(1), p @ Project(_, a: Aggregate)) if a.groupOnly => - Limit(le, p.copy(child = Project(a.aggregateExpressions, LocalLimit(le, a.child)))) + val newAgg = EliminateSorts(a.copy(child = LocalLimit(le, a.child))).asInstanceOf[Aggregate] + Limit(le, p.copy(child = Project(newAgg.aggregateExpressions, newAgg.child))) // Merge offset value and limit value into LocalLimit and pushes down LocalLimit through Offset. case LocalLimit(le, Offset(oe, grandChild)) => Offset(oe, LocalLimit(Add(le, oe), grandChild)) @@ -1555,38 +1556,30 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper { * Note that changes in the final output ordering may affect the file size (SPARK-32318). * This rule handles the following cases: * 1) if the sort order is empty or the sort order does not have any reference - * 2) if the Sort operator is a local sort and the child is already sorted - * 3) if there is another Sort operator separated by 0...n Project, Filter, Repartition or + * 2) if there is another Sort operator separated by 0...n Project, Filter, Repartition or * RepartitionByExpression, RebalancePartitions (with deterministic expressions) operators - * 4) if the Sort operator is within Join separated by 0...n Project, Filter, Repartition or + * 3) if the Sort operator is within Join separated by 0...n Project, Filter, Repartition or * RepartitionByExpression, RebalancePartitions (with deterministic expressions) operators only * and the Join condition is deterministic - * 5) if the Sort operator is within GroupBy separated by 0...n Project, Filter, Repartition or + * 4) if the Sort operator is within GroupBy separated by 0...n Project, Filter, Repartition or * RepartitionByExpression, RebalancePartitions (with deterministic expressions) operators only * and the aggregate function is order irrelevant */ object EliminateSorts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( - _.containsPattern(SORT))(applyLocally) - - private val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(_.containsPattern(SORT)) { case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) => val newOrders = orders.filterNot(_.child.foldable) if (newOrders.isEmpty) { - applyLocally.lift(child).getOrElse(child) + child } else { s.copy(order = newOrders) } - case Sort(orders, false, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => - applyLocally.lift(child).getOrElse(child) case s @ Sort(_, global, child) => s.copy(child = recursiveRemoveSort(child, global)) case j @ Join(originLeft, originRight, _, cond, _) if cond.forall(_.deterministic) => j.copy(left = recursiveRemoveSort(originLeft, true), right = recursiveRemoveSort(originRight, true)) case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) => g.copy(child = recursiveRemoveSort(originChild, true)) - case p: Project if p.getTagValue(Project.dataOrderIrrelevantTag).isDefined => - p.copy(child = recursiveRemoveSort(p.child, true)) } /** @@ -1602,12 +1595,6 @@ object EliminateSorts extends Rule[LogicalPlan] { plan match { case Sort(_, global, child) if canRemoveGlobalSort || !global => recursiveRemoveSort(child, canRemoveGlobalSort) - case Sort(sortOrder, true, child) => - // For this case, the upper sort is local so the ordering of present sort is unnecessary, - // so here we only preserve its output partitioning using `RepartitionByExpression`. - // We should use `None` as the optNumPartitions so AQE can coalesce shuffle partitions. - // This behavior is same with original global sort. - RepartitionByExpression(sortOrder, recursiveRemoveSort(child, true), None) case other if canEliminateSort(other) => other.withNewChildren(other.children.map(c => recursiveRemoveSort(c, canRemoveGlobalSort))) case other if canEliminateGlobalSort(other) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSorts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSorts.scala new file mode 100644 index 0000000000000..204d2a34675bc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSorts.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RepartitionByExpression, Sort} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.SORT + +/** + * Remove redundant local [[Sort]] from the logical plan if its child is already sorted, and also + * rewrite global [[Sort]] under local [[Sort]] into [[RepartitionByExpression]]. + */ +object RemoveRedundantSorts extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + recursiveRemoveSort(plan, optimizeGlobalSort = false) + } + + private def recursiveRemoveSort(plan: LogicalPlan, optimizeGlobalSort: Boolean): LogicalPlan = { + if (!plan.containsPattern(SORT)) { + return plan + } + plan match { + case s @ Sort(orders, false, child) => + if (SortOrder.orderingSatisfies(child.outputOrdering, orders)) { + recursiveRemoveSort(child, optimizeGlobalSort = false) + } else { + s.withNewChildren(Seq(recursiveRemoveSort(child, optimizeGlobalSort = true))) + } + + case s @ Sort(orders, true, child) => + val newChild = recursiveRemoveSort(child, optimizeGlobalSort = false) + if (optimizeGlobalSort) { + // For this case, the upper sort is local so the ordering of present sort is unnecessary, + // so here we only preserve its output partitioning using `RepartitionByExpression`. + // We should use `None` as the optNumPartitions so AQE can coalesce shuffle partitions. + // This behavior is same with original global sort. + RepartitionByExpression(orders, newChild, None) + } else { + s.withNewChildren(Seq(newChild)) + } + + case _ => + plan.withNewChildren(plan.children.map(recursiveRemoveSort(_, optimizeGlobalSort = false))) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 65f4151c0c963..497f485b67fe2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -101,9 +101,6 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) object Project { val hiddenOutputTag: TreeNodeTag[Seq[Attribute]] = TreeNodeTag[Seq[Attribute]]("hidden_output") - // Project with this tag means it doesn't care about the data order of its input. We only set - // this tag when the Project was converted from grouping-only Aggregate. - val dataOrderIrrelevantTag: TreeNodeTag[Unit] = TreeNodeTag[Unit]("data_order_irrelevant") def matchSchema(plan: LogicalPlan, schema: StructType, conf: SQLConf): Project = { assert(plan.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index c6312fa1b1aa1..5cfe4a7bf4623 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -39,7 +39,8 @@ class EliminateSortsSuite extends AnalysisTest { FoldablePropagation, LimitPushDown) :: Batch("Eliminate Sorts", Once, - EliminateSorts) :: + EliminateSorts, + RemoveRedundantSorts) :: Batch("Collapse Project", Once, CollapseProject) :: Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala index ce43edb79c127..3ca516463d367 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala @@ -223,15 +223,24 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write } // assert the outer most sort in the executed plan - assert(plan.collectFirst { - case s: SortExec => s - }.exists { - case SortExec(Seq( - SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _), - SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _) - ), false, _, _) => true - case _ => false - }, plan) + val sort = plan.collectFirst { case s: SortExec => s } + if (enabled) { + // With planned write, optimizer is more efficient and can eliminate the `SORT BY`. + assert(sort.exists { + case SortExec(Seq( + SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _) + ), false, _, _) => true + case _ => false + }, plan) + } else { + assert(sort.exists { + case SortExec(Seq( + SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _), + SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _) + ), false, _, _) => true + case _ => false + }, plan) + } } } } @@ -270,15 +279,24 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write } // assert the outer most sort in the executed plan - assert(plan.collectFirst { - case s: SortExec => s - }.exists { - case SortExec(Seq( - SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _), - SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _) - ), false, _, _) => true - case _ => false - }, plan) + val sort = plan.collectFirst { case s: SortExec => s } + if (enabled) { + // With planned write, optimizer is more efficient and can eliminate the `SORT BY`. + assert(sort.exists { + case SortExec(Seq( + SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _) + ), false, _, _) => true + case _ => false + }, plan) + } else { + assert(sort.exists { + case SortExec(Seq( + SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _), + SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _) + ), false, _, _) => true + case _ => false + }, plan) + } } } }