Skip to content

Commit fdadc4b

Browse files
mgaido91gatorsmile
authored andcommitted
[SPARK-24495][SQL] EnsureRequirement returns wrong plan when reordering equal keys
## What changes were proposed in this pull request? `EnsureRequirement` in its `reorder` method currently assumes that the same key appears only once in the join condition. This of course might not be the case, and when it is not satisfied, it returns a wrong plan which produces a wrong result of the query. ## How was this patch tested? added UT Author: Marco Gaido <[email protected]> Closes #21529 from mgaido91/SPARK-24495.
1 parent 534065e commit fdadc4b

File tree

3 files changed

+40
-2
lines changed

3 files changed

+40
-2
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.execution.exchange
1919

20+
import scala.collection.mutable
2021
import scala.collection.mutable.ArrayBuffer
2122

2223
import org.apache.spark.sql.catalyst.expressions._
@@ -227,9 +228,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
227228
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
228229
val leftKeysBuffer = ArrayBuffer[Expression]()
229230
val rightKeysBuffer = ArrayBuffer[Expression]()
231+
val pickedIndexes = mutable.Set[Int]()
232+
val keysAndIndexes = currentOrderOfKeys.zipWithIndex
230233

231234
expectedOrderOfKeys.foreach(expression => {
232-
val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
235+
val index = keysAndIndexes.find { case (e, idx) =>
236+
// As we may have the same key used many times, we need to filter out its occurrence we
237+
// have already used.
238+
e.semanticEquals(expression) && !pickedIndexes.contains(idx)
239+
}.map(_._2).get
240+
pickedIndexes += index
233241
leftKeysBuffer.append(leftKeys(index))
234242
rightKeysBuffer.append(rightKeys(index))
235243
})
@@ -270,7 +278,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
270278
* partitioning of the join nodes' children.
271279
*/
272280
private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
273-
plan.transformUp {
281+
plan match {
274282
case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left,
275283
right) =>
276284
val (reorderedLeftKeys, reorderedRightKeys) =
@@ -288,6 +296,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
288296
val (reorderedLeftKeys, reorderedRightKeys) =
289297
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
290298
SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right)
299+
300+
case other => other
291301
}
292302
}
293303

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,4 +882,15 @@ class JoinSuite extends QueryTest with SharedSQLContext {
882882
checkAnswer(df, Row(3, 8, 7, 2) :: Row(3, 8, 4, 2) :: Nil)
883883
}
884884
}
885+
886+
test("SPARK-24495: Join may return wrong result when having duplicated equal-join keys") {
887+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1",
888+
SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false",
889+
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
890+
val df1 = spark.range(0, 100, 1, 2)
891+
val df2 = spark.range(100).select($"id".as("b1"), (- $"id").as("b2"))
892+
val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2").select($"b1", $"b2", $"id")
893+
checkAnswer(res, Row(0, 0, 0))
894+
}
895+
}
885896
}

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,23 @@ class PlannerSuite extends SharedSQLContext {
680680
assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0))
681681
}
682682

683+
test("SPARK-24495: EnsureRequirements can return wrong plan when reusing the same key in join") {
684+
val plan1 = DummySparkPlan(outputOrdering = Seq(orderingA),
685+
outputPartitioning = HashPartitioning(exprA :: exprA :: Nil, 5))
686+
val plan2 = DummySparkPlan(outputOrdering = Seq(orderingB),
687+
outputPartitioning = HashPartitioning(exprB :: Nil, 5))
688+
val smjExec = SortMergeJoinExec(
689+
exprA :: exprA :: Nil, exprB :: exprC :: Nil, Inner, None, plan1, plan2)
690+
691+
val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec)
692+
outputPlan match {
693+
case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _) =>
694+
assert(leftKeys == Seq(exprA, exprA))
695+
assert(rightKeys == Seq(exprB, exprC))
696+
case _ => fail()
697+
}
698+
}
699+
683700
test("SPARK-24500: create union with stream of children") {
684701
val df = Union(Stream(
685702
Range(1, 1, 1, 1),

0 commit comments

Comments
 (0)