From 7ef5e30447087de626882e9ee3460a835f455e24 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Wed, 29 Nov 2023 15:31:28 +0800 Subject: [PATCH 1/2] Support inject adaptive query post planner strategy rules in SparkSessionExtensions --- .../spark/sql/SparkSessionExtensions.scala | 21 ++++++++++ .../adaptive/AdaptiveRulesHolder.scala | 5 ++- .../adaptive/AdaptiveSparkPlanExec.scala | 14 ++++++- .../internal/BaseSessionStateBuilder.scala | 3 +- .../sql/SparkSessionExtensionSuite.scala | 42 ++++++++++++++++++- 5 files changed, 80 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index b7c86ab7de6b..fb8cd3e92797 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan} *
  • Customized Parser.
  • *
  • (External) Catalog listeners.
  • *
  • Columnar Rules.
  • + *
  • Adaptive Query Post Planner Strategy Rules.
  • *
  • Adaptive Query Stage Preparation Rules.
  • *
  • Adaptive Query Execution Runtime Optimizer Rules.
  • *
  • Adaptive Query Stage Optimizer Rules.
  • @@ -114,12 +115,15 @@ class SparkSessionExtensions { type ColumnarRuleBuilder = SparkSession => ColumnarRule type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan] type QueryStageOptimizerRuleBuilder = SparkSession => Rule[SparkPlan] + type QueryPostPlannerStrategyBuilder = SparkSession => Rule[SparkPlan] private[this] val columnarRuleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] private[this] val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder] private[this] val runtimeOptimizerRules = mutable.Buffer.empty[RuleBuilder] private[this] val queryStageOptimizerRuleBuilders = mutable.Buffer.empty[QueryStageOptimizerRuleBuilder] + private[this] val queryPostPlannerStrategyRuleBuilders = + mutable.Buffer.empty[QueryPostPlannerStrategyBuilder] /** * Build the override rules for columnar execution. @@ -149,6 +153,14 @@ class SparkSessionExtensions { queryStageOptimizerRuleBuilders.map(_.apply(session)).toSeq } + /** + * Build the override rules for the query post planner strategy phase of adaptive query execution. + */ + private[sql] def buildQueryPostPlannerStrategyRules( + session: SparkSession): Seq[Rule[SparkPlan]] = { + queryPostPlannerStrategyRuleBuilders.map(_.apply(session)).toSeq + } + /** * Inject a rule that can override the columnar execution of an executor. */ @@ -185,6 +197,15 @@ class SparkSessionExtensions { queryStageOptimizerRuleBuilders += builder } + /** + * Inject a rule that applied between `plannerStrategy` and `queryStagePrepRules`, so + * it can get the whole plan before injecting exchanges. + * Note, these rules can only be applied within AQE. + */ + def injectQueryPostPlannerStrategyRule(builder: QueryPostPlannerStrategyBuilder): Unit = { + queryPostPlannerStrategyRuleBuilders += builder + } + private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala index 8391fe44f559..ee2cd8a4953b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala @@ -29,9 +29,12 @@ import org.apache.spark.sql.execution.SparkPlan * query stage * @param queryStageOptimizerRules applied to a new query stage before its execution. It makes sure * all children query stages are materialized + * @param queryPostPlannerStrategyRules applied between `plannerStrategy` and `queryStagePrepRules`, + * so it can get the whole plan before injecting exchanges. */ class AdaptiveRulesHolder( val queryStagePrepRules: Seq[Rule[SparkPlan]], val runtimeOptimizerRules: Seq[Rule[LogicalPlan]], - val queryStageOptimizerRules: Seq[Rule[SparkPlan]]) { + val queryStageOptimizerRules: Seq[Rule[SparkPlan]], + val queryPostPlannerStrategyRules: Seq[Rule[SparkPlan]]) { } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 87499b6afef9..b2a1141f2f5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -193,9 +193,19 @@ case class AdaptiveSparkPlanExec( optimized } + private def applyQueryPostPlannerStrategyRules(plan: SparkPlan): SparkPlan = { + applyPhysicalRules( + plan, + context.session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules, + Some((planChangeLogger, "AQE Query Post Planner Strategy Rules")) + ) + } + @transient val initialPlan = context.session.withActive { applyPhysicalRules( - inputPlan, queryStagePreparationRules, Some((planChangeLogger, "AQE Preparations"))) + applyQueryPostPlannerStrategyRules(inputPlan), + queryStagePreparationRules, + Some((planChangeLogger, "AQE Preparations"))) } @volatile private var currentPhysicalPlan = initialPlan @@ -706,7 +716,7 @@ case class AdaptiveSparkPlanExec( val optimized = optimizer.execute(logicalPlan) val sparkPlan = context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next() val newPlan = applyPhysicalRules( - sparkPlan, + applyQueryPostPlannerStrategyRules(sparkPlan), preprocessingRules ++ queryStagePreparationRules, Some((planChangeLogger, "AQE Replanning"))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index d198e8f5d1f2..00c72294ca07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -332,7 +332,8 @@ abstract class BaseSessionStateBuilder( new AdaptiveRulesHolder( extensions.buildQueryStagePrepRules(session), extensions.buildRuntimeOptimizerRules(session), - extensions.buildQueryStageOptimizerRules(session)) + extensions.buildQueryStageOptimizerRules(session), + extensions.buildQueryPostPlannerStrategyRules(session)) } protected def planNormalizationRules: Seq[Rule[LogicalPlan]] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index c1b5d2761f7b..b80183e43af6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -29,15 +29,17 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIden import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Limit, LocalRelation, LogicalPlan, Statistics, UnresolvedHint} -import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.connector.write.WriterCommitMessage import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec, WriteFilesSpec} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector @@ -516,6 +518,33 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt } } } + + test("SPARK-46170: Support inject adaptive query post planner strategy rules in " + + "SparkSessionExtensions") { + val extensions = create { extensions => + extensions.injectQueryPostPlannerStrategyRule(_ => MyQueryPostPlannerStrategyRule) + } + withSession(extensions) { session => + assert(session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules + .contains(MyQueryPostPlannerStrategyRule)) + import session.sqlContext.implicits._ + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3", + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false") { + val input = Seq((10), (20), (10)).toDF("c1") + val df = input.groupBy("c1").count() + df.collect() + assert(df.rdd.partitions.length == 1) + assert(find(df.queryExecution.executedPlan) { + case s: ShuffleExchangeExec if s.outputPartitioning == SinglePartition => true + case _ => false + }.isDefined) + assert(find(df.queryExecution.executedPlan) { + case _: SortExec => true + case _ => false + }.isDefined) + } + } + } } case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] { @@ -1190,3 +1219,14 @@ object RequireAtLeaseTwoPartitions extends Rule[SparkPlan] { } } } + +object MyQueryPostPlannerStrategyRule extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + plan.transformUp { + case h: HashAggregateExec if h.aggregateExpressions.map(_.mode).contains(Partial) => + ShuffleExchangeExec(SinglePartition, h) + case h: HashAggregateExec if h.aggregateExpressions.map(_.mode).contains(Final) => + SortExec(h.groupingExpressions.map(k => SortOrder.apply(k, Ascending)), false, h) + } + } +} From 93a727ab81976b4bc73463b08435e5fae676e01a Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Thu, 30 Nov 2023 10:11:25 +0800 Subject: [PATCH 2/2] address comments --- .../spark/sql/SparkSessionExtensions.scala | 38 +++++++++---------- .../sql/SparkSessionExtensionSuite.scala | 10 ++--- 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index fb8cd3e92797..677dba008257 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -113,17 +113,17 @@ class SparkSessionExtensions { type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder) type TableFunctionDescription = (FunctionIdentifier, ExpressionInfo, TableFunctionBuilder) type ColumnarRuleBuilder = SparkSession => ColumnarRule + type QueryPostPlannerStrategyBuilder = SparkSession => Rule[SparkPlan] type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan] type QueryStageOptimizerRuleBuilder = SparkSession => Rule[SparkPlan] - type QueryPostPlannerStrategyBuilder = SparkSession => Rule[SparkPlan] private[this] val columnarRuleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder] + private[this] val queryPostPlannerStrategyRuleBuilders = + mutable.Buffer.empty[QueryPostPlannerStrategyBuilder] private[this] val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder] private[this] val runtimeOptimizerRules = mutable.Buffer.empty[RuleBuilder] private[this] val queryStageOptimizerRuleBuilders = mutable.Buffer.empty[QueryStageOptimizerRuleBuilder] - private[this] val queryPostPlannerStrategyRuleBuilders = - mutable.Buffer.empty[QueryPostPlannerStrategyBuilder] /** * Build the override rules for columnar execution. @@ -132,6 +132,14 @@ class SparkSessionExtensions { columnarRuleBuilders.map(_.apply(session)).toSeq } + /** + * Build the override rules for the query post planner strategy phase of adaptive query execution. + */ + private[sql] def buildQueryPostPlannerStrategyRules( + session: SparkSession): Seq[Rule[SparkPlan]] = { + queryPostPlannerStrategyRuleBuilders.map(_.apply(session)).toSeq + } + /** * Build the override rules for the query stage preparation phase of adaptive query execution. */ @@ -154,18 +162,19 @@ class SparkSessionExtensions { } /** - * Build the override rules for the query post planner strategy phase of adaptive query execution. + * Inject a rule that can override the columnar execution of an executor. */ - private[sql] def buildQueryPostPlannerStrategyRules( - session: SparkSession): Seq[Rule[SparkPlan]] = { - queryPostPlannerStrategyRuleBuilders.map(_.apply(session)).toSeq + def injectColumnar(builder: ColumnarRuleBuilder): Unit = { + columnarRuleBuilders += builder } /** - * Inject a rule that can override the columnar execution of an executor. + * Inject a rule that applied between `plannerStrategy` and `queryStagePrepRules`, so + * it can get the whole plan before injecting exchanges. + * Note, these rules can only be applied within AQE. */ - def injectColumnar(builder: ColumnarRuleBuilder): Unit = { - columnarRuleBuilders += builder + def injectQueryPostPlannerStrategyRule(builder: QueryPostPlannerStrategyBuilder): Unit = { + queryPostPlannerStrategyRuleBuilders += builder } /** @@ -197,15 +206,6 @@ class SparkSessionExtensions { queryStageOptimizerRuleBuilders += builder } - /** - * Inject a rule that applied between `plannerStrategy` and `queryStagePrepRules`, so - * it can get the whole plan before injecting exchanges. - * Note, these rules can only be applied within AQE. - */ - def injectQueryPostPlannerStrategyRule(builder: QueryPostPlannerStrategyBuilder): Unit = { - queryPostPlannerStrategyRuleBuilders += builder - } - private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index b80183e43af6..18c1f4dcc4e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -527,20 +527,18 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt withSession(extensions) { session => assert(session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules .contains(MyQueryPostPlannerStrategyRule)) - import session.sqlContext.implicits._ + import session.implicits._ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3", SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false") { - val input = Seq((10), (20), (10)).toDF("c1") + val input = Seq(10, 20, 10).toDF("c1") val df = input.groupBy("c1").count() df.collect() assert(df.rdd.partitions.length == 1) - assert(find(df.queryExecution.executedPlan) { + assert(collectFirst(df.queryExecution.executedPlan) { case s: ShuffleExchangeExec if s.outputPartitioning == SinglePartition => true - case _ => false }.isDefined) - assert(find(df.queryExecution.executedPlan) { + assert(collectFirst(df.queryExecution.executedPlan) { case _: SortExec => true - case _ => false }.isDefined) } }