@@ -28,21 +28,30 @@ import org.apache.spark.sql.catalyst.plans.physical._
2828import org .apache .spark .sql .catalyst .rules .Rule
2929import org .apache .spark .util .MutablePair
3030
31+ object Exchange {
32+ /** Returns true when the ordering expressions are a subset of the key. */
33+ def canSortWithShuffle (partitioning : Partitioning , desiredOrdering : Seq [SortOrder ]): Boolean = {
34+ desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet)
35+ }
36+ }
37+
3138/**
32- * Shuffle data according to a new partition rule, and sort inside each partition if necessary.
33- * @param newPartitioning The new partitioning way that required by parent
34- * @param sort Whether we will sort inside each partition
35- * @param child Child operator
39+ * :: DeveloperApi ::
40+ * Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each
41+ * resulting partition based on expressions from the partition key. It is invalid to construct an
42+ * exchange operator with a `newOrdering` that cannot be calculated using the partitioning key.
3643 */
3744@ DeveloperApi
3845case class Exchange (
3946 newPartitioning : Partitioning ,
40- sort : Boolean ,
47+ newOrdering : Seq [ SortOrder ] ,
4148 child : SparkPlan )
4249 extends UnaryNode {
4350
4451 override def outputPartitioning : Partitioning = newPartitioning
4552
53+ override def outputOrdering : Seq [SortOrder ] = newOrdering
54+
4655 override def output : Seq [Attribute ] = child.output
4756
4857 /** We must copy rows when sort based shuffle is on */
@@ -51,6 +60,20 @@ case class Exchange(
5160 private val bypassMergeThreshold =
5261 child.sqlContext.sparkContext.conf.getInt(" spark.shuffle.sort.bypassMergeThreshold" , 200 )
5362
63+ private val keyOrdering = {
64+ if (newOrdering.nonEmpty) {
65+ val key = newPartitioning.keyExpressions
66+ val boundOrdering = newOrdering.map { o =>
67+ val ordinal = key.indexOf(o.child)
68+ if (ordinal == - 1 ) sys.error(s " Invalid ordering on $o requested for $newPartitioning" )
69+ o.copy(child = BoundReference (ordinal, o.child.dataType, o.child.nullable))
70+ }
71+ new RowOrdering (boundOrdering)
72+ } else {
73+ null // Ordering will not be used
74+ }
75+ }
76+
5477 override def execute (): RDD [Row ] = attachTree(this , " execute" ) {
5578 newPartitioning match {
5679 case HashPartitioning (expressions, numPartitions) =>
@@ -62,7 +85,9 @@ case class Exchange(
6285 // we can avoid the defensive copies to improve performance. In the long run, we probably
6386 // want to include information in shuffle dependencies to indicate whether elements in the
6487 // source RDD should be copied.
65- val rdd = if ((sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || sort) {
88+ val willMergeSort = sortBasedShuffleOn && numPartitions > bypassMergeThreshold
89+
90+ val rdd = if (willMergeSort || newOrdering.nonEmpty) {
6691 child.execute().mapPartitions { iter =>
6792 val hashExpressions = newMutableProjection(expressions, child.output)()
6893 iter.map(r => (hashExpressions(r).copy(), r.copy()))
@@ -75,21 +100,17 @@ case class Exchange(
75100 }
76101 }
77102 val part = new HashPartitioner (numPartitions)
78- val shuffled = sort match {
79- case false => new ShuffledRDD [Row , Row , Row ](rdd, part)
80- case true =>
81- val sortingExpressions = expressions.zipWithIndex.map {
82- case (exp, index) =>
83- new SortOrder (BoundReference (index, exp.dataType, exp.nullable), Ascending )
84- }
85- val ordering = new RowOrdering (sortingExpressions, child.output)
86- new ShuffledRDD [Row , Row , Row ](rdd, part).setKeyOrdering(ordering)
87- }
103+ val shuffled =
104+ if (newOrdering.nonEmpty) {
105+ new ShuffledRDD [Row , Row , Row ](rdd, part).setKeyOrdering(keyOrdering)
106+ } else {
107+ new ShuffledRDD [Row , Row , Row ](rdd, part)
108+ }
88109 shuffled.setSerializer(new SparkSqlSerializer (new SparkConf (false )))
89110 shuffled.map(_._2)
90111
91112 case RangePartitioning (sortingExpressions, numPartitions) =>
92- val rdd = if (sortBasedShuffleOn) {
113+ val rdd = if (sortBasedShuffleOn || newOrdering.nonEmpty ) {
93114 child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null ))}
94115 } else {
95116 child.execute().mapPartitions { iter =>
@@ -102,7 +123,12 @@ case class Exchange(
102123 implicit val ordering = new RowOrdering (sortingExpressions, child.output)
103124
104125 val part = new RangePartitioner (numPartitions, rdd, ascending = true )
105- val shuffled = new ShuffledRDD [Row , Null , Null ](rdd, part)
126+ val shuffled =
127+ if (newOrdering.nonEmpty) {
128+ new ShuffledRDD [Row , Null , Null ](rdd, part).setKeyOrdering(keyOrdering)
129+ } else {
130+ new ShuffledRDD [Row , Null , Null ](rdd, part)
131+ }
106132 shuffled.setSerializer(new SparkSqlSerializer (new SparkConf (false )))
107133
108134 shuffled.map(_._1)
@@ -135,27 +161,35 @@ case class Exchange(
135161 * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning ]]
136162 * of input data meets the
137163 * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution ]] requirements for
138- * each operator by inserting [[Exchange ]] Operators where required.
164+ * each operator by inserting [[Exchange ]] Operators where required. Also ensure that the
165+ * required input partition ordering requirements are met.
139166 */
140- private [sql] case class AddExchange (sqlContext : SQLContext ) extends Rule [SparkPlan ] {
167+ private [sql] case class EnsureRequirements (sqlContext : SQLContext ) extends Rule [SparkPlan ] {
141168 // TODO: Determine the number of partitions.
142169 def numPartitions : Int = sqlContext.conf.numShufflePartitions
143170
144171 def apply (plan : SparkPlan ): SparkPlan = plan.transformUp {
145172 case operator : SparkPlan =>
146- // Check if every child's outputPartitioning satisfies the corresponding
173+ // True iff every child's outputPartitioning satisfies the corresponding
147174 // required data distribution.
148175 def meetsRequirements : Boolean =
149- ! operator.requiredChildDistribution.zip(operator.children).map {
176+ operator.requiredChildDistribution.zip(operator.children).forall {
150177 case (required, child) =>
151178 val valid = child.outputPartitioning.satisfies(required)
152179 logDebug(
153180 s " ${if (valid) " Valid" else " Invalid" } distribution, " +
154181 s " required: $required current: ${child.outputPartitioning}" )
155182 valid
156- }.exists(! _)
183+ }
184+
185+ // True iff any of the children are incorrectly sorted.
186+ def needsAnySort : Boolean =
187+ operator.requiredChildOrdering.zip(operator.children).exists {
188+ case (required, child) => required.nonEmpty && required != child
189+ }
190+
157191
158- // Check if outputPartitionings of children are compatible with each other.
192+ // True iff outputPartitionings of children are compatible with each other.
159193 // It is possible that every child satisfies its required data distribution
160194 // but two children have incompatible outputPartitionings. For example,
161195 // A dataset is range partitioned by "a.asc" (RangePartitioning) and another
@@ -172,40 +206,61 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
172206 case Seq (a,b) => a compatibleWith b
173207 }.exists(! _)
174208
175- // Check if the partitioning we want to ensure is the same as the child's output
176- // partitioning. If so, we do not need to add the Exchange operator.
177- def addExchangeIfNecessary (
209+ // Adds Exchange or Sort operators as required
210+ def addOperatorsIfNecessary (
178211 partitioning : Partitioning ,
179- child : SparkPlan ,
180- rowOrdering : Option [Ordering [Row ]] = None ): SparkPlan = {
181- val needSort = child.outputOrdering != rowOrdering
182- if (child.outputPartitioning != partitioning || needSort) {
183- // TODO: if only needSort, we need only sort each partition instead of an Exchange
184- Exchange (partitioning, sort = needSort, child)
212+ rowOrdering : Seq [SortOrder ],
213+ child : SparkPlan ): SparkPlan = {
214+ val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering
215+ val needsShuffle = child.outputPartitioning != partitioning
216+ val canSortWithShuffle = Exchange .canSortWithShuffle(partitioning, rowOrdering)
217+
218+ if (needSort && needsShuffle && canSortWithShuffle) {
219+ Exchange (partitioning, rowOrdering, child)
185220 } else {
186- child
221+ val withShuffle = if (needsShuffle) {
222+ Exchange (partitioning, Nil , child)
223+ } else {
224+ child
225+ }
226+
227+ val withSort = if (needSort) {
228+ Sort (rowOrdering, global = false , withShuffle)
229+ } else {
230+ withShuffle
231+ }
232+
233+ withSort
187234 }
188235 }
189236
190- if (meetsRequirements && compatible) {
237+ if (meetsRequirements && compatible && ! needsAnySort ) {
191238 operator
192239 } else {
193240 // At least one child does not satisfies its required data distribution or
194241 // at least one child's outputPartitioning is not compatible with another child's
195242 // outputPartitioning. In this case, we need to add Exchange operators.
196- val repartitionedChildren = operator.requiredChildDistribution.zip(
197- operator.children.zip(operator.requiredChildOrdering)
198- ).map {
199- case (AllTuples , (child, _)) =>
200- addExchangeIfNecessary(SinglePartition , child)
201- case (ClusteredDistribution (clustering), (child, rowOrdering)) =>
202- addExchangeIfNecessary(HashPartitioning (clustering, numPartitions), child, rowOrdering)
203- case (OrderedDistribution (ordering), (child, None )) =>
204- addExchangeIfNecessary(RangePartitioning (ordering, numPartitions), child)
205- case (UnspecifiedDistribution , (child, _)) => child
206- case (dist, _) => sys.error(s " Don't know how to ensure $dist" )
243+ val requirements =
244+ (operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)
245+
246+ val fixedChildren = requirements.zipped.map {
247+ case (AllTuples , rowOrdering, child) =>
248+ addOperatorsIfNecessary(SinglePartition , rowOrdering, child)
249+ case (ClusteredDistribution (clustering), rowOrdering, child) =>
250+ addOperatorsIfNecessary(HashPartitioning (clustering, numPartitions), rowOrdering, child)
251+ case (OrderedDistribution (ordering), rowOrdering, child) =>
252+ addOperatorsIfNecessary(RangePartitioning (ordering, numPartitions), Nil , child)
253+
254+ case (UnspecifiedDistribution , Seq (), child) =>
255+ child
256+ case (UnspecifiedDistribution , rowOrdering, child) =>
257+ Sort (rowOrdering, global = false , child)
258+
259+ case (dist, ordering, _) =>
260+ sys.error(s " Don't know how to ensure $dist with ordering $ordering" )
207261 }
208- operator.withNewChildren(repartitionedChildren)
262+
263+ operator.withNewChildren(fixedChildren)
209264 }
210265 }
211266}
0 commit comments