File tree Expand file tree Collapse file tree 2 files changed +20
-1
lines changed
main/scala/org/apache/spark/sql/execution/exchange
test/scala/org/apache/spark/sql/execution Expand file tree Collapse file tree 2 files changed +20
-1
lines changed Original file line number Diff line number Diff line change 1717
1818package org .apache .spark .sql .execution .exchange
1919
20+ import scala .collection .mutable
2021import scala .collection .mutable .ArrayBuffer
2122
2223import org .apache .spark .sql .catalyst .expressions ._
@@ -227,9 +228,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
227228 currentOrderOfKeys : Seq [Expression ]): (Seq [Expression ], Seq [Expression ]) = {
228229 val leftKeysBuffer = ArrayBuffer [Expression ]()
229230 val rightKeysBuffer = ArrayBuffer [Expression ]()
231+ val alreadyUsedIndexes = mutable.Set [Int ]()
232+ val keysAndIndexes = currentOrderOfKeys.zipWithIndex
230233
231234 expectedOrderOfKeys.foreach(expression => {
232- val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
235+ val index = keysAndIndexes.find { case (e, idx) =>
236+ // As we may have the same key used many times, we need to filter out its occurrence we
237+ // have already used.
238+ e.semanticEquals(expression) && ! alreadyUsedIndexes.contains(idx)
239+ }.map(_._2).get
240+ alreadyUsedIndexes += index
233241 leftKeysBuffer.append(leftKeys(index))
234242 rightKeysBuffer.append(rightKeys(index))
235243 })
Original file line number Diff line number Diff line change @@ -679,6 +679,17 @@ class PlannerSuite extends SharedSQLContext {
679679 }
680680 assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning (0 ))
681681 }
682+
683+ test(" SPARK-24495: EnsureRequirements can return wrong plan when reusing the same key in join" ) {
684+ withSQLConf((" spark.sql.shuffle.partitions" , " 1" ),
685+ (" spark.sql.constraintPropagation.enabled" , " false" ),
686+ (" spark.sql.autoBroadcastJoinThreshold" , " -1" )) {
687+ val df1 = spark.range(100 )
688+ val df2 = spark.range(100 ).select(($" id" * 2 ).as(" b1" ), (- $" id" ).as(" b2" ))
689+ val res = df1.join(df2, $" id" === $" b1" && $" id" === $" b2" )
690+ assert(res.collect().sameElements(Array (Row (0 , 0 , 0 ))))
691+ }
692+ }
682693}
683694
684695// Used for unit-testing EnsureRequirements
You can’t perform that action at this time.
0 commit comments