diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 99bc45fa9e9e8..271510b50839e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan} import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, PruneShuffleAndSort, ReuseExchange} import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode @@ -285,6 +285,7 @@ object QueryExecution { PlanDynamicPruningFilters(sparkSession), PlanSubqueries(sparkSession), EnsureRequirements(sparkSession.sessionState.conf), + PruneShuffleAndSort(), ApplyColumnarRulesAndInsertTransitions(sparkSession.sessionState.conf, sparkSession.sessionState.columnarRules), CollapseCodegenStages(sparkSession.sessionState.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 28ef793ed62db..d506aeb7a8b58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -216,12 +216,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - // TODO: remove this after we create a physical operator for `RepartitionByExpression`. - case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) => - child.outputPartitioning match { - case lower: HashPartitioning if upper.semanticEquals(lower) => child - case _ => operator - } case operator: SparkPlan => ensureDistributionAndOrdering(reorderJoinPredicates(operator)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/PruneShuffleAndSort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/PruneShuffleAndSort.scala new file mode 100644 index 0000000000000..e0daf91d4008a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/PruneShuffleAndSort.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{SortExec, SparkPlan} + +/** + * Removes unnecessary shuffles and sorts after new ones are introduced by [[Rule]]s for + * [[SparkPlan]]s, such as [[EnsureRequirements]]. + */ +case class PruneShuffleAndSort() extends Rule[SparkPlan] { + + override def apply(plan: SparkPlan): SparkPlan = { + plan.transformUp { + case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) => + child.outputPartitioning match { + case lower: HashPartitioning if upper.semanticEquals(lower) => child + case _ @ PartitioningCollection(partitionings) => + if (partitionings.exists{ + case lower: HashPartitioning => upper.semanticEquals(lower) + case _ => false + }) { + child + } else { + operator + } + case _ => operator + } + case SortExec(upper, false, child, _) + if SortOrder.orderingSatisfies(child.outputOrdering, upper) => child + case subPlan: SparkPlan => subPlan + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index d428b7ebc0e91..de287d2fbb35e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, PruneShuffleAndSort, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -433,7 +433,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { val inputPlan = ShuffleExchangeExec( partitioning, DummySparkPlan(outputPartitioning = partitioning)) - val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) + val outputPlan = PruneShuffleAndSort().apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) { fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") @@ -727,6 +727,48 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { } } + test("SPARK-28148: repartition after join is not optimized away") { + + def numSorts(plan: SparkPlan): Int = { + plan.collect{case s: SortExec => s }.length + } + + def numShuffles(plan: SparkPlan): Int = { + plan.collect{case s: ShuffleExchangeExec => s }.length + } + + val df1 = spark.range(0, 5000000, 1, 5) + val df2 = spark.range(0, 10000000, 1, 5) + + val outputPlan0 = df1.join(df2, Seq("id"), "left") + .repartition(20, df1("id")).queryExecution.executedPlan + assert(numSorts(outputPlan0) == 2) + assert(numShuffles(outputPlan0) == 3, "user defined numPartitions shouldn't be eliminated") + + // non global sort order and partitioning should be reusable after left join + val outputPlan1 = df1.join(df2, Seq("id"), "left") + .repartition(df1("id")) + .sortWithinPartitions(df1("id")) + .queryExecution.executedPlan + assert(numSorts(outputPlan1) == 2) + assert(numShuffles(outputPlan1) == 2) + + // non global sort order and partitioning should be reusable after inner join + val outputPlan2 = df1.join(df2, Seq("id")) + .repartition(df1("id")) + .sortWithinPartitions(df1("id")) + .queryExecution.executedPlan + assert(numSorts(outputPlan2) == 2) + assert(numShuffles(outputPlan2) == 2) + + // global sort should not be removed + val outputPlan3 = df1.join(df2, Seq("id")) + .orderBy(df1("id")) + .queryExecution.executedPlan + assert(numSorts(outputPlan3) == 3) + assert(numShuffles(outputPlan3) == 3) + } + test("SPARK-24500: create union with stream of children") { val df = Union(Stream( Range(1, 1, 1, 1),