Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,19 @@ case class ClusteredDistribution(
* This is a strictly stronger guarantee than [[ClusteredDistribution]]. Given a tuple and the
* number of partitions, this distribution strictly requires which partition the tuple should be in.
*/
case class HashClusteredDistribution(expressions: Seq[Expression]) extends Distribution {
case class HashClusteredDistribution(
expressions: Seq[Expression],
requiredNumPartitions: Option[Int] = None) extends Distribution {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

require(
expressions != Nil,
"The expressions for hash of a HashPartitionedDistribution should not be Nil. " +
"The expressions for hash of a HashClusteredDistribution should not be Nil. " +
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")

override def requiredNumPartitions: Option[Int] = None

override def createPartitioning(numPartitions: Int): Partitioning = {
assert(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions,
s"This HashClusteredDistribution requires ${requiredNumPartitions.get} partitions, but " +
s"the actual number of partitions is $numPartitions.")
HashPartitioning(expressions, numPartitions)
}
}
Expand Down Expand Up @@ -163,11 +166,22 @@ trait Partitioning {
* i.e. the current dataset does not need to be re-partitioned for the `required`
* Distribution (it is possible that tuples within a partition need to be reorganized).
*
* A [[Partitioning]] can never satisfy a [[Distribution]] if its `numPartitions` does't match
* [[Distribution.requiredNumPartitions]].
*/
final def satisfies(required: Distribution): Boolean = {
required.requiredNumPartitions.forall(_ == numPartitions) && satisfies0(required)
}

/**
* The actual method that defines whether this [[Partitioning]] can satisfy the given
* [[Distribution]], after the `numPartitions` check.
*
* By default a [[Partitioning]] can satisfy [[UnspecifiedDistribution]], and [[AllTuples]] if
* the [[Partitioning]] only have one partition. Implementations can overwrite this method with
* special logic.
* the [[Partitioning]] only have one partition. Implementations can also overwrite this method
* with special logic.
*/
def satisfies(required: Distribution): Boolean = required match {
protected def satisfies0(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
case AllTuples => numPartitions == 1
case _ => false
Expand All @@ -186,9 +200,8 @@ case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning
case object SinglePartition extends Partitioning {
val numPartitions = 1

override def satisfies(required: Distribution): Boolean = required match {
override def satisfies0(required: Distribution): Boolean = required match {
Copy link
Contributor

@tdas tdas Jun 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add docs to explain what is satisfies0 and how it different from satisfies?
Otherwise its quite confusing.
When does one override satisfies, and when does one override satisfies0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added in the base class

case _: BroadcastDistribution => false
case ClusteredDistribution(_, Some(requiredNumPartitions)) => requiredNumPartitions == 1
case _ => true
}
}
Expand All @@ -205,16 +218,15 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
override def nullable: Boolean = false
override def dataType: DataType = IntegerType

override def satisfies(required: Distribution): Boolean = {
super.satisfies(required) || {
override def satisfies0(required: Distribution): Boolean = {
super.satisfies0(required) || {
required match {
case h: HashClusteredDistribution =>
expressions.length == h.expressions.length && expressions.zip(h.expressions).forall {
case (l, r) => l.semanticEquals(r)
}
case ClusteredDistribution(requiredClustering, requiredNumPartitions) =>
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions)
case ClusteredDistribution(requiredClustering, _) =>
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
case _ => false
}
}
Expand Down Expand Up @@ -246,15 +258,14 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
override def nullable: Boolean = false
override def dataType: DataType = IntegerType

override def satisfies(required: Distribution): Boolean = {
super.satisfies(required) || {
override def satisfies0(required: Distribution): Boolean = {
super.satisfies0(required) || {
required match {
case OrderedDistribution(requiredOrdering) =>
val minSize = Seq(requiredOrdering.size, ordering.size).min
requiredOrdering.take(minSize) == ordering.take(minSize)
case ClusteredDistribution(requiredClustering, requiredNumPartitions) =>
ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) &&
(requiredNumPartitions.isEmpty || requiredNumPartitions.get == numPartitions)
case ClusteredDistribution(requiredClustering, _) =>
ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x)))
case _ => false
}
}
Expand Down Expand Up @@ -295,7 +306,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
* Returns true if any `partitioning` of this collection satisfies the given
* [[Distribution]].
*/
override def satisfies(required: Distribution): Boolean =
override def satisfies0(required: Distribution): Boolean =
partitionings.exists(_.satisfies(required))

override def toString: String = {
Expand All @@ -310,7 +321,7 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning {
override val numPartitions: Int = 1

override def satisfies(required: Distribution): Boolean = required match {
override def satisfies0(required: Distribution): Boolean = required match {
case BroadcastDistribution(m) if m == mode => true
case _ => false
}
Expand Down
Loading