Skip to content

Commit 1b36e3c

Browse files
ulysses-youyaooqinn
authored andcommitted
[SPARK-46170][SQL][3.5] Support inject adaptive query post planner strategy rules in SparkSessionExtensions
This pr is backport #44074 for branch-3.5 since 3.5 is a lts version ### What changes were proposed in this pull request? This pr adds a new extension entrance `queryPostPlannerStrategyRules` in `SparkSessionExtensions`. It will be applied between plannerStrategy and queryStagePrepRules in AQE, so it can get the whole plan before injecting exchanges. ### Why are the changes needed? 3.5 is a lts version ### Does this PR introduce _any_ user-facing change? no, only for develop ### How was this patch tested? add test ### Was this patch authored or co-authored using generative AI tooling? no Closes #44074 from ulysses-you/post-planner. Authored-by: ulysses-you <ulyssesyou18gmail.com> Closes #45037 from ulysses-you/SPARK-46170. Authored-by: ulysses-you <[email protected]> Signed-off-by: Kent Yao <[email protected]>
1 parent 3f426b5 commit 1b36e3c

File tree

5 files changed

+78
-5
lines changed

5 files changed

+78
-5
lines changed

sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
4747
* <li>Customized Parser.</li>
4848
* <li>(External) Catalog listeners.</li>
4949
* <li>Columnar Rules.</li>
50+
* <li>Adaptive Query Post Planner Strategy Rules.</li>
5051
* <li>Adaptive Query Stage Preparation Rules.</li>
5152
* <li>Adaptive Query Execution Runtime Optimizer Rules.</li>
5253
* <li>Adaptive Query Stage Optimizer Rules.</li>
@@ -112,10 +113,13 @@ class SparkSessionExtensions {
112113
type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder)
113114
type TableFunctionDescription = (FunctionIdentifier, ExpressionInfo, TableFunctionBuilder)
114115
type ColumnarRuleBuilder = SparkSession => ColumnarRule
116+
type QueryPostPlannerStrategyBuilder = SparkSession => Rule[SparkPlan]
115117
type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan]
116118
type QueryStageOptimizerRuleBuilder = SparkSession => Rule[SparkPlan]
117119

118120
private[this] val columnarRuleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder]
121+
private[this] val queryPostPlannerStrategyRuleBuilders =
122+
mutable.Buffer.empty[QueryPostPlannerStrategyBuilder]
119123
private[this] val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder]
120124
private[this] val runtimeOptimizerRules = mutable.Buffer.empty[RuleBuilder]
121125
private[this] val queryStageOptimizerRuleBuilders =
@@ -128,6 +132,14 @@ class SparkSessionExtensions {
128132
columnarRuleBuilders.map(_.apply(session)).toSeq
129133
}
130134

135+
/**
136+
* Build the override rules for the query post planner strategy phase of adaptive query execution.
137+
*/
138+
private[sql] def buildQueryPostPlannerStrategyRules(
139+
session: SparkSession): Seq[Rule[SparkPlan]] = {
140+
queryPostPlannerStrategyRuleBuilders.map(_.apply(session)).toSeq
141+
}
142+
131143
/**
132144
* Build the override rules for the query stage preparation phase of adaptive query execution.
133145
*/
@@ -156,6 +168,15 @@ class SparkSessionExtensions {
156168
columnarRuleBuilders += builder
157169
}
158170

171+
/**
172+
* Inject a rule that applied between `plannerStrategy` and `queryStagePrepRules`, so
173+
* it can get the whole plan before injecting exchanges.
174+
* Note, these rules can only be applied within AQE.
175+
*/
176+
def injectQueryPostPlannerStrategyRule(builder: QueryPostPlannerStrategyBuilder): Unit = {
177+
queryPostPlannerStrategyRuleBuilders += builder
178+
}
179+
159180
/**
160181
* Inject a rule that can override the query stage preparation phase of adaptive query
161182
* execution.

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveRulesHolder.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@ import org.apache.spark.sql.execution.SparkPlan
2929
* query stage
3030
* @param queryStageOptimizerRules applied to a new query stage before its execution. It makes sure
3131
* all children query stages are materialized
32+
* @param queryPostPlannerStrategyRules applied between `plannerStrategy` and `queryStagePrepRules`,
33+
* so it can get the whole plan before injecting exchanges.
3234
*/
3335
class AdaptiveRulesHolder(
3436
val queryStagePrepRules: Seq[Rule[SparkPlan]],
3537
val runtimeOptimizerRules: Seq[Rule[LogicalPlan]],
36-
val queryStageOptimizerRules: Seq[Rule[SparkPlan]]) {
38+
val queryStageOptimizerRules: Seq[Rule[SparkPlan]],
39+
val queryPostPlannerStrategyRules: Seq[Rule[SparkPlan]]) {
3740
}

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,19 @@ case class AdaptiveSparkPlanExec(
193193
optimized
194194
}
195195

196+
private def applyQueryPostPlannerStrategyRules(plan: SparkPlan): SparkPlan = {
197+
applyPhysicalRules(
198+
plan,
199+
context.session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules,
200+
Some((planChangeLogger, "AQE Query Post Planner Strategy Rules"))
201+
)
202+
}
203+
196204
@transient val initialPlan = context.session.withActive {
197205
applyPhysicalRules(
198-
inputPlan, queryStagePreparationRules, Some((planChangeLogger, "AQE Preparations")))
206+
applyQueryPostPlannerStrategyRules(inputPlan),
207+
queryStagePreparationRules,
208+
Some((planChangeLogger, "AQE Preparations")))
199209
}
200210

201211
@volatile private var currentPhysicalPlan = initialPlan
@@ -706,7 +716,7 @@ case class AdaptiveSparkPlanExec(
706716
val optimized = optimizer.execute(logicalPlan)
707717
val sparkPlan = context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next()
708718
val newPlan = applyPhysicalRules(
709-
sparkPlan,
719+
applyQueryPostPlannerStrategyRules(sparkPlan),
710720
preprocessingRules ++ queryStagePreparationRules,
711721
Some((planChangeLogger, "AQE Replanning")))
712722

sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,8 @@ abstract class BaseSessionStateBuilder(
318318
new AdaptiveRulesHolder(
319319
extensions.buildQueryStagePrepRules(session),
320320
extensions.buildRuntimeOptimizerRules(session),
321-
extensions.buildQueryStageOptimizerRules(session))
321+
extensions.buildQueryStageOptimizerRules(session),
322+
extensions.buildQueryPostPlannerStrategyRules(session))
322323
}
323324

324325
protected def planNormalizationRules: Seq[Rule[LogicalPlan]] = {

sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,17 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIden
2929
import org.apache.spark.sql.catalyst.catalog.BucketSpec
3030
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
3131
import org.apache.spark.sql.catalyst.expressions._
32+
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
3233
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
3334
import org.apache.spark.sql.catalyst.plans.SQLHelper
3435
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Limit, LocalRelation, LogicalPlan, Statistics, UnresolvedHint}
35-
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
36+
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition}
3637
import org.apache.spark.sql.catalyst.rules.Rule
3738
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
3839
import org.apache.spark.sql.connector.write.WriterCommitMessage
3940
import org.apache.spark.sql.execution._
4041
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec}
42+
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
4143
import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec, WriteFilesSpec}
4244
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
4345
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
@@ -516,6 +518,31 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
516518
}
517519
}
518520
}
521+
522+
test("SPARK-46170: Support inject adaptive query post planner strategy rules in " +
523+
"SparkSessionExtensions") {
524+
val extensions = create { extensions =>
525+
extensions.injectQueryPostPlannerStrategyRule(_ => MyQueryPostPlannerStrategyRule)
526+
}
527+
withSession(extensions) { session =>
528+
assert(session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules
529+
.contains(MyQueryPostPlannerStrategyRule))
530+
import session.implicits._
531+
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3",
532+
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false") {
533+
val input = Seq(10, 20, 10).toDF("c1")
534+
val df = input.groupBy("c1").count()
535+
df.collect()
536+
assert(df.rdd.partitions.length == 1)
537+
assert(collectFirst(df.queryExecution.executedPlan) {
538+
case s: ShuffleExchangeExec if s.outputPartitioning == SinglePartition => true
539+
}.isDefined)
540+
assert(collectFirst(df.queryExecution.executedPlan) {
541+
case _: SortExec => true
542+
}.isDefined)
543+
}
544+
}
545+
}
519546
}
520547

521548
case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
@@ -1190,3 +1217,14 @@ object RequireAtLeaseTwoPartitions extends Rule[SparkPlan] {
11901217
}
11911218
}
11921219
}
1220+
1221+
object MyQueryPostPlannerStrategyRule extends Rule[SparkPlan] {
1222+
override def apply(plan: SparkPlan): SparkPlan = {
1223+
plan.transformUp {
1224+
case h: HashAggregateExec if h.aggregateExpressions.map(_.mode).contains(Partial) =>
1225+
ShuffleExchangeExec(SinglePartition, h)
1226+
case h: HashAggregateExec if h.aggregateExpressions.map(_.mode).contains(Final) =>
1227+
SortExec(h.groupingExpressions.map(k => SortOrder.apply(k, Ascending)), false, h)
1228+
}
1229+
}
1230+
}

0 commit comments

Comments
 (0)