Skip to content

Commit a1c12b9

Browse files
committed
Add failing test to demonstrate allCompatible bug
1 parent 0725a34 commit a1c12b9

File tree

8 files changed

+72
-16
lines changed

8 files changed

+72
-16
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,22 @@ sealed trait Partitioning {
9595
def guarantees(other: Partitioning): Boolean
9696
}
9797

98+
object Partitioning {
99+
def allCompatible(partitionings: Seq[Partitioning]): Boolean = {
100+
// Note: this assumes transitivity
101+
partitionings.sliding(2).map {
102+
case Seq(a) => true
103+
case Seq(a, b) =>
104+
if (a.numPartitions != b.numPartitions) {
105+
assert(!a.guarantees(b) && !b.guarantees(a))
106+
false
107+
} else {
108+
a.guarantees(b) && b.guarantees(a)
109+
}
110+
}.forall(_ == true)
111+
}
112+
}
113+
98114
case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
99115
override def satisfies(required: Distribution): Boolean = required match {
100116
case UnspecifiedDistribution => true

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
213213
}
214214

215215
private def ensureChildNumPartitionsAgreementIfNecessary(operator: SparkPlan): SparkPlan = {
216-
if (operator.requiresChildrenToProduceSameNumberOfPartitions) {
216+
if (operator.requiresChildPartitioningsToBeCompatible) {
217217
if (operator.children.map(_.outputPartitioning.numPartitions).distinct.size > 1) {
218218
val newChildren = operator.children.zip(operator.requiredChildDistribution).map {
219219
case (child, requiredDistribution) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
110110
def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
111111

112112
/**
113-
* Specifies whether this operator requires all of its children to produce the same number of
114-
* output partitions.
113+
* Specifies whether this operator requires all of its children to have [[outputPartitioning]]s
114+
* that are compatible with each other.
115115
*/
116-
def requiresChildrenToProduceSameNumberOfPartitions: Boolean = false
116+
def requiresChildPartitioningsToBeCompatible: Boolean = false
117117

118118
/** Specifies whether this operator outputs UnsafeRows */
119119
def outputsUnsafeRows: Boolean = false

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ case class LeftSemiJoinHash(
4242
override def requiredChildDistribution: Seq[Distribution] =
4343
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
4444

45-
override def requiresChildrenToProduceSameNumberOfPartitions: Boolean = true
45+
override def requiresChildPartitioningsToBeCompatible: Boolean = true
4646

4747
protected override def doExecute(): RDD[InternalRow] = {
4848
right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ case class ShuffledHashJoin(
4646
override def requiredChildDistribution: Seq[Distribution] =
4747
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
4848

49-
override def requiresChildrenToProduceSameNumberOfPartitions: Boolean = true
49+
override def requiresChildPartitioningsToBeCompatible: Boolean = true
5050

5151
protected override def doExecute(): RDD[InternalRow] = {
5252
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ case class ShuffledHashOuterJoin(
4444
override def requiredChildDistribution: Seq[Distribution] =
4545
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
4646

47-
override def requiresChildrenToProduceSameNumberOfPartitions: Boolean = true
47+
override def requiresChildPartitioningsToBeCompatible: Boolean = true
4848

4949
override def outputPartitioning: Partitioning = joinType match {
5050
case LeftOuter => left.outputPartitioning

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ case class SortMergeJoin(
4848
override def requiredChildDistribution: Seq[Distribution] =
4949
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
5050

51-
override def requiresChildrenToProduceSameNumberOfPartitions: Boolean = true
51+
override def requiresChildPartitioningsToBeCompatible: Boolean = true
5252

5353
override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys)
5454

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.rdd.RDD
2222
import org.apache.spark.sql.TestData._
2323
import org.apache.spark.sql.catalyst.InternalRow
24-
import org.apache.spark.sql.catalyst.expressions .{Ascending, Literal, Attribute, SortOrder}
24+
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
2525
import org.apache.spark.sql.catalyst.plans._
2626
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2727
import org.apache.spark.sql.catalyst.plans.physical._
@@ -210,9 +210,10 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
210210
// --- Unit tests of EnsureRequirements ---------------------------------------------------------
211211

212212
private def assertDistributionRequirementsAreSatisfied(outputPlan: SparkPlan): Unit = {
213-
if (outputPlan.requiresChildrenToProduceSameNumberOfPartitions) {
214-
if (outputPlan.children.map(_.outputPartitioning.numPartitions).toSet.size != 1) {
215-
fail(s"Children did not produce the same number of partitions:\n$outputPlan")
213+
if (outputPlan.requiresChildPartitioningsToBeCompatible) {
214+
val childPartitionings = outputPlan.children.map(_.outputPartitioning)
215+
if (!Partitioning.allCompatible(childPartitionings)) {
216+
fail(s"Partitionings are not compatible: $childPartitionings")
216217
}
217218
}
218219
outputPlan.children.zip(outputPlan.requiredChildDistribution).foreach {
@@ -222,15 +223,50 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
222223
}
223224
}
224225

225-
test("EnsureRequirements ensures that children produce same number of partitions when required") {
226+
test("EnsureRequirements ensures that child partitionings guarantee each other, if required") {
227+
// Consider an operator that requires inputs that are clustered by two expressions (e.g.
228+
// sort merge join where there are multiple columns in the equi-join condition)
229+
val clusteringA = Literal(1) :: Nil
230+
val clusteringB = Literal(2) :: Nil
231+
val distribution = ClusteredDistribution(clusteringA ++ clusteringB)
232+
// Say that the left and right inputs are each partitioned by _one_ of the two join columns:
233+
val leftPartitioning = HashPartitioning(clusteringA, 1)
234+
val rightPartitioning = HashPartitioning(clusteringB, 1)
235+
// Individually, each input's partitioning satisfies the clustering distribution:
236+
assert(leftPartitioning.satisfies(distribution))
237+
assert(rightPartitioning.satisfies(distribution))
238+
// However, these partitionings are not compatible with each other, so we still need to
239+
// repartition both inputs prior to performing the join:
240+
assert(!leftPartitioning.guarantees(rightPartitioning))
241+
assert(!rightPartitioning.guarantees(leftPartitioning))
242+
val inputPlan = DummyPlan(
243+
children = Seq(
244+
DummyPlan(outputPartitioning = HashPartitioning(clusteringA, 1)),
245+
DummyPlan(outputPartitioning = HashPartitioning(clusteringB, 1))
246+
),
247+
requiresChildPartitioningsToBeCompatible = true,
248+
requiredChildDistribution = Seq(distribution, distribution),
249+
requiredChildOrdering = Seq(Seq.empty, Seq.empty)
250+
)
251+
val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
252+
assertDistributionRequirementsAreSatisfied(outputPlan)
253+
if (outputPlan.collect { case Exchange(_, _) => true }.isEmpty) {
254+
fail(s"Exchanges should have been added:\n$outputPlan")
255+
}
256+
}
257+
258+
test("EnsureRequirements ensures that children produce same number of partitions, if required") {
259+
// This is similar to the previous test, except it checks that partitionings are not compatible
260+
// unless they produce the same number of partitions. This requirement is also enforced via
261+
// assertions in Exchange.
226262
val clustering = Literal(1) :: Nil
227263
val distribution = ClusteredDistribution(clustering)
228264
val inputPlan = DummyPlan(
229265
children = Seq(
230266
DummyPlan(outputPartitioning = HashPartitioning(clustering, 1)),
231267
DummyPlan(outputPartitioning = HashPartitioning(clustering, 2))
232268
),
233-
requiresChildrenToProduceSameNumberOfPartitions = true,
269+
requiresChildPartitioningsToBeCompatible = true,
234270
requiredChildDistribution = Seq(distribution, distribution),
235271
requiredChildOrdering = Seq(Seq.empty, Seq.empty)
236272
)
@@ -239,14 +275,18 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
239275
}
240276

241277
test("EnsureRequirements should not repartition if only ordering requirement is unsatisfied") {
278+
// Consider an operator that imposes both output distribution and ordering requirements on its
279+
// children, such as sort sort merge join. If the distribution requirements are satisfied but
280+
// the output ordering requirements are unsatisfied, then the planner should only add sorts and
281+
// should not need to add additional shuffles / exchanges.
242282
val outputOrdering = Seq(SortOrder(Literal(1), Ascending))
243283
val distribution = ClusteredDistribution(Literal(1) :: Nil)
244284
val inputPlan = DummyPlan(
245285
children = Seq(
246286
DummyPlan(outputPartitioning = SinglePartition),
247287
DummyPlan(outputPartitioning = SinglePartition)
248288
),
249-
requiresChildrenToProduceSameNumberOfPartitions = true,
289+
requiresChildPartitioningsToBeCompatible = true,
250290
requiredChildDistribution = Seq(distribution, distribution),
251291
requiredChildOrdering = Seq(outputOrdering, outputOrdering)
252292
)
@@ -265,7 +305,7 @@ private case class DummyPlan(
265305
override val children: Seq[SparkPlan] = Nil,
266306
override val outputOrdering: Seq[SortOrder] = Nil,
267307
override val outputPartitioning: Partitioning = UnknownPartitioning(0),
268-
override val requiresChildrenToProduceSameNumberOfPartitions: Boolean = false,
308+
override val requiresChildPartitioningsToBeCompatible: Boolean = false,
269309
override val requiredChildDistribution: Seq[Distribution] = Nil,
270310
override val requiredChildOrdering: Seq[Seq[SortOrder]] = Nil
271311
) extends SparkPlan {

0 commit comments

Comments
 (0)