@@ -252,54 +252,56 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
252252 operator.withNewChildren(children)
253253 }
254254
255- /**
256- * When the physical operators are created for JOIN, the ordering of join keys is based on order
257- * in which the join keys appear in the user query. That might not match with the output
258- * partitioning of the join node's children (thus leading to extra sort / shuffle being
259- * introduced). This rule will change the ordering of the join keys to match with the
260- * partitioning of the join nodes' children.
261- */
262- def reorderJoinPredicates (plan : SparkPlan ): SparkPlan = {
263- def reorderJoinKeys (
264- leftKeys : Seq [Expression ],
265- rightKeys : Seq [Expression ],
266- leftPartitioning : Partitioning ,
267- rightPartitioning : Partitioning ): (Seq [Expression ], Seq [Expression ]) = {
268-
269- def reorder (expectedOrderOfKeys : Seq [Expression ],
270- currentOrderOfKeys : Seq [Expression ]): (Seq [Expression ], Seq [Expression ]) = {
271- val leftKeysBuffer = ArrayBuffer [Expression ]()
272- val rightKeysBuffer = ArrayBuffer [Expression ]()
255+ private def reorder (
256+ leftKeys : Seq [Expression ],
257+ rightKeys : Seq [Expression ],
258+ expectedOrderOfKeys : Seq [Expression ],
259+ currentOrderOfKeys : Seq [Expression ]): (Seq [Expression ], Seq [Expression ]) = {
260+ val leftKeysBuffer = ArrayBuffer [Expression ]()
261+ val rightKeysBuffer = ArrayBuffer [Expression ]()
273262
274- expectedOrderOfKeys.foreach(expression => {
275- val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
276- leftKeysBuffer.append(leftKeys(index))
277- rightKeysBuffer.append(rightKeys(index))
278- })
279- (leftKeysBuffer, rightKeysBuffer)
280- }
263+ expectedOrderOfKeys.foreach(expression => {
264+ val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
265+ leftKeysBuffer.append(leftKeys(index))
266+ rightKeysBuffer.append(rightKeys(index))
267+ })
268+ (leftKeysBuffer, rightKeysBuffer)
269+ }
281270
282- if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
283- leftPartitioning match {
284- case HashPartitioning (leftExpressions, _)
285- if leftExpressions.length == leftKeys.length &&
286- leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
287- reorder(leftExpressions, leftKeys)
271+ private def reorderJoinKeys (
272+ leftKeys : Seq [Expression ],
273+ rightKeys : Seq [Expression ],
274+ leftPartitioning : Partitioning ,
275+ rightPartitioning : Partitioning ): (Seq [Expression ], Seq [Expression ]) = {
276+ if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
277+ leftPartitioning match {
278+ case HashPartitioning (leftExpressions, _)
279+ if leftExpressions.length == leftKeys.length &&
280+ leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
281+ reorder(leftKeys, rightKeys, leftExpressions, leftKeys)
288282
289- case _ => rightPartitioning match {
290- case HashPartitioning (rightExpressions, _)
291- if rightExpressions.length == rightKeys.length &&
292- rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
293- reorder(rightExpressions, rightKeys)
283+ case _ => rightPartitioning match {
284+ case HashPartitioning (rightExpressions, _)
285+ if rightExpressions.length == rightKeys.length &&
286+ rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
287+ reorder(leftKeys, rightKeys, rightExpressions, rightKeys)
294288
295- case _ => (leftKeys, rightKeys)
296- }
289+ case _ => (leftKeys, rightKeys)
297290 }
298- } else {
299- (leftKeys, rightKeys)
300291 }
292+ } else {
293+ (leftKeys, rightKeys)
301294 }
295+ }
302296
297+ /**
298+ * When the physical operators are created for JOIN, the ordering of join keys is based on order
299+ * in which the join keys appear in the user query. That might not match with the output
300+ * partitioning of the join node's children (thus leading to extra sort / shuffle being
301+ * introduced). This rule will change the ordering of the join keys to match with the
302+ * partitioning of the join nodes' children.
303+ */
304+ private def reorderJoinPredicates (plan : SparkPlan ): SparkPlan = {
303305 plan.transformUp {
304306 case BroadcastHashJoinExec (leftKeys, rightKeys, joinType, buildSide, condition, left,
305307 right) =>
0 commit comments