Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,14 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
children = withExchangeCoordinator(children, requiredChildDistributions)

// Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
ensureOrdering(
reorderJoinPredicatesForOrdering(operator.withNewChildren(children))
)
}

private def ensureOrdering(operator: SparkPlan): SparkPlan = {
var children: Seq[SparkPlan] = operator.children
children = children.zip(operator.requiredChildOrdering).map { case (child, requiredOrdering) =>
// If child.outputOrdering already satisfies the requiredOrdering, we do not need to sort.
if (SortOrder.orderingSatisfies(child.outputOrdering, requiredOrdering)) {
child
Expand Down Expand Up @@ -243,24 +250,38 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
(leftKeysBuffer, rightKeysBuffer)
}

private def reorderJoinKeys(
private def reorderJoinKeys[A](
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
leftChildDist: A,
rightChildDist: A): (Seq[Expression], Seq[Expression]) = {
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
leftPartitioning match {
leftChildDist match {
case HashPartitioning(leftExpressions, _)
if leftExpressions.length == leftKeys.length &&
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
reorder(leftKeys, rightKeys, leftExpressions, leftKeys)

case _ => rightPartitioning match {
case leftOrders: Seq[_]
if leftOrders.forall(_.isInstanceOf[Expression]) &&
leftOrders.length == leftKeys.length &&
leftKeys.forall { x =>
(leftOrders.map(_.asInstanceOf[Expression])).exists(_.semanticEquals(x))} =>
reorder(leftKeys, rightKeys, leftOrders.map(_.asInstanceOf[Expression]), leftKeys)

case _ => rightChildDist match {
case HashPartitioning(rightExpressions, _)
if rightExpressions.length == rightKeys.length &&
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
reorder(leftKeys, rightKeys, rightExpressions, rightKeys)

case rightOrders: Seq[_]
if rightOrders.forall(_.isInstanceOf[Expression]) &&
rightOrders.length == leftKeys.length &&
leftKeys.forall { x =>
(rightOrders.map(_.asInstanceOf[Expression])).exists(_.semanticEquals(x))} =>
reorder(leftKeys, rightKeys, rightOrders.map(_.asInstanceOf[Expression]), leftKeys)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reorder(leftKeys, rightKeys, rightOrders.map(_.asInstanceOf[Expression]), rigthKeys)

and please add a UT which fails before correcting this and passes after.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this UT test the reorderJoinKeys function? Or do you have something else in mind?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant a test like the one you added. But please, first do prove that the current solution is fine (since I doubt so, see #23267 (comment)). Once we ensure that the current change is safe, you can go ahead addressing these comments. Thanks.


case _ => (leftKeys, rightKeys)
}
}
Expand All @@ -276,7 +297,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
* introduced). This rule will change the ordering of the join keys to match with the
* partitioning of the join nodes' children.
*/
private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
Copy link
Member

@maropu maropu Jan 29, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the historical reason (#19257 (comment)), this method was added as a workaround. So, I feel it is compliated to extend this method for this case... basically, IMO we need a general logic here to cover this case and more. cc: @cloud-fan

private def reorderJoinPredicatesForPartitioning(plan: SparkPlan): SparkPlan = {
plan match {
case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
Expand All @@ -293,6 +314,21 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
}
}

private def reorderJoinPredicatesForOrdering(plan: SparkPlan): SparkPlan = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can avoid this and include this transformation in the former reorderJoinPredicates method, after the reorder for partitionings. I'd rather have a reorderJoinKeysForOrderings called there or something similar.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure this would work. The point here was to first reorder join predicates for partitioning, then check for the child outputPartitioning (which happens in the method ensureDistributionAndOrdering) and decide if we need Exchange or not and AFTER that reorder the join predicates again to satisfy the child outputOrdering to avoid Exchange.

plan match {
case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(
leftKeys,
rightKeys,
left.outputOrdering.map(_.child),
right.outputOrdering.map(_.child))
SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right)

case other => other
}
}

def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
// TODO: remove this after we create a physical operator for `RepartitionByExpression`.
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) =>
Expand All @@ -301,6 +337,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
case _ => operator
}
case operator: SparkPlan =>
ensureDistributionAndOrdering(reorderJoinPredicates(operator))
ensureDistributionAndOrdering(reorderJoinPredicatesForPartitioning(operator))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,26 @@ class PlannerSuite extends SharedSQLContext {
classOf[PartitioningCollection])
}
}

test("SPARK-25401: Reorder the join predicates to match child output ordering") {
val plan1 = DummySparkPlan(outputOrdering = Seq(orderingA, orderingB),
outputPartitioning = HashPartitioning(exprB :: exprA :: Nil, 5))
val plan2 = DummySparkPlan(outputOrdering = Seq(orderingA, orderingB),
outputPartitioning = HashPartitioning(exprB :: exprA :: Nil, 5))
val smjExec = SortMergeJoinExec(
exprB :: exprA :: Nil, exprB :: exprA :: Nil, Inner, None, plan1, plan2)

val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec)
outputPlan match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _) =>
assert(leftKeys == Seq(exprA, exprB))
assert(rightKeys == Seq(exprA, exprB))
case _ => fail()
}
if (outputPlan.collect { case s: SortExec => true }.nonEmpty) {
fail(s"No sorts should have been added:\n$outputPlan")
}
}
}

// Used for unit-testing EnsureRequirements
Expand Down