diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 446571aa8409f..cfb5c5758e8cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -234,9 +234,15 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings: children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => + logDebug(s"Checking if sort of ${requiredOrdering} needed on ${child.simpleString}") if (requiredOrdering.nonEmpty) { // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort. - if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) { + val orderings = requiredOrdering zip child.outputOrdering + val needSort = orderings.length != requiredOrdering.length || + orderings.exists { case (requiredOrder, childOrder) => + !requiredOrder.semanticEquals(childOrder) + } + if (needSort) { SortExec(requiredOrdering, global = false, child = child) } else { child diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index c96239e682018..33b2d170f9a54 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -18,14 +18,14 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{execution, Row} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Descending, Literal, SortOrder} import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange} +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} +import org.apache.spark.sql.execution.exchange._ import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -37,6 +37,14 @@ class PlannerSuite extends SharedSQLContext { setupTestData() + private def sortCount(plan: SparkPlan): Int = { + plan match { + case SortExec(_, _, child, _) => 1 + sortCount(child) + case InMemoryTableScanExec(_, _, relation) => sortCount(relation.child) + case _ => plan.children.map(sortCount).sum + } + } + private def testPartialAggregationPlan(query: LogicalPlan): Unit = { val planner = spark.sessionState.planner import planner._ @@ -416,7 +424,7 @@ class PlannerSuite extends SharedSQLContext { } // This is a regression test for SPARK-11135 - test("EnsureRequirements adds sort when required ordering isn't a prefix of existing ordering") { + test("EnsureRequirements adds sort when required ordering isn't prefix of existing ordering") { val orderingA = SortOrder(Literal(1), Ascending) val orderingB = SortOrder(Literal(2), Ascending) assert(orderingA != orderingB) @@ -471,6 +479,81 @@ class PlannerSuite extends SharedSQLContext { } } + test("EnsureRequirements adds sort when ordering columns same but diff direction") { + val orderingA = SortOrder(Literal(1), Ascending) + val orderingB = SortOrder(Literal(1), Descending) + assert(orderingA != orderingB) + val inputPlan = DummySparkPlan( + children = DummySparkPlan(outputOrdering = Seq(orderingA)) :: Nil, + requiredChildOrdering = Seq(Seq(orderingB)), + requiredChildDistribution = Seq(UnspecifiedDistribution) + ) + val outputPlan = EnsureRequirements(sqlContext.conf).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case s: SortExec => true }.isEmpty) { + fail(s"Sort should have been added:\n$outputPlan") + } + } + + test("EnsureRequirements doesn't add sort with cached sorted table") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withTempTable("t1", "t2") { + val df = Seq( + (1, 1), + (3, 3)).toDF("k", "v") + val df2 = Seq( + (1, 2), + (3, 3)).toDF("k", "v") + + df.filter("k > 0").repartition(2, df("k")).sortWithinPartitions(df("k")) + .registerTempTable("t1") + sqlContext.cacheTable("t1") + df2.filter("k > 0").repartition(2, df2("k")).sortWithinPartitions(df2("k")) + .registerTempTable("t2") + sqlContext.cacheTable("t2") + + val joined = sqlContext.sql( + "select t2.v from t1 inner join t2 on t1.k = t2.k where t2.v < 10") + val outputPlan = joined.queryExecution.executedPlan + assert( + sortCount(outputPlan) == 2, + s"Extra sort should not have been added by SortMergeJoin:\n$outputPlan") + + assert(joined.collect.toSeq == Seq(Row(2), Row(3))) + } + } + } + + test("EnsureRequirements doesn't add sort with different column capitalization") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withTempTable("t1", "t2") { + val df = Seq( + (1, 1), + (3, 3)).toDF("k", "v") + val df2 = Seq( + (1, 2), + (3, 3)).toDF("k", "v") + + df.filter("k > 0").repartition(2, df("k")).sortWithinPartitions(df("k")) + .registerTempTable("t1") + sqlContext.cacheTable("t1") + // upper case K + df2.filter("k > 0").repartition(2, df2("k")).sortWithinPartitions(df2("K")) + .registerTempTable("t2") + sqlContext.cacheTable("t2") + + val joined = sqlContext.sql( + "select t2.v from t1 inner join t2 on t1.k = t2.k where t2.v < 10") + val outputPlan = joined.queryExecution.executedPlan + assert( + sortCount(outputPlan) == 2, + s"Extra sort should not have been added by SortMergeJoin:\n$outputPlan") + + assert(joined.collect.toSeq == Seq(Row(2), Row(3))) + } + } + } + // --------------------------------------------------------------------------------------------- test("Reuse exchanges") {