Skip to content

Commit 0725a34

Browse files
committed
Small assertion cleanup.
1 parent 5172ac5 commit 0725a34

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,19 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
209209

210210
// --- Unit tests of EnsureRequirements ---------------------------------------------------------
211211

212+
private def assertDistributionRequirementsAreSatisfied(outputPlan: SparkPlan): Unit = {
213+
if (outputPlan.requiresChildrenToProduceSameNumberOfPartitions) {
214+
if (outputPlan.children.map(_.outputPartitioning.numPartitions).toSet.size != 1) {
215+
fail(s"Children did not produce the same number of partitions:\n$outputPlan")
216+
}
217+
}
218+
outputPlan.children.zip(outputPlan.requiredChildDistribution).foreach {
219+
case (child, requiredDist) =>
220+
assert(child.outputPartitioning.satisfies(requiredDist),
221+
s"$child output partitioning does not satisfy $requiredDist:\n$outputPlan")
222+
}
223+
}
224+
212225
test("EnsureRequirements ensures that children produce same number of partitions when required") {
213226
val clustering = Literal(1) :: Nil
214227
val distribution = ClusteredDistribution(clustering)
@@ -222,7 +235,7 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
222235
requiredChildOrdering = Seq(Seq.empty, Seq.empty)
223236
)
224237
val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
225-
assert (outputPlan.children.map(_.outputPartitioning.numPartitions).toSet.size === 1)
238+
assertDistributionRequirementsAreSatisfied(outputPlan)
226239
}
227240

228241
test("EnsureRequirements should not repartition if only ordering requirement is unsatisfied") {
@@ -238,6 +251,7 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
238251
requiredChildOrdering = Seq(outputOrdering, outputOrdering)
239252
)
240253
val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
254+
assertDistributionRequirementsAreSatisfied(outputPlan)
241255
if (outputPlan.collect { case Exchange(_, _) => true }.nonEmpty) {
242256
fail(s"No Exchanges should have been added:\n$outputPlan")
243257
}

0 commit comments

Comments
 (0)