From e842f97b6c58fc2a679c4d9e39c84b2ae229dc72 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Thu, 30 Nov 2023 17:14:59 +0800 Subject: [PATCH] [SPARK-46170][SQL] Support inject adaptive query post planner strategy rules in SparkSessionExtensions ### What changes were proposed in this pull request? This pr adds a new extension entrance `queryPostPlannerStrategyRules` in `SparkSessionExtensions`. It will be applied between plannerStrategy and queryStagePrepRules in AQE, so it can get the whole plan before injecting exchanges. ### Why are the changes needed? a part of https://github.com/apache/spark/pull/44013 ### Does this PR introduce _any_ user-facing change? no, only for develop ### How was this patch tested? add test ### Was this patch authored or co-authored using generative AI tooling? no Closes #44074 from ulysses-you/post-planner. Authored-by: ulysses-you Signed-off-by: youxiduo --- .../spark/sql/SparkSessionExtensions.scala | 21 ++++++++++ .../adaptive/AdaptiveRulesHolder.scala | 5 ++- .../adaptive/AdaptiveSparkPlanExec.scala | 14 ++++++- .../internal/BaseSessionStateBuilder.scala | 3 +- .../sql/SparkSessionExtensionSuite.scala | 40 ++++++++++++++++++- 5 files changed, 78 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index b7c86ab7de6b4..677dba0082575 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan} *
  • Customized Parser.
  • *
  • (External) Catalog listeners.
  • *
  • Columnar Rules.
  • + *
  • Adaptive Query Post Planner Strategy Rules.
  • *
  • Adaptive Query Stage Preparation Rules.
  • *
  • Adaptive Query Execution Runtime Optimizer Rules.
  • *
  • Adaptive Query Stage Optimizer Rules.
  • @@ -112,10 +113,13 @@ class SparkSessionExtensions { type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder) type TableFunctionDescription = (FunctionIdentifier, ExpressionInfo, TableFunctionBuilder) type ColumnarRuleBuilder = SparkSession => ColumnarRule + type QueryPostPlannerStrategyBuilder = SparkSession => Rule[SparkPlan] type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan] type QueryStageOptimizerRuleBuilder = SparkSession => Rule[SparkPlan] private[this] val columnarRuleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + private[this] val queryPostPlannerStrategyRuleBuilders = + mutable.Buffer.empty[QueryPostPlannerStrategyBuilder] private[this] val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder] private[this] val runtimeOptimizerRules = mutable.Buffer.empty[RuleBuilder] private[this] val queryStageOptimizerRuleBuilders = @@ -128,6 +132,14 @@ class SparkSessionExtensions { columnarRuleBuilders.map(_.apply(session)).toSeq } + /** + * Build the override rules for the query post planner strategy phase of adaptive query execution. + */ + private[sql] def buildQueryPostPlannerStrategyRules( + session: SparkSession): Seq[Rule[SparkPlan]] = { + queryPostPlannerStrategyRuleBuilders.map(_.apply(session)).toSeq + } + /** * Build the override rules for the query stage preparation phase of adaptive query execution. */ @@ -156,6 +168,15 @@ class SparkSessionExtensions { columnarRuleBuilders += builder } + /** + * Inject a rule that applied between `plannerStrategy` and `queryStagePrepRules`, so + * it can get the whole plan before injecting exchanges. + * Note, these rules can only be applied within AQE. + */ + def injectQueryPostPlannerStrategyRule(builder: QueryPostPlannerStrategyBuilder): Unit = { + queryPostPlannerStrategyRuleBuilders += builder + } + /** * Inject a rule that can override the query stage preparation phase of adaptive query * execution. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala index 8391fe44f5598..ee2cd8a4953bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala @@ -29,9 +29,12 @@ import org.apache.spark.sql.execution.SparkPlan * query stage * @param queryStageOptimizerRules applied to a new query stage before its execution. It makes sure * all children query stages are materialized + * @param queryPostPlannerStrategyRules applied between `plannerStrategy` and `queryStagePrepRules`, + * so it can get the whole plan before injecting exchanges. */ class AdaptiveRulesHolder( val queryStagePrepRules: Seq[Rule[SparkPlan]], val runtimeOptimizerRules: Seq[Rule[LogicalPlan]], - val queryStageOptimizerRules: Seq[Rule[SparkPlan]]) { + val queryStageOptimizerRules: Seq[Rule[SparkPlan]], + val queryPostPlannerStrategyRules: Seq[Rule[SparkPlan]]) { } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index fa671c8faf8b3..96b83a91cc739 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -193,9 +193,19 @@ case class AdaptiveSparkPlanExec( optimized } + private def applyQueryPostPlannerStrategyRules(plan: SparkPlan): SparkPlan = { + applyPhysicalRules( + plan, + context.session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules, + Some((planChangeLogger, "AQE Query Post Planner Strategy Rules")) + ) + } + @transient val initialPlan = context.session.withActive { applyPhysicalRules( - inputPlan, queryStagePreparationRules, Some((planChangeLogger, "AQE Preparations"))) + applyQueryPostPlannerStrategyRules(inputPlan), + queryStagePreparationRules, + Some((planChangeLogger, "AQE Preparations"))) } @volatile private var currentPhysicalPlan = initialPlan @@ -706,7 +716,7 @@ case class AdaptiveSparkPlanExec( val optimized = optimizer.execute(logicalPlan) val sparkPlan = context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next() val newPlan = applyPhysicalRules( - sparkPlan, + applyQueryPostPlannerStrategyRules(sparkPlan), preprocessingRules ++ queryStagePreparationRules, Some((planChangeLogger, "AQE Replanning"))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 5543b409d1702..3a07dbf5480db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -318,7 +318,8 @@ abstract class BaseSessionStateBuilder( new AdaptiveRulesHolder( extensions.buildQueryStagePrepRules(session), extensions.buildRuntimeOptimizerRules(session), - extensions.buildQueryStageOptimizerRules(session)) + extensions.buildQueryStageOptimizerRules(session), + extensions.buildQueryPostPlannerStrategyRules(session)) } protected def planNormalizationRules: Seq[Rule[LogicalPlan]] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 21518085ca4c5..8b4ac474f8753 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -29,15 +29,17 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIden import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Limit, LocalRelation, LogicalPlan, Statistics, UnresolvedHint} -import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.connector.write.WriterCommitMessage import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec, WriteFilesSpec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector @@ -516,6 +518,31 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt } } } + + test("SPARK-46170: Support inject adaptive query post planner strategy rules in " + + "SparkSessionExtensions") { + val extensions = create { extensions => + extensions.injectQueryPostPlannerStrategyRule(_ => MyQueryPostPlannerStrategyRule) + } + withSession(extensions) { session => + assert(session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules + .contains(MyQueryPostPlannerStrategyRule)) + import session.implicits._ + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false") { + val input = Seq(10, 20, 10).toDF("c1") + val df = input.groupBy("c1").count() + df.collect() + assert(df.rdd.partitions.length == 1) + assert(collectFirst(df.queryExecution.executedPlan) { + case s: ShuffleExchangeExec if s.outputPartitioning == SinglePartition => true + }.isDefined) + assert(collectFirst(df.queryExecution.executedPlan) { + case _: SortExec => true + }.isDefined) + } + } + } } case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] { @@ -1190,3 +1217,14 @@ object RequireAtLeaseTwoPartitions extends Rule[SparkPlan] { } } } + +object MyQueryPostPlannerStrategyRule extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + plan.transformUp { + case h: HashAggregateExec if h.aggregateExpressions.map(_.mode).contains(Partial) => + ShuffleExchangeExec(SinglePartition, h) + case h: HashAggregateExec if h.aggregateExpressions.map(_.mode).contains(Final) => + SortExec(h.groupingExpressions.map(k => SortOrder.apply(k, Ascending)), false, h) + } + } +}