From 13d4714ee71c4153c3b6e4ff96fc7487e9308aee Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Thu, 24 Aug 2023 08:35:16 -0700 Subject: [PATCH 1/5] [SPARK-41471][SQL] Reduce Spark shuffle when only one side of a join is KeyGroupedPartitioning ### What changes were proposed in this pull request? When only one side of a SPJ (Storage-Partitioned Join) is KeyGroupedPartitioning, Spark currently needs to shuffle both sides using HashPartitioning. However, we may just need to shuffle the other side according to the partition transforms defined in KeyGroupedPartitioning. This is especially useful when the other side is relatively small. 1. Add new config `spark.sql.sources.v2.bucketing.shuffle.enabled` to control this feature enable or not. 2. Add `KeyGroupedPartitioner` use to partition when we know the tranform value of another side (KeyGroupedPartitioning at now). Spark already know the partition value with partition id of KeyGroupedPartitioning side in `EnsureRequirements`. Then save it in `KeyGroupedPartitioner` use to shuffle another partition, to make sure the same key data will shuffle into same partition. 3. only `identity` transform will work now. Because have another problem for now, same transform between DS V2 connector implement and catalog function will report different value, before solve this problem, we should only support `identity`. eg: in test package, `YearFunction` https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala#L47 and https://github.com/apache/spark/blob/master/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala#L143 ### Why are the changes needed? Reduce data shuffle in specific SPJ scenarios ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? add new test Closes #42194 from Hisoka-X/SPARK-41471_one_side_keygroup. Authored-by: Jia Fan Signed-off-by: Chao Sun (cherry picked from commit ce12f6dbad2d713c6a2a52ac36d1f7a910399ad4) --- .../scala/org/apache/spark/Partitioner.scala | 16 ++ .../plans/physical/partitioning.scala | 8 +- .../util/InternalRowComparableWrapper.scala | 7 +- .../apache/spark/sql/internal/SQLConf.scala | 13 ++ .../datasources/v2/BatchScanExec.scala | 2 +- .../exchange/ShuffleExchangeExec.scala | 9 + .../KeyGroupedPartitioningSuite.scala | 181 ++++++++++++++++++ .../exchange/EnsureRequirementsSuite.scala | 32 +++- 8 files changed, 263 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 5dffba2ee8e08..ae39e2e183e4a 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -137,6 +137,22 @@ private[spark] class PartitionIdPassthrough(override val numPartitions: Int) ext override def getPartition(key: Any): Int = key.asInstanceOf[Int] } +/** + * A [[org.apache.spark.Partitioner]] that partitions all records using partition value map. + * The `valueMap` is a map that contains tuples of (partition value, partition id). It is generated + * by [[org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning]], used to partition + * the other side of a join to make sure records with same partition value are in the same + * partition. + */ +private[spark] class KeyGroupedPartitioner( + valueMap: mutable.Map[Seq[Any], Int], + override val numPartitions: Int) extends Partitioner { + override def getPartition(key: Any): Int = { + val keys = key.asInstanceOf[Seq[Any]] + valueMap.getOrElseUpdate(keys, Utils.nonNegativeMod(keys.hashCode, numPartitions)) + } +} + /** * A [[org.apache.spark.Partitioner]] that partitions all records into a single partition. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 211b5a05eb70c..4fea3e9bfd4bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -780,7 +780,13 @@ case class KeyGroupedShuffleSpec( case _ => false } - override def canCreatePartitioning: Boolean = false + override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && + // Only support partition expressions are AttributeReference for now + partitioning.expressions.forall(_.isInstanceOf[AttributeReference]) + + override def createPartitioning(clustering: Seq[Expression]): Partitioning = { + KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues) + } } case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala index b0e530907310a..9a0bdc6bcfd11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala @@ -79,6 +79,11 @@ object InternalRowComparableWrapper { rightPartitioning.partitionValues .map(new InternalRowComparableWrapper(_, partitionDataTypes)) .foreach(partition => partitionsSet.add(partition)) - partitionsSet.map(_.row).toSeq + // SPARK-41471: We keep to order of partitions to make sure the order of + // partitions is deterministic in different case. + val partitionOrdering: Ordering[InternalRow] = { + RowOrdering.createNaturalAscendingOrdering(partitionDataTypes) + } + partitionsSet.map(_.row).toSeq.sorted(partitionOrdering) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8186d5fa00c3a..b8bf691ac2bbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1514,6 +1514,16 @@ object SQLConf { .booleanConf .createWithDefault(false) + val V2_BUCKETING_SHUFFLE_ENABLED = + buildConf("spark.sql.sources.v2.bucketing.shuffle.enabled") + .doc("During a storage-partitioned join, whether to allow to shuffle only one side." + + "When only one side is KeyGroupedPartitioning, if the conditions are met, spark will " + + "only shuffle the other side. This optimization will reduce the amount of data that " + + s"needs to be shuffle. This config requires ${V2_BUCKETING_ENABLED.key} to be enabled") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") .doc("The maximum number of buckets allowed.") .version("2.4.0") @@ -4942,6 +4952,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def v2BucketingPartiallyClusteredDistributionEnabled: Boolean = getConf(SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED) + def v2BucketingShuffleEnabled: Boolean = + getConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED) + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 2a3a5cdeb82b8..a29d49b025bf7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -153,7 +153,7 @@ case class BatchScanExec( if (spjParams.commonPartitionValues.isDefined && spjParams.applyPartialClustering) { // A mapping from the common partition values to how many splits the partition - // should contain. Note this no longer maintain the partition key ordering. + // should contain. val commonPartValuesMap = spjParams.commonPartitionValues .get .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 91f2099ce2d53..750b96dc83dc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.exchange import java.util.function.Supplier +import scala.collection.mutable import scala.concurrent.Future import org.apache.spark._ @@ -29,6 +30,7 @@ import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProces import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ @@ -299,6 +301,11 @@ object ShuffleExchangeExec { ascending = true, samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new ConstantPartitioner + case k @ KeyGroupedPartitioning(expressions, n, _) => + val valueMap = k.uniquePartitionValues.zipWithIndex.map { + case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index) + }.toMap + new KeyGroupedPartitioner(mutable.Map(valueMap.toSeq: _*), n) case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning") // TODO: Handle BroadcastPartitioning. } @@ -325,6 +332,8 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) row => projection(row) case SinglePartition => identity + case KeyGroupedPartitioning(expressions, _, _) => + row => bindReferences(expressions, outputAttributes).map(_.eval(row)) case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 71e030f535e9d..308c62f4f55d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1062,6 +1062,187 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("SPARK-41471: shuffle one side: only one side reports partitioning") { + val items_partitions = Array(identity("id")) + createTable(items, items_schema, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchases_schema, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + Seq(true, false).foreach { shuffle => + withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString) { + val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (shuffle) { + assert(shuffles.size == 1, "only shuffle one side not report partitioning") + } else { + assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" + + " is not enabled") + } + + checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5))) + } + } + } + + test("SPARK-41471: shuffle one side: shuffle side has more partition value") { + val items_partitions = Array(identity("id")) + createTable(items, items_schema, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchases_schema, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp)), " + + "(5, 26.0, cast('2023-01-01' as timestamp)), " + + "(6, 50.0, cast('2023-02-01' as timestamp))") + + Seq(true, false).foreach { shuffle => + withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString) { + Seq("JOIN", "LEFT OUTER JOIN", "RIGHT OUTER JOIN", "FULL OUTER JOIN").foreach { joinType => + val df = sql(s"SELECT id, name, i.price as purchase_price, p.price as sale_price " + + s"FROM testcat.ns.$items i $joinType testcat.ns.$purchases p " + + "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (shuffle) { + assert(shuffles.size == 1, "only shuffle one side not report partitioning") + } else { + assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one " + + "side is not enabled") + } + joinType match { + case "JOIN" => + checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5))) + case "LEFT OUTER JOIN" => + checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5), + Row(4, "cc", 15.5, null))) + case "RIGHT OUTER JOIN" => + checkAnswer(df, Seq(Row(null, null, null, 26.0), Row(null, null, null, 50.0), + Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5))) + case "FULL OUTER JOIN" => + checkAnswer(df, Seq(Row(null, null, null, 26.0), Row(null, null, null, 50.0), + Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5), + Row(4, "cc", 15.5, null))) + } + } + } + } + } + + test("SPARK-41471: shuffle one side: only one side reports partitioning with two identity") { + val items_partitions = Array(identity("id"), identity("arrive_time")) + createTable(items, items_schema, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchases_schema, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp))") + + Seq(true, false).foreach { shuffle => + withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString) { + val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.id = p.item_id and i.arrive_time = p.time ORDER BY id, purchase_price, sale_price") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (shuffle) { + assert(shuffles.size == 1, "only shuffle one side not report partitioning") + } else { + assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" + + " is not enabled") + } + + checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0))) + } + } + } + + test("SPARK-41471: shuffle one side: partitioning with transform") { + val items_partitions = Array(years("arrive_time")) + createTable(items, items_schema, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2021-02-01' as timestamp))") + + createTable(purchases, purchases_schema, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2021-02-01' as timestamp))") + + Seq(true, false).foreach { shuffle => + withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString) { + val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.arrive_time = p.time ORDER BY id, purchase_price, sale_price") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (shuffle) { + assert(shuffles.size == 2, "partitioning with transform not work now") + } else { + assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" + + " is not enabled") + } + + checkAnswer(df, Seq( + Row(1, "aa", 40.0, 42.0), + Row(3, "bb", 10.0, 42.0), + Row(4, "cc", 15.5, 19.5))) + } + } + } + + test("SPARK-41471: shuffle one side: work with group partition split") { + val items_partitions = Array(identity("id")) + createTable(items, items_schema, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchases_schema, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp)), " + + "(5, 26.0, cast('2023-01-01' as timestamp)), " + + "(6, 50.0, cast('2023-02-01' as timestamp))") + + Seq(true, false).foreach { shuffle => + withSQLConf( + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> shuffle.toString, + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "true") { + val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price") + + checkAnswer(df, Seq(Row(1, "aa", 40.0, 42.0), Row(3, "bb", 10.0, 19.5))) + } + } + } + test("SPARK-44641: duplicated records when SPJ is not triggered") { val items_partitions = Array(bucket(8, "id")) createTable(items, items_schema, items_partitions) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 09da1e1e7b013..3c9b92e5f66b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -18,15 +18,17 @@ package org.apache.spark.sql.execution.exchange import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Sum +import org.apache.spark.sql.catalyst.optimizer.BuildRight import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.plans.physical.{SinglePartition, _} import org.apache.spark.sql.catalyst.statsEstimation.StatsTestPlan import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.execution.{DummySparkPlan, SortExec} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.internal.SQLConf @@ -1109,6 +1111,32 @@ class EnsureRequirementsSuite extends SharedSparkSession { } } + test("SPARK-41471: shuffle right side when" + + " spark.sql.sources.v2.bucketing.shuffle.enabled is true") { + withSQLConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") { + + val a1 = AttributeReference("a1", IntegerType)() + + val partitionValue = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) + val plan1 = DummySparkPlan(outputPartitioning = KeyGroupedPartitioning( + identity(a1) :: Nil, 4, partitionValue)) + val plan2 = DummySparkPlan(outputPartitioning = SinglePartition) + + val smjExec = ShuffledHashJoinExec( + a1 :: Nil, a1 :: Nil, Inner, BuildRight, None, plan1, plan2) + EnsureRequirements.apply(smjExec) match { + case ShuffledHashJoinExec(_, _, _, _, _, + DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), + ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv), + DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) => + assert(left.expressions == a1 :: Nil) + assert(attrs == a1 :: Nil) + assert(partitionValue == pv) + case other => fail(other.toString) + } + } + } + test("SPARK-42168: FlatMapCoGroupInPandas and Window function with differing key order") { val lKey = AttributeReference("key", IntegerType)() val lKey2 = AttributeReference("key2", IntegerType)() From d700f2c60896c3b266967591cf81e93ac53d41ab Mon Sep 17 00:00:00 2001 From: Igor Berman Date: Wed, 18 Jun 2025 18:32:52 +0300 Subject: [PATCH 2/5] [SPARK-45036][SQL] SPJ: Simplify the logic to handle partially clustered distribution In SPJ, currently the logic to handle partially clustered distribution is a bit complicated. For instance, when the feature is eanbled (by enabling both `conf.v2BucketingPushPartValuesEnabled` and `conf.v2BucketingPartiallyClusteredDistributionEnabled`), Spark should postpone the combining of input splits until it is about to create an input RDD in `BatchScanExec`. To implement this, `groupPartitions` in `DataSourceV2ScanExecBase` currently takes the flag as input and has two different behaviors, which could be confusing. This PR introduces a new field in `KeyGroupedPartitioning`, named `originalPartitionValues`, that is used to store the original partition values from input before splits combining has been applied. The field is used when partially clustered distribution is enabled. With this, `groupPartitions` becomes easier to understand. In addition, this also simplifies `BatchScanExec.inputRDD` by combining two branches where partially clustered distribution is not enabled. To simplify the current logic in the SPJ w.r.t partially clustered distribution. No Existing tests. Closes #42757 from sunchao/SPARK-45036. Authored-by: Chao Sun Signed-off-by: Dongjoon Hyun (cherry picked from commit 9e2aafb13739f9c07f8218cd325c5532063b1a51) --- .../plans/physical/partitioning.scala | 37 ++++-- .../datasources/v2/BatchScanExec.scala | 117 ++++++++---------- .../v2/DataSourceV2ScanExecBase.scala | 65 +++++----- .../exchange/EnsureRequirements.scala | 9 +- .../exchange/ShuffleExchangeExec.scala | 4 +- .../DistributionAndOrderingSuiteBase.scala | 6 +- .../KeyGroupedPartitioningSuite.scala | 2 +- .../exchange/EnsureRequirementsSuite.scala | 2 +- 8 files changed, 123 insertions(+), 119 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 4fea3e9bfd4bc..1eacb80567658 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -335,28 +335,39 @@ case class CoalescedHashPartitioning(from: HashPartitioning, partitions: Seq[Coa /** * Represents a partitioning where rows are split across partitions based on transforms defined - * by `expressions`. `partitionValuesOpt`, if defined, should contain value of partition key(s) in + * by `expressions`. `partitionValues`, if defined, should contain value of partition key(s) in * ascending order, after evaluated by the transforms in `expressions`, for each input partition. - * In addition, its length must be the same as the number of input partitions (and thus is a 1-1 - * mapping). The `partitionValues` may contain duplicated partition values. + * In addition, its length must be the same as the number of Spark partitions (and thus is a 1-1 + * mapping), and each row in `partitionValues` must be unique. * - * For example, if `expressions` is `[years(ts_col)]`, then a valid value of `partitionValuesOpt` is - * `[0, 1, 2]`, which represents 3 input partitions with distinct partition values. All rows - * in each partition have the same value for column `ts_col` (which is of timestamp type), after - * being applied by the `years` transform. + * The `originalPartitionValues`, on the other hand, are partition values from the original input + * splits returned by data sources. It may contain duplicated values. * - * On the other hand, `[0, 0, 1]` is not a valid value for `partitionValuesOpt` since `0` is - * duplicated twice. + * For example, if a data source reports partition transform expressions `[years(ts_col)]` with 4 + * input splits whose corresponding partition values are `[0, 1, 2, 2]`, then the `expressions` + * in this case is `[years(ts_col)]`, while `partitionValues` is `[0, 1, 2]`, which + * represents 3 input partitions with distinct partition values. All rows in each partition have + * the same value for column `ts_col` (which is of timestamp type), after being applied by the + * `years` transform. This is generated after combining the two splits with partition value `2` + * into a single Spark partition. + * + * On the other hand, in this example `[0, 1, 2, 2]` is the value of `originalPartitionValues` + * which is calculated from the original input splits. * * @param expressions partition expressions for the partitioning. * @param numPartitions the number of partitions - * @param partitionValues the values for the cluster keys of the distribution, must be - * in ascending order. + * @param partitionValues the values for the final cluster keys (that is, after applying grouping + * on the input splits according to `expressions`) of the distribution, + * must be in ascending order, and must NOT contain duplicated values. + * @param originalPartitionValues the original input partition values before any grouping has been + * applied, must be in ascending order, and may contain duplicated + * values */ case class KeyGroupedPartitioning( expressions: Seq[Expression], numPartitions: Int, - partitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning { + partitionValues: Seq[InternalRow] = Seq.empty, + originalPartitionValues: Seq[InternalRow] = Seq.empty) extends Partitioning { override def satisfies0(required: Distribution): Boolean = { super.satisfies0(required) || { @@ -393,7 +404,7 @@ object KeyGroupedPartitioning { def apply( expressions: Seq[Expression], partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { - KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues) + KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues, partitionValues) } def supportsExpressions(expressions: Seq[Expression]): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index a29d49b025bf7..932ac0f5a1b15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Par import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.connector.read._ -import org.apache.spark.sql.internal.SQLConf /** * Physical plan node for scanning a batch of data from a data source v2. @@ -101,7 +100,7 @@ case class BatchScanExec( "partition values that are not present in the original partitioning.") } - groupPartitions(newPartitions).getOrElse(Seq.empty).map(_._2) + groupPartitions(newPartitions).get.groupedParts.map(_.parts) case _ => // no validation is needed as the data source did not report any specific partitioning @@ -137,81 +136,63 @@ case class BatchScanExec( outputPartitioning match { case p: KeyGroupedPartitioning => - if (conf.v2BucketingPushPartValuesEnabled && - conf.v2BucketingPartiallyClusteredDistributionEnabled) { - assert(filteredPartitions.forall(_.size == 1), - "Expect partitions to be not grouped when " + - s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " + - "is enabled") - - val groupedPartitions = groupPartitions(finalPartitions.map(_.head), - groupSplits = true).getOrElse(Seq.empty) - - // This means the input partitions are not grouped by partition values. We'll need to - // check `groupByPartitionValues` and decide whether to group and replicate splits - // within a partition. - if (spjParams.commonPartitionValues.isDefined && - spjParams.applyPartialClustering) { - // A mapping from the common partition values to how many splits the partition - // should contain. - val commonPartValuesMap = spjParams.commonPartitionValues + val groupedPartitions = filteredPartitions.map(splits => { + assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey]) + (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits) + }) + + // When partially clustered, the input partitions are not grouped by partition + // values. Here we'll need to check `commonPartitionValues` and decide how to group + // and replicate splits within a partition. + if (spjParams.commonPartitionValues.isDefined && spjParams.applyPartialClustering) { + // A mapping from the common partition values to how many splits the partition + // should contain. + val commonPartValuesMap = spjParams.commonPartitionValues .get .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2)) .toMap - val nestGroupedPartitions = groupedPartitions.map { - case (partValue, splits) => - // `commonPartValuesMap` should contain the part value since it's the super set. - val numSplits = commonPartValuesMap - .get(InternalRowComparableWrapper(partValue, p.expressions)) - assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + - "common partition values from Spark plan") - - val newSplits = if (spjParams.replicatePartitions) { - // We need to also replicate partitions according to the other side of join - Seq.fill(numSplits.get)(splits) - } else { - // Not grouping by partition values: this could be the side with partially - // clustered distribution. Because of dynamic filtering, we'll need to check if - // the final number of splits of a partition is smaller than the original - // number, and fill with empty splits if so. This is necessary so that both - // sides of a join will have the same number of partitions & splits. - splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) - } - (InternalRowComparableWrapper(partValue, p.expressions), newSplits) + val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) => + // `commonPartValuesMap` should contain the part value since it's the super set. + val numSplits = commonPartValuesMap + .get(InternalRowComparableWrapper(partValue, p.expressions)) + assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + + "common partition values from Spark plan") + + val newSplits = if (spjParams.replicatePartitions) { + // We need to also replicate partitions according to the other side of join + Seq.fill(numSplits.get)(splits) + } else { + // Not grouping by partition values: this could be the side with partially + // clustered distribution. Because of dynamic filtering, we'll need to check if + // the final number of splits of a partition is smaller than the original + // number, and fill with empty splits if so. This is necessary so that both + // sides of a join will have the same number of partitions & splits. + splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) } + (InternalRowComparableWrapper(partValue, p.expressions), newSplits) + } - // Now fill missing partition keys with empty partitions - val partitionMapping = nestGroupedPartitions.toMap - finalPartitions = spjParams.commonPartitionValues.get.flatMap { - case (partValue, numSplits) => - // Use empty partition for those partition values that are not present. - partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), - Seq.fill(numSplits)(Seq.empty)) - } - } else { - // either `commonPartitionValues` is not defined, or it is defined but - // `applyPartialClustering` is false. - val partitionMapping = groupedPartitions.map { case (row, parts) => - InternalRowComparableWrapper(row, p.expressions) -> parts - }.toMap - - // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there - // could exist duplicated partition values, as partition grouping is not done - // at the beginning and postponed to this method. It is important to use unique - // partition values here so that grouped partitions won't get duplicated. - finalPartitions = p.uniquePartitionValues.map { partValue => - // Use empty partition for those partition values that are not present + // Now fill missing partition keys with empty partitions + val partitionMapping = nestGroupedPartitions.toMap + finalPartitions = spjParams.commonPartitionValues.get.flatMap { + case (partValue, numSplits) => + // Use empty partition for those partition values that are not present. partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) - } + InternalRowComparableWrapper(partValue, p.expressions), + Seq.fill(numSplits)(Seq.empty)) } } else { - val partitionMapping = finalPartitions.map { parts => - val row = parts.head.asInstanceOf[HasPartitionKey].partitionKey() - InternalRowComparableWrapper(row, p.expressions) -> parts + // either `commonPartitionValues` is not defined, or it is defined but + // `applyPartialClustering` is false. + val partitionMapping = groupedPartitions.map { case (partValue, splits) => + InternalRowComparableWrapper(partValue, p.expressions) -> splits }.toMap - finalPartitions = p.partitionValues.map { partValue => + + // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there + // could exist duplicated partition values, as partition grouping is not done + // at the beginning and postponed to this method. It is important to use unique + // partition values here so that grouped partitions won't get duplicated. + finalPartitions = p.uniquePartitionValues.map { partValue => // Use empty partition for those partition values that are not present partitionMapping.getOrElse( InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index f688d3514d9aa..94667fbd00c18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -62,8 +62,9 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { redact(result) } - def partitions: Seq[Seq[InputPartition]] = - groupedPartitions.map(_.map(_._2)).getOrElse(inputPartitions.map(Seq(_))) + def partitions: Seq[Seq[InputPartition]] = { + groupedPartitions.map(_.groupedParts.map(_.parts)).getOrElse(inputPartitions.map(Seq(_))) + } /** * Shorthand for calling redact() without specifying redacting rules @@ -94,8 +95,10 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { keyGroupedPartitioning match { case Some(exprs) if KeyGroupedPartitioning.supportsExpressions(exprs) => groupedPartitions - .map { partitionValues => - KeyGroupedPartitioning(exprs, partitionValues.size, partitionValues.map(_._1)) + .map { keyGroupedPartsInfo => + val keyGroupedParts = keyGroupedPartsInfo.groupedParts + KeyGroupedPartitioning(exprs, keyGroupedParts.size, keyGroupedParts.map(_.value), + keyGroupedPartsInfo.originalParts.map(_.partitionKey())) } .getOrElse(super.outputPartitioning) case _ => @@ -103,7 +106,7 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { } } - @transient lazy val groupedPartitions: Option[Seq[(InternalRow, Seq[InputPartition])]] = { + @transient lazy val groupedPartitions: Option[KeyGroupedPartitionInfo] = { // Early check if we actually need to materialize the input partitions. keyGroupedPartitioning match { case Some(_) => groupPartitions(inputPartitions) @@ -117,24 +120,21 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { * - all input partitions implement [[HasPartitionKey]] * - `keyGroupedPartitioning` is set * - * The result, if defined, is a list of tuples where the first element is a partition value, - * and the second element is a list of input partitions that share the same partition value. + * The result, if defined, is a [[KeyGroupedPartitionInfo]] which contains a list of + * [[KeyGroupedPartition]], as well as a list of partition values from the original input splits, + * sorted according to the partition keys in ascending order. * * A non-empty result means each partition is clustered on a single key and therefore eligible * for further optimizations to eliminate shuffling in some operations such as join and aggregate. */ - def groupPartitions( - inputPartitions: Seq[InputPartition], - groupSplits: Boolean = !conf.v2BucketingPushPartValuesEnabled || - !conf.v2BucketingPartiallyClusteredDistributionEnabled): - Option[Seq[(InternalRow, Seq[InputPartition])]] = { - + def groupPartitions(inputPartitions: Seq[InputPartition]): Option[KeyGroupedPartitionInfo] = { if (!SQLConf.get.v2BucketingEnabled) return None + keyGroupedPartitioning.flatMap { expressions => val results = inputPartitions.takeWhile { case _: HasPartitionKey => true case _ => false - }.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p)) + }.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p.asInstanceOf[HasPartitionKey])) if (results.length != inputPartitions.length || inputPartitions.isEmpty) { // Not all of the `InputPartitions` implements `HasPartitionKey`, therefore skip here. @@ -143,32 +143,25 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { // also sort the input partitions according to their partition key order. This ensures // a canonical order from both sides of a bucketed join, for example. val partitionDataTypes = expressions.map(_.dataType) - val partitionOrdering: Ordering[(InternalRow, Seq[InputPartition])] = { + val partitionOrdering: Ordering[(InternalRow, InputPartition)] = { RowOrdering.createNaturalAscendingOrdering(partitionDataTypes).on(_._1) } - - val partitions = if (groupSplits) { - // Group the splits by their partition value - results + val sortedKeyToPartitions = results.sorted(partitionOrdering) + val groupedPartitions = sortedKeyToPartitions .map(t => (InternalRowComparableWrapper(t._1, expressions), t._2)) .groupBy(_._1) .toSeq - .map { - case (key, s) => (key.row, s.map(_._2)) - } - } else { - // No splits grouping, each split will become a separate Spark partition - results.map(t => (t._1, Seq(t._2))) - } + .map { case (key, s) => KeyGroupedPartition(key.row, s.map(_._2)) } - Some(partitions.sorted(partitionOrdering)) + Some(KeyGroupedPartitionInfo(groupedPartitions, sortedKeyToPartitions.map(_._2))) } } } override def outputOrdering: Seq[SortOrder] = { // when multiple partitions are grouped together, ordering inside partitions is not preserved - val partitioningPreservesOrdering = groupedPartitions.forall(_.forall(_._2.length <= 1)) + val partitioningPreservesOrdering = groupedPartitions + .forall(_.groupedParts.forall(_.parts.length <= 1)) ordering.filter(_ => partitioningPreservesOrdering).getOrElse(super.outputOrdering) } @@ -217,3 +210,19 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { } } } + +/** + * A key-grouped Spark partition, which could consist of multiple input splits + * + * @param value the partition value shared by all the input splits + * @param parts the input splits that are grouped into a single Spark partition + */ +private[v2] case class KeyGroupedPartition(value: InternalRow, parts: Seq[InputPartition]) + +/** + * Information about key-grouped partitions, which contains a list of grouped partitions as well + * as the original input partitions before the grouping. + */ +private[v2] case class KeyGroupedPartitionInfo( + groupedParts: Seq[KeyGroupedPartition], + originalParts: Seq[HasPartitionKey]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index ee0ea11816f9a..4067bbd22a01b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -288,12 +288,12 @@ case class EnsureRequirements( reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, leftPartitioning, None)) - case (Some(KeyGroupedPartitioning(clustering, _, _)), _) => + case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys) .orElse(reorderJoinKeysRecursively( leftKeys, rightKeys, None, rightPartitioning)) - case (_, Some(KeyGroupedPartitioning(clustering, _, _))) => + case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) => val leafExprs = clustering.flatMap(_.collectLeaves()) reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys) .orElse(reorderJoinKeysRecursively( @@ -483,7 +483,10 @@ case class EnsureRequirements( s"'$joinType'. Skipping partially clustered distribution.") replicateRightSide = false } else { - val partValues = if (replicateLeftSide) rightPartValues else leftPartValues + // In partially clustered distribution, we should use un-grouped partition values + val spec = if (replicateLeftSide) rightSpec else leftSpec + val partValues = spec.partitioning.originalPartitionValues + val numExpectedPartitions = partValues .map(InternalRowComparableWrapper(_, partitionExprs)) .groupBy(identity) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 750b96dc83dc0..509f1e6a1e4f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -301,7 +301,7 @@ object ShuffleExchangeExec { ascending = true, samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new ConstantPartitioner - case k @ KeyGroupedPartitioning(expressions, n, _) => + case k @ KeyGroupedPartitioning(expressions, n, _, _) => val valueMap = k.uniquePartitionValues.zipWithIndex.map { case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index) }.toMap @@ -332,7 +332,7 @@ object ShuffleExchangeExec { val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes) row => projection(row) case SinglePartition => identity - case KeyGroupedPartitioning(expressions, _, _) => + case KeyGroupedPartitioning(expressions, _, _, _) => row => bindReferences(expressions, outputAttributes).map(_.eval(row)) case _ => throw new IllegalStateException(s"Exchange not implemented for $newPartitioning") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala index f4317e632761c..1a0efa7c4aafb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DistributionAndOrderingSuiteBase.scala @@ -51,9 +51,9 @@ abstract class DistributionAndOrderingSuiteBase plan: QueryPlan[T]): Partitioning = partitioning match { case HashPartitioning(exprs, numPartitions) => HashPartitioning(exprs.map(resolveAttrs(_, plan)), numPartitions) - case KeyGroupedPartitioning(clustering, numPartitions, partitionValues) => - KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, - partitionValues) + case KeyGroupedPartitioning(clustering, numPartitions, partValues, originalPartValues) => + KeyGroupedPartitioning(clustering.map(resolveAttrs(_, plan)), numPartitions, partValues, + originalPartValues) case PartitioningCollection(partitionings) => PartitioningCollection(partitionings.map(resolvePartitioning(_, plan))) case RangePartitioning(ordering, numPartitions) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 308c62f4f55d3..3b933d0539f6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -131,7 +131,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { // Has exactly one partition. val partitionValues = Seq(31).map(v => InternalRow.fromSeq(Seq(v))) checkQueryPlan(df, distribution, - physical.KeyGroupedPartitioning(distribution.clustering, 1, partitionValues)) + physical.KeyGroupedPartitioning(distribution.clustering, 1, partitionValues, partitionValues)) } test("non-clustered distribution: no V2 catalog") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala index 3c9b92e5f66b6..3b0bb088a1076 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -1127,7 +1127,7 @@ class EnsureRequirementsSuite extends SharedSparkSession { EnsureRequirements.apply(smjExec) match { case ShuffledHashJoinExec(_, _, _, _, _, DummySparkPlan(_, _, left: KeyGroupedPartitioning, _, _), - ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv), + ShuffleExchangeExec(KeyGroupedPartitioning(attrs, 4, pv, _), DummySparkPlan(_, _, SinglePartition, _, _), _, _), _) => assert(left.expressions == a1 :: Nil) assert(attrs == a1 :: Nil) From a99489d2220fa8c8229af60302b277350247152a Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Thu, 7 Sep 2023 18:19:54 +0800 Subject: [PATCH 3/5] [SPARK-45036][FOLLOWUP][SQL] SPJ: Make sure result partitions are sorted according to partition values ### What changes were proposed in this pull request? This PR makes sure the result grouped partitions from `DataSourceV2ScanExec#groupPartitions` are sorted according to the partition values. Previously in the #42757 we were assuming Scala would preserve the input ordering but apparently that's not the case. ### Why are the changes needed? See https://github.com/apache/spark/pull/42757#discussion_r1316926504 for diagnosis. The partition ordering is a fundamental property for SPJ and thus must be guaranteed. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? We have tests in `KeyGroupedPartitioningSuite` to cover this. ### Was this patch authored or co-authored using generative AI tooling? Closes #42839 from sunchao/SPARK-45036-followup. Authored-by: Chao Sun Signed-off-by: yangjie01 (cherry picked from commit af1615d026eaf4aeec27ccfe3c58011ebbcb3de1) --- .../datasources/v2/DataSourceV2ScanExecBase.scala | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index 94667fbd00c18..b2f94cae2dfa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -143,17 +143,16 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { // also sort the input partitions according to their partition key order. This ensures // a canonical order from both sides of a bucketed join, for example. val partitionDataTypes = expressions.map(_.dataType) - val partitionOrdering: Ordering[(InternalRow, InputPartition)] = { - RowOrdering.createNaturalAscendingOrdering(partitionDataTypes).on(_._1) - } - val sortedKeyToPartitions = results.sorted(partitionOrdering) - val groupedPartitions = sortedKeyToPartitions + val rowOrdering = RowOrdering.createNaturalAscendingOrdering(partitionDataTypes) + val sortedKeyToPartitions = results.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) + val sortedGroupedPartitions = sortedKeyToPartitions .map(t => (InternalRowComparableWrapper(t._1, expressions), t._2)) .groupBy(_._1) .toSeq .map { case (key, s) => KeyGroupedPartition(key.row, s.map(_._2)) } + .sorted(rowOrdering.on((k: KeyGroupedPartition) => k.value)) - Some(KeyGroupedPartitionInfo(groupedPartitions, sortedKeyToPartitions.map(_._2))) + Some(KeyGroupedPartitionInfo(sortedGroupedPartitions, sortedKeyToPartitions.map(_._2))) } } } From 8f9b7d71f17b887ec2ad0eaca0fc25bad0fff661 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Mon, 11 Sep 2023 21:18:46 +0300 Subject: [PATCH 4/5] [SPARK-44647][SQL] Support SPJ where join keys are less than cluster keys ### What changes were proposed in this pull request? - Add new conf spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled - Change key compatibility checks in EnsureRequirements. Remove checks where all partition keys must be in join keys to allow isKeyCompatible = true in this case (if this flag is enabled) - Change BatchScanExec/DataSourceV2Relation to group splits by join keys if they differ from partition keys (previously grouped only by partition values). Do same for all auxiliary data structure, like commonPartValues. - Implement partiallyClustered skew-handling. - Group only the replicate side (now by join key as well), replicate by the total size of other-side partitions that share the join key. - add an additional sort for partitions based on join key, as when we group the replicate side, partition ordering becomes out of order from the non-replicate side. ### Why are the changes needed? - Support Storage Partition Join in cases where the join condition does not contain all the partition keys, but just some of them ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? -Added tests in KeyGroupedPartitioningSuite -Found two existing problems, will address in separate PR: - Because of https://github.com/apache/spark/pull/37886 we have to select all join keys to trigger SPJ in this case, otherwise DSV2 scan does not report KeyGroupedPartitioning and SPJ does not get triggered. Need to see how to relax this. - https://issues.apache.org/jira/browse/SPARK-44641 was found when testing this change. This pr refactors some of those code to add group-by-join-key, but doesnt change the underlying logic, so issue continues to exist. Hopefully this will also get fixed in another way. Closes #42306 from szehon-ho/spj_attempt_master. Authored-by: Szehon Ho Signed-off-by: Dongjoon Hyun (cherry picked from commit 9520087409d5bd7e6a2651dacf2c295d564d5559) --- .../plans/physical/partitioning.scala | 59 +++- .../apache/spark/sql/internal/SQLConf.scala | 15 + .../datasources/v2/BatchScanExec.scala | 56 ++-- .../exchange/EnsureRequirements.scala | 15 +- .../KeyGroupedPartitioningSuite.scala | 266 +++++++++++++++++- 5 files changed, 379 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 1eacb80567658..7eaff1a4a78ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -380,7 +380,14 @@ case class KeyGroupedPartitioning( } else { // We'll need to find leaf attributes from the partition expressions first. val attributes = expressions.flatMap(_.collectLeaves()) - attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) + + if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + // check that all join keys (required clustering keys) contained in partitioning + requiredClustering.forall(x => attributes.exists(_.semanticEquals(x))) && + expressions.forall(_.collectLeaves().size == 1) + } else { + attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) + } } case _ => @@ -389,8 +396,20 @@ case class KeyGroupedPartitioning( } } - override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = - KeyGroupedShuffleSpec(this, distribution) + override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = { + val result = KeyGroupedShuffleSpec(this, distribution) + if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + // If allowing join keys to be subset of clustering keys, we should create a new + // `KeyGroupedPartitioning` here that is grouped on the join keys instead, and use that as + // the returned shuffle spec. + val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2) + val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions, + partitionValues, originalPartitionValues) + result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions)) + } else { + result + } + } lazy val uniquePartitionValues: Seq[InternalRow] = { partitionValues @@ -403,8 +422,25 @@ case class KeyGroupedPartitioning( object KeyGroupedPartitioning { def apply( expressions: Seq[Expression], - partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { - KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues, partitionValues) + projectionPositions: Seq[Int], + partitionValues: Seq[InternalRow], + originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { + val projectedExpressions = projectionPositions.map(expressions(_)) + val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _)) + val projectedOriginalPartitionValues = + originalPartitionValues.map(project(expressions, projectionPositions, _)) + + KeyGroupedPartitioning(projectedExpressions, projectedPartitionValues.length, + projectedPartitionValues, projectedOriginalPartitionValues) + } + + def project( + expressions: Seq[Expression], + positions: Seq[Int], + input: InternalRow): InternalRow = { + val projectedValues: Array[Any] = positions.map(i => input.get(i, expressions(i).dataType)) + .toArray + new GenericInternalRow(projectedValues) } def supportsExpressions(expressions: Seq[Expression]): Boolean = { @@ -697,6 +733,14 @@ case class HashShuffleSpec( override def numPartitions: Int = partitioning.numPartitions } +/** + * [[ShuffleSpec]] created by [[KeyGroupedPartitioning]]. + * + * @param partitioning key grouped partitioning + * @param distribution distribution + * @param joinKeyPosition position of join keys among cluster keys. + * This is set if joining on a subset of cluster keys is allowed. + */ case class CoalescedHashShuffleSpec( from: ShuffleSpec, partitions: Seq[CoalescedBoundary]) extends ShuffleSpec { @@ -719,7 +763,8 @@ case class CoalescedHashShuffleSpec( case class KeyGroupedShuffleSpec( partitioning: KeyGroupedPartitioning, - distribution: ClusteredDistribution) extends ShuffleSpec { + distribution: ClusteredDistribution, + joinKeyPositions: Option[Seq[Int]] = None) extends ShuffleSpec { /** * A sequence where each element is a set of positions of the partition expression to the cluster @@ -754,7 +799,7 @@ case class KeyGroupedShuffleSpec( // 3.3 each pair of partition expressions at the same index must share compatible // transform functions. // 4. the partition values from both sides are following the same order. - case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution) => + case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) => distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b8bf691ac2bbc..aca25d22fae99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1524,6 +1524,18 @@ object SQLConf { .booleanConf .createWithDefault(false) + val V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS = + buildConf("spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled") + .doc("Whether to allow storage-partition join in the case where join keys are" + + "a subset of the partition keys of the source tables. At planning time, " + + "Spark will group the partitions by only those keys that are in the join keys." + + s"This is currently enabled only if ${REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key} " + + "is false." + ) + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") .doc("The maximum number of buckets allowed.") .version("2.4.0") @@ -4955,6 +4967,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def v2BucketingShuffleEnabled: Boolean = getConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED) + def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean = + getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS) + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 932ac0f5a1b15..094a7b20808ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -120,7 +120,12 @@ case class BatchScanExec( val newPartValues = spjParams.commonPartitionValues.get.flatMap { case (partValue, numSplits) => Seq.fill(numSplits)(partValue) } - k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues) + val expressions = spjParams.joinKeyPositions match { + case Some(projectionPositions) => projectionPositions.map(i => k.expressions(i)) + case _ => k.expressions + } + k.copy(expressions = expressions, numPartitions = newPartValues.length, + partitionValues = newPartValues) case p => p } } @@ -132,14 +137,29 @@ case class BatchScanExec( // return an empty RDD with 1 partition if dynamic filtering removed the only split sparkContext.parallelize(Array.empty[InternalRow], 1) } else { - var finalPartitions = filteredPartitions - - outputPartitioning match { + val finalPartitions = outputPartitioning match { case p: KeyGroupedPartitioning => - val groupedPartitions = filteredPartitions.map(splits => { - assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey]) - (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits) - }) + assert(spjParams.keyGroupedPartitioning.isDefined) + val expressions = spjParams.keyGroupedPartitioning.get + + // Re-group the input partitions if we are projecting on a subset of join keys + val (groupedPartitions, partExpressions) = spjParams.joinKeyPositions match { + case Some(projectPositions) => + val projectedExpressions = projectPositions.map(i => expressions(i)) + val parts = filteredPartitions.flatten.groupBy(part => { + val row = part.asInstanceOf[HasPartitionKey].partitionKey() + val projectedRow = KeyGroupedPartitioning.project( + expressions, projectPositions, row) + InternalRowComparableWrapper(projectedRow, projectedExpressions) + }).map { case (wrapper, splits) => (wrapper.row, splits) }.toSeq + (parts, projectedExpressions) + case _ => + val groupedParts = filteredPartitions.map(splits => { + assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey]) + (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits) + }) + (groupedParts, expressions) + } // When partially clustered, the input partitions are not grouped by partition // values. Here we'll need to check `commonPartitionValues` and decide how to group @@ -149,12 +169,12 @@ case class BatchScanExec( // should contain. val commonPartValuesMap = spjParams.commonPartitionValues .get - .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2)) + .map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2)) .toMap val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) => // `commonPartValuesMap` should contain the part value since it's the super set. val numSplits = commonPartValuesMap - .get(InternalRowComparableWrapper(partValue, p.expressions)) + .get(InternalRowComparableWrapper(partValue, partExpressions)) assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + "common partition values from Spark plan") @@ -169,37 +189,37 @@ case class BatchScanExec( // sides of a join will have the same number of partitions & splits. splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) } - (InternalRowComparableWrapper(partValue, p.expressions), newSplits) + (InternalRowComparableWrapper(partValue, partExpressions), newSplits) } // Now fill missing partition keys with empty partitions val partitionMapping = nestGroupedPartitions.toMap - finalPartitions = spjParams.commonPartitionValues.get.flatMap { + spjParams.commonPartitionValues.get.flatMap { case (partValue, numSplits) => // Use empty partition for those partition values that are not present. partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), + InternalRowComparableWrapper(partValue, partExpressions), Seq.fill(numSplits)(Seq.empty)) } } else { // either `commonPartitionValues` is not defined, or it is defined but // `applyPartialClustering` is false. val partitionMapping = groupedPartitions.map { case (partValue, splits) => - InternalRowComparableWrapper(partValue, p.expressions) -> splits + InternalRowComparableWrapper(partValue, partExpressions) -> splits }.toMap // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there // could exist duplicated partition values, as partition grouping is not done // at the beginning and postponed to this method. It is important to use unique // partition values here so that grouped partitions won't get duplicated. - finalPartitions = p.uniquePartitionValues.map { partValue => + p.uniquePartitionValues.map { partValue => // Use empty partition for those partition values that are not present partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) + InternalRowComparableWrapper(partValue, partExpressions), Seq.empty) } } - case _ => + case _ => filteredPartitions } new DataSourceRDD( @@ -234,6 +254,7 @@ case class BatchScanExec( case class StoragePartitionJoinParams( keyGroupedPartitioning: Option[Seq[Expression]] = None, + joinKeyPositions: Option[Seq[Int]] = None, commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, applyPartialClustering: Boolean = false, replicatePartitions: Boolean = false) { @@ -247,6 +268,7 @@ case class StoragePartitionJoinParams( } override def hashCode(): Int = Objects.hashCode( + joinKeyPositions: Option[Seq[Int]], commonPartitionValues: Option[Seq[(InternalRow, Int)]], applyPartialClustering: java.lang.Boolean, replicatePartitions: java.lang.Boolean) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 4067bbd22a01b..81d457e8ff3d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -380,7 +380,8 @@ case class EnsureRequirements( val rightSpec = specs(1) var isCompatible = false - if (!conf.v2BucketingPushPartValuesEnabled) { + if (!conf.v2BucketingPushPartValuesEnabled && + !conf.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { isCompatible = leftSpec.isCompatibleWith(rightSpec) } else { logInfo("Pushing common partition values for storage-partitioned join") @@ -505,10 +506,10 @@ case class EnsureRequirements( } // Now we need to push-down the common partition key to the scan in each child - newLeft = populatePartitionValues( - left, mergedPartValues, applyPartialClustering, replicateLeftSide) - newRight = populatePartitionValues( - right, mergedPartValues, applyPartialClustering, replicateRightSide) + newLeft = populatePartitionValues(left, mergedPartValues, leftSpec.joinKeyPositions, + applyPartialClustering, replicateLeftSide) + newRight = populatePartitionValues(right, mergedPartValues, rightSpec.joinKeyPositions, + applyPartialClustering, replicateRightSide) } } @@ -530,19 +531,21 @@ case class EnsureRequirements( private def populatePartitionValues( plan: SparkPlan, values: Seq[(InternalRow, Int)], + joinKeyPositions: Option[Seq[Int]], applyPartialClustering: Boolean, replicatePartitions: Boolean): SparkPlan = plan match { case scan: BatchScanExec => scan.copy( spjParams = scan.spjParams.copy( commonPartitionValues = Some(values), + joinKeyPositions = joinKeyPositions, applyPartialClustering = applyPartialClustering, replicatePartitions = replicatePartitions ) ) case node => node.mapChildren(child => populatePartitionValues( - child, values, applyPartialClustering, replicatePartitions)) + child, values, joinKeyPositions, applyPartialClustering, replicatePartitions)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 3b933d0539f6f..727e54c95f780 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -98,14 +98,17 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val catalystDistribution = physical.ClusteredDistribution( Seq(TransformExpression(YearsFunction, Seq(attr("ts"))))) val partitionValues = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) + val projectedPositions = catalystDistribution.clustering.indices checkQueryPlan(df, catalystDistribution, - physical.KeyGroupedPartitioning(catalystDistribution.clustering, partitionValues)) + physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, + partitionValues, partitionValues)) // multiple group keys should work too as long as partition keys are subset of them df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts") checkQueryPlan(df, catalystDistribution, - physical.KeyGroupedPartitioning(catalystDistribution.clustering, partitionValues)) + physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, + partitionValues, partitionValues)) } test("non-clustered distribution: no partition") { @@ -1299,6 +1302,222 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("SPARK-44647: test join key is subset of cluster key " + + "with push values and partially-clustered") { + val table1 = "tab1e1" + val table2 = "table2" + val partition = Array(identity("id"), identity("data")) + createTable(table1, schema, partition) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(1, 'aa', cast('2020-01-01' as timestamp)), " + + "(2, 'bb', cast('2020-01-01' as timestamp)), " + + "(2, 'cc', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'ee', cast('2020-01-01' as timestamp)), " + + "(3, 'ee', cast('2020-01-01' as timestamp))") + + createTable(table2, schema, partition) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(4, 'zz', cast('2020-01-01' as timestamp)), " + + "(4, 'zz', cast('2020-01-01' as timestamp)), " + + "(3, 'yy', cast('2020-01-01' as timestamp)), " + + "(3, 'yy', cast('2020-01-01' as timestamp)), " + + "(3, 'xx', cast('2020-01-01' as timestamp)), " + + "(3, 'xx', cast('2020-01-01' as timestamp)), " + + "(2, 'ww', cast('2020-01-01' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString) { + + val df = sql("SELECT t1.id AS id, t1.data AS t1data, t2.data AS t2data " + + s"FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 " + + "ON t1.id = t2.id ORDER BY t1.id, t1data, t2data") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (allowJoinKeysSubsetOfPartitionKeys) { + assert(shuffles.isEmpty, "SPJ should be triggered") + } else { + assert(shuffles.nonEmpty, "SPJ should not be triggered") + } + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + + (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { + // SPJ and partially-clustered + case (true, true) => assert(scans == Seq(8, 8)) + // SPJ and not partially-clustered + case (true, false) => assert(scans == Seq(4, 4)) + // No SPJ + case _ => assert(scans == Seq(5, 4)) + } + + checkAnswer(df, Seq( + Row(2, "bb", "ww"), + Row(2, "cc", "ww"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy") + )) + } + } + } + } + } + + test("SPARK-44647: test join key is the second cluster key") { + val table1 = "tab1e1" + val table2 = "table2" + val partition = Array(identity("id"), identity("data")) + createTable(table1, schema, partition) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(1, 'aa', cast('2020-01-01' as timestamp)), " + + "(2, 'bb', cast('2020-01-02' as timestamp)), " + + "(3, 'cc', cast('2020-01-03' as timestamp))") + + createTable(table2, schema, partition) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(4, 'aa', cast('2020-01-01' as timestamp)), " + + "(5, 'bb', cast('2020-01-02' as timestamp)), " + + "(6, 'cc', cast('2020-01-03' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> + pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString) { + + val df = sql("SELECT t1.id AS t1id, t2.id as t2id, t1.data AS data " + + s"FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 " + + "ON t1.data = t2.data ORDER BY t1id, t1id, data") + + checkAnswer(df, Seq(Row(1, 4, "aa"), Row(2, 5, "bb"), Row(3, 6, "cc"))) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (allowJoinKeysSubsetOfPartitionKeys) { + assert(shuffles.isEmpty, "SPJ should be triggered") + } else { + assert(shuffles.nonEmpty, "SPJ should not be triggered") + } + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + (pushDownValues, allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { + // SPJ and partially-clustered + case (true, true, true) => assert(scans == Seq(3, 3)) + // non-SPJ or SPJ/partially-clustered + case _ => assert(scans == Seq(3, 3)) + } + } + } + } + } + } + + test("SPARK-44647: test join key is the second partition key and a transform") { + val items_partitions = Array(bucket(8, "id"), days("arrive_time")) + createTable(items, items_schema, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " + + s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(bucket(8, "item_id"), days("time")) + createTable(purchases, purchases_schema, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(1, 44.0, cast('2020-01-15' as timestamp)), " + + s"(1, 45.0, cast('2020-01-15' as timestamp)), " + + s"(2, 11.0, cast('2020-01-01' as timestamp)), " + + s"(3, 19.5, cast('2020-02-01' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString) { + val df = sql("SELECT id, name, i.price as purchase_price, " + + "p.item_id, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.arrive_time = p.time " + + "ORDER BY id, purchase_price, p.item_id, sale_price") + + // Currently SPJ for case where join key not same as partition key + // only supported when push-part-values enabled + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (allowJoinKeysSubsetOfPartitionKeys) { + assert(shuffles.isEmpty, "SPJ should be triggered") + } else { + assert(shuffles.nonEmpty, "SPJ should not be triggered") + } + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { + // SPJ and partially-clustered + case (true, true) => assert(scans == Seq(5, 5)) + // SPJ and not partially-clustered + case (true, false) => assert(scans == Seq(3, 3)) + // No SPJ + case _ => assert(scans == Seq(4, 4)) + } + + checkAnswer(df, + Seq( + Row(1, "aa", 40.0, 1, 42.0), + Row(1, "aa", 40.0, 2, 11.0), + Row(1, "aa", 41.0, 1, 44.0), + Row(1, "aa", 41.0, 1, 45.0), + Row(2, "bb", 10.0, 1, 42.0), + Row(2, "bb", 10.0, 2, 11.0), + Row(2, "bb", 10.5, 1, 42.0), + Row(2, "bb", 10.5, 2, 11.0), + Row(3, "cc", 15.5, 3, 19.5) + ) + ) + } + } + } + } + } + test("SPARK-45652: SPJ should handle empty partition after dynamic filtering") { withSQLConf( SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", @@ -1341,3 +1560,46 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } } + + test("SPARK-44647: shuffle one side and join keys are less than partition keys") { + val items_partitions = Array(identity("id"), identity("name")) + createTable(items, items_schema, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 30.0, cast('2020-01-02' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchases_schema, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 89.0, cast('2020-01-03' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp)), " + + "(5, 26.0, cast('2023-01-01' as timestamp)), " + + "(6, 50.0, cast('2023-02-01' as timestamp))") + + Seq(true, false).foreach { pushdownValues => + Seq(true, false).foreach { partiallyClustered => + withSQLConf( + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key + -> partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { + val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 1, "SPJ should be triggered") + checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0), + Row(1, "aa", 30.0, 89.0), + Row(1, "aa", 40.0, 42.0), + Row(1, "aa", 40.0, 89.0), + Row(3, "bb", 10.0, 19.5))) + } + } + } + } +} From aa63827f238618f7d7fc7509e7346b44433181d1 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Thu, 2 May 2024 16:14:36 -0700 Subject: [PATCH 5/5] [SPARK-48065][SQL] SPJ: allowJoinKeysSubsetOfPartitionKeys is too strict ### What changes were proposed in this pull request? If spark.sql.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled is true, change KeyGroupedPartitioning.satisfies0(distribution) check from all clustering keys (here, join keys) being in partition keys, to the two sets overlapping. ### Why are the changes needed? If spark.sql.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled is true, then SPJ no longer triggers if there are more join keys than partition keys. But SPJ is supported in this case if flag is false. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added tests in KeyGroupedPartitioningSuite ### Was this patch authored or co-authored using generative AI tooling? No Closes #46325 from szehon-ho/fix_spj_less_join_key. Authored-by: Szehon Ho Signed-off-by: Chao Sun (cherry picked from commit 5ec62a760bcd718d6a600979df03bcabc2192d6b) --- .../plans/physical/partitioning.scala | 5 +- .../KeyGroupedPartitioningSuite.scala | 60 +++++++++++++++++++ 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 7eaff1a4a78ec..79341da0db788 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -382,8 +382,9 @@ case class KeyGroupedPartitioning( val attributes = expressions.flatMap(_.collectLeaves()) if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { - // check that all join keys (required clustering keys) contained in partitioning - requiredClustering.forall(x => attributes.exists(_.semanticEquals(x))) && + // check that join keys (required clustering keys) + // overlap with partition keys (KeyGroupedPartitioning attributes) + requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && expressions.forall(_.collectLeaves().size == 1) } else { attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 727e54c95f780..b342a382a749e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1302,6 +1302,66 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("SPARK-48065: SPJ: allowJoinKeysSubsetOfPartitionKeys is too strict") { + val table1 = "tab1e1" + val table2 = "table2" + val partition = Array(identity("id")) + createTable(table1, columns, partition) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(1, 'aa', cast('2020-01-01' as timestamp)), " + + "(2, 'bb', cast('2020-01-01' as timestamp)), " + + "(2, 'cc', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'ee', cast('2020-01-01' as timestamp)), " + + "(3, 'ee', cast('2020-01-01' as timestamp))") + + createTable(table2, columns, partition) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(4, 'zz', cast('2020-01-01' as timestamp)), " + + "(4, 'zz', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'xx', cast('2020-01-01' as timestamp)), " + + "(3, 'xx', cast('2020-01-01' as timestamp)), " + + "(2, 'ww', cast('2020-01-01' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.id AS id, t1.data AS t1data, t2.data AS t2data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.id = t2.id AND t1.data = t2.data ORDER BY t1.id, t1data, t2data + |""".stripMargin) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "SPJ should be triggered") + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + if (partiallyClustered) { + assert(scans == Seq(8, 8)) + } else { + assert(scans == Seq(4, 4)) + } + checkAnswer(df, Seq( + Row(3, "dd", "dd"), + Row(3, "dd", "dd"), + Row(3, "dd", "dd"), + Row(3, "dd", "dd") + )) + } + } + } + } + test("SPARK-44647: test join key is subset of cluster key " + "with push values and partially-clustered") { val table1 = "tab1e1"