diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index f404621399ce..946475a1e975 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} -import org.apache.spark.sql.execution.joins.ReorderJoinPredicates import org.apache.spark.sql.types.{BinaryType, DateType, DecimalType, TimestampType, _} import org.apache.spark.util.Utils @@ -104,7 +103,6 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { protected def preparations: Seq[Rule[SparkPlan]] = Seq( python.ExtractPythonUDFs, PlanSubqueries(sparkSession), - new ReorderJoinPredicates, EnsureRequirements(sparkSession.sessionState.conf), CollapseCodegenStages(sparkSession.sessionState.conf), ReuseExchange(sparkSession.sessionState.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 4e2ca37bc1a5..82f0b9f5cd06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -17,10 +17,14 @@ package org.apache.spark.sql.execution.exchange +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, + SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf /** @@ -248,6 +252,75 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { operator.withNewChildren(children) } + /** + * When the physical operators are created for JOIN, the ordering of join keys is based on order + * in which the join keys appear in the user query. That might not match with the output + * partitioning of the join node's children (thus leading to extra sort / shuffle being + * introduced). This rule will change the ordering of the join keys to match with the + * partitioning of the join nodes' children. + */ + def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { + def reorderJoinKeys( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftPartitioning: Partitioning, + rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { + + def reorder(expectedOrderOfKeys: Seq[Expression], + currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + val leftKeysBuffer = ArrayBuffer[Expression]() + val rightKeysBuffer = ArrayBuffer[Expression]() + + expectedOrderOfKeys.foreach(expression => { + val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) + leftKeysBuffer.append(leftKeys(index)) + rightKeysBuffer.append(rightKeys(index)) + }) + (leftKeysBuffer, rightKeysBuffer) + } + + if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { + leftPartitioning match { + case HashPartitioning(leftExpressions, _) + if leftExpressions.length == leftKeys.length && + leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => + reorder(leftExpressions, leftKeys) + + case _ => rightPartitioning match { + case HashPartitioning(rightExpressions, _) + if rightExpressions.length == rightKeys.length && + rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => + reorder(rightExpressions, rightKeys) + + case _ => (leftKeys, rightKeys) + } + } + } else { + (leftKeys, rightKeys) + } + } + + plan.transformUp { + case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, + right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, + left, right) + + case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, + left, right) + + case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) + SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + } + } + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator @ ShuffleExchangeExec(partitioning, child, _) => child.children match { @@ -255,6 +328,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { if (childPartitioning.guarantees(partitioning)) child else operator case _ => operator } - case operator: SparkPlan => ensureDistributionAndOrdering(operator) + case operator: SparkPlan => + ensureDistributionAndOrdering(reorderJoinPredicates(operator)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala deleted file mode 100644 index 534d8c5689c2..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ReorderJoinPredicates.scala +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.SparkPlan - -/** - * When the physical operators are created for JOIN, the ordering of join keys is based on order - * in which the join keys appear in the user query. That might not match with the output - * partitioning of the join node's children (thus leading to extra sort / shuffle being - * introduced). This rule will change the ordering of the join keys to match with the - * partitioning of the join nodes' children. - */ -class ReorderJoinPredicates extends Rule[SparkPlan] { - private def reorderJoinKeys( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - leftPartitioning: Partitioning, - rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { - - def reorder( - expectedOrderOfKeys: Seq[Expression], - currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { - val leftKeysBuffer = ArrayBuffer[Expression]() - val rightKeysBuffer = ArrayBuffer[Expression]() - - expectedOrderOfKeys.foreach(expression => { - val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression)) - leftKeysBuffer.append(leftKeys(index)) - rightKeysBuffer.append(rightKeys(index)) - }) - (leftKeysBuffer, rightKeysBuffer) - } - - if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { - leftPartitioning match { - case HashPartitioning(leftExpressions, _) - if leftExpressions.length == leftKeys.length && - leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => - reorder(leftExpressions, leftKeys) - - case _ => rightPartitioning match { - case HashPartitioning(rightExpressions, _) - if rightExpressions.length == rightKeys.length && - rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => - reorder(rightExpressions, rightKeys) - - case _ => (leftKeys, rightKeys) - } - } - } else { - (leftKeys, rightKeys) - } - } - - def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => - val (reorderedLeftKeys, reorderedRightKeys) = - reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, - left, right) - - case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => - val (reorderedLeftKeys, reorderedRightKeys) = - reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, - left, right) - - case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) => - val (reorderedLeftKeys, reorderedRightKeys) = - reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index ab18905e2ddb..9025859e9106 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -602,6 +602,37 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { ) } + test("SPARK-22042 ReorderJoinPredicates can break when child's partitioning is not decided") { + withTable("bucketed_table", "table1", "table2") { + df.write.format("parquet").saveAsTable("table1") + df.write.format("parquet").saveAsTable("table2") + df.write.format("parquet").bucketBy(8, "j", "k").saveAsTable("bucketed_table") + + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") { + checkAnswer( + sql(""" + |SELECT ab.i, ab.j, ab.k, c.i, c.j, c.k + |FROM ( + | SELECT a.i, a.j, a.k + | FROM bucketed_table a + | JOIN table1 b + | ON a.i = b.i + |) ab + |JOIN table2 c + |ON ab.i = c.i + |""".stripMargin), + sql(""" + |SELECT a.i, a.j, a.k, c.i, c.j, c.k + |FROM bucketed_table a + |JOIN table1 b + |ON a.i = b.i + |JOIN table2 c + |ON a.i = c.i + |""".stripMargin)) + } + } + } + test("error if there exists any malformed bucket files") { withTable("bucketed_table") { df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")