@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
2121import org .apache .spark .rdd .RDD
2222import org .apache .spark .sql .TestData ._
2323import org .apache .spark .sql .catalyst .InternalRow
24- import org .apache .spark .sql .catalyst .expressions .{Ascending , Literal , Attribute , SortOrder }
24+ import org .apache .spark .sql .catalyst .expressions .{Ascending , Attribute , Literal , SortOrder }
2525import org .apache .spark .sql .catalyst .plans ._
2626import org .apache .spark .sql .catalyst .plans .logical .LogicalPlan
2727import org .apache .spark .sql .catalyst .plans .physical ._
@@ -210,9 +210,10 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
210210 // --- Unit tests of EnsureRequirements ---------------------------------------------------------
211211
212212 private def assertDistributionRequirementsAreSatisfied (outputPlan : SparkPlan ): Unit = {
213- if (outputPlan.requiresChildrenToProduceSameNumberOfPartitions) {
214- if (outputPlan.children.map(_.outputPartitioning.numPartitions).toSet.size != 1 ) {
215- fail(s " Children did not produce the same number of partitions: \n $outputPlan" )
213+ if (outputPlan.requiresChildPartitioningsToBeCompatible) {
214+ val childPartitionings = outputPlan.children.map(_.outputPartitioning)
215+ if (! Partitioning .allCompatible(childPartitionings)) {
216+ fail(s " Partitionings are not compatible: $childPartitionings" )
216217 }
217218 }
218219 outputPlan.children.zip(outputPlan.requiredChildDistribution).foreach {
@@ -222,15 +223,50 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
222223 }
223224 }
224225
225- test(" EnsureRequirements ensures that children produce same number of partitions when required" ) {
226+ test(" EnsureRequirements ensures that child partitionings guarantee each other, if required" ) {
227+ // Consider an operator that requires inputs that are clustered by two expressions (e.g.
228+ // sort merge join where there are multiple columns in the equi-join condition)
229+ val clusteringA = Literal (1 ) :: Nil
230+ val clusteringB = Literal (2 ) :: Nil
231+ val distribution = ClusteredDistribution (clusteringA ++ clusteringB)
232+ // Say that the left and right inputs are each partitioned by _one_ of the two join columns:
233+ val leftPartitioning = HashPartitioning (clusteringA, 1 )
234+ val rightPartitioning = HashPartitioning (clusteringB, 1 )
235+ // Individually, each input's partitioning satisfies the clustering distribution:
236+ assert(leftPartitioning.satisfies(distribution))
237+ assert(rightPartitioning.satisfies(distribution))
238+ // However, these partitionings are not compatible with each other, so we still need to
239+ // repartition both inputs prior to performing the join:
240+ assert(! leftPartitioning.guarantees(rightPartitioning))
241+ assert(! rightPartitioning.guarantees(leftPartitioning))
242+ val inputPlan = DummyPlan (
243+ children = Seq (
244+ DummyPlan (outputPartitioning = HashPartitioning (clusteringA, 1 )),
245+ DummyPlan (outputPartitioning = HashPartitioning (clusteringB, 1 ))
246+ ),
247+ requiresChildPartitioningsToBeCompatible = true ,
248+ requiredChildDistribution = Seq (distribution, distribution),
249+ requiredChildOrdering = Seq (Seq .empty, Seq .empty)
250+ )
251+ val outputPlan = EnsureRequirements (sqlContext).apply(inputPlan)
252+ assertDistributionRequirementsAreSatisfied(outputPlan)
253+ if (outputPlan.collect { case Exchange (_, _) => true }.isEmpty) {
254+ fail(s " Exchanges should have been added: \n $outputPlan" )
255+ }
256+ }
257+
258+ test(" EnsureRequirements ensures that children produce same number of partitions, if required" ) {
259+ // This is similar to the previous test, except it checks that partitionings are not compatible
260+ // unless they produce the same number of partitions. This requirement is also enforced via
261+ // assertions in Exchange.
226262 val clustering = Literal (1 ) :: Nil
227263 val distribution = ClusteredDistribution (clustering)
228264 val inputPlan = DummyPlan (
229265 children = Seq (
230266 DummyPlan (outputPartitioning = HashPartitioning (clustering, 1 )),
231267 DummyPlan (outputPartitioning = HashPartitioning (clustering, 2 ))
232268 ),
233- requiresChildrenToProduceSameNumberOfPartitions = true ,
269+ requiresChildPartitioningsToBeCompatible = true ,
234270 requiredChildDistribution = Seq (distribution, distribution),
235271 requiredChildOrdering = Seq (Seq .empty, Seq .empty)
236272 )
@@ -239,14 +275,18 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
239275 }
240276
241277 test(" EnsureRequirements should not repartition if only ordering requirement is unsatisfied" ) {
278+ // Consider an operator that imposes both output distribution and ordering requirements on its
279+ // children, such as sort sort merge join. If the distribution requirements are satisfied but
280+ // the output ordering requirements are unsatisfied, then the planner should only add sorts and
281+ // should not need to add additional shuffles / exchanges.
242282 val outputOrdering = Seq (SortOrder (Literal (1 ), Ascending ))
243283 val distribution = ClusteredDistribution (Literal (1 ) :: Nil )
244284 val inputPlan = DummyPlan (
245285 children = Seq (
246286 DummyPlan (outputPartitioning = SinglePartition ),
247287 DummyPlan (outputPartitioning = SinglePartition )
248288 ),
249- requiresChildrenToProduceSameNumberOfPartitions = true ,
289+ requiresChildPartitioningsToBeCompatible = true ,
250290 requiredChildDistribution = Seq (distribution, distribution),
251291 requiredChildOrdering = Seq (outputOrdering, outputOrdering)
252292 )
@@ -265,7 +305,7 @@ private case class DummyPlan(
265305 override val children : Seq [SparkPlan ] = Nil ,
266306 override val outputOrdering : Seq [SortOrder ] = Nil ,
267307 override val outputPartitioning : Partitioning = UnknownPartitioning (0 ),
268- override val requiresChildrenToProduceSameNumberOfPartitions : Boolean = false ,
308+ override val requiresChildPartitioningsToBeCompatible : Boolean = false ,
269309 override val requiredChildDistribution : Seq [Distribution ] = Nil ,
270310 override val requiredChildOrdering : Seq [Seq [SortOrder ]] = Nil
271311 ) extends SparkPlan {
0 commit comments