Skip to content

Commit c6667e7

Browse files
committed
Add PartitioningCollection.
1 parent e616d3b commit c6667e7

File tree

5 files changed

+73
-8
lines changed

5 files changed

+73
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,43 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
305305

306306
override def toNullUnsafePartitioning: Partitioning = this
307307
}
308+
309+
/**
310+
* A collection of [[Partitioning]]s.
311+
*/
312+
case class PartitioningCollection(partitionings: Seq[Partitioning])
313+
extends Expression with Partitioning with Unevaluable {
314+
315+
require(
316+
partitionings.map(_.numPartitions).distinct.length == 1,
317+
s"PartitioningCollection requires all of its partitionings have the same numPartitions.")
318+
319+
override def children: Seq[Expression] = partitionings.collect {
320+
case expr: Expression => expr
321+
}
322+
323+
override def nullable: Boolean = false
324+
325+
override def dataType: DataType = IntegerType
326+
327+
override val numPartitions = partitionings.map(_.numPartitions).distinct.head
328+
329+
override def satisfies(required: Distribution): Boolean =
330+
partitionings.exists(_.satisfies(required))
331+
332+
override def compatibleWith(other: Partitioning): Boolean =
333+
partitionings.exists(_.compatibleWith(other))
334+
335+
override def guarantees(other: Partitioning): Boolean =
336+
partitionings.exists(_.guarantees(other))
337+
338+
override def keyExpressions: Seq[Expression] = partitionings.head.keyExpressions
339+
340+
override def toNullUnsafePartitioning: Partitioning = {
341+
PartitioningCollection(partitionings.map(_.toNullUnsafePartitioning))
342+
}
343+
344+
override def toString: String = {
345+
partitionings.map(_.toString).mkString("(", " or ", ")")
346+
}
347+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,20 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
164164
// TODO: Handle BroadcastPartitioning.
165165
}
166166
def getPartitionKeyExtractor(): InternalRow => InternalRow = newPartitioning match {
167-
case NullSafeHashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)()
167+
case NullSafeHashPartitioning(expressions, _) =>
168+
// Since NullSafeHashPartitioning and NullUnsafeHashPartitioning may be used together
169+
// for a join operator. We need to make sure they calculate the partition id with
170+
// the same way.
171+
val materalizeExpressions = newMutableProjection(expressions, child.output)()
172+
val partitionExpressionSchema = expressions.map {
173+
case ne: NamedExpression => ne.toAttribute
174+
case expr => Alias(expr, "partitionExpr")().toAttribute
175+
}
176+
val partitionId = RowHashCode
177+
val partitionIdExtractor =
178+
newMutableProjection(partitionId :: Nil, partitionExpressionSchema)()
179+
(row: InternalRow) => partitionIdExtractor(materalizeExpressions(row))
180+
// newMutableProjection(expressions, child.output)()
168181
case NullUnsafeHashPartitioning(expressions, numPartition) =>
169182
// For NullUnsafeHashPartitioning, we do not want to send rows having any expression
170183
// in `expressions` evaluated as null to the same node.
@@ -261,7 +274,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
261274
child: SparkPlan): SparkPlan = {
262275

263276
def addShuffleIfNecessary(child: SparkPlan): SparkPlan = {
264-
if (child.outputPartitioning.guarantees(partitioning)) {
277+
if (!child.outputPartitioning.guarantees(partitioning)) {
265278
Exchange(partitioning, child)
266279
} else {
267280
child

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
2121
import org.apache.spark.rdd.RDD
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions.Expression
24-
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, NullUnsafeClusteredDistribution, NullSafeClusteredDistribution, Partitioning}
24+
import org.apache.spark.sql.catalyst.plans.physical._
2525
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
2626

2727
/**
@@ -38,7 +38,8 @@ case class ShuffledHashJoin(
3838
right: SparkPlan)
3939
extends BinaryNode with HashJoin {
4040

41-
override def outputPartitioning: Partitioning = left.outputPartitioning
41+
override def outputPartitioning: Partitioning =
42+
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
4243

4344
override def requiredChildDistribution: Seq[Distribution] =
4445
NullSafeClusteredDistribution(leftKeys) :: NullSafeClusteredDistribution(rightKeys) :: Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,19 @@ case class ShuffledHashOuterJoin(
4848
NullUnsafeClusteredDistribution(rightKeys) :: Nil
4949

5050
override def outputPartitioning: Partitioning = joinType match {
51-
case LeftOuter => left.outputPartitioning
52-
case RightOuter => right.outputPartitioning
53-
case FullOuter => left.outputPartitioning.toNullUnsafePartitioning
51+
case LeftOuter =>
52+
val partitions =
53+
Seq(left.outputPartitioning, right.outputPartitioning.toNullUnsafePartitioning)
54+
PartitioningCollection(partitions)
55+
case RightOuter =>
56+
val partitions =
57+
Seq(right.outputPartitioning, left.outputPartitioning.toNullUnsafePartitioning)
58+
PartitioningCollection(partitions)
59+
case FullOuter =>
60+
val partitions =
61+
Seq(left.outputPartitioning.toNullUnsafePartitioning,
62+
right.outputPartitioning.toNullUnsafePartitioning)
63+
PartitioningCollection(partitions)
5464
case x =>
5565
throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
5666
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ case class SortMergeJoin(
4040

4141
override def output: Seq[Attribute] = left.output ++ right.output
4242

43-
override def outputPartitioning: Partitioning = left.outputPartitioning
43+
override def outputPartitioning: Partitioning =
44+
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
4445

4546
override def requiredChildDistribution: Seq[Distribution] =
4647
NullUnsafeClusteredDistribution(leftKeys) :: NullUnsafeClusteredDistribution(rightKeys) :: Nil

0 commit comments

Comments
 (0)