diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 82f0b9f5cd060..c8e236be28b42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -252,54 +252,56 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { operator.withNewChildren(children) } - /** - * When the physical operators are created for JOIN, the ordering of join keys is based on order - * in which the join keys appear in the user query. That might not match with the output - * partitioning of the join node's children (thus leading to extra sort / shuffle being - * introduced). This rule will change the ordering of the join keys to match with the - * partitioning of the join nodes' children. - */ - def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { - def reorderJoinKeys( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - leftPartitioning: Partitioning, - rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { - - def reorder(expectedOrderOfKeys: Seq[Expression], - currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - val leftKeysBuffer = ArrayBuffer[Expression]() - val rightKeysBuffer = ArrayBuffer[Expression]() + private def reorder( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + expectedOrderOfKeys: Seq[Expression], + currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + val leftKeysBuffer = ArrayBuffer[Expression]() + val rightKeysBuffer = ArrayBuffer[Expression]() - expectedOrderOfKeys.foreach(expression => { - val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) - leftKeysBuffer.append(leftKeys(index)) - rightKeysBuffer.append(rightKeys(index)) - }) - (leftKeysBuffer, rightKeysBuffer) - } + expectedOrderOfKeys.foreach(expression => { + val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) + leftKeysBuffer.append(leftKeys(index)) + rightKeysBuffer.append(rightKeys(index)) + }) + (leftKeysBuffer, rightKeysBuffer) + } - if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { - leftPartitioning match { - case HashPartitioning(leftExpressions, _) - if leftExpressions.length == leftKeys.length && - leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => - reorder(leftExpressions, leftKeys) + private def reorderJoinKeys( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftPartitioning: Partitioning, + rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { + if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { + leftPartitioning match { + case HashPartitioning(leftExpressions, _) + if leftExpressions.length == leftKeys.length && + leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => + reorder(leftKeys, rightKeys, leftExpressions, leftKeys) - case _ => rightPartitioning match { - case HashPartitioning(rightExpressions, _) - if rightExpressions.length == rightKeys.length && - rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => - reorder(rightExpressions, rightKeys) + case _ => rightPartitioning match { + case HashPartitioning(rightExpressions, _) + if rightExpressions.length == rightKeys.length && + rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => + reorder(leftKeys, rightKeys, rightExpressions, rightKeys) - case _ => (leftKeys, rightKeys) - } + case _ => (leftKeys, rightKeys) } - } else { - (leftKeys, rightKeys) } + } else { + (leftKeys, rightKeys) } + } + /** + * When the physical operators are created for JOIN, the ordering of join keys is based on order + * in which the join keys appear in the user query. That might not match with the output + * partitioning of the join node's children (thus leading to extra sort / shuffle being + * introduced). This rule will change the ordering of the join keys to match with the + * partitioning of the join nodes' children. + */ + private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { plan.transformUp { case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 9025859e91066..fb61fa716b946 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -620,7 +620,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { |) ab |JOIN table2 c |ON ab.i = c.i - |""".stripMargin), + """.stripMargin), sql(""" |SELECT a.i, a.j, a.k, c.i, c.j, c.k |FROM bucketed_table a @@ -628,7 +628,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { |ON a.i = b.i |JOIN table2 c |ON a.i = c.i - |""".stripMargin)) + """.stripMargin)) } } }