Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
Batch("Join Reorder", FixedPoint(1),
CostBasedJoinReorder) :+
Batch("Eliminate Sorts", Once,
EliminateSorts) :+
EliminateSorts,
RemoveRedundantSorts) :+
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) :+
// This batch must run after "Decimal Optimizations", as that one may change the
Expand Down Expand Up @@ -769,11 +770,11 @@ object LimitPushDown extends Rule[LogicalPlan] {
LocalLimit(exp, project.copy(child = pushLocalLimitThroughJoin(exp, join)))
// Push down limit 1 through Aggregate and turn Aggregate into Project if it is group only.
case Limit(le @ IntegerLiteral(1), a: Aggregate) if a.groupOnly =>
val project = Project(a.aggregateExpressions, LocalLimit(le, a.child))
project.setTagValue(Project.dataOrderIrrelevantTag, ())
Limit(le, project)
val newAgg = EliminateSorts(a.copy(child = LocalLimit(le, a.child))).asInstanceOf[Aggregate]
Limit(le, Project(newAgg.aggregateExpressions, newAgg.child))
case Limit(le @ IntegerLiteral(1), p @ Project(_, a: Aggregate)) if a.groupOnly =>
Limit(le, p.copy(child = Project(a.aggregateExpressions, LocalLimit(le, a.child))))
val newAgg = EliminateSorts(a.copy(child = LocalLimit(le, a.child))).asInstanceOf[Aggregate]
Limit(le, p.copy(child = Project(newAgg.aggregateExpressions, newAgg.child)))
// Merge offset value and limit value into LocalLimit and pushes down LocalLimit through Offset.
case LocalLimit(le, Offset(oe, grandChild)) =>
Offset(oe, LocalLimit(Add(le, oe), grandChild))
Expand Down Expand Up @@ -1555,38 +1556,30 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
* Note that changes in the final output ordering may affect the file size (SPARK-32318).
* This rule handles the following cases:
* 1) if the sort order is empty or the sort order does not have any reference
* 2) if the Sort operator is a local sort and the child is already sorted
* 3) if there is another Sort operator separated by 0...n Project, Filter, Repartition or
* 2) if there is another Sort operator separated by 0...n Project, Filter, Repartition or
* RepartitionByExpression, RebalancePartitions (with deterministic expressions) operators
* 4) if the Sort operator is within Join separated by 0...n Project, Filter, Repartition or
* 3) if the Sort operator is within Join separated by 0...n Project, Filter, Repartition or
* RepartitionByExpression, RebalancePartitions (with deterministic expressions) operators only
* and the Join condition is deterministic
* 5) if the Sort operator is within GroupBy separated by 0...n Project, Filter, Repartition or
* 4) if the Sort operator is within GroupBy separated by 0...n Project, Filter, Repartition or
Copy link
Contributor Author

@cloud-fan cloud-fan Dec 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is still in EliminateSorts, so EliminateSorts is good enough for LimitPushDown

* RepartitionByExpression, RebalancePartitions (with deterministic expressions) operators only
* and the aggregate function is order irrelevant
*/
object EliminateSorts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(SORT))(applyLocally)

private val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(_.containsPattern(SORT)) {
case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) =>
val newOrders = orders.filterNot(_.child.foldable)
if (newOrders.isEmpty) {
applyLocally.lift(child).getOrElse(child)
child
} else {
s.copy(order = newOrders)
}
case Sort(orders, false, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the expensive part as it need to calculate the ordering of children.

applyLocally.lift(child).getOrElse(child)
case s @ Sort(_, global, child) => s.copy(child = recursiveRemoveSort(child, global))
case j @ Join(originLeft, originRight, _, cond, _) if cond.forall(_.deterministic) =>
j.copy(left = recursiveRemoveSort(originLeft, true),
right = recursiveRemoveSort(originRight, true))
case g @ Aggregate(_, aggs, originChild) if isOrderIrrelevantAggs(aggs) =>
g.copy(child = recursiveRemoveSort(originChild, true))
case p: Project if p.getTagValue(Project.dataOrderIrrelevantTag).isDefined =>
p.copy(child = recursiveRemoveSort(p.child, true))
}

/**
Expand All @@ -1602,12 +1595,6 @@ object EliminateSorts extends Rule[LogicalPlan] {
plan match {
case Sort(_, global, child) if canRemoveGlobalSort || !global =>
recursiveRemoveSort(child, canRemoveGlobalSort)
case Sort(sortOrder, true, child) =>
// For this case, the upper sort is local so the ordering of present sort is unnecessary,
// so here we only preserve its output partitioning using `RepartitionByExpression`.
// We should use `None` as the optNumPartitions so AQE can coalesce shuffle partitions.
// This behavior is same with original global sort.
RepartitionByExpression(sortOrder, recursiveRemoveSort(child, true), None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, previously this rule looks into this global Sort's child to remove local and global Sort recursively without condition. But in the new RemoveRedundantSorts rule:

case s @ Sort(orders, true, child) =>
  val newChild = recursiveRemoveSort(child, optimizeGlobalSort = false)

recursiveRemoveSort in RemoveRedundantSorts only removes local Sort if its child is already sorted. Do we miss this optimization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@viirya viirya Dec 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Said there are Sorts like

- Sort (local)
  - Sort (global)
    - Sort (local)

We reach:

case s @ Sort(_, global, child) => s.copy(child = recursiveRemoveSort(child, global))

Previously we can get rid of the middle global Sort and the bottom local Sort by RepartitionByExpression(sortOrder, recursiveRemoveSort(child, true), None) and:

case Sort(_, global, child) if canRemoveGlobalSort || !global =>
  recursiveRemoveSort(child, canRemoveGlobalSort)

How does EliminateSorts still do it?
The code you point is same (not changed in this PR):

case s @ Sort(_, global, child) => s.copy(child = recursiveRemoveSort(child, global))

But in recursiveRemoveSort, as canRemoveGlobalSort is false, we don't get rid of the middle global Sort now (it will be done in RemoveRedundantSorts now).

Then the bottom local Sort under the rewritten RepartitionByExpression won't be optimized as it requires its child is sorted.

Do I miss or misread something?

Copy link
Contributor Author

@cloud-fan cloud-fan Dec 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After running EliminateSorts, the bottom sort is removed, then we run RemoveRedundantSorts which will turn the middle sort to local sort.

These two rules are in the same batch

case other if canEliminateSort(other) =>
other.withNewChildren(other.children.map(c => recursiveRemoveSort(c, canRemoveGlobalSort)))
case other if canEliminateGlobalSort(other) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RepartitionByExpression, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.SORT

/**
* Remove redundant local [[Sort]] from the logical plan if its child is already sorted, and also
* rewrite global [[Sort]] under local [[Sort]] into [[RepartitionByExpression]].
*/
object RemoveRedundantSorts extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
recursiveRemoveSort(plan, optimizeGlobalSort = false)
}

private def recursiveRemoveSort(plan: LogicalPlan, optimizeGlobalSort: Boolean): LogicalPlan = {
if (!plan.containsPattern(SORT)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we pull out this to apply method ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should put it here to skip some children of a plan node.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plan.containsPattern contains the bitset of children..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we traverse down a tree, we still need to apply the skipping for each plan node that has more than one children.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see, make sense. Here we traverse the tree manually

return plan
}
plan match {
case s @ Sort(orders, false, child) =>
if (SortOrder.orderingSatisfies(child.outputOrdering, orders)) {
recursiveRemoveSort(child, optimizeGlobalSort = false)
} else {
s.withNewChildren(Seq(recursiveRemoveSort(child, optimizeGlobalSort = true)))
}

case s @ Sort(orders, true, child) =>
val newChild = recursiveRemoveSort(child, optimizeGlobalSort = false)
if (optimizeGlobalSort) {
// For this case, the upper sort is local so the ordering of present sort is unnecessary,
// so here we only preserve its output partitioning using `RepartitionByExpression`.
// We should use `None` as the optNumPartitions so AQE can coalesce shuffle partitions.
// This behavior is same with original global sort.
RepartitionByExpression(orders, newChild, None)
} else {
s.withNewChildren(Seq(newChild))
}

case _ =>
plan.withNewChildren(plan.children.map(recursiveRemoveSort(_, optimizeGlobalSort = false)))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,6 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan)

object Project {
val hiddenOutputTag: TreeNodeTag[Seq[Attribute]] = TreeNodeTag[Seq[Attribute]]("hidden_output")
// Project with this tag means it doesn't care about the data order of its input. We only set
// this tag when the Project was converted from grouping-only Aggregate.
val dataOrderIrrelevantTag: TreeNodeTag[Unit] = TreeNodeTag[Unit]("data_order_irrelevant")

def matchSchema(plan: LogicalPlan, schema: StructType, conf: SQLConf): Project = {
assert(plan.resolved)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ class EliminateSortsSuite extends AnalysisTest {
FoldablePropagation,
LimitPushDown) ::
Batch("Eliminate Sorts", Once,
EliminateSorts) ::
EliminateSorts,
RemoveRedundantSorts) ::
Batch("Collapse Project", Once,
CollapseProject) :: Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,24 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write
}

// assert the outer most sort in the executed plan
assert(plan.collectFirst {
case s: SortExec => s
}.exists {
case SortExec(Seq(
SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _),
SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _)
), false, _, _) => true
case _ => false
}, plan)
val sort = plan.collectFirst { case s: SortExec => s }
if (enabled) {
// With planned write, optimizer is more efficient and can eliminate the `SORT BY`.
assert(sort.exists {
case SortExec(Seq(
SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _)
), false, _, _) => true
case _ => false
}, plan)
} else {
assert(sort.exists {
case SortExec(Seq(
SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _),
SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _)
), false, _, _) => true
case _ => false
}, plan)
}
}
}
}
Expand Down Expand Up @@ -270,15 +279,24 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write
}

// assert the outer most sort in the executed plan
assert(plan.collectFirst {
case s: SortExec => s
}.exists {
case SortExec(Seq(
SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _),
SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _)
), false, _, _) => true
case _ => false
}, plan)
val sort = plan.collectFirst { case s: SortExec => s }
if (enabled) {
// With planned write, optimizer is more efficient and can eliminate the `SORT BY`.
assert(sort.exists {
case SortExec(Seq(
SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _)
), false, _, _) => true
case _ => false
}, plan)
} else {
assert(sort.exists {
case SortExec(Seq(
SortOrder(AttributeReference("value", StringType, _, _), Ascending, NullsFirst, _),
SortOrder(AttributeReference("key", IntegerType, _, _), Ascending, NullsFirst, _)
), false, _, _) => true
case _ => false
}, plan)
}
}
}
}
Expand Down