@@ -209,6 +209,19 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
209209
210210 // --- Unit tests of EnsureRequirements ---------------------------------------------------------
211211
212+ private def assertDistributionRequirementsAreSatisfied (outputPlan : SparkPlan ): Unit = {
213+ if (outputPlan.requiresChildrenToProduceSameNumberOfPartitions) {
214+ if (outputPlan.children.map(_.outputPartitioning.numPartitions).toSet.size != 1 ) {
215+ fail(s " Children did not produce the same number of partitions: \n $outputPlan" )
216+ }
217+ }
218+ outputPlan.children.zip(outputPlan.requiredChildDistribution).foreach {
219+ case (child, requiredDist) =>
220+ assert(child.outputPartitioning.satisfies(requiredDist),
221+ s " $child output partitioning does not satisfy $requiredDist: \n $outputPlan" )
222+ }
223+ }
224+
212225 test(" EnsureRequirements ensures that children produce same number of partitions when required" ) {
213226 val clustering = Literal (1 ) :: Nil
214227 val distribution = ClusteredDistribution (clustering)
@@ -222,7 +235,7 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
222235 requiredChildOrdering = Seq (Seq .empty, Seq .empty)
223236 )
224237 val outputPlan = EnsureRequirements (sqlContext).apply(inputPlan)
225- assert (outputPlan.children.map(_.outputPartitioning.numPartitions).toSet.size === 1 )
238+ assertDistributionRequirementsAreSatisfied (outputPlan)
226239 }
227240
228241 test(" EnsureRequirements should not repartition if only ordering requirement is unsatisfied" ) {
@@ -238,6 +251,7 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
238251 requiredChildOrdering = Seq (outputOrdering, outputOrdering)
239252 )
240253 val outputPlan = EnsureRequirements (sqlContext).apply(inputPlan)
254+ assertDistributionRequirementsAreSatisfied(outputPlan)
241255 if (outputPlan.collect { case Exchange (_, _) => true }.nonEmpty) {
242256 fail(s " No Exchanges should have been added: \n $outputPlan" )
243257 }
0 commit comments