Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/** Specifies how data is partitioned across different nodes in the cluster. */
def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH!

/** Specifies any partition requirements on the input data for this operator. */
/**
* Specifies the data distribution requirements of all the children for this operator. By default
* it's [[UnspecifiedDistribution]] for each child, which means each child can have any
* distribution.
*
* If an operator overwrites this method, and specifies distribution requirements(excluding
* [[UnspecifiedDistribution]] and [[BroadcastDistribution]]) for more than one child, Spark
* guarantees that the outputs of these children will have same number of partitions, so that the
* operator can safely zip partitions of these children's result RDDs. Some operators can leverage
* this guarantee to satisfy some interesting requirement, e.g., non-broadcast joins can specify
* HashClusteredDistribution(a,b) for its left child, and specify HashClusteredDistribution(c,d)
* for its right child, then it's guaranteed that left and right child are co-partitioned by
* a,b/c,d, which means tuples of same value are in the partitions of same index, e.g.,
* (a=1,b=2) and (c=1,d=2) are both in the second partition of left and right child.
*/
def requiredChildDistribution: Seq[Distribution] =
Seq.fill(children.size)(UnspecifiedDistribution)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,23 +42,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None
}

/**
* Given a required distribution, returns a partitioning that satisfies that distribution.
* @param requiredDistribution The distribution that is required by the operator
* @param numPartitions Used when the distribution doesn't require a specific number of partitions
*/
private def createPartitioning(
requiredDistribution: Distribution,
numPartitions: Int): Partitioning = {
requiredDistribution match {
case AllTuples => SinglePartition
case ClusteredDistribution(clustering, desiredPartitions) =>
HashPartitioning(clustering, desiredPartitions.getOrElse(numPartitions))
case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions)
case dist => sys.error(s"Do not know how to satisfy distribution $dist")
}
}

/**
* Adds [[ExchangeCoordinator]] to [[ShuffleExchangeExec]]s if adaptive query execution is enabled
* and partitioning schemes of these [[ShuffleExchangeExec]]s support [[ExchangeCoordinator]].
Expand All @@ -84,8 +67,9 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
// shuffle data when we have more than one children because data generated by
// these children may not be partitioned in the same way.
// Please see the comment in withCoordinator for more details.
val supportsDistribution =
requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution])
val supportsDistribution = requiredChildDistributions.forall { dist =>
dist.isInstanceOf[ClusteredDistribution] || dist.isInstanceOf[HashClusteredDistribution]
}
children.length > 1 && supportsDistribution
}

Expand Down Expand Up @@ -138,8 +122,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
//
// It will be great to introduce a new Partitioning to represent the post-shuffle
// partitions when one post-shuffle partition includes multiple pre-shuffle partitions.
val targetPartitioning =
createPartitioning(distribution, defaultNumPreShufflePartitions)
val targetPartitioning = distribution.createPartitioning(defaultNumPreShufflePartitions)
assert(targetPartitioning.isInstanceOf[HashPartitioning])
ShuffleExchangeExec(targetPartitioning, child, Some(coordinator))
}
Expand All @@ -158,71 +141,56 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
assert(requiredChildDistributions.length == children.length)
assert(requiredChildOrderings.length == children.length)

// Ensure that the operator's children satisfy their output distribution requirements:
// Ensure that the operator's children satisfy their output distribution requirements.
children = children.zip(requiredChildDistributions).map {
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
case (child, BroadcastDistribution(mode)) =>
BroadcastExchangeExec(mode, child)
case (child, distribution) =>
ShuffleExchangeExec(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
val numPartitions = distribution.requiredNumPartitions
.getOrElse(defaultNumPreShufflePartitions)
ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child)
}

// If the operator has multiple children and specifies child output distributions (e.g. join),
// then the children's output partitionings must be compatible:
def requireCompatiblePartitioning(distribution: Distribution): Boolean = distribution match {
case UnspecifiedDistribution => false
case BroadcastDistribution(_) => false
// Get the indexes of children which have specified distribution requirements and need to have
// same number of partitions.
val childrenIndexes = requiredChildDistributions.zipWithIndex.filter {
case (UnspecifiedDistribution, _) => false
case (_: BroadcastDistribution, _) => false
case _ => true
}
if (children.length > 1
&& requiredChildDistributions.exists(requireCompatiblePartitioning)
&& !Partitioning.allCompatible(children.map(_.outputPartitioning))) {

// First check if the existing partitions of the children all match. This means they are
// partitioned by the same partitioning into the same number of partitions. In that case,
// don't try to make them match `defaultPartitions`, just use the existing partitioning.
val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max
val useExistingPartitioning = children.zip(requiredChildDistributions).forall {
case (child, distribution) =>
child.outputPartitioning.guarantees(
createPartitioning(distribution, maxChildrenNumPartitions))
}.map(_._2)

val childrenNumPartitions =
childrenIndexes.map(children(_).outputPartitioning.numPartitions).toSet

if (childrenNumPartitions.size > 1) {
// Get the number of partitions which is explicitly required by the distributions.
val requiredNumPartitions = {
val numPartitionsSet = childrenIndexes.flatMap {
index => requiredChildDistributions(index).requiredNumPartitions
}.toSet
assert(numPartitionsSet.size <= 1,
s"$operator have incompatible requirements of the number of partitions for its children")
numPartitionsSet.headOption
}

children = if (useExistingPartitioning) {
// We do not need to shuffle any child's output.
children
} else {
// We need to shuffle at least one child's output.
// Now, we will determine the number of partitions that will be used by created
// partitioning schemes.
val numPartitions = {
// Let's see if we need to shuffle all child's outputs when we use
// maxChildrenNumPartitions.
val shufflesAllChildren = children.zip(requiredChildDistributions).forall {
case (child, distribution) =>
!child.outputPartitioning.guarantees(
createPartitioning(distribution, maxChildrenNumPartitions))
val targetNumPartitions = requiredNumPartitions.getOrElse(childrenNumPartitions.max)

children = children.zip(requiredChildDistributions).zipWithIndex.map {
case ((child, distribution), index) if childrenIndexes.contains(index) =>
if (child.outputPartitioning.numPartitions == targetNumPartitions) {
child
} else {
val defaultPartitioning = distribution.createPartitioning(targetNumPartitions)
child match {
// If child is an exchange, we replace it with a new one having defaultPartitioning.
case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(defaultPartitioning, c)
case _ => ShuffleExchangeExec(defaultPartitioning, child)
}
}
// If we need to shuffle all children, we use defaultNumPreShufflePartitions as the
// number of partitions. Otherwise, we use maxChildrenNumPartitions.
if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions
}

children.zip(requiredChildDistributions).map {
case (child, distribution) =>
val targetPartitioning = createPartitioning(distribution, numPartitions)
if (child.outputPartitioning.guarantees(targetPartitioning)) {
child
} else {
child match {
// If child is an exchange, we replace it with
// a new one having targetPartitioning.
case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(targetPartitioning, c)
case _ => ShuffleExchangeExec(targetPartitioning, child)
}
}
}
case ((child, _), _) => child
}
}

Expand All @@ -249,10 +217,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
}

def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case operator @ ShuffleExchangeExec(partitioning, child, _) =>
child.children match {
case ShuffleExchangeExec(childPartitioning, baseChild, _)::Nil =>
if (childPartitioning.guarantees(partitioning)) child else operator
// TODO: remove this after we create a physical operator for `RepartitionByExpression`.
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) =>
child.outputPartitioning match {
case lower: HashPartitioning if upper.semanticEquals(lower) => child
case _ => operator
}
case operator: SparkPlan => ensureDistributionAndOrdering(operator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ case class ShuffledHashJoinExec(
"avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe"))

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil

private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = {
val buildDataSize = longMetric("buildDataSize")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ case class SortMergeJoinExec(
}

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil

override def outputOrdering: Seq[SortOrder] = joinType match {
// For inner join, orders of both sides keys should be kept.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ case class CoGroupExec(
right: SparkPlan) extends BinaryExecNode with ObjectProducerExec {

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
HashClusteredDistribution(leftGroup) :: HashClusteredDistribution(rightGroup) :: Nil

override def requiredChildOrdering: Seq[Seq[SortOrder]] =
leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil
Expand Down
Loading