Skip to content

Commit 779ac0f

Browse files
allisonwang-dbcloud-fan
authored andcommitted
[SPARK-33183][SQL][2.4] Fix Optimizer rule EliminateSorts and add a physical rule to remove redundant sorts
Backport #30093 for branch-2.4. ### What changes were proposed in this pull request? This PR aims to fix a correctness bug in the optimizer rule EliminateSorts. It also adds a new physical rule to remove redundant sorts that cannot be eliminated in the Optimizer rule after the bugfix. ### Why are the changes needed? A global sort should not be eliminated even if its child is ordered since we don't know if its child ordering is global or local. For example, in the following scenario, the first sort shouldn't be removed because it has a stronger guarantee than the second sort even if the sort orders are the same for both sorts. ``` Sort(orders, global = True, ...) Sort(orders, global = False, ...) ``` Since there is no straightforward way to identify whether a node's output ordering is local or global, we should not remove a global sort even if its child is already ordered. ### Does this PR introduce any user-facing change? Yes ### How was this patch tested? Unit tests Closes #30194 from allisonwang-db/SPARK-33183-branch-2.4. Authored-by: allisonwang-db <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent f1c3041 commit 779ac0f

File tree

8 files changed

+253
-27
lines changed

8 files changed

+253
-27
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -879,13 +879,15 @@ object EliminateSorts extends Rule[LogicalPlan] {
879879

880880
/**
881881
* Removes redundant Sort operation. This can happen:
882-
* 1) if the child is already sorted
882+
* 1) if the Sort operator is a local sort and the child is already sorted
883883
* 2) if there is another Sort operator separated by 0...n Project/Filter operators
884884
*/
885885
object RemoveRedundantSorts extends Rule[LogicalPlan] {
886-
def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
887-
case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) =>
888-
child
886+
def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally
887+
888+
private val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
889+
case Sort(orders, false, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) =>
890+
applyLocally.lift(child).getOrElse(child)
889891
case s @ Sort(_, _, child) => s.copy(child = recursiveRemoveSort(child))
890892
}
891893

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,12 @@ object SQLConf {
850850
.booleanConf
851851
.createWithDefault(true)
852852

853+
val REMOVE_REDUNDANT_SORTS_ENABLED = buildConf("spark.sql.execution.removeRedundantSorts")
854+
.internal()
855+
.doc("Whether to remove redundant physical sort node")
856+
.booleanConf
857+
.createWithDefault(true)
858+
853859
val STATE_STORE_PROVIDER_CLASS =
854860
buildConf("spark.sql.streaming.stateStore.providerClass")
855861
.internal()

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,11 @@ class EliminateSortsSuite extends PlanTest {
9292
val correctAnswer = distributedPlan.analyze
9393
comparePlans(optimized, correctAnswer)
9494
}
95+
96+
test("SPARK-33183: remove consecutive no-op sorts") {
97+
val plan = testRelation.orderBy().orderBy().orderBy()
98+
val optimized = Optimize.execute(plan.analyze)
99+
val correctAnswer = testRelation.analyze
100+
comparePlans(optimized, correctAnswer)
101+
}
95102
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,27 @@ class RemoveRedundantSortsSuite extends PlanTest {
3636

3737
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
3838

39-
test("remove redundant order by") {
39+
test("SPARK-33183: remove redundant sort by") {
4040
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
41-
val unnecessaryReordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst)
41+
val unnecessaryReordered = orderedPlan.limit(2).select('a).sortBy('a.asc, 'b.desc_nullsFirst)
4242
val optimized = Optimize.execute(unnecessaryReordered.analyze)
4343
val correctAnswer = orderedPlan.limit(2).select('a).analyze
44-
comparePlans(Optimize.execute(optimized), correctAnswer)
44+
comparePlans(optimized, correctAnswer)
45+
}
46+
47+
test("SPARK-33183: remove all redundant local sorts") {
48+
val orderedPlan = testRelation.sortBy('a.asc).orderBy('a.asc).sortBy('a.asc)
49+
val optimized = Optimize.execute(orderedPlan.analyze)
50+
val correctAnswer = testRelation.orderBy('a.asc).analyze
51+
comparePlans(optimized, correctAnswer)
52+
}
53+
54+
test("SPARK-33183: should not remove global sort") {
55+
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst)
56+
val reordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst)
57+
val optimized = Optimize.execute(reordered.analyze)
58+
val correctAnswer = reordered.analyze
59+
comparePlans(optimized, correctAnswer)
4560
}
4661

4762
test("do not remove sort if the order is different") {
@@ -52,22 +67,39 @@ class RemoveRedundantSortsSuite extends PlanTest {
5267
comparePlans(optimized, correctAnswer)
5368
}
5469

55-
test("filters don't affect order") {
70+
test("SPARK-33183: remove top level local sort with filter operators") {
5671
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
57-
val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc)
72+
val filteredAndReordered = orderedPlan.where('a > Literal(10)).sortBy('a.asc, 'b.desc)
5873
val optimized = Optimize.execute(filteredAndReordered.analyze)
5974
val correctAnswer = orderedPlan.where('a > Literal(10)).analyze
6075
comparePlans(optimized, correctAnswer)
6176
}
6277

63-
test("limits don't affect order") {
78+
test("SPARK-33183: keep top level global sort with filter operators") {
79+
val projectPlan = testRelation.select('a, 'b)
80+
val orderedPlan = projectPlan.orderBy('a.asc, 'b.desc)
81+
val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc)
82+
val optimized = Optimize.execute(filteredAndReordered.analyze)
83+
val correctAnswer = projectPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc).analyze
84+
comparePlans(optimized, correctAnswer)
85+
}
86+
87+
test("SPARK-33183: limits should not affect order for local sort") {
6488
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
65-
val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc)
89+
val filteredAndReordered = orderedPlan.limit(Literal(10)).sortBy('a.asc, 'b.desc)
6690
val optimized = Optimize.execute(filteredAndReordered.analyze)
6791
val correctAnswer = orderedPlan.limit(Literal(10)).analyze
6892
comparePlans(optimized, correctAnswer)
6993
}
7094

95+
test("SPARK-33183: should not remove global sort with limit operators") {
96+
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc)
97+
val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc)
98+
val optimized = Optimize.execute(filteredAndReordered.analyze)
99+
val correctAnswer = filteredAndReordered.analyze
100+
comparePlans(optimized, correctAnswer)
101+
}
102+
71103
test("different sorts are not simplified if limit is in between") {
72104
val orderedPlan = testRelation.select('a, 'b).orderBy('b.desc).limit(Literal(10))
73105
.orderBy('a.asc)
@@ -76,11 +108,11 @@ class RemoveRedundantSortsSuite extends PlanTest {
76108
comparePlans(optimized, correctAnswer)
77109
}
78110

79-
test("range is already sorted") {
111+
test("SPARK-33183: should not remove global sort with range operator") {
80112
val inputPlan = Range(1L, 1000L, 1, 10)
81113
val orderedPlan = inputPlan.orderBy('id.asc)
82114
val optimized = Optimize.execute(orderedPlan.analyze)
83-
val correctAnswer = inputPlan.analyze
115+
val correctAnswer = orderedPlan.analyze
84116
comparePlans(optimized, correctAnswer)
85117

86118
val reversedPlan = inputPlan.orderBy('id.desc)
@@ -91,10 +123,18 @@ class RemoveRedundantSortsSuite extends PlanTest {
91123
val negativeStepInputPlan = Range(10L, 1L, -1, 10)
92124
val negativeStepOrderedPlan = negativeStepInputPlan.orderBy('id.desc)
93125
val negativeStepOptimized = Optimize.execute(negativeStepOrderedPlan.analyze)
94-
val negativeStepCorrectAnswer = negativeStepInputPlan.analyze
126+
val negativeStepCorrectAnswer = negativeStepOrderedPlan.analyze
95127
comparePlans(negativeStepOptimized, negativeStepCorrectAnswer)
96128
}
97129

130+
test("SPARK-33183: remove local sort with range operator") {
131+
val inputPlan = Range(1L, 1000L, 1, 10)
132+
val orderedPlan = inputPlan.sortBy('id.asc)
133+
val optimized = Optimize.execute(orderedPlan.analyze)
134+
val correctAnswer = inputPlan.analyze
135+
comparePlans(optimized, correctAnswer)
136+
}
137+
98138
test("sort should not be removed when there is a node which doesn't guarantee any order") {
99139
val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc)
100140
val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc)
@@ -135,4 +175,39 @@ class RemoveRedundantSortsSuite extends PlanTest {
135175
.select(('b + 1).as('c)).orderBy('c.asc).analyze
136176
comparePlans(optimizedThrice, correctAnswerThrice)
137177
}
178+
179+
test("SPARK-33183: remove consecutive global sorts with the same ordering") {
180+
Seq(
181+
(testRelation.orderBy('a.asc).orderBy('a.asc), testRelation.orderBy('a.asc)),
182+
(testRelation.orderBy('a.asc, 'b.desc).orderBy('a.asc), testRelation.orderBy('a.asc))
183+
).foreach { case (ordered, answer) =>
184+
val optimized = Optimize.execute(ordered.analyze)
185+
comparePlans(optimized, answer.analyze)
186+
}
187+
}
188+
189+
test("SPARK-33183: remove consecutive local sorts with the same ordering") {
190+
val orderedPlan = testRelation.sortBy('a.asc).sortBy('a.asc).sortBy('a.asc)
191+
val optimized = Optimize.execute(orderedPlan.analyze)
192+
val correctAnswer = testRelation.sortBy('a.asc).analyze
193+
comparePlans(optimized, correctAnswer)
194+
}
195+
196+
test("SPARK-33183: remove consecutive local sorts with different ordering") {
197+
val orderedPlan = testRelation.sortBy('b.asc).sortBy('a.desc).sortBy('a.asc)
198+
val optimized = Optimize.execute(orderedPlan.analyze)
199+
val correctAnswer = testRelation.sortBy('a.asc).analyze
200+
comparePlans(optimized, correctAnswer)
201+
}
202+
203+
test("SPARK-33183: should keep global sort when child is a local sort with the same ordering") {
204+
val correctAnswer = testRelation.orderBy('a.asc).analyze
205+
Seq(
206+
testRelation.sortBy('a.asc).orderBy('a.asc),
207+
testRelation.orderBy('a.asc).sortBy('a.asc).orderBy('a.asc)
208+
).foreach { ordered =>
209+
val optimized = Optimize.execute(ordered.analyze)
210+
comparePlans(optimized, correctAnswer)
211+
}
212+
}
138213
}

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
9797
/** A sequence of rules that will be applied in order to the physical plan before execution. */
9898
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
9999
PlanSubqueries(sparkSession),
100+
RemoveRedundantSorts(sparkSession.sessionState.conf),
100101
EnsureRequirements(sparkSession.sessionState.conf),
101102
CollapseCodegenStages(sparkSession.sessionState.conf),
102103
ReuseExchange(sparkSession.sessionState.conf),
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import org.apache.spark.sql.catalyst.expressions.SortOrder
21+
import org.apache.spark.sql.catalyst.rules.Rule
22+
import org.apache.spark.sql.internal.SQLConf
23+
24+
/**
25+
* Remove redundant SortExec node from the spark plan. A sort node is redundant when
26+
* its child satisfies both its sort orders and its required child distribution. Note
27+
* this rule differs from the Optimizer rule EliminateSorts in that this rule also checks
28+
* if the child satisfies the required distribution so that it is safe to remove not only a
29+
* local sort but also a global sort when its child already satisfies required sort orders.
30+
*/
31+
case class RemoveRedundantSorts(conf: SQLConf) extends Rule[SparkPlan] {
32+
def apply(plan: SparkPlan): SparkPlan = {
33+
if (!conf.getConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED)) {
34+
plan
35+
} else {
36+
removeSorts(plan)
37+
}
38+
}
39+
40+
private def removeSorts(plan: SparkPlan): SparkPlan = plan transform {
41+
case s @ SortExec(orders, _, child, _)
42+
if SortOrder.orderingSatisfies(child.outputOrdering, orders) &&
43+
child.outputPartitioning.satisfies(s.requiredChildDistribution.head) =>
44+
child
45+
}
46+
}

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -230,19 +230,6 @@ class PlannerSuite extends SharedSQLContext {
230230
}
231231
}
232232

233-
test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") {
234-
val query = testData.select('key, 'value).sort('key.desc).cache()
235-
assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation])
236-
val resorted = query.sort('key.desc)
237-
assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty)
238-
assert(resorted.select('key).collect().map(_.getInt(0)).toSeq ==
239-
(1 to 100).reverse)
240-
// with a different order, the sort is needed
241-
val sortedAsc = query.sort('key)
242-
assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.size == 1)
243-
assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100))
244-
}
245-
246233
test("PartitioningCollection") {
247234
withTempView("normal", "small", "tiny") {
248235
testData.createOrReplaceTempView("normal")
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import org.apache.spark.sql.{DataFrame, QueryTest}
21+
import org.apache.spark.sql.internal.SQLConf
22+
import org.apache.spark.sql.test.SharedSparkSession
23+
24+
25+
class RemoveRedundantSortsSuite
26+
extends QueryTest
27+
with SharedSparkSession {
28+
import testImplicits._
29+
30+
private def checkNumSorts(df: DataFrame, count: Int): Unit = {
31+
val plan = df.queryExecution.executedPlan
32+
assert(plan.collect { case s: SortExec => s }.length == count)
33+
}
34+
35+
private def checkSorts(query: String, enabledCount: Int, disabledCount: Int): Unit = {
36+
withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "true") {
37+
val df = sql(query)
38+
checkNumSorts(df, enabledCount)
39+
val result = df.collect()
40+
withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "false") {
41+
val df = sql(query)
42+
checkNumSorts(df, disabledCount)
43+
checkAnswer(df, result)
44+
}
45+
}
46+
}
47+
48+
test("remove redundant sorts with limit") {
49+
withTempView("t") {
50+
spark.range(100).select('id as "key").createOrReplaceTempView("t")
51+
val query =
52+
"""
53+
|SELECT key FROM
54+
| (SELECT key FROM t WHERE key > 10 ORDER BY key DESC LIMIT 10)
55+
|ORDER BY key DESC
56+
|""".stripMargin
57+
checkSorts(query, 0, 1)
58+
}
59+
}
60+
61+
test("remove redundant sorts with sort merge join") {
62+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
63+
withTempView("t1", "t2") {
64+
spark.range(1000).select('id as "key").createOrReplaceTempView("t1")
65+
spark.range(1000).select('id as "key").createOrReplaceTempView("t2")
66+
val query = """
67+
|SELECT t1.key FROM
68+
| (SELECT key FROM t1 WHERE key > 10 ORDER BY key DESC LIMIT 10) t1
69+
|JOIN
70+
| (SELECT key FROM t2 WHERE key > 50 ORDER BY key DESC LIMIT 100) t2
71+
|ON t1.key = t2.key
72+
|ORDER BY t1.key
73+
""".stripMargin
74+
75+
val queryAsc = query + " ASC"
76+
checkSorts(queryAsc, 2, 3)
77+
78+
// The top level sort should not be removed since the child output ordering is ASC and
79+
// the required ordering is DESC.
80+
val queryDesc = query + " DESC"
81+
checkSorts(queryDesc, 3, 3)
82+
}
83+
}
84+
}
85+
86+
test("cached sorted data doesn't need to be re-sorted") {
87+
withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "true") {
88+
val df = spark.range(1000).select('id as "key").sort('key.desc).cache()
89+
val resorted = df.sort('key.desc)
90+
val sortedAsc = df.sort('key.asc)
91+
checkNumSorts(df, 0)
92+
checkNumSorts(resorted, 0)
93+
checkNumSorts(sortedAsc, 1)
94+
val result = resorted.collect()
95+
withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "false") {
96+
val resorted = df.sort('key.desc)
97+
checkNumSorts(resorted, 1)
98+
checkAnswer(resorted, result)
99+
}
100+
}
101+
}
102+
}

0 commit comments

Comments
 (0)