Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,21 @@ case object AllTuples extends Distribution
* can mean such tuples are either co-located in the same partition or they will be contiguous
* within a single partition.
*/
case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution {
case class ClusteredDistribution(
clustering: Seq[Expression],
nullSafe: Boolean) extends Distribution {
require(
clustering != Nil,
"The clustering expressions of a ClusteredDistribution should not be Nil. " +
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")
}

object ClusteredDistribution {
def apply(clustering: Seq[Expression]): ClusteredDistribution =
ClusteredDistribution(clustering, nullSafe = true)
}

/**
* Represents data where tuples have been ordered according to the `ordering`
* [[Expression Expressions]]. This is a strictly stronger guarantee than
Expand Down Expand Up @@ -90,9 +97,22 @@ sealed trait Partitioning {
/**
* Returns true iff we can say that the partitioning scheme of this [[Partitioning]]
* guarantees the same partitioning scheme described by `other`.
*
* If a [[Partitioning]] supports `nullSafe` setting, the nullSafe version of this
* [[Partitioning]] should always `guarantees` its nullUnsafe version.
* For example, HashPartitioning(expressions = 'a, numPartitions = 10, nullSafe = true)
* guarantees HashPartitioning(expressions = 'a, numPartitions = 10, nullSafe = false).
* However, HashPartitioning(expressions = 'a, numPartitions = 10, nullSafe = false) does not
* guarantees HashPartitioning(expressions = 'a, numPartitions = 10, nullSafe = true).
*/
// TODO: Add an example once we have the `nullSafe` concept.
def guarantees(other: Partitioning): Boolean

/**
* If a [[Partitioning]] supports `nullSafe` setting, returns a new instance of this
* [[Partitioning]] with the given nullSafe setting. Otherwise, returns this
* [[Partitioning]].
*/
def withNullSafeSetting(newNullSafe: Boolean): Partitioning
}

case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
Expand All @@ -102,6 +122,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
}

override def guarantees(other: Partitioning): Boolean = false

override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this
}

case object SinglePartition extends Partitioning {
Expand All @@ -113,6 +135,8 @@ case object SinglePartition extends Partitioning {
case SinglePartition => true
case _ => false
}

override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this
}

case object BroadcastPartitioning extends Partitioning {
Expand All @@ -124,14 +148,19 @@ case object BroadcastPartitioning extends Partitioning {
case BroadcastPartitioning => true
case _ => false
}

override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this
}

/**
* Represents a partitioning where rows are split up across partitions based on the hash
* of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be
* in the same partition.
*/
case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case class HashPartitioning(
expressions: Seq[Expression],
numPartitions: Int,
nullSafe: Boolean)
extends Expression with Partitioning with Unevaluable {

override def children: Seq[Expression] = expressions
Expand All @@ -142,16 +171,30 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)

override def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
case ClusteredDistribution(requiredClustering) =>
case ClusteredDistribution(requiredClustering, _) if nullSafe =>
clusteringSet.subsetOf(requiredClustering.toSet)
case ClusteredDistribution(requiredClustering, false) if !nullSafe =>
clusteringSet.subsetOf(requiredClustering.toSet)
case _ => false
}

override def guarantees(other: Partitioning): Boolean = other match {
case o: HashPartitioning =>
case o: HashPartitioning if (nullSafe || (!nullSafe && !o.nullSafe)) =>
this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions
case _ => false
}

override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = {
HashPartitioning(expressions, numPartitions, nullSafe = newNullSafe)
}

override def toString: String =
s"${super.toString} numPartitions=$numPartitions nullSafe=$nullSafe"
}

object HashPartitioning {
def apply(expressions: Seq[Expression], numPartitions: Int): HashPartitioning =
HashPartitioning(expressions, numPartitions, nullSafe = true)
}

/**
Expand Down Expand Up @@ -180,7 +223,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case OrderedDistribution(requiredOrdering) =>
val minSize = Seq(requiredOrdering.size, ordering.size).min
requiredOrdering.take(minSize) == ordering.take(minSize)
case ClusteredDistribution(requiredClustering) =>
case ClusteredDistribution(requiredClustering, _) =>
clusteringSet.subsetOf(requiredClustering.toSet)
case _ => false
}
Expand All @@ -189,6 +232,10 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case o: RangePartitioning => this == o
case _ => false
}

override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = this

override def toString: String = s"${super.toString} numPartitions=$numPartitions"
}

/**
Expand Down Expand Up @@ -235,6 +282,10 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
override def guarantees(other: Partitioning): Boolean =
partitionings.exists(_.guarantees(other))

override def withNullSafeSetting(newNullSafe: Boolean): Partitioning = {
PartitioningCollection(partitionings.map(_.withNullSafeSetting(newNullSafe)))
}

override def toString: String = {
partitionings.map(_.toString).mkString("(", " or ", ")")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,80 @@ class DistributionSuite extends SparkFunSuite {
*/
}

test("HashPartitioning (with nullSafe = false) is the output partitioning") {
// Cases which do not need an exchange between two data properties.
checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10, false),
UnspecifiedDistribution,
true)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10, false),
ClusteredDistribution(Seq('a, 'b, 'c), false),
true)

checkSatisfied(
HashPartitioning(Seq('b, 'c), 10, false),
ClusteredDistribution(Seq('a, 'b, 'c), false),
true)

checkSatisfied(
SinglePartition,
ClusteredDistribution(Seq('a, 'b, 'c), false),
true)

checkSatisfied(
SinglePartition,
OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)),
true)

// Cases which need an exchange between two data properties.
checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10, false),
ClusteredDistribution(Seq('a, 'b, 'c)),
false)

checkSatisfied(
HashPartitioning(Seq('b, 'c), 10, false),
ClusteredDistribution(Seq('a, 'b, 'c)),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10, false),
ClusteredDistribution(Seq('b, 'c)),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10, false),
ClusteredDistribution(Seq('d, 'e)),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10, false),
ClusteredDistribution(Seq('b, 'c), false),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10, false),
ClusteredDistribution(Seq('d, 'e), false),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10, false),
AllTuples,
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10, false),
OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)),
false)

checkSatisfied(
HashPartitioning(Seq('b, 'c), 10, false),
OrderedDistribution(Seq('a.asc, 'b.asc, 'c.asc)),
false)
}

test("RangePartitioning is the output partitioning") {
// Cases which do not need an exchange between two data properties.
checkSatisfied(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
protected override def doExecute(): RDD[InternalRow] = attachTree(this , "execute") {
val rdd = child.execute()
val part: Partitioner = newPartitioning match {
case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions)
case HashPartitioning(expressions, numPartitions, nullSafe) =>
new HashPartitioner(numPartitions)
case RangePartitioning(sortingExpressions, numPartitions) =>
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
// partition bounds. To get accurate samples, we need to copy the mutable keys.
Expand All @@ -167,7 +168,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
// TODO: Handle BroadcastPartitioning.
}
def getPartitionKeyExtractor(): InternalRow => InternalRow = newPartitioning match {
case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)()
// TODO: If nullSafe is false, we can randomly distribute rows having any null in
// clustering.
case HashPartitioning(expressions, _, _) => newMutableProjection(expressions, child.output)()
case RangePartitioning(_, _) | SinglePartition => identity
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
}
Expand Down Expand Up @@ -210,7 +213,12 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[

def addShuffleIfNecessary(child: SparkPlan): SparkPlan = {
if (!child.outputPartitioning.guarantees(partitioning)) {
Exchange(partitioning, child)
// If the child's outputPartitioning does not guarantees partitioning,
// we need to add an Exchange operator. At here, we always use
// the nullSafe version of the given partitioning because the nullSafe
// version always guarantees the nullUnsafe version of the partitioning and
// we do not have any special handling for nullUnsafe partitioning for now.
Exchange(partitioning.withNullSafeSetting(newNullSafe = true), child)
} else {
child
}
Expand Down Expand Up @@ -240,8 +248,9 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
val fixedChildren = requirements.zipped.map {
case (AllTuples, rowOrdering, child) =>
addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
case (ClusteredDistribution(clustering), rowOrdering, child) =>
addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
case (ClusteredDistribution(clustering, nullSafe), rowOrdering, child) =>
val hashPartitioning = HashPartitioning(clustering, numPartitions, nullSafe)
addOperatorsIfNecessary(hashPartitioning, rowOrdering, child)
case (OrderedDistribution(ordering), rowOrdering, child) =>
addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ case class ShuffledHashJoin(
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
ClusteredDistribution(leftKeys, nullSafe = false) ::
ClusteredDistribution(rightKeys, nullSafe = false) :: Nil

protected override def doExecute(): RDD[InternalRow] = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,23 @@ case class ShuffledHashOuterJoin(
right: SparkPlan) extends BinaryNode with HashOuterJoin {

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
ClusteredDistribution(leftKeys, nullSafe = false) ::
ClusteredDistribution(rightKeys, nullSafe = false) :: Nil

override def outputPartitioning: Partitioning = joinType match {
case LeftOuter => left.outputPartitioning
case RightOuter => right.outputPartitioning
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
case LeftOuter =>
val partitions =
left.outputPartitioning :: right.outputPartitioning.withNullSafeSetting(false) :: Nil
PartitioningCollection(partitions)
case RightOuter =>
val partitions =
Seq(right.outputPartitioning, left.outputPartitioning.withNullSafeSetting(false))
PartitioningCollection(partitions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to overwrite the PartitioningCollection.nullSafe.

case FullOuter =>
val partitions =
left.outputPartitioning.withNullSafeSetting(false) ::
right.outputPartitioning.withNullSafeSetting(false) :: Nil
PartitioningCollection(partitions)
case x =>
throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ case class SortMergeJoin(
PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
ClusteredDistribution(leftKeys, nullSafe = false) ::
ClusteredDistribution(rightKeys, nullSafe = false) :: Nil

// this is to manually construct an ordering that can be used to compare keys from both sides
private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType))
Expand Down
Loading