Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,22 @@ private[spark] class PartitionIdPassthrough(override val numPartitions: Int) ext
override def getPartition(key: Any): Int = key.asInstanceOf[Int]
}

/**
* A [[org.apache.spark.Partitioner]] that partitions all records using partition value map.
* The `valueMap` is a map that contains tuples of (partition value, partition id). It is generated
* by [[org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning]], used to partition
* the other side of a join to make sure records with same partition value are in the same
* partition.
*/
private[spark] class KeyGroupedPartitioner(
valueMap: mutable.Map[Seq[Any], Int],
override val numPartitions: Int) extends Partitioner {
override def getPartition(key: Any): Int = {
val keys = key.asInstanceOf[Seq[Any]]
valueMap.getOrElseUpdate(keys, Utils.nonNegativeMod(keys.hashCode, numPartitions))
}
}

/**
* A [[org.apache.spark.Partitioner]] that partitions all records into a single partition.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,28 +335,39 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa

/**
* Represents a partitioning where rows are split across partitions based on transforms defined
* by `expressions`. `partitionValuesOpt`, if defined, should contain value of partition key(s) in
* by `expressions`. `partitionValues`, if defined, should contain value of partition key(s) in
* ascending order, after evaluated by the transforms in `expressions`, for each input partition.
* In addition, its length must be the same as the number of input partitions (and thus is a 1-1
* mapping). The `partitionValues` may contain duplicated partition values.
* In addition, its length must be the same as the number of Spark partitions (and thus is a 1-1
* mapping), and each row in `partitionValues` must be unique.
*
* For example, if `expressions` is `[years(ts_col)]`, then a valid value of `partitionValuesOpt` is
* `[0, 1, 2]`, which represents 3 input partitions with distinct partition values. All rows
* in each partition have the same value for column `ts_col` (which is of timestamp type), after
* being applied by the `years` transform.
* The `originalPartitionValues`, on the other hand, are partition values from the original input
* splits returned by data sources. It may contain duplicated values.
*
* On the other hand, `[0, 0, 1]` is not a valid value for `partitionValuesOpt` since `0` is
* duplicated twice.
* For example, if a data source reports partition transform expressions `[years(ts_col)]` with 4
* input splits whose corresponding partition values are `[0, 1, 2, 2]`, then the `expressions`
* in this case is `[years(ts_col)]`, while `partitionValues` is `[0, 1, 2]`, which
* represents 3 input partitions with distinct partition values. All rows in each partition have
* the same value for column `ts_col` (which is of timestamp type), after being applied by the
* `years` transform. This is generated after combining the two splits with partition value `2`
* into a single Spark partition.
*
* On the other hand, in this example `[0, 1, 2, 2]` is the value of `originalPartitionValues`
* which is calculated from the original input splits.
*
* @param expressions partition expressions for the partitioning.
* @param numPartitions the number of partitions
* @param partitionValues the values for the cluster keys of the distribution, must be
* in ascending order.
* @param partitionValues the values for the final cluster keys (that is, after applying grouping
* on the input splits according to `expressions`) of the distribution,
* must be in ascending order, and must NOT contain duplicated values.
* @param originalPartitionValues the original input partition values before any grouping has been
* applied, must be in ascending order, and may contain duplicated
* values
*/
case class KeyGroupedPartitioning(
expressions: Seq[Expression],
numPartitions: Int,
partitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning {
partitionValues: Seq[InternalRow] = Seq.empty,
originalPartitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning {

override def satisfies0(required: Distribution): Boolean = {
super.satisfies0(required) || {
Expand All @@ -369,7 +380,15 @@ case class KeyGroupedPartitioning(
} else {
// We'll need to find leaf attributes from the partition expressions first.
val attributes = expressions.flatMap(_.collectLeaves())
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))

if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
// check that join keys (required clustering keys)
// overlap with partition keys (KeyGroupedPartitioning attributes)
requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) &&
expressions.forall(_.collectLeaves().size == 1)
} else {
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
}
}

case _ =>
Expand All @@ -378,8 +397,20 @@ case class KeyGroupedPartitioning(
}
}

override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
KeyGroupedShuffleSpec(this, distribution)
override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = {
val result = KeyGroupedShuffleSpec(this, distribution)
if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
// If allowing join keys to be subset of clustering keys, we should create a new
// `KeyGroupedPartitioning` here that is grouped on the join keys instead, and use that as
// the returned shuffle spec.
val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2)
val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions,
partitionValues, originalPartitionValues)
result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions))
} else {
result
}
}

lazy val uniquePartitionValues: Seq[InternalRow] = {
partitionValues
Expand All @@ -392,8 +423,25 @@ case class KeyGroupedPartitioning(
object KeyGroupedPartitioning {
def apply(
expressions: Seq[Expression],
partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues)
projectionPositions: Seq[Int],
partitionValues: Seq[InternalRow],
originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
val projectedExpressions = projectionPositions.map(expressions(_))
val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _))
val projectedOriginalPartitionValues =
originalPartitionValues.map(project(expressions, projectionPositions, _))

KeyGroupedPartitioning(projectedExpressions, projectedPartitionValues.length,
projectedPartitionValues, projectedOriginalPartitionValues)
}

def project(
expressions: Seq[Expression],
positions: Seq[Int],
input: InternalRow): InternalRow = {
val projectedValues: Array[Any] = positions.map(i => input.get(i, expressions(i).dataType))
.toArray
new GenericInternalRow(projectedValues)
}

def supportsExpressions(expressions: Seq[Expression]): Boolean = {
Expand Down Expand Up @@ -686,6 +734,14 @@ case class HashShuffleSpec(
override def numPartitions: Int = partitioning.numPartitions
}

/**
* [[ShuffleSpec]] created by [[KeyGroupedPartitioning]].
*
* @param partitioning key grouped partitioning
* @param distribution distribution
* @param joinKeyPosition position of join keys among cluster keys.
* This is set if joining on a subset of cluster keys is allowed.
*/
case class CoalescedHashShuffleSpec(
from: ShuffleSpec,
partitions: Seq[CoalescedBoundary]) extends ShuffleSpec {
Expand All @@ -708,7 +764,8 @@ case class CoalescedHashShuffleSpec(

case class KeyGroupedShuffleSpec(
partitioning: KeyGroupedPartitioning,
distribution: ClusteredDistribution) extends ShuffleSpec {
distribution: ClusteredDistribution,
joinKeyPositions: Option[Seq[Int]] = None) extends ShuffleSpec {

/**
* A sequence where each element is a set of positions of the partition expression to the cluster
Expand Down Expand Up @@ -743,7 +800,7 @@ case class KeyGroupedShuffleSpec(
// 3.3 each pair of partition expressions at the same index must share compatible
// transform functions.
// 4. the partition values from both sides are following the same order.
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution) =>
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) =>
distribution.clustering.length == otherDistribution.clustering.length &&
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
Expand Down Expand Up @@ -780,7 +837,13 @@ case class KeyGroupedShuffleSpec(
case _ => false
}

override def canCreatePartitioning: Boolean = false
override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
// Only support partition expressions are AttributeReference for now
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])

override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues)
}
}

case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ object InternalRowComparableWrapper {
rightPartitioning.partitionValues
.map(new InternalRowComparableWrapper(_, partitionDataTypes))
.foreach(partition => partitionsSet.add(partition))
partitionsSet.map(_.row).toSeq
// SPARK-41471: We keep to order of partitions to make sure the order of
// partitions is deterministic in different case.
val partitionOrdering: Ordering[InternalRow] = {
RowOrdering.createNaturalAscendingOrdering(partitionDataTypes)
}
partitionsSet.map(_.row).toSeq.sorted(partitionOrdering)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1514,6 +1514,28 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val V2_BUCKETING_SHUFFLE_ENABLED =
buildConf("spark.sql.sources.v2.bucketing.shuffle.enabled")
.doc("During a storage-partitioned join, whether to allow to shuffle only one side." +
"When only one side is KeyGroupedPartitioning, if the conditions are met, spark will " +
"only shuffle the other side. This optimization will reduce the amount of data that " +
s"needs to be shuffle. This config requires ${V2_BUCKETING_ENABLED.key} to be enabled")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS =
buildConf("spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled")
.doc("Whether to allow storage-partition join in the case where join keys are" +
"a subset of the partition keys of the source tables. At planning time, " +
"Spark will group the partitions by only those keys that are in the join keys." +
s"This is currently enabled only if ${REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key} " +
"is false."
)
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
.doc("The maximum number of buckets allowed.")
.version("2.4.0")
Expand Down Expand Up @@ -4942,6 +4964,12 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def v2BucketingPartiallyClusteredDistributionEnabled: Boolean =
getConf(SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED)

def v2BucketingShuffleEnabled: Boolean =
getConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED)

def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean =
getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS)

def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)

Expand Down
Loading