Skip to content

Commit e9c91ba

Browse files
wangzhenhuacloud-fan
authored andcommitted
[SPARK-20010][SQL] Sort information is lost after sort merge join
## What changes were proposed in this pull request? After sort merge join for inner join, now we only keep left key ordering. However, after inner join, right key has the same value and order as left key. So if we need another smj on right key, we will unnecessarily add a sort which causes additional cost. As a more complicated example, A join B on A.key = B.key join C on B.key = C.key join D on A.key = D.key. We will unnecessarily add a sort on B.key when join {A, B} and C, and add a sort on A.key when join {A, B, C} and D. To fix this, we need to propagate all sorted information (equivalent expressions) from bottom up through `outputOrdering` and `SortOrder`. ## How was this patch tested? Test cases are added. Author: wangzhenhua <[email protected]> Closes #17339 from wzhfy/sortEnhance.
1 parent 10691d3 commit e9c91ba

File tree

9 files changed

+81
-18
lines changed

9 files changed

+81
-18
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -966,9 +966,9 @@ class Analyzer(
966966
case s @ Sort(orders, global, child)
967967
if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) =>
968968
val newOrders = orders map {
969-
case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering) =>
969+
case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) =>
970970
if (index > 0 && index <= child.output.size) {
971-
SortOrder(child.output(index - 1), direction, nullOrdering)
971+
SortOrder(child.output(index - 1), direction, nullOrdering, Set.empty)
972972
} else {
973973
s.failAnalysis(
974974
s"ORDER BY position $index is not in select list " +

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan]
3636
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
3737
case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
3838
val newOrders = s.order.map {
39-
case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) =>
39+
case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) =>
4040
val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
4141
withOrigin(order.origin)(order.copy(child = newOrdinal))
4242
case other => other

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ package object dsl {
109109
def cast(to: DataType): Expression = Cast(expr, to)
110110

111111
def asc: SortOrder = SortOrder(expr, Ascending)
112-
def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast)
112+
def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty)
113113
def desc: SortOrder = SortOrder(expr, Descending)
114-
def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst)
114+
def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Set.empty)
115115
def as(alias: String): NamedExpression = Alias(expr, alias)()
116116
def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)()
117117
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,15 @@ case object NullsLast extends NullOrdering{
5353
/**
5454
* An expression that can be used to sort a tuple. This class extends expression primarily so that
5555
* transformations over expression will descend into its child.
56+
* `sameOrderExpressions` is a set of expressions with the same sort order as the child. It is
57+
* derived from equivalence relation in an operator, e.g. left/right keys of an inner sort merge
58+
* join.
5659
*/
57-
case class SortOrder(child: Expression, direction: SortDirection, nullOrdering: NullOrdering)
60+
case class SortOrder(
61+
child: Expression,
62+
direction: SortDirection,
63+
nullOrdering: NullOrdering,
64+
sameOrderExpressions: Set[Expression])
5865
extends UnaryExpression with Unevaluable {
5966

6067
/** Sort order is not foldable because we don't have an eval for it. */
@@ -75,11 +82,19 @@ case class SortOrder(child: Expression, direction: SortDirection, nullOrdering:
7582
override def sql: String = child.sql + " " + direction.sql + " " + nullOrdering.sql
7683

7784
def isAscending: Boolean = direction == Ascending
85+
86+
def satisfies(required: SortOrder): Boolean = {
87+
(sameOrderExpressions + child).exists(required.child.semanticEquals) &&
88+
direction == required.direction && nullOrdering == required.nullOrdering
89+
}
7890
}
7991

8092
object SortOrder {
81-
def apply(child: Expression, direction: SortDirection): SortOrder = {
82-
new SortOrder(child, direction, direction.defaultNullOrdering)
93+
def apply(
94+
child: Expression,
95+
direction: SortDirection,
96+
sameOrderExpressions: Set[Expression] = Set.empty): SortOrder = {
97+
new SortOrder(child, direction, direction.defaultNullOrdering, sameOrderExpressions)
8398
}
8499
}
85100

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
12291229
} else {
12301230
direction.defaultNullOrdering
12311231
}
1232-
SortOrder(expression(ctx.expression), direction, nullOrdering)
1232+
SortOrder(expression(ctx.expression), direction, nullOrdering, Set.empty)
12331233
}
12341234

12351235
/**

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,7 @@ class Column(val expr: Expression) extends Logging {
10371037
* @group expr_ops
10381038
* @since 2.1.0
10391039
*/
1040-
def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst) }
1040+
def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Set.empty) }
10411041

10421042
/**
10431043
* Returns a descending ordering used in sorting, where null values appear after non-null values.
@@ -1052,7 +1052,7 @@ class Column(val expr: Expression) extends Logging {
10521052
* @group expr_ops
10531053
* @since 2.1.0
10541054
*/
1055-
def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast) }
1055+
def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Set.empty) }
10561056

10571057
/**
10581058
* Returns an ascending ordering used in sorting.
@@ -1082,7 +1082,7 @@ class Column(val expr: Expression) extends Logging {
10821082
* @group expr_ops
10831083
* @since 2.1.0
10841084
*/
1085-
def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst) }
1085+
def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Set.empty) }
10861086

10871087
/**
10881088
* Returns an ordering used in sorting, where null values appear after non-null values.
@@ -1097,7 +1097,7 @@ class Column(val expr: Expression) extends Logging {
10971097
* @group expr_ops
10981098
* @since 2.1.0
10991099
*/
1100-
def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast) }
1100+
def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Set.empty) }
11011101

11021102
/**
11031103
* Prints the expression to the console for debugging purpose.

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
241241
} else {
242242
requiredOrdering.zip(child.outputOrdering).forall {
243243
case (requiredOrder, childOutputOrder) =>
244-
requiredOrder.semanticEquals(childOutputOrder)
244+
childOutputOrder.satisfies(requiredOrder)
245245
}
246246
}
247247

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,37 @@ case class SortMergeJoinExec(
8181
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
8282

8383
override def outputOrdering: Seq[SortOrder] = joinType match {
84+
// For inner join, orders of both sides keys should be kept.
85+
case Inner =>
86+
val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering)
87+
val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering)
88+
leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) =>
89+
// Also add the right key and its `sameOrderExpressions`
90+
SortOrder(lKey.child, Ascending, lKey.sameOrderExpressions + rKey.child ++ rKey
91+
.sameOrderExpressions)
92+
}
8493
// For left and right outer joins, the output is ordered by the streamed input's join keys.
85-
case LeftOuter => requiredOrders(leftKeys)
86-
case RightOuter => requiredOrders(rightKeys)
94+
case LeftOuter => getKeyOrdering(leftKeys, left.outputOrdering)
95+
case RightOuter => getKeyOrdering(rightKeys, right.outputOrdering)
8796
// There are null rows in both streams, so there is no order.
8897
case FullOuter => Nil
89-
case _: InnerLike | LeftExistence(_) => requiredOrders(leftKeys)
98+
case LeftExistence(_) => getKeyOrdering(leftKeys, left.outputOrdering)
9099
case x =>
91100
throw new IllegalArgumentException(
92101
s"${getClass.getSimpleName} should not take $x as the JoinType")
93102
}
94103

104+
/**
105+
* For SMJ, child's output must have been sorted on key or expressions with the same order as
106+
* key, so we can get ordering for key from child's output ordering.
107+
*/
108+
private def getKeyOrdering(keys: Seq[Expression], childOutputOrdering: Seq[SortOrder])
109+
: Seq[SortOrder] = {
110+
keys.zip(childOutputOrdering).map { case (key, childOrder) =>
111+
SortOrder(key, Ascending, childOrder.sameOrderExpressions + childOrder.child - key)
112+
}
113+
}
114+
95115
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
96116
requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
97117

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,14 +477,18 @@ class PlannerSuite extends SharedSQLContext {
477477

478478
private val exprA = Literal(1)
479479
private val exprB = Literal(2)
480+
private val exprC = Literal(3)
480481
private val orderingA = SortOrder(exprA, Ascending)
481482
private val orderingB = SortOrder(exprB, Ascending)
483+
private val orderingC = SortOrder(exprC, Ascending)
482484
private val planA = DummySparkPlan(outputOrdering = Seq(orderingA),
483485
outputPartitioning = HashPartitioning(exprA :: Nil, 5))
484486
private val planB = DummySparkPlan(outputOrdering = Seq(orderingB),
485487
outputPartitioning = HashPartitioning(exprB :: Nil, 5))
488+
private val planC = DummySparkPlan(outputOrdering = Seq(orderingC),
489+
outputPartitioning = HashPartitioning(exprC :: Nil, 5))
486490

487-
assert(orderingA != orderingB)
491+
assert(orderingA != orderingB && orderingA != orderingC && orderingB != orderingC)
488492

489493
private def assertSortRequirementsAreSatisfied(
490494
childPlan: SparkPlan,
@@ -508,6 +512,30 @@ class PlannerSuite extends SharedSQLContext {
508512
}
509513
}
510514

515+
test("EnsureRequirements skips sort when either side of join keys is required after inner SMJ") {
516+
val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB)
517+
// Both left and right keys should be sorted after the SMJ.
518+
Seq(orderingA, orderingB).foreach { ordering =>
519+
assertSortRequirementsAreSatisfied(
520+
childPlan = innerSmj,
521+
requiredOrdering = Seq(ordering),
522+
shouldHaveSort = false)
523+
}
524+
}
525+
526+
test("EnsureRequirements skips sort when key order of a parent SMJ is propagated from its " +
527+
"child SMJ") {
528+
val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB)
529+
val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, Inner, None, childSmj, planC)
530+
// After the second SMJ, exprA, exprB and exprC should all be sorted.
531+
Seq(orderingA, orderingB, orderingC).foreach { ordering =>
532+
assertSortRequirementsAreSatisfied(
533+
childPlan = parentSmj,
534+
requiredOrdering = Seq(ordering),
535+
shouldHaveSort = false)
536+
}
537+
}
538+
511539
test("EnsureRequirements for sort operator after left outer sort merge join") {
512540
// Only left key is sorted after left outer SMJ (thus doesn't need a sort).
513541
val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, LeftOuter, None, planA, planB)

0 commit comments

Comments
 (0)