Skip to content

Commit 06858cd

Browse files
committed
[SPARK-24495][SQL] EnsureRequirement returns worng plan when reordering equal keys
1 parent e76b012 commit 06858cd

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.execution.exchange
1919

20+
import scala.collection.mutable
2021
import scala.collection.mutable.ArrayBuffer
2122

2223
import org.apache.spark.sql.catalyst.expressions._
@@ -227,9 +228,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
227228
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
228229
val leftKeysBuffer = ArrayBuffer[Expression]()
229230
val rightKeysBuffer = ArrayBuffer[Expression]()
231+
val alreadyUsedIndexes = mutable.Set[Int]()
232+
val keysAndIndexes = currentOrderOfKeys.zipWithIndex
230233

231234
expectedOrderOfKeys.foreach(expression => {
232-
val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
235+
val index = keysAndIndexes.find { case (e, idx) =>
236+
// As we may have the same key used many times, we need to filter out its occurrence we
237+
// have already used.
238+
e.semanticEquals(expression) && !alreadyUsedIndexes.contains(idx)
239+
}.map(_._2).get
240+
alreadyUsedIndexes += index
233241
leftKeysBuffer.append(leftKeys(index))
234242
rightKeysBuffer.append(rightKeys(index))
235243
})

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,17 @@ class PlannerSuite extends SharedSQLContext {
679679
}
680680
assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0))
681681
}
682+
683+
test("SPARK-24495: EnsureRequirements can return wrong plan when reusing the same key in join") {
684+
withSQLConf(("spark.sql.shuffle.partitions", "1"),
685+
("spark.sql.constraintPropagation.enabled", "false"),
686+
("spark.sql.autoBroadcastJoinThreshold", "-1")) {
687+
val df1 = spark.range(100)
688+
val df2 = spark.range(100).select(($"id" * 2).as("b1"), (- $"id").as("b2"))
689+
val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2")
690+
assert(res.collect().sameElements(Array(Row(0, 0, 0))))
691+
}
692+
}
682693
}
683694

684695
// Used for unit-testing EnsureRequirements

0 commit comments

Comments
 (0)