Skip to content

Commit 1c38e97

Browse files
committed
convert same order expressions to Seq
1 parent 2b39165 commit 1c38e97

File tree

7 files changed

+44
-15
lines changed

7 files changed

+44
-15
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1817,7 +1817,7 @@ class Analyzer(override val catalogManager: CatalogManager)
18171817
val newOrders = orders map {
18181818
case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) =>
18191819
if (index > 0 && index <= child.output.size) {
1820-
SortOrder(child.output(index - 1), direction, nullOrdering, Set.empty)
1820+
SortOrder(child.output(index - 1), direction, nullOrdering, Seq.empty)
18211821
} else {
18221822
s.failAnalysis(
18231823
s"ORDER BY position $index is not in select list " +

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,9 @@ package object dsl {
131131
}
132132

133133
def asc: SortOrder = SortOrder(expr, Ascending)
134-
def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty)
134+
def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Seq.empty)
135135
def desc: SortOrder = SortOrder(expr, Descending)
136-
def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Set.empty)
136+
def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Seq.empty)
137137
def as(alias: String): NamedExpression = Alias(expr, alias)()
138138
def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)()
139139
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ case class SortOrder(
6363
child: Expression,
6464
direction: SortDirection,
6565
nullOrdering: NullOrdering,
66-
sameOrderExpressions: Set[Expression])
66+
sameOrderExpressions: Seq[Expression])
6767
extends Expression with Unevaluable {
6868

69-
override def children: Seq[Expression] = child +: sameOrderExpressions.toSeq
69+
override def children: Seq[Expression] = child +: sameOrderExpressions
7070

7171
override def checkInputDataTypes(): TypeCheckResult = {
7272
if (RowOrdering.isOrderable(dataType)) {
@@ -94,7 +94,7 @@ object SortOrder {
9494
def apply(
9595
child: Expression,
9696
direction: SortDirection,
97-
sameOrderExpressions: Set[Expression] = Set.empty): SortOrder = {
97+
sameOrderExpressions: Seq[Expression] = Seq.empty): SortOrder = {
9898
new SortOrder(child, direction, direction.defaultNullOrdering, sameOrderExpressions)
9999
}
100100

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1893,7 +1893,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg
18931893
} else {
18941894
direction.defaultNullOrdering
18951895
}
1896-
SortOrder(expression(ctx.expression), direction, nullOrdering, Set.empty)
1896+
SortOrder(expression(ctx.expression), direction, nullOrdering, Seq.empty)
18971897
}
18981898

18991899
/**

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,7 +1226,7 @@ class Column(val expr: Expression) extends Logging {
12261226
* @group expr_ops
12271227
* @since 2.1.0
12281228
*/
1229-
def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Set.empty) }
1229+
def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Seq.empty) }
12301230

12311231
/**
12321232
* Returns a sort expression based on the descending order of the column,
@@ -1242,7 +1242,7 @@ class Column(val expr: Expression) extends Logging {
12421242
* @group expr_ops
12431243
* @since 2.1.0
12441244
*/
1245-
def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Set.empty) }
1245+
def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Seq.empty) }
12461246

12471247
/**
12481248
* Returns a sort expression based on ascending order of the column.
@@ -1273,7 +1273,7 @@ class Column(val expr: Expression) extends Logging {
12731273
* @group expr_ops
12741274
* @since 2.1.0
12751275
*/
1276-
def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Set.empty) }
1276+
def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Seq.empty) }
12771277

12781278
/**
12791279
* Returns a sort expression based on ascending order of the column,
@@ -1289,7 +1289,7 @@ class Column(val expr: Expression) extends Logging {
12891289
* @group expr_ops
12901290
* @since 2.1.0
12911291
*/
1292-
def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Set.empty) }
1292+
def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Seq.empty) }
12931293

12941294
/**
12951295
* Prints the expression to the console for debugging purposes.

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ case class SortMergeJoinExec(
6868
val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering)
6969
val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering)
7070
leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) =>
71-
// Also add the right key and its `sameOrderExpressions`
72-
SortOrder(lKey.child, Ascending, lKey.sameOrderExpressions + rKey.child ++ rKey
73-
.sameOrderExpressions)
71+
// Also add expressions from right side sort order
72+
val sameOrderExpressions = ExpressionSet(lKey.children ++ rKey.children) - lKey.child
73+
SortOrder(lKey.child, Ascending, sameOrderExpressions.toSeq)
7474
}
7575
// For left and right outer joins, the output is ordered by the streamed input's join keys.
7676
case LeftOuter => getKeyOrdering(leftKeys, left.outputOrdering)
@@ -96,7 +96,8 @@ case class SortMergeJoinExec(
9696
val requiredOrdering = requiredOrders(keys)
9797
if (SortOrder.orderingSatisfies(childOutputOrdering, requiredOrdering)) {
9898
keys.zip(childOutputOrdering).map { case (key, childOrder) =>
99-
SortOrder(key, Ascending, childOrder.sameOrderExpressions + childOrder.child - key)
99+
val sameOrderExpressionsSet = ExpressionSet(childOrder.children) - key
100+
SortOrder(key, Ascending, sameOrderExpressionsSet.toSeq)
100101
}
101102
} else {
102103
requiredOrdering

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,6 +1090,34 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
10901090
}
10911091
}
10921092

1093+
test("sort order doesn't have repeated expressions") {
1094+
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
1095+
withTempView("t1") {
1096+
withTempView("t2") {
1097+
spark.range(10).repartition($"id").createTempView("t1")
1098+
spark.range(20).repartition($"id").createTempView("t2")
1099+
val planned = sql(
1100+
"""
1101+
| SELECT t12.id, t1.id
1102+
| FROM (SELECT t1.id FROM t1, t2 WHERE t1.id * 2 = t2.id) t12, t1
1103+
| where 2 * t12.id = t1.id
1104+
""".stripMargin).queryExecution.executedPlan
1105+
1106+
// t12 is already sorted on `t1.id * 2`. and we need to sort it on `2 * t12.id`
1107+
// for 2nd join. So sorting on t12 can be avoided
1108+
val sortNodes = planned.collect { case s: SortExec => s }
1109+
assert(sortNodes.size == 3)
1110+
val outputOrdering = planned.outputOrdering
1111+
assert(outputOrdering.size == 1)
1112+
// Sort order should have 3 childrens, not 4. This is because t1.id*2 and 2*t1.id are same
1113+
assert(outputOrdering.head.children.size == 3)
1114+
assert(outputOrdering.head.children.count(_.isInstanceOf[AttributeReference]) == 2)
1115+
assert(outputOrdering.head.children.count(_.isInstanceOf[Multiply]) == 1)
1116+
}
1117+
}
1118+
}
1119+
}
1120+
10931121
test("aliases to expressions should not be replaced") {
10941122
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
10951123
withTempView("df1", "df2") {

0 commit comments

Comments
 (0)