diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 040f1bfab65b..78d153c5a0e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -72,9 +72,14 @@ case object AllTuples extends Distribution { /** * Represents data where tuples that share the same values for the `clustering` * [[Expression Expressions]] will be co-located in the same partition. + * + * @param requireAllClusterKeys When true, `Partitioning` which satisfies this distribution, + * must match all `clustering` expressions in the same ordering. */ case class ClusteredDistribution( clustering: Seq[Expression], + requireAllClusterKeys: Boolean = SQLConf.get.getConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION), requiredNumPartitions: Option[Int] = None) extends Distribution { require( clustering != Nil, @@ -88,6 +93,19 @@ case class ClusteredDistribution( s"the actual number of partitions is $numPartitions.") HashPartitioning(clustering, numPartitions) } + + /** + * Checks if `expressions` match all `clustering` expressions in the same ordering. + * + * `Partitioning` should call this to check its expressions when `requireAllClusterKeys` + * is set to true. + */ + def areAllClusterKeysMatched(expressions: Seq[Expression]): Boolean = { + expressions.length == clustering.length && + expressions.zip(clustering).forall { + case (l, r) => l.semanticEquals(r) + } + } } /** @@ -261,8 +279,14 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) expressions.length == h.expressions.length && expressions.zip(h.expressions).forall { case (l, r) => l.semanticEquals(r) } - case ClusteredDistribution(requiredClustering, _) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + if (requireAllClusterKeys) { + // Checks `HashPartitioning` is partitioned on exactly same clustering keys of + // `ClusteredDistribution`. + c.areAllClusterKeysMatched(expressions) + } else { + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) + } case _ => false } } @@ -322,8 +346,15 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) // `RangePartitioning(a, b, c)` satisfies `OrderedDistribution(a, b)`. val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering, _) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) + case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) => + val expressions = ordering.map(_.child) + if (requireAllClusterKeys) { + // Checks `RangePartitioning` is partitioned on exactly same clustering keys of + // `ClusteredDistribution`. + c.areAllClusterKeysMatched(expressions) + } else { + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) + } case _ => false } } @@ -524,10 +555,7 @@ case class HashShuffleSpec( // will add shuffles with the default partitioning of `ClusteredDistribution`, which uses all // the join keys. if (SQLConf.get.getConf(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION)) { - partitioning.expressions.length == distribution.clustering.length && - partitioning.expressions.zip(distribution.clustering).forall { - case (l, r) => l.semanticEquals(r) - } + distribution.areAllClusterKeysMatched(partitioning.expressions) } else { true } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3a7ce650ea63..a050156518c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -407,6 +407,18 @@ object SQLConf { .booleanConf .createWithDefault(true) + val REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION = + buildConf("spark.sql.requireAllClusterKeysForDistribution") + .internal() + .doc("When true, the planner requires all the clustering keys as the partition keys " + + "(with same ordering) of the children, to eliminate the shuffle for the operator that " + + "requires its children be clustered distributed, such as AGGREGATE and WINDOW node. " + + "This is to avoid data skews which can lead to significant performance regression if " + + "shuffle is eliminated.") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + val RADIX_SORT_ENABLED = buildConf("spark.sql.sort.enableRadixSort") .internal() .doc("When true, enable use of radix sort when possible. Radix sort is much faster but " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index e047d4c070be..a924a9ed02e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -169,6 +169,24 @@ class DistributionSuite extends SparkFunSuite { ClusteredDistribution(Seq($"d", $"e")), false) + // When ClusteredDistribution.requireAllClusterKeys is set to true, + // HashPartitioning can only satisfy ClusteredDistribution iff its hash expressions are + // exactly same as the required clustering expressions. + checkSatisfied( + HashPartitioning(Seq($"a", $"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), + true) + + checkSatisfied( + HashPartitioning(Seq($"b", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), + false) + + checkSatisfied( + HashPartitioning(Seq($"b", $"a", $"c"), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), + false) + // HashPartitioning cannot satisfy OrderedDistribution checkSatisfied( HashPartitioning(Seq($"a", $"b", $"c"), 10), @@ -249,22 +267,40 @@ class DistributionSuite extends SparkFunSuite { RangePartitioning(Seq($"a".asc, $"b".asc, $"c".asc), 10), ClusteredDistribution(Seq($"c", $"d")), false) + + // When ClusteredDistribution.requireAllClusterKeys is set to true, + // RangePartitioning can only satisfy ClusteredDistribution iff its ordering expressions are + // exactly same as the required clustering expressions. + checkSatisfied( + RangePartitioning(Seq($"a".asc, $"b".asc, $"c".asc), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), + true) + + checkSatisfied( + RangePartitioning(Seq($"a".asc, $"b".asc), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), + false) + + checkSatisfied( + RangePartitioning(Seq($"b".asc, $"a".asc, $"c".asc), 10), + ClusteredDistribution(Seq($"a", $"b", $"c"), requireAllClusterKeys = true), + false) } test("Partitioning.numPartitions must match Distribution.requiredNumPartitions to satisfy it") { checkSatisfied( SinglePartition, - ClusteredDistribution(Seq($"a", $"b", $"c"), Some(10)), + ClusteredDistribution(Seq($"a", $"b", $"c"), requiredNumPartitions = Some(10)), false) checkSatisfied( HashPartitioning(Seq($"a", $"b", $"c"), 10), - ClusteredDistribution(Seq($"a", $"b", $"c"), Some(5)), + ClusteredDistribution(Seq($"a", $"b", $"c"), requiredNumPartitions = Some(5)), false) checkSatisfied( RangePartitioning(Seq($"a".asc, $"b".asc, $"c".asc), 10), - ClusteredDistribution(Seq($"a", $"b", $"c"), Some(5)), + ClusteredDistribution(Seq($"a", $"b", $"c"), requiredNumPartitions = Some(5)), false) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala index cbd4ee698df2..51833012a128 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEUtils.scala @@ -37,7 +37,7 @@ object AQEUtils { } else { None } - Some(ClusteredDistribution(h.expressions, numPartitions)) + Some(ClusteredDistribution(h.expressions, requiredNumPartitions = numPartitions)) case f: FilterExec => getRequiredDistribution(f.child) case s: SortExec if !s.global => getRequiredDistribution(s.child) case c: CollectMetricsExec => getRequiredDistribution(c.child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index 93ed5916bfb2..dfcb70737666 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -96,8 +96,10 @@ case class FlatMapGroupsWithStateExec( // NOTE: Please read through the NOTE on the classdoc of StatefulOpClusteredDistribution // before making any changes. // TODO(SPARK-38204) - ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) :: - ClusteredDistribution(initialStateGroupAttrs, stateInfo.map(_.numPartitions)) :: + ClusteredDistribution( + groupingAttributes, requiredNumPartitions = stateInfo.map(_.numPartitions)) :: + ClusteredDistribution( + initialStateGroupAttrs, requiredNumPartitions = stateInfo.map(_.numPartitions)) :: Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index bbc6fa05d514..f9ae65cdc47d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -287,6 +287,11 @@ abstract class StreamExecution( // Disable cost-based join optimization as we do not want stateful operations // to be rearranged sparkSessionForStream.conf.set(SQLConf.CBO_ENABLED.key, "false") + // Disable any config affecting the required child distribution of stateful operators. + // Please read through the NOTE on the classdoc of StatefulOpClusteredDistribution for + // details. + sparkSessionForStream.conf.set(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key, + "false") updateStatusMessage("Initializing sources") // force initialization of the logical plan so that the sources can be created diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 3ab2ad47e98c..45c6430f9642 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -340,7 +340,8 @@ case class StateStoreRestoreExec( if (keyExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + ClusteredDistribution(keyExpressions, + requiredNumPartitions = stateInfo.map(_.numPartitions)) :: Nil } } @@ -502,7 +503,8 @@ case class StateStoreSaveExec( if (keyExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + ClusteredDistribution(keyExpressions, + requiredNumPartitions = stateInfo.map(_.numPartitions)) :: Nil } } @@ -582,7 +584,8 @@ case class SessionWindowStateStoreRestoreExec( // NOTE: Please read through the NOTE on the classdoc of StatefulOpClusteredDistribution // before making any changes. // TODO(SPARK-38204) - ClusteredDistribution(keyWithoutSessionExpressions, stateInfo.map(_.numPartitions)) :: Nil + ClusteredDistribution(keyWithoutSessionExpressions, + requiredNumPartitions = stateInfo.map(_.numPartitions)) :: Nil } override def requiredChildOrdering: Seq[Seq[SortOrder]] = { @@ -696,7 +699,8 @@ case class SessionWindowStateStoreSaveExec( // NOTE: Please read through the NOTE on the classdoc of StatefulOpClusteredDistribution // before making any changes. // TODO(SPARK-38204) - ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + ClusteredDistribution(keyExpressions, + requiredNumPartitions = stateInfo.map(_.numPartitions)) :: Nil } override def shouldRunAnotherBatch(newMetadata: OffsetSeqMetadata): Boolean = { @@ -757,7 +761,8 @@ case class StreamingDeduplicateExec( // NOTE: Please read through the NOTE on the classdoc of StatefulOpClusteredDistribution // before making any changes. // TODO(SPARK-38204) - ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil + ClusteredDistribution(keyExpressions, + requiredNumPartitions = stateInfo.map(_.numPartitions)) :: Nil } override protected def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 1491c5a4f26b..3cf61c3402bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -20,9 +20,12 @@ package org.apache.spark.sql import org.scalatest.matchers.must.Matchers.the import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.optimizer.TransposeWindow +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.exchange.Exchange +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, Exchange, ShuffleExchangeExec} +import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -1071,4 +1074,48 @@ class DataFrameWindowFunctionsSuite extends QueryTest Row("a", 1, "x", "x"), Row("b", 0, null, null))) } + + test("SPARK-38237: require all cluster keys for child required distribution for window query") { + def partitionExpressionsColumns(expressions: Seq[Expression]): Seq[String] = { + expressions.flatMap { + case ref: AttributeReference => Some(ref.name) + } + } + + def isShuffleExecByRequirement( + plan: ShuffleExchangeExec, + desiredClusterColumns: Seq[String]): Boolean = plan match { + case ShuffleExchangeExec(op: HashPartitioning, _, ENSURE_REQUIREMENTS) => + partitionExpressionsColumns(op.expressions) === desiredClusterColumns + case _ => false + } + + val df = Seq(("a", 1, 1), ("a", 2, 2), ("b", 1, 3), ("b", 1, 4)).toDF("key1", "key2", "value") + val windowSpec = Window.partitionBy("key1", "key2").orderBy("value") + + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key -> "true") { + + val windowed = df + // repartition by subset of window partitionBy keys which satisfies ClusteredDistribution + .repartition($"key1") + .select( + lead($"key1", 1).over(windowSpec), + lead($"value", 1).over(windowSpec)) + + checkAnswer(windowed, Seq(Row("b", 4), Row(null, null), Row(null, null), Row(null, null))) + + val shuffleByRequirement = windowed.queryExecution.executedPlan.find { + case w: WindowExec => + w.child.find { + case s: ShuffleExchangeExec => isShuffleExecByRequirement(s, Seq("key1", "key2")) + case _ => false + }.nonEmpty + case _ => false + } + + assert(shuffleByRequirement.nonEmpty, "Can't find desired shuffle node from the query plan") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 383b84dc0d8f..2ab1b6d4963a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -432,7 +432,7 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { } test("EnsureRequirements should respect ClusteredDistribution's num partitioning") { - val distribution = ClusteredDistribution(Literal(1) :: Nil, Some(13)) + val distribution = ClusteredDistribution(Literal(1) :: Nil, requiredNumPartitions = Some(13)) // Number of partitions differ val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 13) val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5)