Skip to content

Commit 8f60e45

Browse files
committed
Support plan fragment level SQL configs in AQE
1 parent da78949 commit 8f60e45

File tree

3 files changed

+298
-8
lines changed

3 files changed

+298
-8
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.adaptive
19+
20+
import scala.collection.mutable
21+
22+
import org.apache.spark.annotation.Experimental
23+
import org.apache.spark.sql.catalyst.SQLConfHelper
24+
25+
/**
26+
* Provide the functionality to modify the next plan fragment configs in AQE rules.
27+
* The configs will be cleanup before going to execute next plan fragment.
28+
* To get instance, use: {{{ AdaptiveRuleContext.get() }}}
29+
*
30+
* @param isSubquery if the input query plan is subquery
31+
* @param isFinalStage if the next stage is final stage
32+
*/
33+
@Experimental
34+
case class AdaptiveRuleContext(isSubquery: Boolean, isFinalStage: Boolean) {
35+
36+
/**
37+
* Set SQL configs for next plan fragment. The configs will affect all of rules in AQE,
38+
* i.e., the runtime optimizer, planner, queryStagePreparationRules, queryStageOptimizerRules,
39+
* columnarRules.
40+
* This configs will be cleared before going to get the next plan fragment.
41+
*/
42+
private val nextPlanFragmentConf = new mutable.HashMap[String, String]()
43+
44+
private[sql] def withFinalStage(isFinalStage: Boolean): AdaptiveRuleContext = {
45+
if (this.isFinalStage == isFinalStage) {
46+
this
47+
} else {
48+
val newRuleContext = copy(isFinalStage = isFinalStage)
49+
newRuleContext.setConfigs(this.configs())
50+
newRuleContext
51+
}
52+
}
53+
54+
def setConfig(key: String, value: String): Unit = {
55+
nextPlanFragmentConf.put(key, value)
56+
}
57+
58+
def setConfigs(kvs: Map[String, String]): Unit = {
59+
kvs.foreach(kv => nextPlanFragmentConf.put(kv._1, kv._2))
60+
}
61+
62+
private[sql] def configs(): Map[String, String] = nextPlanFragmentConf.toMap
63+
64+
private[sql] def clearConfigs(): Unit = nextPlanFragmentConf.clear()
65+
}
66+
67+
object AdaptiveRuleContext extends SQLConfHelper {
68+
private val ruleContextThreadLocal = new ThreadLocal[AdaptiveRuleContext]
69+
70+
/**
71+
* If a rule is applied inside AQE then the returned value is always defined, else return None.
72+
*/
73+
def get(): Option[AdaptiveRuleContext] = Option(ruleContextThreadLocal.get())
74+
75+
private[sql] def withRuleContext[T](ruleContext: AdaptiveRuleContext)(block: => T): T = {
76+
assert(ruleContext != null)
77+
val origin = ruleContextThreadLocal.get()
78+
ruleContextThreadLocal.set(ruleContext)
79+
try {
80+
val conf = ruleContext.configs()
81+
withSQLConf(conf.toSeq: _*) {
82+
block
83+
}
84+
} finally {
85+
ruleContextThreadLocal.set(origin)
86+
}
87+
}
88+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,25 @@ case class AdaptiveSparkPlanExec(
8585
case _ => logDebug(_)
8686
}
8787

88+
@transient private var ruleContext = new AdaptiveRuleContext(
89+
isSubquery = isSubquery,
90+
isFinalStage = false)
91+
92+
private def withRuleContext[T](f: => T): T =
93+
AdaptiveRuleContext.withRuleContext(ruleContext) { f }
94+
95+
private def applyPhysicalRulesWithRuleContext(
96+
plan: => SparkPlan,
97+
rules: Seq[Rule[SparkPlan]],
98+
loggerAndBatchName: Option[(PlanChangeLogger[SparkPlan], String)] = None): SparkPlan = {
99+
// Apply the last rules if exists before going to apply the next batch of rules,
100+
// so that we can propagate the configs.
101+
val newPlan = plan
102+
withRuleContext {
103+
applyPhysicalRules(newPlan, rules, loggerAndBatchName)
104+
}
105+
}
106+
88107
@transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()
89108

90109
// The logical plan optimizer for re-optimizing the current logical plan.
@@ -161,7 +180,9 @@ case class AdaptiveSparkPlanExec(
161180
collapseCodegenStagesRule
162181
)
163182

164-
private def optimizeQueryStage(plan: SparkPlan, isFinalStage: Boolean): SparkPlan = {
183+
private def optimizeQueryStage(
184+
plan: SparkPlan,
185+
isFinalStage: Boolean): SparkPlan = withRuleContext {
165186
val rules = if (isFinalStage &&
166187
!conf.getConf(SQLConf.ADAPTIVE_EXECUTION_APPLY_FINAL_STAGE_SHUFFLE_OPTIMIZATIONS)) {
167188
queryStageOptimizerRules.filterNot(_.isInstanceOf[AQEShuffleReadRule])
@@ -197,15 +218,15 @@ case class AdaptiveSparkPlanExec(
197218
}
198219

199220
private def applyQueryPostPlannerStrategyRules(plan: SparkPlan): SparkPlan = {
200-
applyPhysicalRules(
221+
applyPhysicalRulesWithRuleContext(
201222
plan,
202223
context.session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules,
203224
Some((planChangeLogger, "AQE Query Post Planner Strategy Rules"))
204225
)
205226
}
206227

207228
@transient val initialPlan = context.session.withActive {
208-
applyPhysicalRules(
229+
applyPhysicalRulesWithRuleContext(
209230
applyQueryPostPlannerStrategyRules(inputPlan),
210231
queryStagePreparationRules,
211232
Some((planChangeLogger, "AQE Preparations")))
@@ -282,6 +303,7 @@ case class AdaptiveSparkPlanExec(
282303
val errors = new mutable.ArrayBuffer[Throwable]()
283304
var stagesToReplace = Seq.empty[QueryStageExec]
284305
while (!result.allChildStagesMaterialized) {
306+
ruleContext.clearConfigs()
285307
currentPhysicalPlan = result.newPlan
286308
if (result.newStages.nonEmpty) {
287309
stagesToReplace = result.newStages ++ stagesToReplace
@@ -373,11 +395,13 @@ case class AdaptiveSparkPlanExec(
373395
result = createQueryStages(currentPhysicalPlan)
374396
}
375397

398+
ruleContext = ruleContext.withFinalStage(isFinalStage = true)
376399
// Run the final plan when there's no more unfinished stages.
377-
currentPhysicalPlan = applyPhysicalRules(
400+
currentPhysicalPlan = applyPhysicalRulesWithRuleContext(
378401
optimizeQueryStage(result.newPlan, isFinalStage = true),
379402
postStageCreationRules(supportsColumnar),
380403
Some((planChangeLogger, "AQE Post Stage Creation")))
404+
ruleContext.clearConfigs()
381405
_isFinalPlan = true
382406
executionId.foreach(onUpdatePlan(_, Seq(currentPhysicalPlan)))
383407
currentPhysicalPlan
@@ -595,7 +619,7 @@ case class AdaptiveSparkPlanExec(
595619
val queryStage = plan match {
596620
case e: Exchange =>
597621
val optimized = e.withNewChildren(Seq(optimizeQueryStage(e.child, isFinalStage = false)))
598-
val newPlan = applyPhysicalRules(
622+
val newPlan = applyPhysicalRulesWithRuleContext(
599623
optimized,
600624
postStageCreationRules(outputsColumnar = plan.supportsColumnar),
601625
Some((planChangeLogger, "AQE Post Stage Creation")))
@@ -722,9 +746,11 @@ case class AdaptiveSparkPlanExec(
722746
private def reOptimize(logicalPlan: LogicalPlan): Option[(SparkPlan, LogicalPlan)] = {
723747
try {
724748
logicalPlan.invalidateStatsCache()
725-
val optimized = optimizer.execute(logicalPlan)
726-
val sparkPlan = context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next()
727-
val newPlan = applyPhysicalRules(
749+
val optimized = withRuleContext { optimizer.execute(logicalPlan) }
750+
val sparkPlan = withRuleContext {
751+
context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next()
752+
}
753+
val newPlan = applyPhysicalRulesWithRuleContext(
728754
applyQueryPostPlannerStrategyRules(sparkPlan),
729755
preprocessingRules ++ queryStagePreparationRules,
730756
Some((planChangeLogger, "AQE Replanning")))
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.adaptive
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.{SparkSession, SparkSessionExtensionsProvider}
22+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
23+
import org.apache.spark.sql.catalyst.rules.Rule
24+
import org.apache.spark.sql.execution.{ColumnarRule, RangeExec, SparkPlan, SparkStrategy}
25+
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
26+
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
27+
28+
class AdaptiveRuleContextSuite extends SparkFunSuite with AdaptiveSparkPlanHelper {
29+
30+
private def stop(spark: SparkSession): Unit = {
31+
spark.stop()
32+
SparkSession.clearActiveSession()
33+
SparkSession.clearDefaultSession()
34+
}
35+
36+
private def withSession(
37+
builders: Seq[SparkSessionExtensionsProvider])(f: SparkSession => Unit): Unit = {
38+
val builder = SparkSession.builder().master("local[1]")
39+
builders.foreach(builder.withExtensions)
40+
val spark = builder.getOrCreate()
41+
try f(spark) finally {
42+
stop(spark)
43+
}
44+
}
45+
46+
test("test adaptive rule context") {
47+
withSession(
48+
Seq(_.injectRuntimeOptimizerRule(_ => MyRuleContextForRuntimeOptimization),
49+
_.injectPlannerStrategy(_ => MyRuleContextForPlannerStrategy),
50+
_.injectQueryPostPlannerStrategyRule(_ => MyRuleContextForPostPlannerStrategyRule),
51+
_.injectQueryStagePrepRule(_ => MyRuleContextForPreQueryStageRule),
52+
_.injectQueryStageOptimizerRule(_ => MyRuleContextForQueryStageRule),
53+
_.injectColumnar(_ => MyRuleContextForColumnarRule))) { spark =>
54+
val df = spark.range(1, 10, 1, 3).selectExpr("id % 3 as c").groupBy("c").count()
55+
df.collect()
56+
assert(collectFirst(df.queryExecution.executedPlan) {
57+
case s: ShuffleExchangeExec if s.numPartitions == 2 => s
58+
}.isDefined)
59+
}
60+
}
61+
62+
test("test adaptive rule context with subquery") {
63+
withSession(
64+
Seq(_.injectQueryStagePrepRule(_ => MyRuleContextForQueryStageWithSubquery))) { spark =>
65+
spark.sql("select (select count(*) from range(10)), id from range(10)").collect()
66+
}
67+
}
68+
}
69+
70+
object MyRuleContext {
71+
def checkAndGetRuleContext(): AdaptiveRuleContext = {
72+
val ruleContextOpt = AdaptiveRuleContext.get()
73+
assert(ruleContextOpt.isDefined)
74+
ruleContextOpt.get
75+
}
76+
77+
def checkRuleContextForQueryStage(plan: SparkPlan): SparkPlan = {
78+
val ruleContext = checkAndGetRuleContext()
79+
assert(!ruleContext.isSubquery)
80+
val stage = plan.find(_.isInstanceOf[ShuffleQueryStageExec])
81+
if (stage.isDefined && stage.get.asInstanceOf[ShuffleQueryStageExec].isMaterialized) {
82+
assert(ruleContext.isFinalStage)
83+
assert(!ruleContext.configs().get("spark.sql.shuffle.partitions").contains("2"))
84+
} else {
85+
assert(!ruleContext.isFinalStage)
86+
assert(ruleContext.configs().get("spark.sql.shuffle.partitions").contains("2"))
87+
}
88+
plan
89+
}
90+
}
91+
92+
object MyRuleContextForRuntimeOptimization extends Rule[LogicalPlan] {
93+
override def apply(plan: LogicalPlan): LogicalPlan = {
94+
MyRuleContext.checkAndGetRuleContext()
95+
plan
96+
}
97+
}
98+
99+
object MyRuleContextForPlannerStrategy extends SparkStrategy {
100+
override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
101+
plan match {
102+
case _: LogicalQueryStage =>
103+
val ruleContext = MyRuleContext.checkAndGetRuleContext()
104+
assert(!ruleContext.configs().get("spark.sql.shuffle.partitions").contains("2"))
105+
Nil
106+
case _ => Nil
107+
}
108+
}
109+
}
110+
111+
object MyRuleContextForPostPlannerStrategyRule extends Rule[SparkPlan] {
112+
override def apply(plan: SparkPlan): SparkPlan = {
113+
val ruleContext = MyRuleContext.checkAndGetRuleContext()
114+
if (plan.find(_.isInstanceOf[RangeExec]).isDefined) {
115+
ruleContext.setConfig("spark.sql.shuffle.partitions", "2")
116+
}
117+
plan
118+
}
119+
}
120+
121+
object MyRuleContextForPreQueryStageRule extends Rule[SparkPlan] {
122+
override def apply(plan: SparkPlan): SparkPlan = {
123+
val ruleContext = MyRuleContext.checkAndGetRuleContext()
124+
assert(!ruleContext.isFinalStage)
125+
plan
126+
}
127+
}
128+
129+
object MyRuleContextForQueryStageRule extends Rule[SparkPlan] {
130+
override def apply(plan: SparkPlan): SparkPlan = {
131+
MyRuleContext.checkRuleContextForQueryStage(plan)
132+
}
133+
}
134+
135+
object MyRuleContextForColumnarRule extends ColumnarRule {
136+
override def preColumnarTransitions: Rule[SparkPlan] = {
137+
plan: SparkPlan => {
138+
if (plan.isInstanceOf[AdaptiveSparkPlanExec]) {
139+
// skip if we are not inside AQE
140+
assert(AdaptiveRuleContext.get().isEmpty)
141+
plan
142+
} else {
143+
MyRuleContext.checkRuleContextForQueryStage(plan)
144+
}
145+
}
146+
}
147+
148+
override def postColumnarTransitions: Rule[SparkPlan] = {
149+
plan: SparkPlan => {
150+
if (plan.isInstanceOf[AdaptiveSparkPlanExec]) {
151+
// skip if we are not inside AQE
152+
assert(AdaptiveRuleContext.get().isEmpty)
153+
plan
154+
} else {
155+
MyRuleContext.checkRuleContextForQueryStage(plan)
156+
}
157+
}
158+
}
159+
}
160+
161+
object MyRuleContextForQueryStageWithSubquery extends Rule[SparkPlan] {
162+
override def apply(plan: SparkPlan): SparkPlan = {
163+
val ruleContext = MyRuleContext.checkAndGetRuleContext()
164+
if (plan.exists(_.isInstanceOf[HashAggregateExec])) {
165+
assert(ruleContext.isSubquery)
166+
if (plan.exists(_.isInstanceOf[RangeExec])) {
167+
assert(!ruleContext.isFinalStage)
168+
} else {
169+
assert(ruleContext.isFinalStage)
170+
}
171+
} else {
172+
assert(!ruleContext.isSubquery)
173+
}
174+
plan
175+
}
176+
}

0 commit comments

Comments
 (0)