Skip to content

Commit fce9053

Browse files
committed
Add the concept of nullSafe to ClusteredDistribution and Partitioning.
1 parent 8be198c commit fce9053

File tree

7 files changed

+195
-44
lines changed

7 files changed

+195
-44
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,21 @@ case object AllTuples extends Distribution
4949
* can mean such tuples are either co-located in the same partition or they will be contiguous
5050
* within a single partition.
5151
*/
52-
case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution {
52+
case class ClusteredDistribution(
53+
clustering: Seq[Expression],
54+
nullSafe: Boolean) extends Distribution {
5355
require(
5456
clustering != Nil,
5557
"The clustering expressions of a ClusteredDistribution should not be Nil. " +
5658
"An AllTuples should be used to represent a distribution that only has " +
5759
"a single partition.")
5860
}
5961

62+
object ClusteredDistribution {
63+
def apply(clustering: Seq[Expression]): ClusteredDistribution =
64+
ClusteredDistribution(clustering, nullSafe = true)
65+
}
66+
6067
/**
6168
* Represents data where tuples have been ordered according to the `ordering`
6269
* [[Expression Expressions]]. This is a strictly stronger guarantee than
@@ -90,9 +97,20 @@ sealed trait Partitioning {
9097
/**
9198
* Returns true iff we can say that the partitioning scheme of this [[Partitioning]]
9299
* guarantees the same partitioning scheme described by `other`.
100+
*
101+
* For example, HashPartitioning(expressions = 'a, numPartitions = 10, nullSafe = true)
102+
* guarantees HashPartitioning(expressions = 'a, numPartitions = 10, nullSafe = false).
103+
* However, HashPartitioning(expressions = 'a, numPartitions = 10, nullSafe = false) does not
104+
* guarantees HashPartitioning(expressions = 'a, numPartitions = 10, nullSafe = true).
93105
*/
94-
// TODO: Add an example once we have the `nullSafe` concept.
95106
def guarantees(other: Partitioning): Boolean
107+
108+
/**
109+
* If a [[Partitioning]] supports `nullSafe` setting, returns a new instance of this
110+
* [[Partitioning]] with the given nullSafe setting. Otherwise, returns this
111+
* [[Partitioning]].
112+
*/
113+
def withNullSafeSetting(newNullSafe: Boolean): Partitioning
96114
}
97115

98116
case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
@@ -102,6 +120,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
102120
}
103121

104122
override def guarantees(other: Partitioning): Boolean = false
123+
124+
override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this
105125
}
106126

107127
case object SinglePartition extends Partitioning {
@@ -113,6 +133,8 @@ case object SinglePartition extends Partitioning {
113133
case SinglePartition => true
114134
case _ => false
115135
}
136+
137+
override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this
116138
}
117139

118140
case object BroadcastPartitioning extends Partitioning {
@@ -124,14 +146,19 @@ case object BroadcastPartitioning extends Partitioning {
124146
case BroadcastPartitioning => true
125147
case _ => false
126148
}
149+
150+
override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this
127151
}
128152

129153
/**
130154
* Represents a partitioning where rows are split up across partitions based on the hash
131155
* of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be
132156
* in the same partition.
133157
*/
134-
case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
158+
case class HashPartitioning(
159+
expressions: Seq[Expression],
160+
numPartitions: Int,
161+
nullSafe: Boolean)
135162
extends Expression with Partitioning with Unevaluable {
136163

137164
override def children: Seq[Expression] = expressions
@@ -142,16 +169,30 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
142169

143170
override def satisfies(required: Distribution): Boolean = required match {
144171
case UnspecifiedDistribution => true
145-
case ClusteredDistribution(requiredClustering) =>
172+
case ClusteredDistribution(requiredClustering, _) if nullSafe =>
173+
clusteringSet.subsetOf(requiredClustering.toSet)
174+
case ClusteredDistribution(requiredClustering, false) if !nullSafe =>
146175
clusteringSet.subsetOf(requiredClustering.toSet)
147176
case _ => false
148177
}
149178

150179
override def guarantees(other: Partitioning): Boolean = other match {
151-
case o: HashPartitioning =>
180+
case o: HashPartitioning if (nullSafe || (!nullSafe && !o.nullSafe)) =>
152181
this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions
153182
case _ => false
154183
}
184+
185+
override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = {
186+
HashPartitioning(expressions, numPartitions, nullSafe = newNullSafe)
187+
}
188+
189+
override def toString: String =
190+
s"${super.toString} numPartitions=$numPartitions nullSafe=$nullSafe"
191+
}
192+
193+
object HashPartitioning {
194+
def apply(expressions: Seq[Expression], numPartitions: Int): HashPartitioning =
195+
HashPartitioning(expressions, numPartitions, nullSafe = true)
155196
}
156197

157198
/**
@@ -180,7 +221,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
180221
case OrderedDistribution(requiredOrdering) =>
181222
val minSize = Seq(requiredOrdering.size, ordering.size).min
182223
requiredOrdering.take(minSize) == ordering.take(minSize)
183-
case ClusteredDistribution(requiredClustering) =>
224+
case ClusteredDistribution(requiredClustering, _) =>
184225
clusteringSet.subsetOf(requiredClustering.toSet)
185226
case _ => false
186227
}
@@ -189,6 +230,10 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
189230
case o: RangePartitioning => this == o
190231
case _ => false
191232
}
233+
234+
override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this
235+
236+
override def toString: String = s"${super.toString} numPartitions=$numPartitions"
192237
}
193238

194239
/**
@@ -235,6 +280,10 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
235280
override def guarantees(other: Partitioning): Boolean =
236281
partitionings.exists(_.guarantees(other))
237282

283+
override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = {
284+
PartitioningCollection(partitionings.map(_.withNullSafeSetting(newNullSafe)))
285+
}
286+
238287
override def toString: String = {
239288
partitionings.map(_.toString).mkString("(", " or ", ")")
240289
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,80 @@ class DistributionSuite extends SparkFunSuite {
104104
*/
105105
}
106106

107+
test("HashPartitioning (with nullSafe = false) is the output partitioning") {
108+
// Cases which do not need an exchange between two data properties.
109+
checkSatisfied(
110+
HashPartitioning(Seq('a, 'b, 'c), 10, false),
111+
UnspecifiedDistribution,
112+
true)
113+
114+
checkSatisfied(
115+
HashPartitioning(Seq('a, 'b, 'c), 10, false),
116+
ClusteredDistribution(Seq('a, 'b, 'c), false),
117+
true)
118+
119+
checkSatisfied(
120+
HashPartitioning(Seq('b, 'c), 10, false),
121+
ClusteredDistribution(Seq('a, 'b, 'c), false),
122+
true)
123+
124+
checkSatisfied(
125+
SinglePartition,
126+
ClusteredDistribution(Seq('a, 'b, 'c), false),
127+
true)
128+
129+
checkSatisfied(
130+
SinglePartition,
131+
OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)),
132+
true)
133+
134+
// Cases which need an exchange between two data properties.
135+
checkSatisfied(
136+
HashPartitioning(Seq('a, 'b, 'c), 10, false),
137+
ClusteredDistribution(Seq('a, 'b, 'c)),
138+
false)
139+
140+
checkSatisfied(
141+
HashPartitioning(Seq('b, 'c), 10, false),
142+
ClusteredDistribution(Seq('a, 'b, 'c)),
143+
false)
144+
145+
checkSatisfied(
146+
HashPartitioning(Seq('a, 'b, 'c), 10, false),
147+
ClusteredDistribution(Seq('b, 'c)),
148+
false)
149+
150+
checkSatisfied(
151+
HashPartitioning(Seq('a, 'b, 'c), 10, false),
152+
ClusteredDistribution(Seq('d, 'e)),
153+
false)
154+
155+
checkSatisfied(
156+
HashPartitioning(Seq('a, 'b, 'c), 10, false),
157+
ClusteredDistribution(Seq('b, 'c), false),
158+
false)
159+
160+
checkSatisfied(
161+
HashPartitioning(Seq('a, 'b, 'c), 10, false),
162+
ClusteredDistribution(Seq('d, 'e), false),
163+
false)
164+
165+
checkSatisfied(
166+
HashPartitioning(Seq('a, 'b, 'c), 10, false),
167+
AllTuples,
168+
false)
169+
170+
checkSatisfied(
171+
HashPartitioning(Seq('a, 'b, 'c), 10, false),
172+
OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)),
173+
false)
174+
175+
checkSatisfied(
176+
HashPartitioning(Seq('b, 'c), 10, false),
177+
OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)),
178+
false)
179+
}
180+
107181
test("RangePartitioning is the output partitioning") {
108182
// Cases which do not need an exchange between two data properties.
109183
checkSatisfied(

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
148148
protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") {
149149
val rdd = child.execute()
150150
val part: Partitioner = newPartitioning match {
151-
case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions)
151+
case HashPartitioning(expressions, numPartitions, nullSafe) =>
152+
new HashPartitioner(numPartitions)
152153
case RangePartitioning(sortingExpressions, numPartitions) =>
153154
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
154155
// partition bounds. To get accurate samples, we need to copy the mutable keys.
@@ -167,7 +168,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
167168
// TODO: Handle BroadcastPartitioning.
168169
}
169170
def getPartitionKeyExtractor(): InternalRow => InternalRow = newPartitioning match {
170-
case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)()
171+
// TODO: If nullSafe is false, we can randomly distribute rows having any null in
172+
// clustering.
173+
case HashPartitioning(expressions, _, _) => newMutableProjection(expressions, child.output)()
171174
case RangePartitioning(_, _) | SinglePartition => identity
172175
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
173176
}
@@ -240,8 +243,9 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
240243
val fixedChildren = requirements.zipped.map {
241244
case (AllTuples, rowOrdering, child) =>
242245
addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
243-
case (ClusteredDistribution(clustering), rowOrdering, child) =>
244-
addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
246+
case (ClusteredDistribution(clustering, nullSafe), rowOrdering, child) =>
247+
val hashPartitioning = HashPartitioning(clustering, numPartitions, nullSafe)
248+
addOperatorsIfNecessary(hashPartitioning, rowOrdering, child)
245249
case (OrderedDistribution(ordering), rowOrdering, child) =>
246250
addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child)
247251

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ case class ShuffledHashJoin(
4242
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
4343

4444
override def requiredChildDistribution: Seq[Distribution] =
45-
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
45+
ClusteredDistribution(leftKeys, nullSafe = false) ::
46+
ClusteredDistribution(rightKeys, nullSafe = false) :: Nil
4647

4748
protected override def doExecute(): RDD[InternalRow] = {
4849
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,23 @@ case class ShuffledHashOuterJoin(
4242
right: SparkPlan) extends BinaryNode with HashOuterJoin {
4343

4444
override def requiredChildDistribution: Seq[Distribution] =
45-
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
45+
ClusteredDistribution(leftKeys, nullSafe = false) ::
46+
ClusteredDistribution(rightKeys, nullSafe = false) :: Nil
4647

4748
override def outputPartitioning: Partitioning = joinType match {
48-
case LeftOuter => left.outputPartitioning
49-
case RightOuter => right.outputPartitioning
50-
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
49+
case LeftOuter =>
50+
val partitions =
51+
left.outputPartitioning :: right.outputPartitioning.withNullSafeSetting(false) :: Nil
52+
PartitioningCollection(partitions)
53+
case RightOuter =>
54+
val partitions =
55+
Seq(right.outputPartitioning, left.outputPartitioning.withNullSafeSetting(false))
56+
PartitioningCollection(partitions)
57+
case FullOuter =>
58+
val partitions =
59+
left.outputPartitioning.withNullSafeSetting(false) ::
60+
right.outputPartitioning.withNullSafeSetting(false) :: Nil
61+
PartitioningCollection(partitions)
5162
case x =>
5263
throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
5364
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ case class SortMergeJoin(
4444
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
4545

4646
override def requiredChildDistribution: Seq[Distribution] =
47-
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
47+
ClusteredDistribution(leftKeys, nullSafe = false) ::
48+
ClusteredDistribution(rightKeys, nullSafe = false) :: Nil
4849

4950
// this is to manually construct an ordering that can be used to compare keys from both sides
5051
private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType))

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -170,35 +170,46 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
170170

171171
// Disable broadcast join
172172
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
173-
{
174-
val numExchanges = sql(
175-
"""
176-
|SELECT *
177-
|FROM
178-
| normal JOIN small ON (normal.key = small.key)
179-
| JOIN tiny ON (small.key = tiny.key)
180-
""".stripMargin
181-
).queryExecution.executedPlan.collect {
182-
case exchange: Exchange => exchange
183-
}.length
184-
assert(numExchanges === 3)
173+
val joins = Array("JOIN", "LEFT OUTER JOIN", "RIGHT OUTER JOIN", "FULL OUTER JOIN")
174+
var i = 0
175+
while (i < joins.length) {
176+
var j = 0
177+
while (j < joins.length) {
178+
val firstJoin: String = joins(i)
179+
val secondJoin: String = joins(j)
180+
181+
{
182+
val numExchanges: Int = sql(
183+
s"""
184+
|SELECT *
185+
|FROM
186+
| normal $firstJoin small ON (normal.key = small.key)
187+
| $secondJoin tiny ON (small.key = tiny.key)
188+
""".stripMargin
189+
).queryExecution.executedPlan.collect {
190+
case exchange: Exchange => exchange
191+
}.length
192+
assert(numExchanges === 3)
193+
}
194+
195+
{
196+
val numExchanges: Int = sql(
197+
s"""
198+
|SELECT *
199+
|FROM
200+
| normal $firstJoin small ON (normal.key = small.key)
201+
| $secondJoin tiny ON (normal.key = tiny.key)
202+
""".stripMargin
203+
).queryExecution.executedPlan.collect {
204+
case exchange: Exchange => exchange
205+
}.length
206+
assert(numExchanges === 3)
207+
}
208+
209+
j += 1
210+
}
211+
i += 1
185212
}
186-
187-
{
188-
// This second query joins on different keys:
189-
val numExchanges = sql(
190-
"""
191-
|SELECT *
192-
|FROM
193-
| normal JOIN small ON (normal.key = small.key)
194-
| JOIN tiny ON (normal.key = tiny.key)
195-
""".stripMargin
196-
).queryExecution.executedPlan.collect {
197-
case exchange: Exchange => exchange
198-
}.length
199-
assert(numExchanges === 3)
200-
}
201-
202213
}
203214
}
204215
}

0 commit comments

Comments
 (0)