Skip to content

Commit c9fb231

Browse files
committed
Rewrite exchange to fix better handle this case.
1 parent adcc742 commit c9fb231

File tree

6 files changed

+73
-45
lines changed

6 files changed

+73
-45
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 59 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -201,62 +201,76 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
201201
*/
202202
private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
203203
// TODO: Determine the number of partitions.
204-
def numPartitions: Int = sqlContext.conf.numShufflePartitions
204+
private def numPartitions: Int = sqlContext.conf.numShufflePartitions
205205

206-
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
207-
case operator: SparkPlan =>
208-
// Adds Exchange or Sort operators as required
209-
def addOperatorsIfNecessary(
210-
partitioning: Partitioning,
211-
rowOrdering: Seq[SortOrder],
212-
child: SparkPlan): SparkPlan = {
213-
214-
def addShuffleIfNecessary(child: SparkPlan): SparkPlan = {
215-
if (!child.outputPartitioning.guarantees(partitioning)) {
216-
Exchange(partitioning, child)
217-
} else {
218-
child
219-
}
220-
}
221-
222-
def addSortIfNecessary(child: SparkPlan): SparkPlan = {
206+
private def canonicalPartitioning(requiredDistribution: Distribution): Partitioning = {
207+
requiredDistribution match {
208+
case AllTuples => SinglePartition
209+
case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions)
210+
case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions)
211+
case dist => sys.error(s"Do not know how to satisfy distribution $dist")
212+
}
213+
}
223214

224-
if (rowOrdering.nonEmpty) {
225-
// If child.outputOrdering is [a, b] and rowOrdering is [a], we do not need to sort.
226-
val minSize = Seq(rowOrdering.size, child.outputOrdering.size).min
227-
if (minSize == 0 || rowOrdering.take(minSize) != child.outputOrdering.take(minSize)) {
228-
sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
229-
} else {
215+
private def ensureChildNumPartitionsAgreementIfNecessary(operator: SparkPlan): SparkPlan = {
216+
if (operator.requiresChildrenToProduceSameNumberOfPartitions) {
217+
if (operator.children.map(_.outputPartitioning.numPartitions).distinct.size > 1) {
218+
val newChildren = operator.children.zip(operator.requiredChildDistribution).map {
219+
case (child, requiredDistribution) =>
220+
val targetPartitioning = canonicalPartitioning(requiredDistribution)
221+
if (child.outputPartitioning.guarantees(targetPartitioning)) {
230222
child
223+
} else {
224+
Exchange(targetPartitioning, child)
231225
}
232-
} else {
233-
child
234-
}
235226
}
236-
237-
addSortIfNecessary(addShuffleIfNecessary(child))
227+
operator.withNewChildren(newChildren)
228+
} else {
229+
operator
238230
}
231+
} else {
232+
operator
233+
}
234+
}
239235

240-
val requirements =
241-
(operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)
236+
private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
242237

243-
val fixedChildren = requirements.zipped.map {
244-
case (AllTuples, rowOrdering, child) =>
245-
addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
246-
case (ClusteredDistribution(clustering), rowOrdering, child) =>
247-
addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
248-
case (OrderedDistribution(ordering), rowOrdering, child) =>
249-
addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child)
238+
def addShuffleIfNecessary(child: SparkPlan, requiredDistribution: Distribution): SparkPlan = {
239+
if (child.outputPartitioning.satisfies(requiredDistribution)) {
240+
child
241+
} else {
242+
Exchange(canonicalPartitioning(requiredDistribution), child)
243+
}
244+
}
250245

251-
case (UnspecifiedDistribution, Seq(), child) =>
246+
def addSortIfNecessary(child: SparkPlan, requiredOrdering: Seq[SortOrder]): SparkPlan = {
247+
if (requiredOrdering.nonEmpty) {
248+
// If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort.
249+
val minSize = Seq(requiredOrdering.size, child.outputOrdering.size).min
250+
if (minSize == 0 || requiredOrdering.take(minSize) != child.outputOrdering.take(minSize)) {
251+
sqlContext.planner.BasicOperators.getSortOperator(requiredOrdering, global = false, child)
252+
} else {
252253
child
253-
case (UnspecifiedDistribution, rowOrdering, child) =>
254-
sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
255-
256-
case (dist, ordering, _) =>
257-
sys.error(s"Don't know how to ensure $dist with ordering $ordering")
254+
}
255+
} else {
256+
child
258257
}
258+
}
259+
260+
val children = operator.children
261+
val requiredChildDistribution = operator.requiredChildDistribution
262+
val requiredChildOrdering = operator.requiredChildOrdering
263+
assert(children.length == requiredChildDistribution.length)
264+
assert(children.length == requiredChildOrdering.length)
265+
val newChildren = (children, requiredChildDistribution, requiredChildOrdering).zipped.map {
266+
case (child, requiredDistribution, requiredOrdering) =>
267+
addSortIfNecessary(addShuffleIfNecessary(child, requiredDistribution), requiredOrdering)
268+
}
269+
operator.withNewChildren(newChildren)
270+
}
259271

260-
operator.withNewChildren(fixedChildren)
272+
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
273+
case operator: SparkPlan =>
274+
ensureDistributionAndOrdering(ensureChildNumPartitionsAgreementIfNecessary(operator))
261275
}
262276
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
109109
/** Specifies sort order for each partition requirements on the input data for this operator. */
110110
def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
111111

112+
/**
113+
* Specifies whether this operator requires all of its children to produce the same number of
114+
* output partitions.
115+
*/
116+
def requiresChildrenToProduceSameNumberOfPartitions: Boolean = false
117+
112118
/** Specifies whether this operator outputs UnsafeRows */
113119
def outputsUnsafeRows: Boolean = false
114120

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ case class LeftSemiJoinHash(
4242
override def requiredChildDistribution: Seq[Distribution] =
4343
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
4444

45+
override def requiresChildrenToProduceSameNumberOfPartitions: Boolean = true
46+
4547
protected override def doExecute(): RDD[InternalRow] = {
4648
right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
4749
if (condition.isEmpty) {

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ case class ShuffledHashJoin(
4646
override def requiredChildDistribution: Seq[Distribution] =
4747
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
4848

49+
override def requiresChildrenToProduceSameNumberOfPartitions: Boolean = true
50+
4951
protected override def doExecute(): RDD[InternalRow] = {
5052
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
5153
val hashed = HashedRelation(buildIter, buildSideKeyGenerator)

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ case class ShuffledHashOuterJoin(
4444
override def requiredChildDistribution: Seq[Distribution] =
4545
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
4646

47+
override def requiresChildrenToProduceSameNumberOfPartitions: Boolean = true
48+
4749
override def outputPartitioning: Partitioning = joinType match {
4850
case LeftOuter => left.outputPartitioning
4951
case RightOuter => right.outputPartitioning

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ case class SortMergeJoin(
4848
override def requiredChildDistribution: Seq[Distribution] =
4949
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
5050

51+
override def requiresChildrenToProduceSameNumberOfPartitions: Boolean = true
52+
5153
override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys)
5254

5355
override def requiredChildOrdering: Seq[Seq[SortOrder]] =

0 commit comments

Comments
 (0)