Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -312,26 +312,37 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
* Represents a partitioning where rows are split across partitions based on transforms defined
* by `expressions`. `partitionValues`, if defined, should contain value of partition key(s) in
* ascending order, after evaluated by the transforms in `expressions`, for each input partition.
* In addition, its length must be the same as the number of input partitions (and thus is a 1-1
* mapping). The `partitionValues` may contain duplicated partition values.
* In addition, its length must be the same as the number of Spark partitions (and thus is a 1-1
* mapping), and each row in `partitionValues` must be unique.
*
* For example, if `expressions` is `[years(ts_col)]`, then a valid value of `partitionValues` is
* `[0, 1, 2]`, which represents 3 input partitions with distinct partition values. All rows
* in each partition have the same value for column `ts_col` (which is of timestamp type), after
* being applied by the `years` transform.
* The `originalPartitionValues`, on the other hand, are partition values from the original input
* splits returned by data sources. It may contain duplicated values.
*
* On the other hand, `[0, 0, 1]` is not a valid value for `partitionValues` since `0` is
* duplicated twice.
* For example, if a data source reports partition transform expressions `[years(ts_col)]` with 4
* input splits whose corresponding partition values are `[0, 1, 2, 2]`, then the `expressions`
* in this case is `[years(ts_col)]`, while `partitionValues` is `[0, 1, 2]`, which
* represents 3 input partitions with distinct partition values. All rows in each partition have
* the same value for column `ts_col` (which is of timestamp type), after being applied by the
* `years` transform. This is generated after combining the two splits with partition value `2`
* into a single Spark partition.
*
* On the other hand, in this example `[0, 1, 2, 2]` is the value of `originalPartitionValues`
* which is calculated from the original input splits.
*
* @param expressions partition expressions for the partitioning.
* @param numPartitions the number of partitions
* @param partitionValues the values for the cluster keys of the distribution, must be
* in ascending order.
* @param partitionValues the values for the final cluster keys (that is, after applying grouping
* on the input splits according to `expressions`) of the distribution,
* must be in ascending order, and must NOT contain duplicated values.
* @param originalPartitionValues the original input partition values before any grouping has been
* applied, must be in ascending order, and may contain duplicated
* values
*/
case class KeyGroupedPartitioning(
expressions: Seq[Expression],
numPartitions: Int,
partitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning {
partitionValues: Seq[InternalRow] = Seq.empty,
originalPartitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning {

override def satisfies0(required: Distribution): Boolean = {
super.satisfies0(required) || {
Expand Down Expand Up @@ -368,7 +379,7 @@ object KeyGroupedPartitioning {
def apply(
expressions: Seq[Expression],
partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues)
KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues, partitionValues)
}

def supportsExpressions(expressions: Seq[Expression]): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Par
import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper}
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.internal.SQLConf

/**
* Physical plan node for scanning a batch of data from a data source v2.
Expand Down Expand Up @@ -101,7 +100,7 @@ case class BatchScanExec(
"partition values that are not present in the original partitioning.")
}

groupPartitions(newPartitions).get.map(_._2)
groupPartitions(newPartitions).get.groupedParts.map(_.parts)

case _ =>
// no validation is needed as the data source did not report any specific partitioning
Expand Down Expand Up @@ -137,81 +136,63 @@ case class BatchScanExec(

outputPartitioning match {
case p: KeyGroupedPartitioning =>
if (conf.v2BucketingPushPartValuesEnabled &&
conf.v2BucketingPartiallyClusteredDistributionEnabled) {
assert(filteredPartitions.forall(_.size == 1),
"Expect partitions to be not grouped when " +
s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
"is enabled")

val groupedPartitions = groupPartitions(finalPartitions.map(_.head),
groupSplits = true).get

// This means the input partitions are not grouped by partition values. We'll need to
// check `groupByPartitionValues` and decide whether to group and replicate splits
// within a partition.
if (spjParams.commonPartitionValues.isDefined &&
spjParams.applyPartialClustering) {
// A mapping from the common partition values to how many splits the partition
// should contain.
val commonPartValuesMap = spjParams.commonPartitionValues
val groupedPartitions = filteredPartitions.map(splits => {
assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey])
(splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits)
})

// When partially clustered, the input partitions are not grouped by partition
// values. Here we'll need to check `commonPartitionValues` and decide how to group
// and replicate splits within a partition.
if (spjParams.commonPartitionValues.isDefined && spjParams.applyPartialClustering) {
// A mapping from the common partition values to how many splits the partition
// should contain.
val commonPartValuesMap = spjParams.commonPartitionValues
.get
.map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2))
.toMap
val nestGroupedPartitions = groupedPartitions.map {
case (partValue, splits) =>
// `commonPartValuesMap` should contain the part value since it's the super set.
val numSplits = commonPartValuesMap
.get(InternalRowComparableWrapper(partValue, p.expressions))
assert(numSplits.isDefined, s"Partition value $partValue does not exist in " +
"common partition values from Spark plan")

val newSplits = if (spjParams.replicatePartitions) {
// We need to also replicate partitions according to the other side of join
Seq.fill(numSplits.get)(splits)
} else {
// Not grouping by partition values: this could be the side with partially
// clustered distribution. Because of dynamic filtering, we'll need to check if
// the final number of splits of a partition is smaller than the original
// number, and fill with empty splits if so. This is necessary so that both
// sides of a join will have the same number of partitions & splits.
splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
}
(InternalRowComparableWrapper(partValue, p.expressions), newSplits)
val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) =>
// `commonPartValuesMap` should contain the part value since it's the super set.
val numSplits = commonPartValuesMap
.get(InternalRowComparableWrapper(partValue, p.expressions))
assert(numSplits.isDefined, s"Partition value $partValue does not exist in " +
"common partition values from Spark plan")

val newSplits = if (spjParams.replicatePartitions) {
// We need to also replicate partitions according to the other side of join
Seq.fill(numSplits.get)(splits)
} else {
// Not grouping by partition values: this could be the side with partially
// clustered distribution. Because of dynamic filtering, we'll need to check if
// the final number of splits of a partition is smaller than the original
// number, and fill with empty splits if so. This is necessary so that both
// sides of a join will have the same number of partitions & splits.
splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
}
(InternalRowComparableWrapper(partValue, p.expressions), newSplits)
}

// Now fill missing partition keys with empty partitions
val partitionMapping = nestGroupedPartitions.toMap
finalPartitions = spjParams.commonPartitionValues.get.flatMap {
case (partValue, numSplits) =>
// Use empty partition for those partition values that are not present.
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, p.expressions),
Seq.fill(numSplits)(Seq.empty))
}
} else {
// either `commonPartitionValues` is not defined, or it is defined but
// `applyPartialClustering` is false.
val partitionMapping = groupedPartitions.map { case (row, parts) =>
InternalRowComparableWrapper(row, p.expressions) -> parts
}.toMap

// In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there
// could exist duplicated partition values, as partition grouping is not done
// at the beginning and postponed to this method. It is important to use unique
// partition values here so that grouped partitions won't get duplicated.
finalPartitions = p.uniquePartitionValues.map { partValue =>
// Use empty partition for those partition values that are not present
// Now fill missing partition keys with empty partitions
val partitionMapping = nestGroupedPartitions.toMap
finalPartitions = spjParams.commonPartitionValues.get.flatMap {
case (partValue, numSplits) =>
// Use empty partition for those partition values that are not present.
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
}
InternalRowComparableWrapper(partValue, p.expressions),
Seq.fill(numSplits)(Seq.empty))
}
} else {
val partitionMapping = finalPartitions.map { parts =>
val row = parts.head.asInstanceOf[HasPartitionKey].partitionKey()
InternalRowComparableWrapper(row, p.expressions) -> parts
// either `commonPartitionValues` is not defined, or it is defined but
// `applyPartialClustering` is false.
val partitionMapping = groupedPartitions.map { case (partValue, splits) =>
InternalRowComparableWrapper(partValue, p.expressions) -> splits
}.toMap
finalPartitions = p.partitionValues.map { partValue =>

// In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there
// could exist duplicated partition values, as partition grouping is not done
// at the beginning and postponed to this method. It is important to use unique
// partition values here so that grouped partitions won't get duplicated.
finalPartitions = p.uniquePartitionValues.map { partValue =>
// Use empty partition for those partition values that are not present
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
redact(result)
}

def partitions: Seq[Seq[InputPartition]] =
groupedPartitions.map(_.map(_._2)).getOrElse(inputPartitions.map(Seq(_)))
def partitions: Seq[Seq[InputPartition]] = {
groupedPartitions.map(_.groupedParts.map(_.parts)).getOrElse(inputPartitions.map(Seq(_)))
}

/**
* Shorthand for calling redact() without specifying redacting rules
Expand Down Expand Up @@ -94,16 +95,18 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
keyGroupedPartitioning match {
case Some(exprs) if KeyGroupedPartitioning.supportsExpressions(exprs) =>
groupedPartitions
.map { partitionValues =>
KeyGroupedPartitioning(exprs, partitionValues.size, partitionValues.map(_._1))
.map { keyGroupedPartsInfo =>
val keyGroupedParts = keyGroupedPartsInfo.groupedParts
KeyGroupedPartitioning(exprs, keyGroupedParts.size, keyGroupedParts.map(_.value),
keyGroupedPartsInfo.originalParts.map(_.partitionKey()))
}
.getOrElse(super.outputPartitioning)
case _ =>
super.outputPartitioning
}
}

@transient lazy val groupedPartitions: Option[Seq[(InternalRow, Seq[InputPartition])]] = {
@transient lazy val groupedPartitions: Option[KeyGroupedPartitionInfo] = {
// Early check if we actually need to materialize the input partitions.
keyGroupedPartitioning match {
case Some(_) => groupPartitions(inputPartitions)
Expand All @@ -117,24 +120,21 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
* - all input partitions implement [[HasPartitionKey]]
* - `keyGroupedPartitioning` is set
*
* The result, if defined, is a list of tuples where the first element is a partition value,
* and the second element is a list of input partitions that share the same partition value.
* The result, if defined, is a [[KeyGroupedPartitionInfo]] which contains a list of
* [[KeyGroupedPartition]], as well as a list of partition values from the original input splits,
* sorted according to the partition keys in ascending order.
*
* A non-empty result means each partition is clustered on a single key and therefore eligible
* for further optimizations to eliminate shuffling in some operations such as join and aggregate.
*/
def groupPartitions(
inputPartitions: Seq[InputPartition],
groupSplits: Boolean = !conf.v2BucketingPushPartValuesEnabled ||
!conf.v2BucketingPartiallyClusteredDistributionEnabled):
Option[Seq[(InternalRow, Seq[InputPartition])]] = {

def groupPartitions(inputPartitions: Seq[InputPartition]): Option[KeyGroupedPartitionInfo] = {
if (!SQLConf.get.v2BucketingEnabled) return None

keyGroupedPartitioning.flatMap { expressions =>
val results = inputPartitions.takeWhile {
case _: HasPartitionKey => true
case _ => false
}.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p))
}.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p.asInstanceOf[HasPartitionKey]))

if (results.length != inputPartitions.length || inputPartitions.isEmpty) {
// Not all of the `InputPartitions` implements `HasPartitionKey`, therefore skip here.
Expand All @@ -143,32 +143,25 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
// also sort the input partitions according to their partition key order. This ensures
// a canonical order from both sides of a bucketed join, for example.
val partitionDataTypes = expressions.map(_.dataType)
val partitionOrdering: Ordering[(InternalRow, Seq[InputPartition])] = {
val partitionOrdering: Ordering[(InternalRow, InputPartition)] = {
RowOrdering.createNaturalAscendingOrdering(partitionDataTypes).on(_._1)
}

val partitions = if (groupSplits) {
// Group the splits by their partition value
results
val sortedKeyToPartitions = results.sorted(partitionOrdering)
val groupedPartitions = sortedKeyToPartitions
.map(t => (InternalRowComparableWrapper(t._1, expressions), t._2))
.groupBy(_._1)
Copy link
Contributor

@LuciferYang LuciferYang Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem likely comes from this groupBy, as there are some differences between Scala 2.12 and Scala 2.13.

For example:

  • Scala 2.12.18
Welcome to Scala 2.12.18 (OpenJDK 64-Bit Server VM, Java 1.8.0_382).
Type in expressions for evaluation. Or try :help.

scala> val input = Seq((50,50),(51,51),(52,52))
input: Seq[(Int, Int)] = List((50,50), (51,51), (52,52))

scala> input.groupBy(_._1).toSeq
res0: Seq[(Int, Seq[(Int, Int)])] = Vector((50,List((50,50))), (51,List((51,51))), (52,List((52,52))))
  • Scala 2.13.8
Welcome to Scala 2.13.8 (OpenJDK 64-Bit Server VM, Java 1.8.0_382).
Type in expressions for evaluation. Or try :help.

scala> val input = Seq((50,50),(51,51),(52,52))
val input: Seq[(Int, Int)] = List((50,50), (51,51), (52,52))

scala> input.groupBy(_._1).toSeq
val res0: Seq[(Int, Seq[(Int, Int)])] = List((52,List((52,52))), (50,List((50,50))), (51,List((51,51))))

We can see that when using Scala 2.13.8, the order of the results has changed.

The possible fix maybe:

  1. Using another function to replace groupBy to maintain the output order, such as foldLeft with LinkedHashMap ?
  2. Re-sorting the groupedPartitions ?

Perhaps there are other better ways to fix it?

Copy link
Member Author

@sunchao sunchao Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @LuciferYang for the findings! Yes it's a bug as I was assuming the order will be preserved in the groupBy. Let me open a follow-up PR to fix this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened #42839 to fix this.

.toSeq
.map {
case (key, s) => (key.row, s.map(_._2))
}
} else {
// No splits grouping, each split will become a separate Spark partition
results.map(t => (t._1, Seq(t._2)))
}
.map { case (key, s) => KeyGroupedPartition(key.row, s.map(_._2)) }

Some(partitions.sorted(partitionOrdering))
Some(KeyGroupedPartitionInfo(groupedPartitions, sortedKeyToPartitions.map(_._2)))
}
}
}

override def outputOrdering: Seq[SortOrder] = {
// when multiple partitions are grouped together, ordering inside partitions is not preserved
val partitioningPreservesOrdering = groupedPartitions.forall(_.forall(_._2.length <= 1))
val partitioningPreservesOrdering = groupedPartitions
.forall(_.groupedParts.forall(_.parts.length <= 1))
ordering.filter(_ => partitioningPreservesOrdering).getOrElse(super.outputOrdering)
}

Expand Down Expand Up @@ -217,3 +210,19 @@ trait DataSourceV2ScanExecBase extends LeafExecNode {
}
}
}

/**
* A key-grouped Spark partition, which could consist of multiple input splits
*
* @param value the partition value shared by all the input splits
* @param parts the input splits that are grouped into a single Spark partition
*/
private[v2] case class KeyGroupedPartition(value: InternalRow, parts: Seq[InputPartition])

/**
* Information about key-grouped partitions, which contains a list of grouped partitions as well
* as the original input partitions before the grouping.
*/
private[v2] case class KeyGroupedPartitionInfo(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like it would refer to info about one KeyGroupedPartition. How about KeyGroupedPartitionInfos ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it but is Infos a proper plural noun?

groupedParts: Seq[KeyGroupedPartition],
originalParts: Seq[HasPartitionKey])
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,12 @@ case class EnsureRequirements(
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, leftPartitioning, None))
case (Some(KeyGroupedPartitioning(clustering, _, _)), _) =>
case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) =>
val leafExprs = clustering.flatMap(_.collectLeaves())
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, None, rightPartitioning))
case (_, Some(KeyGroupedPartitioning(clustering, _, _))) =>
case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) =>
val leafExprs = clustering.flatMap(_.collectLeaves())
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys)
.orElse(reorderJoinKeysRecursively(
Expand Down Expand Up @@ -483,7 +483,10 @@ case class EnsureRequirements(
s"'$joinType'. Skipping partially clustered distribution.")
replicateRightSide = false
} else {
val partValues = if (replicateLeftSide) rightPartValues else leftPartValues
// In partially clustered distribution, we should use un-grouped partition values
val spec = if (replicateLeftSide) rightSpec else leftSpec
val partValues = spec.partitioning.originalPartitionValues

val numExpectedPartitions = partValues
.map(InternalRowComparableWrapper(_, partitionExprs))
.groupBy(identity)
Expand Down
Loading