From 06858cd5a7e1b11cd5fd1edb206296d75550a8d4 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 11 Jun 2018 19:59:19 +0200 Subject: [PATCH 1/4] [SPARK-24495][SQL] EnsureRequirement returns worng plan when reordering equal keys --- .../sql/execution/exchange/EnsureRequirements.scala | 10 +++++++++- .../org/apache/spark/sql/execution/PlannerSuite.scala | 11 +++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index e3d28388c547..91a8f34cd971 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.exchange +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.expressions._ @@ -227,9 +228,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { val leftKeysBuffer = ArrayBuffer[Expression]() val rightKeysBuffer = ArrayBuffer[Expression]() + val alreadyUsedIndexes = mutable.Set[Int]() + val keysAndIndexes = currentOrderOfKeys.zipWithIndex expectedOrderOfKeys.foreach(expression => { - val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) + val index = keysAndIndexes.find { case (e, idx) => + // As we may have the same key used many times, we need to filter out its occurrence we + // have already used. + e.semanticEquals(expression) && !alreadyUsedIndexes.contains(idx) + }.map(_._2).get + alreadyUsedIndexes += index leftKeysBuffer.append(leftKeys(index)) rightKeysBuffer.append(rightKeys(index)) }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 98a50fbd52b4..d7496d492492 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -679,6 +679,17 @@ class PlannerSuite extends SharedSQLContext { } assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0)) } + + test("SPARK-24495: EnsureRequirements can return wrong plan when reusing the same key in join") { + withSQLConf(("spark.sql.shuffle.partitions", "1"), + ("spark.sql.constraintPropagation.enabled", "false"), + ("spark.sql.autoBroadcastJoinThreshold", "-1")) { + val df1 = spark.range(100) + val df2 = spark.range(100).select(($"id" * 2).as("b1"), (- $"id").as("b2")) + val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2") + assert(res.collect().sameElements(Array(Row(0, 0, 0)))) + } + } } // Used for unit-testing EnsureRequirements From 341f1b28572df079d5ed6868c6983610d2bd951f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 12 Jun 2018 12:58:27 +0200 Subject: [PATCH 2/4] address comments --- .../sql/execution/exchange/EnsureRequirements.scala | 10 ++++++---- .../org/apache/spark/sql/execution/PlannerSuite.scala | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 91a8f34cd971..ad95879d86f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -228,16 +228,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { val leftKeysBuffer = ArrayBuffer[Expression]() val rightKeysBuffer = ArrayBuffer[Expression]() - val alreadyUsedIndexes = mutable.Set[Int]() + val pickedIndexes = mutable.Set[Int]() val keysAndIndexes = currentOrderOfKeys.zipWithIndex expectedOrderOfKeys.foreach(expression => { val index = keysAndIndexes.find { case (e, idx) => // As we may have the same key used many times, we need to filter out its occurrence we // have already used. - e.semanticEquals(expression) && !alreadyUsedIndexes.contains(idx) + e.semanticEquals(expression) && !pickedIndexes.contains(idx) }.map(_._2).get - alreadyUsedIndexes += index + pickedIndexes += index leftKeysBuffer.append(leftKeys(index)) rightKeysBuffer.append(rightKeys(index)) }) @@ -278,7 +278,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { * partitioning of the join nodes' children. */ private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { - plan.transformUp { + plan match { case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => val (reorderedLeftKeys, reorderedRightKeys) = @@ -296,6 +296,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + + case other => other } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index d7496d492492..78e955684558 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -684,7 +684,7 @@ class PlannerSuite extends SharedSQLContext { withSQLConf(("spark.sql.shuffle.partitions", "1"), ("spark.sql.constraintPropagation.enabled", "false"), ("spark.sql.autoBroadcastJoinThreshold", "-1")) { - val df1 = spark.range(100) + val df1 = spark.range(100).repartition(2, $"id", $"id") val df2 = spark.range(100).select(($"id" * 2).as("b1"), (- $"id").as("b2")) val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2") assert(res.collect().sameElements(Array(Row(0, 0, 0)))) From 40abcff647748719dc775fd3cbce09661323ca8b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 13 Jun 2018 16:01:06 +0200 Subject: [PATCH 3/4] address comments --- .../org/apache/spark/sql/JoinSuite.scala | 11 ++++++++++ .../spark/sql/execution/PlannerSuite.scala | 20 ++++++++++++------- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8fa747465cb1..e7641b1642cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -882,4 +882,15 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(3, 8, 7, 2) :: Row(3, 8, 4, 2) :: Nil) } } + + test("SPARK-24495: EnsureRequirements can return wrong plan when reusing the same key in join") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df1 = spark.range(0, 100, 1, 2) + val df2 = spark.range(100).select($"id".as("b1"), (- $"id").as("b2")) + val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2") + checkAnswer(res, Row(0, 0, 0)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 78e955684558..0dfb4d5714f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -681,13 +681,19 @@ class PlannerSuite extends SharedSQLContext { } test("SPARK-24495: EnsureRequirements can return wrong plan when reusing the same key in join") { - withSQLConf(("spark.sql.shuffle.partitions", "1"), - ("spark.sql.constraintPropagation.enabled", "false"), - ("spark.sql.autoBroadcastJoinThreshold", "-1")) { - val df1 = spark.range(100).repartition(2, $"id", $"id") - val df2 = spark.range(100).select(($"id" * 2).as("b1"), (- $"id").as("b2")) - val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2") - assert(res.collect().sameElements(Array(Row(0, 0, 0)))) + val plan1 = DummySparkPlan(outputOrdering = Seq(orderingA), + outputPartitioning = HashPartitioning(exprA :: exprA :: Nil, 5)) + val plan2 = DummySparkPlan(outputOrdering = Seq(orderingB), + outputPartitioning = HashPartitioning(exprB :: Nil, 5)) + val smjExec = SortMergeJoinExec( + exprA :: exprA :: Nil, exprB :: exprC :: Nil, Inner, None, plan1, plan2) + + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec) + outputPlan match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _) => + assert(leftKeys == Seq(exprA, exprA)) + assert(rightKeys.contains(exprB) && rightKeys.contains(exprC)) + case _ => fail() } } } From 6553c27bc23efbaed268556b0cb19405fa693de9 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 13 Jun 2018 20:20:44 +0200 Subject: [PATCH 4/4] address comments --- sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala | 4 ++-- .../scala/org/apache/spark/sql/execution/PlannerSuite.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index e7641b1642cc..44767dfc9249 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -883,13 +883,13 @@ class JoinSuite extends QueryTest with SharedSQLContext { } } - test("SPARK-24495: EnsureRequirements can return wrong plan when reusing the same key in join") { + test("SPARK-24495: Join may return wrong result when having duplicated equal-join keys") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val df1 = spark.range(0, 100, 1, 2) val df2 = spark.range(100).select($"id".as("b1"), (- $"id").as("b2")) - val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2") + val res = df1.join(df2, $"id" === $"b1" && $"id" === $"b2").select($"b1", $"b2", $"id") checkAnswer(res, Row(0, 0, 0)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 0dfb4d5714f8..d26662db5015 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -692,7 +692,7 @@ class PlannerSuite extends SharedSQLContext { outputPlan match { case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _) => assert(leftKeys == Seq(exprA, exprA)) - assert(rightKeys.contains(exprB) && rightKeys.contains(exprC)) + assert(rightKeys == Seq(exprB, exprC)) case _ => fail() } }