Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,6 @@ sealed trait Partitioning {
*/
def satisfies(required: Distribution): Boolean

/**
* Returns true iff all distribution guarantees made by this partitioning can also be made
* for the `other` specified partitioning.
* For example, two [[HashPartitioning HashPartitioning]]s are
* only compatible if the `numPartitions` of them is the same.
*/
def compatibleWith(other: Partitioning): Boolean

/** Returns the expressions that are used to key the partitioning. */
def keyExpressions: Seq[Expression]
}
Expand All @@ -104,11 +96,6 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
case _ => false
}

override def compatibleWith(other: Partitioning): Boolean = other match {
case UnknownPartitioning(_) => true
case _ => false
}

override def keyExpressions: Seq[Expression] = Nil
}

Expand All @@ -117,11 +104,6 @@ case object SinglePartition extends Partitioning {

override def satisfies(required: Distribution): Boolean = true

override def compatibleWith(other: Partitioning): Boolean = other match {
case SinglePartition => true
case _ => false
}

override def keyExpressions: Seq[Expression] = Nil
}

Expand All @@ -130,11 +112,6 @@ case object BroadcastPartitioning extends Partitioning {

override def satisfies(required: Distribution): Boolean = true

override def compatibleWith(other: Partitioning): Boolean = other match {
case SinglePartition => true
case _ => false
}

override def keyExpressions: Seq[Expression] = Nil
}

Expand All @@ -159,12 +136,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case _ => false
}

override def compatibleWith(other: Partitioning): Boolean = other match {
case BroadcastPartitioning => true
case h: HashPartitioning if h == this => true
case _ => false
}

override def keyExpressions: Seq[Expression] = expressions
}

Expand Down Expand Up @@ -199,11 +170,5 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case _ => false
}

override def compatibleWith(other: Partitioning): Boolean = other match {
case BroadcastPartitioning => true
case r: RangePartitioning if r == this => true
case _ => false
}

override def keyExpressions: Seq[Expression] = ordering.map(_.child)
}
Original file line number Diff line number Diff line change
Expand Up @@ -202,41 +202,6 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[

def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case operator: SparkPlan =>
// True iff every child's outputPartitioning satisfies the corresponding
// required data distribution.
def meetsRequirements: Boolean =
operator.requiredChildDistribution.zip(operator.children).forall {
case (required, child) =>
val valid = child.outputPartitioning.satisfies(required)
logDebug(
s"${if (valid) "Valid" else "Invalid"} distribution," +
s"required: $required current: ${child.outputPartitioning}")
valid
}

// True iff any of the children are incorrectly sorted.
def needsAnySort: Boolean =
operator.requiredChildOrdering.zip(operator.children).exists {
case (required, child) => required.nonEmpty && required != child.outputOrdering
}

// True iff outputPartitionings of children are compatible with each other.
// It is possible that every child satisfies its required data distribution
// but two children have incompatible outputPartitionings. For example,
// A dataset is range partitioned by "a.asc" (RangePartitioning) and another
// dataset is hash partitioned by "a" (HashPartitioning). Tuples in these two
// datasets are both clustered by "a", but these two outputPartitionings are not
// compatible.
// TODO: ASSUMES TRANSITIVITY?
def compatible: Boolean =
operator.children
.map(_.outputPartitioning)
.sliding(2)
.forall {
case Seq(a) => true
case Seq(a, b) => a.compatibleWith(b)
}

// Adds Exchange or Sort operators as required
def addOperatorsIfNecessary(
partitioning: Partitioning,
Expand Down Expand Up @@ -269,33 +234,26 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
addSortIfNecessary(addShuffleIfNecessary(child))
}

if (meetsRequirements && compatible && !needsAnySort) {
operator
} else {
// At least one child does not satisfies its required data distribution or
// at least one child's outputPartitioning is not compatible with another child's
// outputPartitioning. In this case, we need to add Exchange operators.
val requirements =
(operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)
val requirements =
(operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)

val fixedChildren = requirements.zipped.map {
case (AllTuples, rowOrdering, child) =>
addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
case (ClusteredDistribution(clustering), rowOrdering, child) =>
addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
case (OrderedDistribution(ordering), rowOrdering, child) =>
addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child)
val fixedChildren = requirements.zipped.map {
case (AllTuples, rowOrdering, child) =>
addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
case (ClusteredDistribution(clustering), rowOrdering, child) =>
addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
case (OrderedDistribution(ordering), rowOrdering, child) =>
addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child)

case (UnspecifiedDistribution, Seq(), child) =>
child
case (UnspecifiedDistribution, rowOrdering, child) =>
sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)
case (UnspecifiedDistribution, Seq(), child) =>
child
case (UnspecifiedDistribution, rowOrdering, child) =>
sqlContext.planner.BasicOperators.getSortOperator(rowOrdering, global = false, child)

case (dist, ordering, _) =>
sys.error(s"Don't know how to ensure $dist with ordering $ordering")
}

operator.withNewChildren(fixedChildren)
case (dist, ordering, _) =>
sys.error(s"Don't know how to ensure $dist with ordering $ordering")
}

operator.withNewChildren(fixedChildren)
}
}