diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
index b7c86ab7de6b4..677dba0082575 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -47,6 +47,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
*
Customized Parser.
* (External) Catalog listeners.
* Columnar Rules.
+ * Adaptive Query Post Planner Strategy Rules.
* Adaptive Query Stage Preparation Rules.
* Adaptive Query Execution Runtime Optimizer Rules.
* Adaptive Query Stage Optimizer Rules.
@@ -112,10 +113,13 @@ class SparkSessionExtensions {
type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder)
type TableFunctionDescription = (FunctionIdentifier, ExpressionInfo, TableFunctionBuilder)
type ColumnarRuleBuilder = SparkSession => ColumnarRule
+ type QueryPostPlannerStrategyBuilder = SparkSession => Rule[SparkPlan]
type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan]
type QueryStageOptimizerRuleBuilder = SparkSession => Rule[SparkPlan]
private[this] val columnarRuleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder]
+ private[this] val queryPostPlannerStrategyRuleBuilders =
+ mutable.Buffer.empty[QueryPostPlannerStrategyBuilder]
private[this] val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder]
private[this] val runtimeOptimizerRules = mutable.Buffer.empty[RuleBuilder]
private[this] val queryStageOptimizerRuleBuilders =
@@ -128,6 +132,14 @@ class SparkSessionExtensions {
columnarRuleBuilders.map(_.apply(session)).toSeq
}
+ /**
+ * Build the override rules for the query post planner strategy phase of adaptive query execution.
+ */
+ private[sql] def buildQueryPostPlannerStrategyRules(
+ session: SparkSession): Seq[Rule[SparkPlan]] = {
+ queryPostPlannerStrategyRuleBuilders.map(_.apply(session)).toSeq
+ }
+
/**
* Build the override rules for the query stage preparation phase of adaptive query execution.
*/
@@ -156,6 +168,15 @@ class SparkSessionExtensions {
columnarRuleBuilders += builder
}
+ /**
+ * Inject a rule that applied between `plannerStrategy` and `queryStagePrepRules`, so
+ * it can get the whole plan before injecting exchanges.
+ * Note, these rules can only be applied within AQE.
+ */
+ def injectQueryPostPlannerStrategyRule(builder: QueryPostPlannerStrategyBuilder): Unit = {
+ queryPostPlannerStrategyRuleBuilders += builder
+ }
+
/**
* Inject a rule that can override the query stage preparation phase of adaptive query
* execution.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala
index 8391fe44f5598..ee2cd8a4953bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala
@@ -29,9 +29,12 @@ import org.apache.spark.sql.execution.SparkPlan
* query stage
* @param queryStageOptimizerRules applied to a new query stage before its execution. It makes sure
* all children query stages are materialized
+ * @param queryPostPlannerStrategyRules applied between `plannerStrategy` and `queryStagePrepRules`,
+ * so it can get the whole plan before injecting exchanges.
*/
class AdaptiveRulesHolder(
val queryStagePrepRules: Seq[Rule[SparkPlan]],
val runtimeOptimizerRules: Seq[Rule[LogicalPlan]],
- val queryStageOptimizerRules: Seq[Rule[SparkPlan]]) {
+ val queryStageOptimizerRules: Seq[Rule[SparkPlan]],
+ val queryPostPlannerStrategyRules: Seq[Rule[SparkPlan]]) {
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index fa671c8faf8b3..96b83a91cc739 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -193,9 +193,19 @@ case class AdaptiveSparkPlanExec(
optimized
}
+ private def applyQueryPostPlannerStrategyRules(plan: SparkPlan): SparkPlan = {
+ applyPhysicalRules(
+ plan,
+ context.session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules,
+ Some((planChangeLogger, "AQE Query Post Planner Strategy Rules"))
+ )
+ }
+
@transient val initialPlan = context.session.withActive {
applyPhysicalRules(
- inputPlan, queryStagePreparationRules, Some((planChangeLogger, "AQE Preparations")))
+ applyQueryPostPlannerStrategyRules(inputPlan),
+ queryStagePreparationRules,
+ Some((planChangeLogger, "AQE Preparations")))
}
@volatile private var currentPhysicalPlan = initialPlan
@@ -706,7 +716,7 @@ case class AdaptiveSparkPlanExec(
val optimized = optimizer.execute(logicalPlan)
val sparkPlan = context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next()
val newPlan = applyPhysicalRules(
- sparkPlan,
+ applyQueryPostPlannerStrategyRules(sparkPlan),
preprocessingRules ++ queryStagePreparationRules,
Some((planChangeLogger, "AQE Replanning")))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 5543b409d1702..3a07dbf5480db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -318,7 +318,8 @@ abstract class BaseSessionStateBuilder(
new AdaptiveRulesHolder(
extensions.buildQueryStagePrepRules(session),
extensions.buildRuntimeOptimizerRules(session),
- extensions.buildQueryStageOptimizerRules(session))
+ extensions.buildQueryStageOptimizerRules(session),
+ extensions.buildQueryPostPlannerStrategyRules(session))
}
protected def planNormalizationRules: Seq[Rule[LogicalPlan]] = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 21518085ca4c5..8b4ac474f8753 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -29,15 +29,17 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIden
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Limit, LocalRelation, LogicalPlan, Statistics, UnresolvedHint}
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.connector.write.WriterCommitMessage
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec}
+import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec, WriteFilesSpec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
@@ -516,6 +518,31 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
}
}
}
+
+ test("SPARK-46170: Support inject adaptive query post planner strategy rules in " +
+ "SparkSessionExtensions") {
+ val extensions = create { extensions =>
+ extensions.injectQueryPostPlannerStrategyRule(_ => MyQueryPostPlannerStrategyRule)
+ }
+ withSession(extensions) { session =>
+ assert(session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules
+ .contains(MyQueryPostPlannerStrategyRule))
+ import session.implicits._
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+ SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false") {
+ val input = Seq(10, 20, 10).toDF("c1")
+ val df = input.groupBy("c1").count()
+ df.collect()
+ assert(df.rdd.partitions.length == 1)
+ assert(collectFirst(df.queryExecution.executedPlan) {
+ case s: ShuffleExchangeExec if s.outputPartitioning == SinglePartition => true
+ }.isDefined)
+ assert(collectFirst(df.queryExecution.executedPlan) {
+ case _: SortExec => true
+ }.isDefined)
+ }
+ }
+ }
}
case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
@@ -1190,3 +1217,14 @@ object RequireAtLeaseTwoPartitions extends Rule[SparkPlan] {
}
}
}
+
+object MyQueryPostPlannerStrategyRule extends Rule[SparkPlan] {
+ override def apply(plan: SparkPlan): SparkPlan = {
+ plan.transformUp {
+ case h: HashAggregateExec if h.aggregateExpressions.map(_.mode).contains(Partial) =>
+ ShuffleExchangeExec(SinglePartition, h)
+ case h: HashAggregateExec if h.aggregateExpressions.map(_.mode).contains(Final) =>
+ SortExec(h.groupingExpressions.map(k => SortOrder.apply(k, Ascending)), false, h)
+ }
+ }
+}