Skip to content

Commit dbc38d3

Browse files
[SPARK-33472][SQL] Adjust RemoveRedundantSorts rule order
This PR switched the order for the rule `RemoveRedundantSorts` and `EnsureRequirements` so that `EnsureRequirements` will be invoked before `RemoveRedundantSorts` to avoid IllegalArgumentException when instantiating PartitioningCollection. `RemoveRedundantSorts` rule uses SparkPlan's `outputPartitioning` to check whether a sort node is redundant. Currently, it is added before `EnsureRequirements`. Since `PartitioningCollection` requires left and right partitioning to have the same number of partitions, which is not necessarily true before applying `EnsureRequirements`, the rule can fail with the following exception: ``` IllegalArgumentException: requirement failed: PartitioningCollection requires all of its partitionings have the same numPartitions. ``` No Unit test Closes apache#30373 from allisonwang-db/sort-follow-up. Authored-by: allisonwang-db <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> (cherry picked from commit a03c540) Signed-off-by: allisonwang-db <[email protected]>
1 parent 3772bfa commit dbc38d3

File tree

3 files changed

+36
-2
lines changed

3 files changed

+36
-2
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,10 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
9797
/** A sequence of rules that will be applied in order to the physical plan before execution. */
9898
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
9999
PlanSubqueries(sparkSession),
100-
RemoveRedundantSorts(sparkSession.sessionState.conf),
101100
EnsureRequirements(sparkSession.sessionState.conf),
101+
// `RemoveRedundantSorts` needs to be added before `EnsureRequirements` to guarantee the same
102+
// number of partitions when instantiating PartitioningCollection.
103+
RemoveRedundantSorts(sparkSession.sessionState.conf),
102104
CollapseCodegenStages(sparkSession.sessionState.conf),
103105
ReuseExchange(sparkSession.sessionState.conf),
104106
ReuseSubquery(sparkSession.sessionState.conf))

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
9191
def longMetric(name: String): SQLMetric = metrics(name)
9292

9393
// TODO: Move to `DistributedPlan`
94-
/** Specifies how data is partitioned across different nodes in the cluster. */
94+
/**
95+
* Specifies how data is partitioned across different nodes in the cluster.
96+
* Note this method may fail if it is invoked before `EnsureRequirements` is applied
97+
* since `PartitioningCollection` requires all its partitionings to have
98+
* the same number of partitions.
99+
*/
95100
def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH!
96101

97102
/**

sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.sql.{DataFrame, QueryTest}
21+
import org.apache.spark.sql.catalyst.plans.physical.{RangePartitioning, UnknownPartitioning}
22+
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
2123
import org.apache.spark.sql.internal.SQLConf
2224
import org.apache.spark.sql.test.SharedSparkSession
2325

@@ -99,4 +101,29 @@ class RemoveRedundantSortsSuite
99101
}
100102
}
101103
}
104+
105+
test("SPARK-33472: shuffled join with different left and right side partition numbers") {
106+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
107+
withTempView("t1", "t2") {
108+
spark.range(0, 100, 1, 2).select('id as "key").createOrReplaceTempView("t1")
109+
(0 to 100).toDF("key").createOrReplaceTempView("t2")
110+
111+
val query = """
112+
|SELECT t1.key
113+
|FROM t1 JOIN t2 ON t1.key = t2.key
114+
|WHERE t1.key > 10 AND t2.key < 50
115+
|ORDER BY t1.key ASC
116+
""".stripMargin
117+
118+
val df = sql(query)
119+
val sparkPlan = df.queryExecution.sparkPlan
120+
val join = sparkPlan.collect { case j: SortMergeJoinExec => j }.head
121+
val leftPartitioning = join.left.outputPartitioning
122+
assert(leftPartitioning.isInstanceOf[RangePartitioning])
123+
assert(leftPartitioning.numPartitions == 2)
124+
assert(join.right.outputPartitioning == UnknownPartitioning(0))
125+
checkSorts(query, 3, 3)
126+
}
127+
}
128+
}
102129
}

0 commit comments

Comments
 (0)