Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,15 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {

// Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
logDebug(s"Checking if sort of ${requiredOrdering} needed on ${child.simpleString}")
if (requiredOrdering.nonEmpty) {
// If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort.
if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) {
val orderings = requiredOrdering zip child.outputOrdering
val needSort = orderings.length != requiredOrdering.length ||
orderings.exists { case (requiredOrder, childOrder) =>
!requiredOrder.semanticEquals(childOrder)
}
if (needSort) {
SortExec(requiredOrdering, global = false, child = child)
} else {
child
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
package org.apache.spark.sql.execution

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, Row}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Descending, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange}
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.exchange._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -37,6 +37,14 @@ class PlannerSuite extends SharedSQLContext {

setupTestData()

private def sortCount(plan: SparkPlan): Int = {
plan match {
case SortExec(_, _, child, _) => 1 + sortCount(child)
case InMemoryTableScanExec(_, _, relation) => sortCount(relation.child)
case _ => plan.children.map(sortCount).sum
}
}

private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
val planner = spark.sessionState.planner
import planner._
Expand Down Expand Up @@ -416,7 +424,7 @@ class PlannerSuite extends SharedSQLContext {
}

// This is a regression test for SPARK-11135
test("EnsureRequirements adds sort when required ordering isn't a prefix of existing ordering") {
test("EnsureRequirements adds sort when required ordering isn't prefix of existing ordering") {
val orderingA = SortOrder(Literal(1), Ascending)
val orderingB = SortOrder(Literal(2), Ascending)
assert(orderingA != orderingB)
Expand Down Expand Up @@ -471,6 +479,81 @@ class PlannerSuite extends SharedSQLContext {
}
}

test("EnsureRequirements adds sort when ordering columns same but diff direction") {
val orderingA = SortOrder(Literal(1), Ascending)
val orderingB = SortOrder(Literal(1), Descending)
assert(orderingA != orderingB)
val inputPlan = DummySparkPlan(
children = DummySparkPlan(outputOrdering = Seq(orderingA)) :: Nil,
requiredChildOrdering = Seq(Seq(orderingB)),
requiredChildDistribution = Seq(UnspecifiedDistribution)
)
val outputPlan = EnsureRequirements(sqlContext.conf).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
if (outputPlan.collect { case s: SortExec => true }.isEmpty) {
fail(s"Sort should have been added:\n$outputPlan")
}
}

test("EnsureRequirements doesn't add sort with cached sorted table") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTempTable("t1", "t2") {
val df = Seq(
(1, 1),
(3, 3)).toDF("k", "v")
val df2 = Seq(
(1, 2),
(3, 3)).toDF("k", "v")

df.filter("k > 0").repartition(2, df("k")).sortWithinPartitions(df("k"))
.registerTempTable("t1")
sqlContext.cacheTable("t1")
df2.filter("k > 0").repartition(2, df2("k")).sortWithinPartitions(df2("k"))
.registerTempTable("t2")
sqlContext.cacheTable("t2")

val joined = sqlContext.sql(
"select t2.v from t1 inner join t2 on t1.k = t2.k where t2.v < 10")
val outputPlan = joined.queryExecution.executedPlan
assert(
sortCount(outputPlan) == 2,
s"Extra sort should not have been added by SortMergeJoin:\n$outputPlan")

assert(joined.collect.toSeq == Seq(Row(2), Row(3)))
}
}
}

test("EnsureRequirements doesn't add sort with different column capitalization") {
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
withTempTable("t1", "t2") {
val df = Seq(
(1, 1),
(3, 3)).toDF("k", "v")
val df2 = Seq(
(1, 2),
(3, 3)).toDF("k", "v")

df.filter("k > 0").repartition(2, df("k")).sortWithinPartitions(df("k"))
.registerTempTable("t1")
sqlContext.cacheTable("t1")
// upper case K
df2.filter("k > 0").repartition(2, df2("k")).sortWithinPartitions(df2("K"))
.registerTempTable("t2")
sqlContext.cacheTable("t2")

val joined = sqlContext.sql(
"select t2.v from t1 inner join t2 on t1.k = t2.k where t2.v < 10")
val outputPlan = joined.queryExecution.executedPlan
assert(
sortCount(outputPlan) == 2,
s"Extra sort should not have been added by SortMergeJoin:\n$outputPlan")

assert(joined.collect.toSeq == Seq(Row(2), Row(3)))
}
}
}

// ---------------------------------------------------------------------------------------------

test("Reuse exchanges") {
Expand Down