Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -252,54 +252,56 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
operator.withNewChildren(children)
}

/**
* When the physical operators are created for JOIN, the ordering of join keys is based on order
* in which the join keys appear in the user query. That might not match with the output
* partitioning of the join node's children (thus leading to extra sort / shuffle being
* introduced). This rule will change the ordering of the join keys to match with the
* partitioning of the join nodes' children.
*/
def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
def reorderJoinKeys(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {

def reorder(expectedOrderOfKeys: Seq[Expression],
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
val leftKeysBuffer = ArrayBuffer[Expression]()
val rightKeysBuffer = ArrayBuffer[Expression]()
private def reorder(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
expectedOrderOfKeys: Seq[Expression],
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
val leftKeysBuffer = ArrayBuffer[Expression]()
val rightKeysBuffer = ArrayBuffer[Expression]()

expectedOrderOfKeys.foreach(expression => {
val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
leftKeysBuffer.append(leftKeys(index))
rightKeysBuffer.append(rightKeys(index))
})
(leftKeysBuffer, rightKeysBuffer)
}
expectedOrderOfKeys.foreach(expression => {
val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
leftKeysBuffer.append(leftKeys(index))
rightKeysBuffer.append(rightKeys(index))
})
(leftKeysBuffer, rightKeysBuffer)
}

if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
leftPartitioning match {
case HashPartitioning(leftExpressions, _)
if leftExpressions.length == leftKeys.length &&
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
reorder(leftExpressions, leftKeys)
private def reorderJoinKeys(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
leftPartitioning match {
case HashPartitioning(leftExpressions, _)
if leftExpressions.length == leftKeys.length &&
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
reorder(leftKeys, rightKeys, leftExpressions, leftKeys)

case _ => rightPartitioning match {
case HashPartitioning(rightExpressions, _)
if rightExpressions.length == rightKeys.length &&
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
reorder(rightExpressions, rightKeys)
case _ => rightPartitioning match {
case HashPartitioning(rightExpressions, _)
if rightExpressions.length == rightKeys.length &&
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
reorder(leftKeys, rightKeys, rightExpressions, rightKeys)

case _ => (leftKeys, rightKeys)
}
case _ => (leftKeys, rightKeys)
}
} else {
(leftKeys, rightKeys)
}
} else {
(leftKeys, rightKeys)
}
}

/**
* When the physical operators are created for JOIN, the ordering of join keys is based on order
* in which the join keys appear in the user query. That might not match with the output
* partitioning of the join node's children (thus leading to extra sort / shuffle being
* introduced). This rule will change the ordering of the join keys to match with the
* partitioning of the join nodes' children.
*/
private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
plan.transformUp {
case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left,
right) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,15 +620,15 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
|) ab
|JOIN table2 c
|ON ab.i = c.i
|""".stripMargin),
""".stripMargin),
sql("""
|SELECT a.i, a.j, a.k, c.i, c.j, c.k
|FROM bucketed_table a
|JOIN table1 b
|ON a.i = b.i
|JOIN table2 c
|ON a.i = c.i
|""".stripMargin))
""".stripMargin))
}
}
}
Expand Down