From 8f60e453a6490f509357ebaae6b3aa51a27759db Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Mon, 4 Dec 2023 14:12:44 +0800 Subject: [PATCH 1/2] Support plan fragment level SQL configs in AQE --- .../adaptive/AdaptiveRuleContext.scala | 88 +++++++++ .../adaptive/AdaptiveSparkPlanExec.scala | 42 ++++- .../adaptive/AdaptiveRuleContextSuite.scala | 176 ++++++++++++++++++ 3 files changed, 298 insertions(+), 8 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContextSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala new file mode 100644 index 0000000000000..709fa0c17872e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.SQLConfHelper + +/** + * Provide the functionality to modify the next plan fragment configs in AQE rules. + * The configs will be cleanup before going to execute next plan fragment. + * To get instance, use: {{{ AdaptiveRuleContext.get() }}} + * + * @param isSubquery if the input query plan is subquery + * @param isFinalStage if the next stage is final stage + */ +@Experimental +case class AdaptiveRuleContext(isSubquery: Boolean, isFinalStage: Boolean) { + + /** + * Set SQL configs for next plan fragment. The configs will affect all of rules in AQE, + * i.e., the runtime optimizer, planner, queryStagePreparationRules, queryStageOptimizerRules, + * columnarRules. + * This configs will be cleared before going to get the next plan fragment. + */ + private val nextPlanFragmentConf = new mutable.HashMap[String, String]() + + private[sql] def withFinalStage(isFinalStage: Boolean): AdaptiveRuleContext = { + if (this.isFinalStage == isFinalStage) { + this + } else { + val newRuleContext = copy(isFinalStage = isFinalStage) + newRuleContext.setConfigs(this.configs()) + newRuleContext + } + } + + def setConfig(key: String, value: String): Unit = { + nextPlanFragmentConf.put(key, value) + } + + def setConfigs(kvs: Map[String, String]): Unit = { + kvs.foreach(kv => nextPlanFragmentConf.put(kv._1, kv._2)) + } + + private[sql] def configs(): Map[String, String] = nextPlanFragmentConf.toMap + + private[sql] def clearConfigs(): Unit = nextPlanFragmentConf.clear() +} + +object AdaptiveRuleContext extends SQLConfHelper { + private val ruleContextThreadLocal = new ThreadLocal[AdaptiveRuleContext] + + /** + * If a rule is applied inside AQE then the returned value is always defined, else return None. + */ + def get(): Option[AdaptiveRuleContext] = Option(ruleContextThreadLocal.get()) + + private[sql] def withRuleContext[T](ruleContext: AdaptiveRuleContext)(block: => T): T = { + assert(ruleContext != null) + val origin = ruleContextThreadLocal.get() + ruleContextThreadLocal.set(ruleContext) + try { + val conf = ruleContext.configs() + withSQLConf(conf.toSeq: _*) { + block + } + } finally { + ruleContextThreadLocal.set(origin) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index f30ffaf515664..f21960aeedd64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -85,6 +85,25 @@ case class AdaptiveSparkPlanExec( case _ => logDebug(_) } + @transient private var ruleContext = new AdaptiveRuleContext( + isSubquery = isSubquery, + isFinalStage = false) + + private def withRuleContext[T](f: => T): T = + AdaptiveRuleContext.withRuleContext(ruleContext) { f } + + private def applyPhysicalRulesWithRuleContext( + plan: => SparkPlan, + rules: Seq[Rule[SparkPlan]], + loggerAndBatchName: Option[(PlanChangeLogger[SparkPlan], String)] = None): SparkPlan = { + // Apply the last rules if exists before going to apply the next batch of rules, + // so that we can propagate the configs. + val newPlan = plan + withRuleContext { + applyPhysicalRules(newPlan, rules, loggerAndBatchName) + } + } + @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]() // The logical plan optimizer for re-optimizing the current logical plan. @@ -161,7 +180,9 @@ case class AdaptiveSparkPlanExec( collapseCodegenStagesRule ) - private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = { + private def optimizeQueryStage( + plan: SparkPlan, + isFinalStage: Boolean): SparkPlan = withRuleContext { val rules = if (isFinalStage && !conf.getConf(SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS)) { queryStageOptimizerRules.filterNot(_.isInstanceOf[AQEShuffleReadRule]) @@ -197,7 +218,7 @@ case class AdaptiveSparkPlanExec( } private def applyQueryPostPlannerStrategyRules(plan: SparkPlan): SparkPlan = { - applyPhysicalRules( + applyPhysicalRulesWithRuleContext( plan, context.session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules, Some((planChangeLogger, "AQE Query Post Planner Strategy Rules")) @@ -205,7 +226,7 @@ case class AdaptiveSparkPlanExec( } @transient val initialPlan = context.session.withActive { - applyPhysicalRules( + applyPhysicalRulesWithRuleContext( applyQueryPostPlannerStrategyRules(inputPlan), queryStagePreparationRules, Some((planChangeLogger, "AQE Preparations"))) @@ -282,6 +303,7 @@ case class AdaptiveSparkPlanExec( val errors = new mutable.ArrayBuffer[Throwable]() var stagesToReplace = Seq.empty[QueryStageExec] while (!result.allChildStagesMaterialized) { + ruleContext.clearConfigs() currentPhysicalPlan = result.newPlan if (result.newStages.nonEmpty) { stagesToReplace = result.newStages ++ stagesToReplace @@ -373,11 +395,13 @@ case class AdaptiveSparkPlanExec( result = createQueryStages(currentPhysicalPlan) } + ruleContext = ruleContext.withFinalStage(isFinalStage = true) // Run the final plan when there's no more unfinished stages. - currentPhysicalPlan = applyPhysicalRules( + currentPhysicalPlan = applyPhysicalRulesWithRuleContext( optimizeQueryStage(result.newPlan, isFinalStage = true), postStageCreationRules(supportsColumnar), Some((planChangeLogger, "AQE Post Stage Creation"))) + ruleContext.clearConfigs() _isFinalPlan = true executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan))) currentPhysicalPlan @@ -595,7 +619,7 @@ case class AdaptiveSparkPlanExec( val queryStage = plan match { case e: Exchange => val optimized = e.withNewChildren(Seq(optimizeQueryStage(e.child, isFinalStage = false))) - val newPlan = applyPhysicalRules( + val newPlan = applyPhysicalRulesWithRuleContext( optimized, postStageCreationRules(outputsColumnar = plan.supportsColumnar), Some((planChangeLogger, "AQE Post Stage Creation"))) @@ -722,9 +746,11 @@ case class AdaptiveSparkPlanExec( private def reOptimize(logicalPlan: LogicalPlan): Option[(SparkPlan, LogicalPlan)] = { try { logicalPlan.invalidateStatsCache() - val optimized = optimizer.execute(logicalPlan) - val sparkPlan = context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next() - val newPlan = applyPhysicalRules( + val optimized = withRuleContext { optimizer.execute(logicalPlan) } + val sparkPlan = withRuleContext { + context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next() + } + val newPlan = applyPhysicalRulesWithRuleContext( applyQueryPostPlannerStrategyRules(sparkPlan), preprocessingRules ++ queryStagePreparationRules, Some((planChangeLogger, "AQE Replanning"))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContextSuite.scala new file mode 100644 index 0000000000000..04c9e6c946b45 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContextSuite.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{SparkSession, SparkSessionExtensionsProvider} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ColumnarRule, RangeExec, SparkPlan, SparkStrategy} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec + +class AdaptiveRuleContextSuite extends SparkFunSuite with AdaptiveSparkPlanHelper { + + private def stop(spark: SparkSession): Unit = { + spark.stop() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + + private def withSession( + builders: Seq[SparkSessionExtensionsProvider])(f: SparkSession => Unit): Unit = { + val builder = SparkSession.builder().master("local[1]") + builders.foreach(builder.withExtensions) + val spark = builder.getOrCreate() + try f(spark) finally { + stop(spark) + } + } + + test("test adaptive rule context") { + withSession( + Seq(_.injectRuntimeOptimizerRule(_ => MyRuleContextForRuntimeOptimization), + _.injectPlannerStrategy(_ => MyRuleContextForPlannerStrategy), + _.injectQueryPostPlannerStrategyRule(_ => MyRuleContextForPostPlannerStrategyRule), + _.injectQueryStagePrepRule(_ => MyRuleContextForPreQueryStageRule), + _.injectQueryStageOptimizerRule(_ => MyRuleContextForQueryStageRule), + _.injectColumnar(_ => MyRuleContextForColumnarRule))) { spark => + val df = spark.range(1, 10, 1, 3).selectExpr("id % 3 as c").groupBy("c").count() + df.collect() + assert(collectFirst(df.queryExecution.executedPlan) { + case s: ShuffleExchangeExec if s.numPartitions == 2 => s + }.isDefined) + } + } + + test("test adaptive rule context with subquery") { + withSession( + Seq(_.injectQueryStagePrepRule(_ => MyRuleContextForQueryStageWithSubquery))) { spark => + spark.sql("select (select count(*) from range(10)), id from range(10)").collect() + } + } +} + +object MyRuleContext { + def checkAndGetRuleContext(): AdaptiveRuleContext = { + val ruleContextOpt = AdaptiveRuleContext.get() + assert(ruleContextOpt.isDefined) + ruleContextOpt.get + } + + def checkRuleContextForQueryStage(plan: SparkPlan): SparkPlan = { + val ruleContext = checkAndGetRuleContext() + assert(!ruleContext.isSubquery) + val stage = plan.find(_.isInstanceOf[ShuffleQueryStageExec]) + if (stage.isDefined && stage.get.asInstanceOf[ShuffleQueryStageExec].isMaterialized) { + assert(ruleContext.isFinalStage) + assert(!ruleContext.configs().get("spark.sql.shuffle.partitions").contains("2")) + } else { + assert(!ruleContext.isFinalStage) + assert(ruleContext.configs().get("spark.sql.shuffle.partitions").contains("2")) + } + plan + } +} + +object MyRuleContextForRuntimeOptimization extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + MyRuleContext.checkAndGetRuleContext() + plan + } +} + +object MyRuleContextForPlannerStrategy extends SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = { + plan match { + case _: LogicalQueryStage => + val ruleContext = MyRuleContext.checkAndGetRuleContext() + assert(!ruleContext.configs().get("spark.sql.shuffle.partitions").contains("2")) + Nil + case _ => Nil + } + } +} + +object MyRuleContextForPostPlannerStrategyRule extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + val ruleContext = MyRuleContext.checkAndGetRuleContext() + if (plan.find(_.isInstanceOf[RangeExec]).isDefined) { + ruleContext.setConfig("spark.sql.shuffle.partitions", "2") + } + plan + } +} + +object MyRuleContextForPreQueryStageRule extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + val ruleContext = MyRuleContext.checkAndGetRuleContext() + assert(!ruleContext.isFinalStage) + plan + } +} + +object MyRuleContextForQueryStageRule extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + MyRuleContext.checkRuleContextForQueryStage(plan) + } +} + +object MyRuleContextForColumnarRule extends ColumnarRule { + override def preColumnarTransitions: Rule[SparkPlan] = { + plan: SparkPlan => { + if (plan.isInstanceOf[AdaptiveSparkPlanExec]) { + // skip if we are not inside AQE + assert(AdaptiveRuleContext.get().isEmpty) + plan + } else { + MyRuleContext.checkRuleContextForQueryStage(plan) + } + } + } + + override def postColumnarTransitions: Rule[SparkPlan] = { + plan: SparkPlan => { + if (plan.isInstanceOf[AdaptiveSparkPlanExec]) { + // skip if we are not inside AQE + assert(AdaptiveRuleContext.get().isEmpty) + plan + } else { + MyRuleContext.checkRuleContextForQueryStage(plan) + } + } + } +} + +object MyRuleContextForQueryStageWithSubquery extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + val ruleContext = MyRuleContext.checkAndGetRuleContext() + if (plan.exists(_.isInstanceOf[HashAggregateExec])) { + assert(ruleContext.isSubquery) + if (plan.exists(_.isInstanceOf[RangeExec])) { + assert(!ruleContext.isFinalStage) + } else { + assert(ruleContext.isFinalStage) + } + } else { + assert(!ruleContext.isSubquery) + } + plan + } +} From a6cfa722ba80d66b7c174a80628d15846ce8df1f Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 24 May 2024 14:22:28 +0800 Subject: [PATCH 2/2] Update sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala --- .../spark/sql/execution/adaptive/AdaptiveRuleContext.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala index 709fa0c17872e..fce20b79e1136 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRuleContext.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper * @param isFinalStage if the next stage is final stage */ @Experimental +@DeveloperApi case class AdaptiveRuleContext(isSubquery: Boolean, isFinalStage: Boolean) { /**