@@ -201,62 +201,76 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
201201 */
202202private [sql] case class EnsureRequirements (sqlContext : SQLContext ) extends Rule [SparkPlan ] {
203203 // TODO: Determine the number of partitions.
204- def numPartitions : Int = sqlContext.conf.numShufflePartitions
204+ private def numPartitions : Int = sqlContext.conf.numShufflePartitions
205205
206- def apply (plan : SparkPlan ): SparkPlan = plan.transformUp {
207- case operator : SparkPlan =>
208- // Adds Exchange or Sort operators as required
209- def addOperatorsIfNecessary (
210- partitioning : Partitioning ,
211- rowOrdering : Seq [SortOrder ],
212- child : SparkPlan ): SparkPlan = {
213-
214- def addShuffleIfNecessary (child : SparkPlan ): SparkPlan = {
215- if (! child.outputPartitioning.guarantees(partitioning)) {
216- Exchange (partitioning, child)
217- } else {
218- child
219- }
220- }
221-
222- def addSortIfNecessary (child : SparkPlan ): SparkPlan = {
206+ private def canonicalPartitioning (requiredDistribution : Distribution ): Partitioning = {
207+ requiredDistribution match {
208+ case AllTuples => SinglePartition
209+ case ClusteredDistribution (clustering) => HashPartitioning (clustering, numPartitions)
210+ case OrderedDistribution (ordering) => RangePartitioning (ordering, numPartitions)
211+ case dist => sys.error(s " Do not know how to satisfy distribution $dist" )
212+ }
213+ }
223214
224- if (rowOrdering.nonEmpty) {
225- // If child.outputOrdering is [a, b] and rowOrdering is [a], we do not need to sort.
226- val minSize = Seq (rowOrdering.size, child.outputOrdering.size).min
227- if (minSize == 0 || rowOrdering.take(minSize) != child.outputOrdering.take(minSize)) {
228- sqlContext.planner.BasicOperators .getSortOperator(rowOrdering, global = false , child)
229- } else {
215+ private def ensureChildNumPartitionsAgreementIfNecessary (operator : SparkPlan ): SparkPlan = {
216+ if (operator.requiresChildrenToProduceSameNumberOfPartitions) {
217+ if (operator.children.map(_.outputPartitioning.numPartitions).distinct.size > 1 ) {
218+ val newChildren = operator.children.zip(operator.requiredChildDistribution).map {
219+ case (child, requiredDistribution) =>
220+ val targetPartitioning = canonicalPartitioning(requiredDistribution)
221+ if (child.outputPartitioning.guarantees(targetPartitioning)) {
230222 child
223+ } else {
224+ Exchange (targetPartitioning, child)
231225 }
232- } else {
233- child
234- }
235226 }
236-
237- addSortIfNecessary(addShuffleIfNecessary(child))
227+ operator.withNewChildren(newChildren)
228+ } else {
229+ operator
238230 }
231+ } else {
232+ operator
233+ }
234+ }
239235
240- val requirements =
241- (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)
236+ private def ensureDistributionAndOrdering (operator : SparkPlan ): SparkPlan = {
242237
243- val fixedChildren = requirements.zipped.map {
244- case ( AllTuples , rowOrdering, child) =>
245- addOperatorsIfNecessary( SinglePartition , rowOrdering, child)
246- case ( ClusteredDistribution (clustering), rowOrdering, child) =>
247- addOperatorsIfNecessary( HashPartitioning (clustering, numPartitions), rowOrdering , child)
248- case ( OrderedDistribution (ordering), rowOrdering, child) =>
249- addOperatorsIfNecessary( RangePartitioning (ordering, numPartitions), rowOrdering, child)
238+ def addShuffleIfNecessary ( child : SparkPlan , requiredDistribution : Distribution ) : SparkPlan = {
239+ if ( child.outputPartitioning.satisfies(requiredDistribution)) {
240+ child
241+ } else {
242+ Exchange (canonicalPartitioning(requiredDistribution) , child)
243+ }
244+ }
250245
251- case (UnspecifiedDistribution , Seq (), child) =>
246+ def addSortIfNecessary (child : SparkPlan , requiredOrdering : Seq [SortOrder ]): SparkPlan = {
247+ if (requiredOrdering.nonEmpty) {
248+ // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort.
249+ val minSize = Seq (requiredOrdering.size, child.outputOrdering.size).min
250+ if (minSize == 0 || requiredOrdering.take(minSize) != child.outputOrdering.take(minSize)) {
251+ sqlContext.planner.BasicOperators .getSortOperator(requiredOrdering, global = false , child)
252+ } else {
252253 child
253- case (UnspecifiedDistribution , rowOrdering, child) =>
254- sqlContext.planner.BasicOperators .getSortOperator(rowOrdering, global = false , child)
255-
256- case (dist, ordering, _) =>
257- sys.error(s " Don't know how to ensure $dist with ordering $ordering" )
254+ }
255+ } else {
256+ child
258257 }
258+ }
259+
260+ val children = operator.children
261+ val requiredChildDistribution = operator.requiredChildDistribution
262+ val requiredChildOrdering = operator.requiredChildOrdering
263+ assert(children.length == requiredChildDistribution.length)
264+ assert(children.length == requiredChildOrdering.length)
265+ val newChildren = (children, requiredChildDistribution, requiredChildOrdering).zipped.map {
266+ case (child, requiredDistribution, requiredOrdering) =>
267+ addSortIfNecessary(addShuffleIfNecessary(child, requiredDistribution), requiredOrdering)
268+ }
269+ operator.withNewChildren(newChildren)
270+ }
259271
260- operator.withNewChildren(fixedChildren)
272+ def apply (plan : SparkPlan ): SparkPlan = plan.transformUp {
273+ case operator : SparkPlan =>
274+ ensureDistributionAndOrdering(ensureChildNumPartitionsAgreementIfNecessary(operator))
261275 }
262276}
0 commit comments