diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 5dffba2ee8e08..ae39e2e183e4a 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -137,6 +137,22 @@ private[spark] class PartitionIdPassthrough(override val numPartitions: Int) ext override def getPartition(key: Any): Int = key.asInstanceOf[Int] } +/** + * A [[org.apache.spark.Partitioner]] that partitions all records using partition value map. + * The `valueMap` is a map that contains tuples of (partition value, partition id). It is generated + * by [[org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning]], used to partition + * the other side of a join to make sure records with same partition value are in the same + * partition. + */ +private[spark] class KeyGroupedPartitioner( + valueMap: mutable.Map[Seq[Any], Int], + override val numPartitions: Int) extends Partitioner { + override def getPartition(key: Any): Int = { + val keys = key.asInstanceOf[Seq[Any]] + valueMap.getOrElseUpdate(keys, Utils.nonNegativeMod(keys.hashCode, numPartitions)) + } +} + /** * A [[org.apache.spark.Partitioner]] that partitions all records into a single partition. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 211b5a05eb70c..79341da0db788 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -335,28 +335,39 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa /** * Represents a partitioning where rows are split across partitions based on transforms defined - * by `expressions`. `partitionValuesOpt`, if defined, should contain value of partition key(s) in + * by `expressions`. `partitionValues`, if defined, should contain value of partition key(s) in * ascending order, after evaluated by the transforms in `expressions`, for each input partition. - * In addition, its length must be the same as the number of input partitions (and thus is a 1-1 - * mapping). The `partitionValues` may contain duplicated partition values. + * In addition, its length must be the same as the number of Spark partitions (and thus is a 1-1 + * mapping), and each row in `partitionValues` must be unique. * - * For example, if `expressions` is `[years(ts_col)]`, then a valid value of `partitionValuesOpt` is - * `[0, 1, 2]`, which represents 3 input partitions with distinct partition values. All rows - * in each partition have the same value for column `ts_col` (which is of timestamp type), after - * being applied by the `years` transform. + * The `originalPartitionValues`, on the other hand, are partition values from the original input + * splits returned by data sources. It may contain duplicated values. * - * On the other hand, `[0, 0, 1]` is not a valid value for `partitionValuesOpt` since `0` is - * duplicated twice. + * For example, if a data source reports partition transform expressions `[years(ts_col)]` with 4 + * input splits whose corresponding partition values are `[0, 1, 2, 2]`, then the `expressions` + * in this case is `[years(ts_col)]`, while `partitionValues` is `[0, 1, 2]`, which + * represents 3 input partitions with distinct partition values. All rows in each partition have + * the same value for column `ts_col` (which is of timestamp type), after being applied by the + * `years` transform. This is generated after combining the two splits with partition value `2` + * into a single Spark partition. + * + * On the other hand, in this example `[0, 1, 2, 2]` is the value of `originalPartitionValues` + * which is calculated from the original input splits. * * @param expressions partition expressions for the partitioning. * @param numPartitions the number of partitions - * @param partitionValues the values for the cluster keys of the distribution, must be - * in ascending order. + * @param partitionValues the values for the final cluster keys (that is, after applying grouping + * on the input splits according to `expressions`) of the distribution, + * must be in ascending order, and must NOT contain duplicated values. + * @param originalPartitionValues the original input partition values before any grouping has been + * applied, must be in ascending order, and may contain duplicated + * values */ case class KeyGroupedPartitioning( expressions: Seq[Expression], numPartitions: Int, - partitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning { + partitionValues: Seq[InternalRow] = Seq.empty, + originalPartitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning { override def satisfies0(required: Distribution): Boolean = { super.satisfies0(required) || { @@ -369,7 +380,15 @@ case class KeyGroupedPartitioning( } else { // We'll need to find leaf attributes from the partition expressions first. val attributes = expressions.flatMap(_.collectLeaves()) - attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) + + if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + // check that join keys (required clustering keys) + // overlap with partition keys (KeyGroupedPartitioning attributes) + requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && + expressions.forall(_.collectLeaves().size == 1) + } else { + attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) + } } case _ => @@ -378,8 +397,20 @@ case class KeyGroupedPartitioning( } } - override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = - KeyGroupedShuffleSpec(this, distribution) + override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = { + val result = KeyGroupedShuffleSpec(this, distribution) + if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + // If allowing join keys to be subset of clustering keys, we should create a new + // `KeyGroupedPartitioning` here that is grouped on the join keys instead, and use that as + // the returned shuffle spec. + val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2) + val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions, + partitionValues, originalPartitionValues) + result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions)) + } else { + result + } + } lazy val uniquePartitionValues: Seq[InternalRow] = { partitionValues @@ -392,8 +423,25 @@ case class KeyGroupedPartitioning( object KeyGroupedPartitioning { def apply( expressions: Seq[Expression], - partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { - KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues) + projectionPositions: Seq[Int], + partitionValues: Seq[InternalRow], + originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { + val projectedExpressions = projectionPositions.map(expressions(_)) + val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _)) + val projectedOriginalPartitionValues = + originalPartitionValues.map(project(expressions, projectionPositions, _)) + + KeyGroupedPartitioning(projectedExpressions, projectedPartitionValues.length, + projectedPartitionValues, projectedOriginalPartitionValues) + } + + def project( + expressions: Seq[Expression], + positions: Seq[Int], + input: InternalRow): InternalRow = { + val projectedValues: Array[Any] = positions.map(i => input.get(i, expressions(i).dataType)) + .toArray + new GenericInternalRow(projectedValues) } def supportsExpressions(expressions: Seq[Expression]): Boolean = { @@ -686,6 +734,14 @@ case class HashShuffleSpec( override def numPartitions: Int = partitioning.numPartitions } +/** + * [[ShuffleSpec]] created by [[KeyGroupedPartitioning]]. + * + * @param partitioning key grouped partitioning + * @param distribution distribution + * @param joinKeyPosition position of join keys among cluster keys. + * This is set if joining on a subset of cluster keys is allowed. + */ case class CoalescedHashShuffleSpec( from: ShuffleSpec, partitions: Seq[CoalescedBoundary]) extends ShuffleSpec { @@ -708,7 +764,8 @@ case class CoalescedHashShuffleSpec( case class KeyGroupedShuffleSpec( partitioning: KeyGroupedPartitioning, - distribution: ClusteredDistribution) extends ShuffleSpec { + distribution: ClusteredDistribution, + joinKeyPositions: Option[Seq[Int]] = None) extends ShuffleSpec { /** * A sequence where each element is a set of positions of the partition expression to the cluster @@ -743,7 +800,7 @@ case class KeyGroupedShuffleSpec( // 3.3 each pair of partition expressions at the same index must share compatible // transform functions. // 4. the partition values from both sides are following the same order. - case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution) => + case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) => distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { @@ -780,7 +837,13 @@ case class KeyGroupedShuffleSpec( case _ => false } - override def canCreatePartitioning: Boolean = false + override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && + // Only support partition expressions are AttributeReference for now + partitioning.expressions.forall(_.isInstanceOf[AttributeReference]) + + override def createPartitioning(clustering: Seq[Expression]): Partitioning = { + KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues) + } } case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala index b0e530907310a..9a0bdc6bcfd11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala @@ -79,6 +79,11 @@ object InternalRowComparableWrapper { rightPartitioning.partitionValues .map(new InternalRowComparableWrapper(_, partitionDataTypes)) .foreach(partition => partitionsSet.add(partition)) - partitionsSet.map(_.row).toSeq + // SPARK-41471: We keep to order of partitions to make sure the order of + // partitions is deterministic in different case. + val partitionOrdering: Ordering[InternalRow] = { + RowOrdering.createNaturalAscendingOrdering(partitionDataTypes) + } + partitionsSet.map(_.row).toSeq.sorted(partitionOrdering) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8186d5fa00c3a..aca25d22fae99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1514,6 +1514,28 @@ object SQLConf { .booleanConf .createWithDefault(false) + val V2_BUCKETING_SHUFFLE_ENABLED = + buildConf("spark.sql.sources.v2.bucketing.shuffle.enabled") + .doc("During a storage-partitioned join, whether to allow to shuffle only one side." + + "When only one side is KeyGroupedPartitioning, if the conditions are met, spark will " + + "only shuffle the other side. This optimization will reduce the amount of data that " + + s"needs to be shuffle. This config requires ${V2_BUCKETING_ENABLED.key} to be enabled") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + + val V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS = + buildConf("spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled") + .doc("Whether to allow storage-partition join in the case where join keys are" + + "a subset of the partition keys of the source tables. At planning time, " + + "Spark will group the partitions by only those keys that are in the join keys." + + s"This is currently enabled only if ${REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key} " + + "is false." + ) + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") .doc("The maximum number of buckets allowed.") .version("2.4.0") @@ -4942,6 +4964,12 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def v2BucketingPartiallyClusteredDistributionEnabled: Boolean = getConf(SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED) + def v2BucketingShuffleEnabled: Boolean = + getConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED) + + def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean = + getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS) + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 2a3a5cdeb82b8..094a7b20808ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Par import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read._ -import org.apache.spark.sql.internal.SQLConf /** * Physical plan node for scanning a batch of data from a data source v2. @@ -101,7 +100,7 @@ case class BatchScanExec( "partition values that are not present in the original partitioning.") } - groupPartitions(newPartitions).getOrElse(Seq.empty).map(_._2) + groupPartitions(newPartitions).get.groupedParts.map(_.parts) case _ => // no validation is needed as the data source did not report any specific partitioning @@ -121,7 +120,12 @@ case class BatchScanExec( val newPartValues = spjParams.commonPartitionValues.get.flatMap { case (partValue, numSplits) => Seq.fill(numSplits)(partValue) } - k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues) + val expressions = spjParams.joinKeyPositions match { + case Some(projectionPositions) => projectionPositions.map(i => k.expressions(i)) + case _ => k.expressions + } + k.copy(expressions = expressions, numPartitions = newPartValues.length, + partitionValues = newPartValues) case p => p } } @@ -133,92 +137,89 @@ case class BatchScanExec( // return an empty RDD with 1 partition if dynamic filtering removed the only split sparkContext.parallelize(Array.empty[InternalRow], 1) } else { - var finalPartitions = filteredPartitions - - outputPartitioning match { + val finalPartitions = outputPartitioning match { case p: KeyGroupedPartitioning => - if (conf.v2BucketingPushPartValuesEnabled && - conf.v2BucketingPartiallyClusteredDistributionEnabled) { - assert(filteredPartitions.forall(_.size == 1), - "Expect partitions to be not grouped when " + - s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " + - "is enabled") - - val groupedPartitions = groupPartitions(finalPartitions.map(_.head), - groupSplits = true).getOrElse(Seq.empty) - - // This means the input partitions are not grouped by partition values. We'll need to - // check `groupByPartitionValues` and decide whether to group and replicate splits - // within a partition. - if (spjParams.commonPartitionValues.isDefined && - spjParams.applyPartialClustering) { - // A mapping from the common partition values to how many splits the partition - // should contain. Note this no longer maintain the partition key ordering. - val commonPartValuesMap = spjParams.commonPartitionValues + assert(spjParams.keyGroupedPartitioning.isDefined) + val expressions = spjParams.keyGroupedPartitioning.get + + // Re-group the input partitions if we are projecting on a subset of join keys + val (groupedPartitions, partExpressions) = spjParams.joinKeyPositions match { + case Some(projectPositions) => + val projectedExpressions = projectPositions.map(i => expressions(i)) + val parts = filteredPartitions.flatten.groupBy(part => { + val row = part.asInstanceOf[HasPartitionKey].partitionKey() + val projectedRow = KeyGroupedPartitioning.project( + expressions, projectPositions, row) + InternalRowComparableWrapper(projectedRow, projectedExpressions) + }).map { case (wrapper, splits) => (wrapper.row, splits) }.toSeq + (parts, projectedExpressions) + case _ => + val groupedParts = filteredPartitions.map(splits => { + assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey]) + (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits) + }) + (groupedParts, expressions) + } + + // When partially clustered, the input partitions are not grouped by partition + // values. Here we'll need to check `commonPartitionValues` and decide how to group + // and replicate splits within a partition. + if (spjParams.commonPartitionValues.isDefined && spjParams.applyPartialClustering) { + // A mapping from the common partition values to how many splits the partition + // should contain. + val commonPartValuesMap = spjParams.commonPartitionValues .get - .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2)) + .map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2)) .toMap - val nestGroupedPartitions = groupedPartitions.map { - case (partValue, splits) => - // `commonPartValuesMap` should contain the part value since it's the super set. - val numSplits = commonPartValuesMap - .get(InternalRowComparableWrapper(partValue, p.expressions)) - assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + - "common partition values from Spark plan") - - val newSplits = if (spjParams.replicatePartitions) { - // We need to also replicate partitions according to the other side of join - Seq.fill(numSplits.get)(splits) - } else { - // Not grouping by partition values: this could be the side with partially - // clustered distribution. Because of dynamic filtering, we'll need to check if - // the final number of splits of a partition is smaller than the original - // number, and fill with empty splits if so. This is necessary so that both - // sides of a join will have the same number of partitions & splits. - splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) - } - (InternalRowComparableWrapper(partValue, p.expressions), newSplits) + val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) => + // `commonPartValuesMap` should contain the part value since it's the super set. + val numSplits = commonPartValuesMap + .get(InternalRowComparableWrapper(partValue, partExpressions)) + assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + + "common partition values from Spark plan") + + val newSplits = if (spjParams.replicatePartitions) { + // We need to also replicate partitions according to the other side of join + Seq.fill(numSplits.get)(splits) + } else { + // Not grouping by partition values: this could be the side with partially + // clustered distribution. Because of dynamic filtering, we'll need to check if + // the final number of splits of a partition is smaller than the original + // number, and fill with empty splits if so. This is necessary so that both + // sides of a join will have the same number of partitions & splits. + splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) } + (InternalRowComparableWrapper(partValue, partExpressions), newSplits) + } - // Now fill missing partition keys with empty partitions - val partitionMapping = nestGroupedPartitions.toMap - finalPartitions = spjParams.commonPartitionValues.get.flatMap { - case (partValue, numSplits) => - // Use empty partition for those partition values that are not present. - partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), - Seq.fill(numSplits)(Seq.empty)) - } - } else { - // either `commonPartitionValues` is not defined, or it is defined but - // `applyPartialClustering` is false. - val partitionMapping = groupedPartitions.map { case (row, parts) => - InternalRowComparableWrapper(row, p.expressions) -> parts - }.toMap - - // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there - // could exist duplicated partition values, as partition grouping is not done - // at the beginning and postponed to this method. It is important to use unique - // partition values here so that grouped partitions won't get duplicated. - finalPartitions = p.uniquePartitionValues.map { partValue => - // Use empty partition for those partition values that are not present + // Now fill missing partition keys with empty partitions + val partitionMapping = nestGroupedPartitions.toMap + spjParams.commonPartitionValues.get.flatMap { + case (partValue, numSplits) => + // Use empty partition for those partition values that are not present. partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) - } + InternalRowComparableWrapper(partValue, partExpressions), + Seq.fill(numSplits)(Seq.empty)) } } else { - val partitionMapping = finalPartitions.map { parts => - val row = parts.head.asInstanceOf[HasPartitionKey].partitionKey() - InternalRowComparableWrapper(row, p.expressions) -> parts + // either `commonPartitionValues` is not defined, or it is defined but + // `applyPartialClustering` is false. + val partitionMapping = groupedPartitions.map { case (partValue, splits) => + InternalRowComparableWrapper(partValue, partExpressions) -> splits }.toMap - finalPartitions = p.partitionValues.map { partValue => + + // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there + // could exist duplicated partition values, as partition grouping is not done + // at the beginning and postponed to this method. It is important to use unique + // partition values here so that grouped partitions won't get duplicated. + p.uniquePartitionValues.map { partValue => // Use empty partition for those partition values that are not present partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) + InternalRowComparableWrapper(partValue, partExpressions), Seq.empty) } } - case _ => + case _ => filteredPartitions } new DataSourceRDD( @@ -253,6 +254,7 @@ case class BatchScanExec( case class StoragePartitionJoinParams( keyGroupedPartitioning: Option[Seq[Expression]] = None, + joinKeyPositions: Option[Seq[Int]] = None, commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, applyPartialClustering: Boolean = false, replicatePartitions: Boolean = false) { @@ -266,6 +268,7 @@ case class StoragePartitionJoinParams( } override def hashCode(): Int = Objects.hashCode( + joinKeyPositions: Option[Seq[Int]], commonPartitionValues: Option[Seq[(InternalRow, Int)]], applyPartialClustering: java.lang.Boolean, replicatePartitions: java.lang.Boolean) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index f688d3514d9aa..b2f94cae2dfa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -62,8 +62,9 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { redact(result) } - def partitions: Seq[Seq[InputPartition]] = - groupedPartitions.map(_.map(_._2)).getOrElse(inputPartitions.map(Seq(_))) + def partitions: Seq[Seq[InputPartition]] = { + groupedPartitions.map(_.groupedParts.map(_.parts)).getOrElse(inputPartitions.map(Seq(_))) + } /** * Shorthand for calling redact() without specifying redacting rules @@ -94,8 +95,10 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { keyGroupedPartitioning match { case Some(exprs) if KeyGroupedPartitioning.supportsExpressions(exprs) => groupedPartitions - .map { partitionValues => - KeyGroupedPartitioning(exprs, partitionValues.size, partitionValues.map(_._1)) + .map { keyGroupedPartsInfo => + val keyGroupedParts = keyGroupedPartsInfo.groupedParts + KeyGroupedPartitioning(exprs, keyGroupedParts.size, keyGroupedParts.map(_.value), + keyGroupedPartsInfo.originalParts.map(_.partitionKey())) } .getOrElse(super.outputPartitioning) case _ => @@ -103,7 +106,7 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { } } - @transient lazy val groupedPartitions: Option[Seq[(InternalRow, Seq[InputPartition])]] = { + @transient lazy val groupedPartitions: Option[KeyGroupedPartitionInfo] = { // Early check if we actually need to materialize the input partitions. keyGroupedPartitioning match { case Some(_) => groupPartitions(inputPartitions) @@ -117,24 +120,21 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { * - all input partitions implement [[HasPartitionKey]] * - `keyGroupedPartitioning` is set * - * The result, if defined, is a list of tuples where the first element is a partition value, - * and the second element is a list of input partitions that share the same partition value. + * The result, if defined, is a [[KeyGroupedPartitionInfo]] which contains a list of + * [[KeyGroupedPartition]], as well as a list of partition values from the original input splits, + * sorted according to the partition keys in ascending order. * * A non-empty result means each partition is clustered on a single key and therefore eligible * for further optimizations to eliminate shuffling in some operations such as join and aggregate. */ - def groupPartitions( - inputPartitions: Seq[InputPartition], - groupSplits: Boolean = !conf.v2BucketingPushPartValuesEnabled || - !conf.v2BucketingPartiallyClusteredDistributionEnabled): - Option[Seq[(InternalRow, Seq[InputPartition])]] = { - + def groupPartitions(inputPartitions: Seq[InputPartition]): Option[KeyGroupedPartitionInfo] = { if (!SQLConf.get.v2BucketingEnabled) return None + keyGroupedPartitioning.flatMap { expressions => val results = inputPartitions.takeWhile { case _: HasPartitionKey => true case _ => false - }.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p)) + }.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p.asInstanceOf[HasPartitionKey])) if (results.length != inputPartitions.length || inputPartitions.isEmpty) { // Not all of the `InputPartitions` implements `HasPartitionKey`, therefore skip here. @@ -143,32 +143,24 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { // also sort the input partitions according to their partition key order. This ensures // a canonical order from both sides of a bucketed join, for example. val partitionDataTypes = expressions.map(_.dataType) - val partitionOrdering: Ordering[(InternalRow, Seq[InputPartition])] = { - RowOrdering.createNaturalAscendingOrdering(partitionDataTypes).on(_._1) - } - - val partitions = if (groupSplits) { - // Group the splits by their partition value - results + val rowOrdering = RowOrdering.createNaturalAscendingOrdering(partitionDataTypes) + val sortedKeyToPartitions = results.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) + val sortedGroupedPartitions = sortedKeyToPartitions .map(t => (InternalRowComparableWrapper(t._1, expressions), t._2)) .groupBy(_._1) .toSeq - .map { - case (key, s) => (key.row, s.map(_._2)) - } - } else { - // No splits grouping, each split will become a separate Spark partition - results.map(t => (t._1, Seq(t._2))) - } + .map { case (key, s) => KeyGroupedPartition(key.row, s.map(_._2)) } + .sorted(rowOrdering.on((k: KeyGroupedPartition) => k.value)) - Some(partitions.sorted(partitionOrdering)) + Some(KeyGroupedPartitionInfo(sortedGroupedPartitions, sortedKeyToPartitions.map(_._2))) } } } override def outputOrdering: Seq[SortOrder] = { // when multiple partitions are grouped together, ordering inside partitions is not preserved - val partitioningPreservesOrdering = groupedPartitions.forall(_.forall(_._2.length <= 1)) + val partitioningPreservesOrdering = groupedPartitions + .forall(_.groupedParts.forall(_.parts.length <= 1)) ordering.filter(_ => partitioningPreservesOrdering).getOrElse(super.outputOrdering) } @@ -217,3 +209,19 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { } } } + +/** + * A key-grouped Spark partition, which could consist of multiple input splits + * + * @param value the partition value shared by all the input splits + * @param parts the input splits that are grouped into a single Spark partition + */ +private[v2] case class KeyGroupedPartition(value: InternalRow, parts: Seq[InputPartition]) + +/** + * Information about key-grouped partitions, which contains a list of grouped partitions as well + * as the original input partitions before the grouping. + */ +private[v2] case class KeyGroupedPartitionInfo( + groupedParts: Seq[KeyGroupedPartition], + originalParts: Seq[HasPartitionKey]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index ee0ea11816f9a..81d457e8ff3d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -288,12 +288,12 @@ case class EnsureRequirements( reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) - case (Some(KeyGroupedPartitioning(clustering, _, _)), _) => + case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, None, rightPartitioning)) - case (_, Some(KeyGroupedPartitioning(clustering, _, _))) => + case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys) .orElse(reorderJoinKeysRecursively( @@ -380,7 +380,8 @@ case class EnsureRequirements( val rightSpec = specs(1) var isCompatible = false - if (!conf.v2BucketingPushPartValuesEnabled) { + if (!conf.v2BucketingPushPartValuesEnabled && + !conf.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { isCompatible = leftSpec.isCompatibleWith(rightSpec) } else { logInfo("Pushing common partition values for storage-partitioned join") @@ -483,7 +484,10 @@ case class EnsureRequirements( s"'$joinType'. Skipping partially clustered distribution.") replicateRightSide = false } else { - val partValues = if (replicateLeftSide) rightPartValues else leftPartValues + // In partially clustered distribution, we should use un-grouped partition values + val spec = if (replicateLeftSide) rightSpec else leftSpec + val partValues = spec.partitioning.originalPartitionValues + val numExpectedPartitions = partValues .map(InternalRowComparableWrapper(_, partitionExprs)) .groupBy(identity) @@ -502,10 +506,10 @@ case class EnsureRequirements( } // Now we need to push-down the common partition key to the scan in each child - newLeft = populatePartitionValues( - left, mergedPartValues, applyPartialClustering, replicateLeftSide) - newRight = populatePartitionValues( - right, mergedPartValues, applyPartialClustering, replicateRightSide) + newLeft = populatePartitionValues(left, mergedPartValues, leftSpec.joinKeyPositions, + applyPartialClustering, replicateLeftSide) + newRight = populatePartitionValues(right, mergedPartValues, rightSpec.joinKeyPositions, + applyPartialClustering, replicateRightSide) } } @@ -527,19 +531,21 @@ case class EnsureRequirements( private def populatePartitionValues( plan: SparkPlan, values: Seq[(InternalRow, Int)], + joinKeyPositions: Option[Seq[Int]], applyPartialClustering: Boolean, replicatePartitions: Boolean): SparkPlan = plan match { case scan: BatchScanExec => scan.copy( spjParams = scan.spjParams.copy( commonPartitionValues = Some(values), + joinKeyPositions = joinKeyPositions, applyPartialClustering = applyPartialClustering, replicatePartitions = replicatePartitions ) ) case node => node.mapChildren(child => populatePartitionValues( - child, values, applyPartialClustering, replicatePartitions)) + child, values, joinKeyPositions, applyPartialClustering, replicatePartitions)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 91f2099ce2d53..509f1e6a1e4f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.exchange import java.util.function.Supplier +import scala.collection.mutable import scala.concurrent.Future import org.apache.spark._ @@ -29,6 +30,7 @@ import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProces import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ @@ -299,6 +301,11 @@ object ShuffleExchangeExec { ascending = true, samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new ConstantPartitioner + case k @ KeyGroupedPartitioning(expressions, n, _, _) => + val valueMap = k.uniquePartitionValues.zipWithIndex.map { + case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index) + }.toMap + new KeyGroupedPartitioner(mutable.Map(valueMap.toSeq: _*), n) case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } @@ -325,6 +332,8 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) row => projection(row) case SinglePartition => identity + case KeyGroupedPartitioning(expressions, _, _, _) => + row => bindReferences(expressions, outputAttributes).map(_.eval(row)) case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala index f4317e632761c..1a0efa7c4aafb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala @@ -51,9 +51,9 @@ abstract class DistributionAndOrderingSuiteBase plan: QueryPlan[T]): Partitioning = partitioning match { case HashPartitioning(exprs, numPartitions) => HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions) - case KeyGroupedPartitioning(clustering, numPartitions, partitionValues) => - KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, - partitionValues) + case KeyGroupedPartitioning(clustering, numPartitions, partValues, originalPartValues) => + KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, partValues, + originalPartValues) case PartitioningCollection(partitionings) => PartitioningCollection(partitionings.map(resolvePartitioning(_, plan))) case RangePartitioning(ordering, numPartitions) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 71e030f535e9d..b342a382a749e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -98,14 +98,17 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val catalystDistribution = physical.ClusteredDistribution( Seq(TransformExpression(YearsFunction, Seq(attr("ts"))))) val partitionValues = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) + val projectedPositions = catalystDistribution.clustering.indices checkQueryPlan(df, catalystDistribution, - physical.KeyGroupedPartitioning(catalystDistribution.clustering, partitionValues)) + physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, + partitionValues, partitionValues)) // multiple group keys should work too as long as partition keys are subset of them df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts") checkQueryPlan(df, catalystDistribution, - physical.KeyGroupedPartitioning(catalystDistribution.clustering, partitionValues)) + physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, + partitionValues, partitionValues)) } test("non-clustered distribution: no partition") { @@ -131,7 +134,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { // Has exactly one partition. val partitionValues = Seq(31).map(v => InternalRow.fromSeq(Seq(v))) checkQueryPlan(df, distribution, - physical.KeyGroupedPartitioning(distribution.clustering, 1, partitionValues)) + physical.KeyGroupedPartitioning(distribution.clustering, 1, partitionValues, partitionValues)) } test("non-clustered distribution: no V2 catalog") { @@ -1062,6 +1065,187 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("SPARK-41471: shuffle one side: only one side reports partitioning") { + val items_partitions = Array(identity("id")) + createTable(items, items_schema, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchases_schema, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + Seq(true, false).foreach { shuffle => + withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString) { + val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (shuffle) { + assert(shuffles.size == 1, "only shuffle one side not report partitioning") + } else { + assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" + + " is not enabled") + } + + checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5))) + } + } + } + + test("SPARK-41471: shuffle one side: shuffle side has more partition value") { + val items_partitions = Array(identity("id")) + createTable(items, items_schema, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchases_schema, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp)), " + + "(5, 26.0, cast('2023-01-01' as timestamp)), " + + "(6, 50.0, cast('2023-02-01' as timestamp))") + + Seq(true, false).foreach { shuffle => + withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString) { + Seq("JOIN", "LEFT OUTER JOIN", "RIGHT OUTER JOIN", "FULL OUTER JOIN").foreach { joinType => + val df = sql(s"SELECT id, name, i.price as purchase_price, p.price as sale_price " + + s"FROM testcat.ns.$items i $joinType testcat.ns.$purchases p " + + "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (shuffle) { + assert(shuffles.size == 1, "only shuffle one side not report partitioning") + } else { + assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one " + + "side is not enabled") + } + joinType match { + case "JOIN" => + checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5))) + case "LEFT OUTER JOIN" => + checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5), + Row(4, "cc", 15.5, null))) + case "RIGHT OUTER JOIN" => + checkAnswer(df, Seq(Row(null, null, null, 26.0), Row(null, null, null, 50.0), + Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5))) + case "FULL OUTER JOIN" => + checkAnswer(df, Seq(Row(null, null, null, 26.0), Row(null, null, null, 50.0), + Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5), + Row(4, "cc", 15.5, null))) + } + } + } + } + } + + test("SPARK-41471: shuffle one side: only one side reports partitioning with two identity") { + val items_partitions = Array(identity("id"), identity("arrive_time")) + createTable(items, items_schema, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchases_schema, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + Seq(true, false).foreach { shuffle => + withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString) { + val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.id = p.item_id and i.arrive_time = p.time ORDER BY id, purchase_price, sale_price") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (shuffle) { + assert(shuffles.size == 1, "only shuffle one side not report partitioning") + } else { + assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" + + " is not enabled") + } + + checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0))) + } + } + } + + test("SPARK-41471: shuffle one side: partitioning with transform") { + val items_partitions = Array(years("arrive_time")) + createTable(items, items_schema, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2021-02-01' as timestamp))") + + createTable(purchases, purchases_schema, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2021-02-01' as timestamp))") + + Seq(true, false).foreach { shuffle => + withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString) { + val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.arrive_time = p.time ORDER BY id, purchase_price, sale_price") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (shuffle) { + assert(shuffles.size == 2, "partitioning with transform not work now") + } else { + assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" + + " is not enabled") + } + + checkAnswer(df, Seq( + Row(1, "aa", 40.0, 42.0), + Row(3, "bb", 10.0, 42.0), + Row(4, "cc", 15.5, 19.5))) + } + } + } + + test("SPARK-41471: shuffle one side: work with group partition split") { + val items_partitions = Array(identity("id")) + createTable(items, items_schema, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchases_schema, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp)), " + + "(5, 26.0, cast('2023-01-01' as timestamp)), " + + "(6, 50.0, cast('2023-02-01' as timestamp))") + + Seq(true, false).foreach { shuffle => + withSQLConf( + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString, + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "true") { + val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price") + + checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5))) + } + } + } + test("SPARK-44641: duplicated records when SPJ is not triggered") { val items_partitions = Array(bucket(8, "id")) createTable(items, items_schema, items_partitions) @@ -1118,6 +1302,282 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("SPARK-48065: SPJ: allowJoinKeysSubsetOfPartitionKeys is too strict") { + val table1 = "tab1e1" + val table2 = "table2" + val partition = Array(identity("id")) + createTable(table1, columns, partition) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(1, 'aa', cast('2020-01-01' as timestamp)), " + + "(2, 'bb', cast('2020-01-01' as timestamp)), " + + "(2, 'cc', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'ee', cast('2020-01-01' as timestamp)), " + + "(3, 'ee', cast('2020-01-01' as timestamp))") + + createTable(table2, columns, partition) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(4, 'zz', cast('2020-01-01' as timestamp)), " + + "(4, 'zz', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'xx', cast('2020-01-01' as timestamp)), " + + "(3, 'xx', cast('2020-01-01' as timestamp)), " + + "(2, 'ww', cast('2020-01-01' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.id AS id, t1.data AS t1data, t2.data AS t2data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.id = t2.id AND t1.data = t2.data ORDER BY t1.id, t1data, t2data + |""".stripMargin) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "SPJ should be triggered") + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + if (partiallyClustered) { + assert(scans == Seq(8, 8)) + } else { + assert(scans == Seq(4, 4)) + } + checkAnswer(df, Seq( + Row(3, "dd", "dd"), + Row(3, "dd", "dd"), + Row(3, "dd", "dd"), + Row(3, "dd", "dd") + )) + } + } + } + } + + test("SPARK-44647: test join key is subset of cluster key " + + "with push values and partially-clustered") { + val table1 = "tab1e1" + val table2 = "table2" + val partition = Array(identity("id"), identity("data")) + createTable(table1, schema, partition) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(1, 'aa', cast('2020-01-01' as timestamp)), " + + "(2, 'bb', cast('2020-01-01' as timestamp)), " + + "(2, 'cc', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'ee', cast('2020-01-01' as timestamp)), " + + "(3, 'ee', cast('2020-01-01' as timestamp))") + + createTable(table2, schema, partition) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(4, 'zz', cast('2020-01-01' as timestamp)), " + + "(4, 'zz', cast('2020-01-01' as timestamp)), " + + "(3, 'yy', cast('2020-01-01' as timestamp)), " + + "(3, 'yy', cast('2020-01-01' as timestamp)), " + + "(3, 'xx', cast('2020-01-01' as timestamp)), " + + "(3, 'xx', cast('2020-01-01' as timestamp)), " + + "(2, 'ww', cast('2020-01-01' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString) { + + val df = sql("SELECT t1.id AS id, t1.data AS t1data, t2.data AS t2data " + + s"FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 " + + "ON t1.id = t2.id ORDER BY t1.id, t1data, t2data") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (allowJoinKeysSubsetOfPartitionKeys) { + assert(shuffles.isEmpty, "SPJ should be triggered") + } else { + assert(shuffles.nonEmpty, "SPJ should not be triggered") + } + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + + (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { + // SPJ and partially-clustered + case (true, true) => assert(scans == Seq(8, 8)) + // SPJ and not partially-clustered + case (true, false) => assert(scans == Seq(4, 4)) + // No SPJ + case _ => assert(scans == Seq(5, 4)) + } + + checkAnswer(df, Seq( + Row(2, "bb", "ww"), + Row(2, "cc", "ww"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy") + )) + } + } + } + } + } + + test("SPARK-44647: test join key is the second cluster key") { + val table1 = "tab1e1" + val table2 = "table2" + val partition = Array(identity("id"), identity("data")) + createTable(table1, schema, partition) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(1, 'aa', cast('2020-01-01' as timestamp)), " + + "(2, 'bb', cast('2020-01-02' as timestamp)), " + + "(3, 'cc', cast('2020-01-03' as timestamp))") + + createTable(table2, schema, partition) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(4, 'aa', cast('2020-01-01' as timestamp)), " + + "(5, 'bb', cast('2020-01-02' as timestamp)), " + + "(6, 'cc', cast('2020-01-03' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> + pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString) { + + val df = sql("SELECT t1.id AS t1id, t2.id as t2id, t1.data AS data " + + s"FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 " + + "ON t1.data = t2.data ORDER BY t1id, t1id, data") + + checkAnswer(df, Seq(Row(1, 4, "aa"), Row(2, 5, "bb"), Row(3, 6, "cc"))) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (allowJoinKeysSubsetOfPartitionKeys) { + assert(shuffles.isEmpty, "SPJ should be triggered") + } else { + assert(shuffles.nonEmpty, "SPJ should not be triggered") + } + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + (pushDownValues, allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { + // SPJ and partially-clustered + case (true, true, true) => assert(scans == Seq(3, 3)) + // non-SPJ or SPJ/partially-clustered + case _ => assert(scans == Seq(3, 3)) + } + } + } + } + } + } + + test("SPARK-44647: test join key is the second partition key and a transform") { + val items_partitions = Array(bucket(8, "id"), days("arrive_time")) + createTable(items, items_schema, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " + + s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(bucket(8, "item_id"), days("time")) + createTable(purchases, purchases_schema, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(1, 44.0, cast('2020-01-15' as timestamp)), " + + s"(1, 45.0, cast('2020-01-15' as timestamp)), " + + s"(2, 11.0, cast('2020-01-01' as timestamp)), " + + s"(3, 19.5, cast('2020-02-01' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString) { + val df = sql("SELECT id, name, i.price as purchase_price, " + + "p.item_id, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.arrive_time = p.time " + + "ORDER BY id, purchase_price, p.item_id, sale_price") + + // Currently SPJ for case where join key not same as partition key + // only supported when push-part-values enabled + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (allowJoinKeysSubsetOfPartitionKeys) { + assert(shuffles.isEmpty, "SPJ should be triggered") + } else { + assert(shuffles.nonEmpty, "SPJ should not be triggered") + } + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { + // SPJ and partially-clustered + case (true, true) => assert(scans == Seq(5, 5)) + // SPJ and not partially-clustered + case (true, false) => assert(scans == Seq(3, 3)) + // No SPJ + case _ => assert(scans == Seq(4, 4)) + } + + checkAnswer(df, + Seq( + Row(1, "aa", 40.0, 1, 42.0), + Row(1, "aa", 40.0, 2, 11.0), + Row(1, "aa", 41.0, 1, 44.0), + Row(1, "aa", 41.0, 1, 45.0), + Row(2, "bb", 10.0, 1, 42.0), + Row(2, "bb", 10.0, 2, 11.0), + Row(2, "bb", 10.5, 1, 42.0), + Row(2, "bb", 10.5, 2, 11.0), + Row(3, "cc", 15.5, 3, 19.5) + ) + ) + } + } + } + } + } + test("SPARK-45652: SPJ should handle empty partition after dynamic filtering") { withSQLConf( SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", @@ -1160,3 +1620,46 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } } + + test("SPARK-44647: shuffle one side and join keys are less than partition keys") { + val items_partitions = Array(identity("id"), identity("name")) + createTable(items, items_schema, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 30.0, cast('2020-01-02' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchases_schema, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 89.0, cast('2020-01-03' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp)), " + + "(5, 26.0, cast('2023-01-01' as timestamp)), " + + "(6, 50.0, cast('2023-02-01' as timestamp))") + + Seq(true, false).foreach { pushdownValues => + Seq(true, false).foreach { partiallyClustered => + withSQLConf( + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key + -> partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { + val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 1, "SPJ should be triggered") + checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0), + Row(1, "aa", 30.0, 89.0), + Row(1, "aa", 40.0, 42.0), + Row(1, "aa", 40.0, 89.0), + Row(3, "bb", 10.0, 19.5))) + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 09da1e1e7b013..3b0bb088a1076 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -18,15 +18,17 @@ package org.apache.spark.sql.execution.exchange import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Sum +import org.apache.spark.sql.catalyst.optimizer.BuildRight import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.plans.physical.{SinglePartition, _} import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.execution.{DummySparkPlan, SortExec} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf @@ -1109,6 +1111,32 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } + test("SPARK-41471: shuffle right side when" + + " spark.sql.sources.v2.bucketing.shuffle.enabled is true") { + withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") { + + val a1 = AttributeReference("a1", IntegerType)() + + val partitionValue = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) + val plan1 = DummySparkPlan(outputPartitioning = KeyGroupedPartitioning( + identity(a1) :: Nil, 4, partitionValue)) + val plan2 = DummySparkPlan(outputPartitioning = SinglePartition) + + val smjExec = ShuffledHashJoinExec( + a1 :: Nil, a1 :: Nil, Inner, BuildRight, None, plan1, plan2) + EnsureRequirements.apply(smjExec) match { + case ShuffledHashJoinExec(_, _, _, _, _, + DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), + ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _), + DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) => + assert(left.expressions == a1 :: Nil) + assert(attrs == a1 :: Nil) + assert(partitionValue == pv) + case other => fail(other.toString) + } + } + } + test("SPARK-42168: FlatMapCoGroupInPandas and Window function with differing key order") { val lKey = AttributeReference("key", IntegerType)() val lKey2 = AttributeReference("key2", IntegerType)()