From 1436c5a6ec0d2765e7e76174e2fa77d929c2e6eb Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Tue, 5 Sep 2023 13:22:34 +0800 Subject: [PATCH 1/5] [SPARK-44647][SQL] Support SPJ where join keys are less than cluster keys ### What changes were proposed in this pull request? - Add new conf spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled - Change key compatibility checks in EnsureRequirements. Remove checks where all partition keys must be in join keys to allow isKeyCompatible = true in this case (if this flag is enabled) - "Project" partitions by join keys in KeyGroupedPartitioning/KeyGroupedShuffleSpec - Add join key grouping to the partition grouping in BatchScanExec ### Why are the changes needed? - Support Storage Partition Join in cases where the join condition does not contain all the partition keys, but just some of them ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? -Added tests in KeyGroupedPartitioningSuite -Because of https://github.com/apache/spark/pull/37886 we have to select all join keys to trigger SPJ in this case, otherwise DSV2 scan does not report KeyGroupedPartitioning and SPJ does not get triggered. Need to see how to relax this in separate PR. --- .../plans/physical/partitioning.scala | 52 +++- .../apache/spark/sql/internal/SQLConf.scala | 15 + .../datasources/v2/BatchScanExec.scala | 56 ++-- .../exchange/EnsureRequirements.scala | 15 +- .../KeyGroupedPartitioningSuite.scala | 269 +++++++++++++++++- 5 files changed, 375 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 0be4a61f27587..545c0ec118faa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -355,7 +355,14 @@ case class KeyGroupedPartitioning( } else { // We'll need to find leaf attributes from the partition expressions first. val attributes = expressions.flatMap(_.collectLeaves()) - attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) + + if (SQLConf.get.getConf( + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS)) { + requiredClustering.forall(x => attributes.exists(_.semanticEquals(x))) && + expressions.forall(_.collectLeaves().size == 1) + } else { + attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) + } } case _ => @@ -364,8 +371,21 @@ case class KeyGroupedPartitioning( } } - override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = - KeyGroupedShuffleSpec(this, distribution) + override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = { + var result = KeyGroupedShuffleSpec(this, distribution) + if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + // If allowing join keys to be subset of clustering keys, we should create a new + // `KeyGroupedPartitioning` here that is grouped on the join keys instead, and use that as + // the returned shuffle spec. + val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2) + val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions, + partitionValues, originalPartitionValues) + result = result.copy(partitioning = projectedPartitioning, joinKeyPositions = + Some(joinKeyPositions)) + } + + result + } lazy val uniquePartitionValues: Seq[InternalRow] = { partitionValues @@ -378,8 +398,25 @@ case class KeyGroupedPartitioning( object KeyGroupedPartitioning { def apply( expressions: Seq[Expression], - partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { - KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues, partitionValues) + projectionPositions: Seq[Int], + partitionValues: Seq[InternalRow], + originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { + val projectedExpressions = projectionPositions.map(expressions(_)) + val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _)) + val projectedOriginalPartitionValues = + originalPartitionValues.map(project(expressions, projectionPositions, _)) + + KeyGroupedPartitioning(projectedExpressions, projectedPartitionValues.length, + projectedPartitionValues, projectedOriginalPartitionValues) + } + + def project( + expressions: Seq[Expression], + positions: Seq[Int], + input: InternalRow): InternalRow = { + val projectedValues: Array[Any] = positions.map(i => input.get(i, expressions(i).dataType)) + .toArray + new GenericInternalRow(projectedValues) } def supportsExpressions(expressions: Seq[Expression]): Boolean = { @@ -674,7 +711,8 @@ case class HashShuffleSpec( case class KeyGroupedShuffleSpec( partitioning: KeyGroupedPartitioning, - distribution: ClusteredDistribution) extends ShuffleSpec { + distribution: ClusteredDistribution, + joinKeyPositions: Option[Seq[Int]] = None) extends ShuffleSpec { /** * A sequence where each element is a set of positions of the partition expression to the cluster @@ -709,7 +747,7 @@ case class KeyGroupedShuffleSpec( // 3.3 each pair of partition expressions at the same index must share compatible // transform functions. // 4. the partition values from both sides are following the same order. - case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution) => + case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) => distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8c8b33921e321..6410a394bf9eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1530,6 +1530,18 @@ object SQLConf { .booleanConf .createWithDefault(false) + val V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS = + buildConf("spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled") + .doc("Whether to allow storage-partition join in the case where join keys are" + + "a subset of the partition keys of the source tables. At planning time, " + + "Spark will group the partitions by only those keys that are in the join keys." + + "This is currently enabled only if spark.sql.sources.v2.bucketing.pushPartValues.enabled " + + "is also enabled." + ) + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") .doc("The maximum number of buckets allowed.") .version("2.4.0") @@ -4936,6 +4948,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def v2BucketingShuffleEnabled: Boolean = getConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED) + def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean = + getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS) + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 932ac0f5a1b15..094a7b20808ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -120,7 +120,12 @@ case class BatchScanExec( val newPartValues = spjParams.commonPartitionValues.get.flatMap { case (partValue, numSplits) => Seq.fill(numSplits)(partValue) } - k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues) + val expressions = spjParams.joinKeyPositions match { + case Some(projectionPositions) => projectionPositions.map(i => k.expressions(i)) + case _ => k.expressions + } + k.copy(expressions = expressions, numPartitions = newPartValues.length, + partitionValues = newPartValues) case p => p } } @@ -132,14 +137,29 @@ case class BatchScanExec( // return an empty RDD with 1 partition if dynamic filtering removed the only split sparkContext.parallelize(Array.empty[InternalRow], 1) } else { - var finalPartitions = filteredPartitions - - outputPartitioning match { + val finalPartitions = outputPartitioning match { case p: KeyGroupedPartitioning => - val groupedPartitions = filteredPartitions.map(splits => { - assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey]) - (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits) - }) + assert(spjParams.keyGroupedPartitioning.isDefined) + val expressions = spjParams.keyGroupedPartitioning.get + + // Re-group the input partitions if we are projecting on a subset of join keys + val (groupedPartitions, partExpressions) = spjParams.joinKeyPositions match { + case Some(projectPositions) => + val projectedExpressions = projectPositions.map(i => expressions(i)) + val parts = filteredPartitions.flatten.groupBy(part => { + val row = part.asInstanceOf[HasPartitionKey].partitionKey() + val projectedRow = KeyGroupedPartitioning.project( + expressions, projectPositions, row) + InternalRowComparableWrapper(projectedRow, projectedExpressions) + }).map { case (wrapper, splits) => (wrapper.row, splits) }.toSeq + (parts, projectedExpressions) + case _ => + val groupedParts = filteredPartitions.map(splits => { + assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey]) + (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits) + }) + (groupedParts, expressions) + } // When partially clustered, the input partitions are not grouped by partition // values. Here we'll need to check `commonPartitionValues` and decide how to group @@ -149,12 +169,12 @@ case class BatchScanExec( // should contain. val commonPartValuesMap = spjParams.commonPartitionValues .get - .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2)) + .map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2)) .toMap val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) => // `commonPartValuesMap` should contain the part value since it's the super set. val numSplits = commonPartValuesMap - .get(InternalRowComparableWrapper(partValue, p.expressions)) + .get(InternalRowComparableWrapper(partValue, partExpressions)) assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + "common partition values from Spark plan") @@ -169,37 +189,37 @@ case class BatchScanExec( // sides of a join will have the same number of partitions & splits. splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) } - (InternalRowComparableWrapper(partValue, p.expressions), newSplits) + (InternalRowComparableWrapper(partValue, partExpressions), newSplits) } // Now fill missing partition keys with empty partitions val partitionMapping = nestGroupedPartitions.toMap - finalPartitions = spjParams.commonPartitionValues.get.flatMap { + spjParams.commonPartitionValues.get.flatMap { case (partValue, numSplits) => // Use empty partition for those partition values that are not present. partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), + InternalRowComparableWrapper(partValue, partExpressions), Seq.fill(numSplits)(Seq.empty)) } } else { // either `commonPartitionValues` is not defined, or it is defined but // `applyPartialClustering` is false. val partitionMapping = groupedPartitions.map { case (partValue, splits) => - InternalRowComparableWrapper(partValue, p.expressions) -> splits + InternalRowComparableWrapper(partValue, partExpressions) -> splits }.toMap // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there // could exist duplicated partition values, as partition grouping is not done // at the beginning and postponed to this method. It is important to use unique // partition values here so that grouped partitions won't get duplicated. - finalPartitions = p.uniquePartitionValues.map { partValue => + p.uniquePartitionValues.map { partValue => // Use empty partition for those partition values that are not present partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) + InternalRowComparableWrapper(partValue, partExpressions), Seq.empty) } } - case _ => + case _ => filteredPartitions } new DataSourceRDD( @@ -234,6 +254,7 @@ case class BatchScanExec( case class StoragePartitionJoinParams( keyGroupedPartitioning: Option[Seq[Expression]] = None, + joinKeyPositions: Option[Seq[Int]] = None, commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, applyPartialClustering: Boolean = false, replicatePartitions: Boolean = false) { @@ -247,6 +268,7 @@ case class StoragePartitionJoinParams( } override def hashCode(): Int = Objects.hashCode( + joinKeyPositions: Option[Seq[Int]], commonPartitionValues: Option[Seq[(InternalRow, Int)]], applyPartialClustering: java.lang.Boolean, replicatePartitions: java.lang.Boolean) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index f8e6fd1d0167f..8552c950f6776 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -380,7 +380,8 @@ case class EnsureRequirements( val rightSpec = specs(1) var isCompatible = false - if (!conf.v2BucketingPushPartValuesEnabled) { + if (!conf.v2BucketingPushPartValuesEnabled && + !conf.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { isCompatible = leftSpec.isCompatibleWith(rightSpec) } else { logInfo("Pushing common partition values for storage-partitioned join") @@ -505,10 +506,10 @@ case class EnsureRequirements( } // Now we need to push-down the common partition key to the scan in each child - newLeft = populatePartitionValues( - left, mergedPartValues, applyPartialClustering, replicateLeftSide) - newRight = populatePartitionValues( - right, mergedPartValues, applyPartialClustering, replicateRightSide) + newLeft = populatePartitionValues(left, mergedPartValues, leftSpec.joinKeyPositions, + applyPartialClustering, replicateLeftSide) + newRight = populatePartitionValues(right, mergedPartValues, rightSpec.joinKeyPositions, + applyPartialClustering, replicateRightSide) } } @@ -530,19 +531,21 @@ case class EnsureRequirements( private def populatePartitionValues( plan: SparkPlan, values: Seq[(InternalRow, Int)], + joinKeyPositions: Option[Seq[Int]], applyPartialClustering: Boolean, replicatePartitions: Boolean): SparkPlan = plan match { case scan: BatchScanExec => scan.copy( spjParams = scan.spjParams.copy( commonPartitionValues = Some(values), + joinKeyPositions = joinKeyPositions, applyPartialClustering = applyPartialClustering, replicatePartitions = replicatePartitions ) ) case node => node.mapChildren(child => populatePartitionValues( - child, values, applyPartialClustering, replicatePartitions)) + child, values, joinKeyPositions, applyPartialClustering, replicatePartitions)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index b22aba61aabd8..97341391c701b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -98,14 +98,17 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val catalystDistribution = physical.ClusteredDistribution( Seq(TransformExpression(YearsFunction, Seq(attr("ts"))))) val partitionValues = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) + val projectedPositions = catalystDistribution.clustering.indices checkQueryPlan(df, catalystDistribution, - physical.KeyGroupedPartitioning(catalystDistribution.clustering, partitionValues)) + physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, + partitionValues, partitionValues)) // multiple group keys should work too as long as partition keys are subset of them df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts") checkQueryPlan(df, catalystDistribution, - physical.KeyGroupedPartitioning(catalystDistribution.clustering, partitionValues)) + physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, + partitionValues, partitionValues)) } test("non-clustered distribution: no partition") { @@ -1276,4 +1279,266 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } } + + test("SPARK-44647: test join key is subset of cluster key " + + "with push values and partially-clustered") { + val table1 = "tab1e1" + val table2 = "table2" + val partition = Array(identity("id"), identity("data")) + createTable(table1, schema, partition) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(1, 'aa', cast('2020-01-01' as timestamp)), " + + "(2, 'bb', cast('2020-01-01' as timestamp)), " + + "(2, 'cc', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'ee', cast('2020-01-01' as timestamp)), " + + "(3, 'ee', cast('2020-01-01' as timestamp))") + + createTable(table2, schema, partition) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(4, 'zz', cast('2020-01-01' as timestamp)), " + + "(4, 'zz', cast('2020-01-01' as timestamp)), " + + "(3, 'yy', cast('2020-01-01' as timestamp)), " + + "(3, 'yy', cast('2020-01-01' as timestamp)), " + + "(3, 'xx', cast('2020-01-01' as timestamp)), " + + "(3, 'xx', cast('2020-01-01' as timestamp)), " + + "(2, 'ww', cast('2020-01-01' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString) { + + val df = sql("SELECT t1.id AS id, t1.data AS t1data, t2.data AS t2data " + + s"FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 " + + "ON t1.id = t2.id ORDER BY t1.id, t1data, t2data") + + // Currently SPJ for case where join key not same as partition key + // only supported when push-part-values enabled + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (allowJoinKeysSubsetOfPartitionKeys) { + assert(shuffles.isEmpty, "SPJ should be triggered") + } else { + assert(shuffles.nonEmpty, "SPJ should not be triggered") + } + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + + (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { + // SPJ and partially-clustered + case (true, true) => assert(scans == Seq(8, 8)) + // SPJ and not partially-clustered + case (true, false) => assert(scans == Seq(4, 4)) + // No SPJ + case _ => assert(scans == Seq(5, 4)) + } + + checkAnswer(df, Seq( + Row(2, "bb", "ww"), + Row(2, "cc", "ww"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy") + )) + } + } + } + } + } + + test("SPARK-44647: test join key is the second cluster key") { + val table1 = "tab1e1" + val table2 = "table2" + val partition = Array(identity("id"), identity("data")) + createTable(table1, schema, partition) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(1, 'aa', cast('2020-01-01' as timestamp)), " + + "(2, 'bb', cast('2020-01-02' as timestamp)), " + + "(3, 'cc', cast('2020-01-03' as timestamp))") + + createTable(table2, schema, partition) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(4, 'aa', cast('2020-01-01' as timestamp)), " + + "(5, 'bb', cast('2020-01-02' as timestamp)), " + + "(6, 'cc', cast('2020-01-03' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> + pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString) { + + val df = sql("SELECT t1.id AS t1id, t2.id as t2id, t1.data AS data " + + s"FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 " + + "ON t1.data = t2.data ORDER BY t1id, t1id, data") + + checkAnswer(df, Seq(Row(1, 4, "aa"), Row(2, 5, "bb"), Row(3, 6, "cc"))) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (allowJoinKeysSubsetOfPartitionKeys) { + assert(shuffles.isEmpty, "SPJ should be triggered") + } else { + assert(shuffles.nonEmpty, "SPJ should not be triggered") + } + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + (pushDownValues, allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { + // SPJ and partially-clustered + case (true, true, true) => assert(scans == Seq(3, 3)) + // non-SPJ or SPJ/partially-clustered + case _ => assert(scans == Seq(3, 3)) + } + } + } + } + } + } + + test("SPARK-44647: test join key is the second partition key and a transform") { + val items_partitions = Array(bucket(8, "id"), days("arrive_time")) + createTable(items, items_schema, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " + + s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(bucket(8, "item_id"), days("time")) + createTable(purchases, purchases_schema, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(1, 44.0, cast('2020-01-15' as timestamp)), " + + s"(1, 45.0, cast('2020-01-15' as timestamp)), " + + s"(2, 11.0, cast('2020-01-01' as timestamp)), " + + s"(3, 19.5, cast('2020-02-01' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString) { + val df = sql("SELECT id, name, i.price as purchase_price, " + + "p.item_id, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.arrive_time = p.time " + + "ORDER BY id, purchase_price, p.item_id, sale_price") + + // Currently SPJ for case where join key not same as partition key + // only supported when push-part-values enabled + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (allowJoinKeysSubsetOfPartitionKeys) { + assert(shuffles.isEmpty, "SPJ should be triggered") + } else { + assert(shuffles.nonEmpty, "SPJ should not be triggered") + } + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { + // SPJ and partially-clustered + case (true, true) => assert(scans == Seq(5, 5)) + // SPJ and not partially-clustered + case (true, false) => assert(scans == Seq(3, 3)) + // No SPJ + case _ => assert(scans == Seq(4, 4)) + } + + checkAnswer(df, + Seq( + Row(1, "aa", 40.0, 1, 42.0), + Row(1, "aa", 40.0, 2, 11.0), + Row(1, "aa", 41.0, 1, 44.0), + Row(1, "aa", 41.0, 1, 45.0), + Row(2, "bb", 10.0, 1, 42.0), + Row(2, "bb", 10.0, 2, 11.0), + Row(2, "bb", 10.5, 1, 42.0), + Row(2, "bb", 10.5, 2, 11.0), + Row(3, "cc", 15.5, 3, 19.5) + ) + ) + } + } + } + } + } + + test("SPARK-44647: shuffle one side and join keys are less than partition keys") { + val items_partitions = Array(identity("id"), identity("name")) + createTable(items, items_schema, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 30.0, cast('2020-01-02' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchases_schema, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 89.0, cast('2020-01-03' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp)), " + + "(5, 26.0, cast('2023-01-01' as timestamp)), " + + "(6, 50.0, cast('2023-02-01' as timestamp))") + + Seq(true, false).foreach { pushdownValues => + Seq(true, false).foreach { partiallyClustered => + withSQLConf( + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key + -> partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { + val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price") + + // Currently SPJ for case where join key not same as partition key + // only supported when push-part-values enabled + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 1, "SPJ should be triggered") + checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0), + Row(1, "aa", 30.0, 89.0), + Row(1, "aa", 40.0, 42.0), + Row(1, "aa", 40.0, 89.0), + Row(3, "bb", 10.0, 19.5))) + } + } + } + } } From 0df6e97d7c33560eb3943b7dbf478c21882bff51 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Fri, 8 Sep 2023 13:38:04 +0800 Subject: [PATCH 2/5] Review comments --- .../plans/physical/partitioning.scala | 22 ++++++++++++------- .../apache/spark/sql/internal/SQLConf.scala | 4 ++-- .../KeyGroupedPartitioningSuite.scala | 4 ---- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 545c0ec118faa..78cf1f59f59f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -356,8 +356,8 @@ case class KeyGroupedPartitioning( // We'll need to find leaf attributes from the partition expressions first. val attributes = expressions.flatMap(_.collectLeaves()) - if (SQLConf.get.getConf( - SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS)) { + if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + // check that all join keys (required clustering keys) contained in partitioning requiredClustering.forall(x => attributes.exists(_.semanticEquals(x))) && expressions.forall(_.collectLeaves().size == 1) } else { @@ -372,19 +372,18 @@ case class KeyGroupedPartitioning( } override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = { - var result = KeyGroupedShuffleSpec(this, distribution) + val result = KeyGroupedShuffleSpec(this, distribution) if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { // If allowing join keys to be subset of clustering keys, we should create a new // `KeyGroupedPartitioning` here that is grouped on the join keys instead, and use that as // the returned shuffle spec. val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2) val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions, - partitionValues, originalPartitionValues) - result = result.copy(partitioning = projectedPartitioning, joinKeyPositions = - Some(joinKeyPositions)) + partitionValues, originalPartitionValues) + result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions)) + } else { + result } - - result } lazy val uniquePartitionValues: Seq[InternalRow] = { @@ -709,6 +708,13 @@ case class HashShuffleSpec( override def numPartitions: Int = partitioning.numPartitions } +/** + * [[ShuffleSpec]] created by [[KeyGroupedPartitioning]]. + * @param partitioning key grouped partitioning + * @param distribution distribution + * @param joinKeyPosition position of join keys among cluster keys. + * This is set if joining on a subset of cluster keys is allowed. + */ case class KeyGroupedShuffleSpec( partitioning: KeyGroupedPartitioning, distribution: ClusteredDistribution, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6410a394bf9eb..02734df717398 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1535,8 +1535,8 @@ object SQLConf { .doc("Whether to allow storage-partition join in the case where join keys are" + "a subset of the partition keys of the source tables. At planning time, " + "Spark will group the partitions by only those keys that are in the join keys." + - "This is currently enabled only if spark.sql.sources.v2.bucketing.pushPartValues.enabled " + - "is also enabled." + "This is currently enabled only if spark.sql.requireAllClusterKeysForDistribution " + + "is false." ) .version("4.0.0") .booleanConf diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 97341391c701b..ffd1c8e31e919 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1321,8 +1321,6 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { s"FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 " + "ON t1.id = t2.id ORDER BY t1.id, t1data, t2data") - // Currently SPJ for case where join key not same as partition key - // only supported when push-part-values enabled val shuffles = collectShuffles(df.queryExecution.executedPlan) if (allowJoinKeysSubsetOfPartitionKeys) { assert(shuffles.isEmpty, "SPJ should be triggered") @@ -1528,8 +1526,6 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price") - // Currently SPJ for case where join key not same as partition key - // only supported when push-part-values enabled val shuffles = collectShuffles(df.queryExecution.executedPlan) assert(shuffles.size == 1, "SPJ should be triggered") checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0), From a62e32b5eb673beaeaaaac5a6abeccf047184752 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Fri, 8 Sep 2023 13:38:04 +0800 Subject: [PATCH 3/5] Review comments --- .../spark/sql/catalyst/plans/physical/partitioning.scala | 1 + .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 78cf1f59f59f9..a61bd3b7324be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -710,6 +710,7 @@ case class HashShuffleSpec( /** * [[ShuffleSpec]] created by [[KeyGroupedPartitioning]]. + * * @param partitioning key grouped partitioning * @param distribution distribution * @param joinKeyPosition position of join keys among cluster keys. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 02734df717398..13007b7394918 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1533,10 +1533,9 @@ object SQLConf { val V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS = buildConf("spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled") .doc("Whether to allow storage-partition join in the case where join keys are" + - "a subset of the partition keys of the source tables. At planning time, " + + "a subset of the partition keys of the source tables. At planning time, " + "Spark will group the partitions by only those keys that are in the join keys." + - "This is currently enabled only if spark.sql.requireAllClusterKeysForDistribution " + - "is false." + s"This is currently enabled only if $REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION is false." + ) .version("4.0.0") .booleanConf From 6a7ca35f145caa101e2287e5f66b1fd6caf73892 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Sat, 9 Sep 2023 17:32:31 +0800 Subject: [PATCH 4/5] Fix typo --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 13007b7394918..b722972a8b98a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1535,7 +1535,7 @@ object SQLConf { .doc("Whether to allow storage-partition join in the case where join keys are" + "a subset of the partition keys of the source tables. At planning time, " + "Spark will group the partitions by only those keys that are in the join keys." + - s"This is currently enabled only if $REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION is false." + + s"This is currently enabled only if $REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION is false." ) .version("4.0.0") .booleanConf From e83265293869c232a79b302bea002f82f6623920 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Sun, 10 Sep 2023 19:24:27 +0800 Subject: [PATCH 5/5] Fix sqlconf test --- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b722972a8b98a..49a4b0bf98bb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1535,7 +1535,8 @@ object SQLConf { .doc("Whether to allow storage-partition join in the case where join keys are" + "a subset of the partition keys of the source tables. At planning time, " + "Spark will group the partitions by only those keys that are in the join keys." + - s"This is currently enabled only if $REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION is false." + s"This is currently enabled only if ${REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key} " + + "is false." ) .version("4.0.0") .booleanConf