Skip to content

Commit 9fb4536

Browse files
allisonwang-dbcloud-fan
authored andcommitted
[SPARK-33183][SQL] Fix Optimizer rule EliminateSorts and add a physical rule to remove redundant sorts
### What changes were proposed in this pull request? This PR aims to fix a correctness bug in the optimizer rule `EliminateSorts`. It also adds a new physical rule to remove redundant sorts that cannot be eliminated in the Optimizer rule after the bugfix. ### Why are the changes needed? A global sort should not be eliminated even if its child is ordered since we don't know if its child ordering is global or local. For example, in the following scenario, the first sort shouldn't be removed because it has a stronger guarantee than the second sort even if the sort orders are the same for both sorts. ``` Sort(orders, global = True, ...) Sort(orders, global = False, ...) ``` Since there is no straightforward way to identify whether a node's output ordering is local or global, we should not remove a global sort even if its child is already ordered. ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? Unit tests Closes apache#30093 from allisonwang-db/fix-sort. Authored-by: allisonwang-db <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 528160f commit 9fb4536

File tree

8 files changed

+303
-28
lines changed

8 files changed

+303
-28
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,7 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
10201020
* Note that changes in the final output ordering may affect the file size (SPARK-32318).
10211021
* This rule handles the following cases:
10221022
* 1) if the sort order is empty or the sort order does not have any reference
1023-
* 2) if the child is already sorted
1023+
* 2) if the Sort operator is a local sort and the child is already sorted
10241024
* 3) if there is another Sort operator separated by 0...n Project, Filter, Repartition or
10251025
* RepartitionByExpression (with deterministic expressions) operators
10261026
* 4) if the Sort operator is within Join separated by 0...n Project, Filter, Repartition or
@@ -1031,12 +1031,18 @@ object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
10311031
* function is order irrelevant
10321032
*/
10331033
object EliminateSorts extends Rule[LogicalPlan] {
1034-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
1034+
def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally
1035+
1036+
private val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
10351037
case s @ Sort(orders, _, child) if orders.isEmpty || orders.exists(_.child.foldable) =>
10361038
val newOrders = orders.filterNot(_.child.foldable)
1037-
if (newOrders.isEmpty) child else s.copy(order = newOrders)
1038-
case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) =>
1039-
child
1039+
if (newOrders.isEmpty) {
1040+
applyLocally.lift(child).getOrElse(child)
1041+
} else {
1042+
s.copy(order = newOrders)
1043+
}
1044+
case Sort(orders, false, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) =>
1045+
applyLocally.lift(child).getOrElse(child)
10401046
case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child))
10411047
case j @ Join(originLeft, originRight, _, cond, _) if cond.forall(_.deterministic) =>
10421048
j.copy(left = recursiveRemoveSort(originLeft), right = recursiveRemoveSort(originRight))

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,6 +1253,13 @@ object SQLConf {
12531253
.booleanConf
12541254
.createWithDefault(true)
12551255

1256+
val REMOVE_REDUNDANT_SORTS_ENABLED = buildConf("spark.sql.execution.removeRedundantSorts")
1257+
.internal()
1258+
.doc("Whether to remove redundant physical sort node")
1259+
.version("3.1.0")
1260+
.booleanConf
1261+
.createWithDefault(true)
1262+
12561263
val STATE_STORE_PROVIDER_CLASS =
12571264
buildConf("spark.sql.streaming.stateStore.providerClass")
12581265
.internal()

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala

Lines changed: 92 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,34 @@ class EliminateSortsSuite extends AnalysisTest {
9999
comparePlans(optimized, correctAnswer)
100100
}
101101

102-
test("remove redundant order by") {
102+
test("SPARK-33183: remove consecutive no-op sorts") {
103+
val plan = testRelation.orderBy().orderBy().orderBy()
104+
val optimized = Optimize.execute(plan.analyze)
105+
val correctAnswer = testRelation.analyze
106+
comparePlans(optimized, correctAnswer)
107+
}
108+
109+
test("SPARK-33183: remove redundant sort by") {
103110
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
104-
val unnecessaryReordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst)
111+
val unnecessaryReordered = orderedPlan.limit(2).select('a).sortBy('a.asc, 'b.desc_nullsFirst)
105112
val optimized = Optimize.execute(unnecessaryReordered.analyze)
106113
val correctAnswer = orderedPlan.limit(2).select('a).analyze
107-
comparePlans(Optimize.execute(optimized), correctAnswer)
114+
comparePlans(optimized, correctAnswer)
115+
}
116+
117+
test("SPARK-33183: remove all redundant local sorts") {
118+
val orderedPlan = testRelation.sortBy('a.asc).orderBy('a.asc).sortBy('a.asc)
119+
val optimized = Optimize.execute(orderedPlan.analyze)
120+
val correctAnswer = testRelation.orderBy('a.asc).analyze
121+
comparePlans(optimized, correctAnswer)
122+
}
123+
124+
test("SPARK-33183: should not remove global sort") {
125+
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
126+
val reordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst)
127+
val optimized = Optimize.execute(reordered.analyze)
128+
val correctAnswer = reordered.analyze
129+
comparePlans(optimized, correctAnswer)
108130
}
109131

110132
test("do not remove sort if the order is different") {
@@ -115,22 +137,39 @@ class EliminateSortsSuite extends AnalysisTest {
115137
comparePlans(optimized, correctAnswer)
116138
}
117139

118-
test("filters don't affect order") {
140+
test("SPARK-33183: remove top level local sort with filter operators") {
119141
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
120-
val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc)
142+
val filteredAndReordered = orderedPlan.where('a > Literal(10)).sortBy('a.asc, 'b.desc)
121143
val optimized = Optimize.execute(filteredAndReordered.analyze)
122144
val correctAnswer = orderedPlan.where('a > Literal(10)).analyze
123145
comparePlans(optimized, correctAnswer)
124146
}
125147

126-
test("limits don't affect order") {
148+
test("SPARK-33183: keep top level global sort with filter operators") {
149+
val projectPlan = testRelation.select('a, 'b)
150+
val orderedPlan = projectPlan.orderBy('a.asc, 'b.desc)
151+
val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc)
152+
val optimized = Optimize.execute(filteredAndReordered.analyze)
153+
val correctAnswer = projectPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc).analyze
154+
comparePlans(optimized, correctAnswer)
155+
}
156+
157+
test("SPARK-33183: limits should not affect order for local sort") {
127158
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
128-
val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc)
159+
val filteredAndReordered = orderedPlan.limit(Literal(10)).sortBy('a.asc, 'b.desc)
129160
val optimized = Optimize.execute(filteredAndReordered.analyze)
130161
val correctAnswer = orderedPlan.limit(Literal(10)).analyze
131162
comparePlans(optimized, correctAnswer)
132163
}
133164

165+
test("SPARK-33183: should not remove global sort with limit operators") {
166+
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
167+
val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc)
168+
val optimized = Optimize.execute(filteredAndReordered.analyze)
169+
val correctAnswer = filteredAndReordered.analyze
170+
comparePlans(optimized, correctAnswer)
171+
}
172+
134173
test("different sorts are not simplified if limit is in between") {
135174
val orderedPlan = testRelation.select('a, 'b).orderBy('b.desc).limit(Literal(10))
136175
.orderBy('a.asc)
@@ -139,11 +178,11 @@ class EliminateSortsSuite extends AnalysisTest {
139178
comparePlans(optimized, correctAnswer)
140179
}
141180

142-
test("range is already sorted") {
181+
test("SPARK-33183: should not remove global sort with range operator") {
143182
val inputPlan = Range(1L, 1000L, 1, 10)
144183
val orderedPlan = inputPlan.orderBy('id.asc)
145184
val optimized = Optimize.execute(orderedPlan.analyze)
146-
val correctAnswer = inputPlan.analyze
185+
val correctAnswer = orderedPlan.analyze
147186
comparePlans(optimized, correctAnswer)
148187

149188
val reversedPlan = inputPlan.orderBy('id.desc)
@@ -154,10 +193,18 @@ class EliminateSortsSuite extends AnalysisTest {
154193
val negativeStepInputPlan = Range(10L, 1L, -1, 10)
155194
val negativeStepOrderedPlan = negativeStepInputPlan.orderBy('id.desc)
156195
val negativeStepOptimized = Optimize.execute(negativeStepOrderedPlan.analyze)
157-
val negativeStepCorrectAnswer = negativeStepInputPlan.analyze
196+
val negativeStepCorrectAnswer = negativeStepOrderedPlan.analyze
158197
comparePlans(negativeStepOptimized, negativeStepCorrectAnswer)
159198
}
160199

200+
test("SPARK-33183: remove local sort with range operator") {
201+
val inputPlan = Range(1L, 1000L, 1, 10)
202+
val orderedPlan = inputPlan.sortBy('id.asc)
203+
val optimized = Optimize.execute(orderedPlan.analyze)
204+
val correctAnswer = inputPlan.analyze
205+
comparePlans(optimized, correctAnswer)
206+
}
207+
161208
test("sort should not be removed when there is a node which doesn't guarantee any order") {
162209
val orderedPlan = testRelation.select('a, 'b)
163210
val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc)
@@ -333,4 +380,39 @@ class EliminateSortsSuite extends AnalysisTest {
333380
val correctAnswer = PushDownOptimizer.execute(noOrderByPlan.analyze)
334381
comparePlans(optimized, correctAnswer)
335382
}
383+
384+
test("SPARK-33183: remove consecutive global sorts with the same ordering") {
385+
Seq(
386+
(testRelation.orderBy('a.asc).orderBy('a.asc), testRelation.orderBy('a.asc)),
387+
(testRelation.orderBy('a.asc, 'b.desc).orderBy('a.asc), testRelation.orderBy('a.asc))
388+
).foreach { case (ordered, answer) =>
389+
val optimized = Optimize.execute(ordered.analyze)
390+
comparePlans(optimized, answer.analyze)
391+
}
392+
}
393+
394+
test("SPARK-33183: remove consecutive local sorts with the same ordering") {
395+
val orderedPlan = testRelation.sortBy('a.asc).sortBy('a.asc).sortBy('a.asc)
396+
val optimized = Optimize.execute(orderedPlan.analyze)
397+
val correctAnswer = testRelation.sortBy('a.asc).analyze
398+
comparePlans(optimized, correctAnswer)
399+
}
400+
401+
test("SPARK-33183: remove consecutive local sorts with different ordering") {
402+
val orderedPlan = testRelation.sortBy('b.asc).sortBy('a.desc).sortBy('a.asc)
403+
val optimized = Optimize.execute(orderedPlan.analyze)
404+
val correctAnswer = testRelation.sortBy('a.asc).analyze
405+
comparePlans(optimized, correctAnswer)
406+
}
407+
408+
test("SPARK-33183: should keep global sort when child is a local sort with the same ordering") {
409+
val correctAnswer = testRelation.orderBy('a.asc).analyze
410+
Seq(
411+
testRelation.sortBy('a.asc).orderBy('a.asc),
412+
testRelation.orderBy('a.asc).sortBy('a.asc).orderBy('a.asc)
413+
).foreach { ordered =>
414+
val optimized = Optimize.execute(ordered.analyze)
415+
comparePlans(optimized, correctAnswer)
416+
}
417+
}
336418
}

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ object QueryExecution {
343343
PlanDynamicPruningFilters,
344344
PlanSubqueries,
345345
RemoveRedundantProjects,
346+
RemoveRedundantSorts,
346347
EnsureRequirements,
347348
DisableUnnecessaryBucketedScan,
348349
ApplyColumnarRulesAndInsertTransitions(sparkSession.sessionState.columnarRules),
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import org.apache.spark.sql.catalyst.expressions.SortOrder
21+
import org.apache.spark.sql.catalyst.rules.Rule
22+
import org.apache.spark.sql.internal.SQLConf
23+
24+
/**
25+
* Remove redundant SortExec node from the spark plan. A sort node is redundant when
26+
* its child satisfies both its sort orders and its required child distribution. Note
27+
* this rule differs from the Optimizer rule EliminateSorts in that this rule also checks
28+
* if the child satisfies the required distribution so that it is safe to remove not only a
29+
* local sort but also a global sort when its child already satisfies required sort orders.
30+
*/
31+
object RemoveRedundantSorts extends Rule[SparkPlan] {
32+
def apply(plan: SparkPlan): SparkPlan = {
33+
if (!conf.getConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED)) {
34+
plan
35+
} else {
36+
removeSorts(plan)
37+
}
38+
}
39+
40+
private def removeSorts(plan: SparkPlan): SparkPlan = plan transform {
41+
case s @ SortExec(orders, _, child, _)
42+
if SortOrder.orderingSatisfies(child.outputOrdering, orders) &&
43+
child.outputPartitioning.satisfies(s.requiredChildDistribution.head) =>
44+
child
45+
}
46+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,15 @@ case class AdaptiveSparkPlanExec(
8383
@transient private val optimizer = new AQEOptimizer(conf)
8484

8585
@transient private val removeRedundantProjects = RemoveRedundantProjects
86+
@transient private val removeRedundantSorts = RemoveRedundantSorts
8687
@transient private val ensureRequirements = EnsureRequirements
8788

8889
// A list of physical plan rules to be applied before creation of query stages. The physical
8990
// plan should reach a final status of query stages (i.e., no more addition or removal of
9091
// Exchange nodes) after running these rules.
9192
private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq(
9293
removeRedundantProjects,
94+
removeRedundantSorts,
9395
ensureRequirements
9496
) ++ context.session.sessionState.queryStagePrepRules
9597

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -234,19 +234,6 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
234234
}
235235
}
236236

237-
test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") {
238-
val query = testData.select('key, 'value).sort('key.desc).cache()
239-
assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation])
240-
val resorted = query.sort('key.desc)
241-
assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty)
242-
assert(resorted.select('key).collect().map(_.getInt(0)).toSeq ==
243-
(1 to 100).reverse)
244-
// with a different order, the sort is needed
245-
val sortedAsc = query.sort('key)
246-
assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.size == 1)
247-
assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100))
248-
}
249-
250237
test("PartitioningCollection") {
251238
withTempView("normal", "small", "tiny") {
252239
testData.createOrReplaceTempView("normal")

0 commit comments

Comments
 (0)