Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
* <li>Customized Parser.</li>
* <li>(External) Catalog listeners.</li>
* <li>Columnar Rules.</li>
* <li>Adaptive Query Post Planner Strategy Rules.</li>
* <li>Adaptive Query Stage Preparation Rules.</li>
* <li>Adaptive Query Execution Runtime Optimizer Rules.</li>
* <li>Adaptive Query Stage Optimizer Rules.</li>
Expand Down Expand Up @@ -112,10 +113,13 @@ class SparkSessionExtensions {
type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder)
type TableFunctionDescription = (FunctionIdentifier, ExpressionInfo, TableFunctionBuilder)
type ColumnarRuleBuilder = SparkSession => ColumnarRule
type QueryPostPlannerStrategyBuilder = SparkSession => Rule[SparkPlan]
type QueryStagePrepRuleBuilder = SparkSession => Rule[SparkPlan]
type QueryStageOptimizerRuleBuilder = SparkSession => Rule[SparkPlan]

private[this] val columnarRuleBuilders = mutable.Buffer.empty[ColumnarRuleBuilder]
private[this] val queryPostPlannerStrategyRuleBuilders =
mutable.Buffer.empty[QueryPostPlannerStrategyBuilder]
private[this] val queryStagePrepRuleBuilders = mutable.Buffer.empty[QueryStagePrepRuleBuilder]
private[this] val runtimeOptimizerRules = mutable.Buffer.empty[RuleBuilder]
private[this] val queryStageOptimizerRuleBuilders =
Expand All @@ -128,6 +132,14 @@ class SparkSessionExtensions {
columnarRuleBuilders.map(_.apply(session)).toSeq
}

/**
* Build the override rules for the query post planner strategy phase of adaptive query execution.
*/
private[sql] def buildQueryPostPlannerStrategyRules(
session: SparkSession): Seq[Rule[SparkPlan]] = {
queryPostPlannerStrategyRuleBuilders.map(_.apply(session)).toSeq
}

/**
* Build the override rules for the query stage preparation phase of adaptive query execution.
*/
Expand Down Expand Up @@ -156,6 +168,15 @@ class SparkSessionExtensions {
columnarRuleBuilders += builder
}

/**
* Inject a rule that applied between `plannerStrategy` and `queryStagePrepRules`, so
* it can get the whole plan before injecting exchanges.
* Note, these rules can only be applied within AQE.
*/
def injectQueryPostPlannerStrategyRule(builder: QueryPostPlannerStrategyBuilder): Unit = {
queryPostPlannerStrategyRuleBuilders += builder
}

/**
* Inject a rule that can override the query stage preparation phase of adaptive query
* execution.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ import org.apache.spark.sql.execution.SparkPlan
* query stage
* @param queryStageOptimizerRules applied to a new query stage before its execution. It makes sure
* all children query stages are materialized
* @param queryPostPlannerStrategyRules applied between `plannerStrategy` and `queryStagePrepRules`,
* so it can get the whole plan before injecting exchanges.
*/
class AdaptiveRulesHolder(
val queryStagePrepRules: Seq[Rule[SparkPlan]],
val runtimeOptimizerRules: Seq[Rule[LogicalPlan]],
val queryStageOptimizerRules: Seq[Rule[SparkPlan]]) {
val queryStageOptimizerRules: Seq[Rule[SparkPlan]],
val queryPostPlannerStrategyRules: Seq[Rule[SparkPlan]]) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,19 @@ case class AdaptiveSparkPlanExec(
optimized
}

private def applyQueryPostPlannerStrategyRules(plan: SparkPlan): SparkPlan = {
applyPhysicalRules(
plan,
context.session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules,
Some((planChangeLogger, "AQE Query Post Planner Strategy Rules"))
)
}

@transient val initialPlan = context.session.withActive {
applyPhysicalRules(
inputPlan, queryStagePreparationRules, Some((planChangeLogger, "AQE Preparations")))
applyQueryPostPlannerStrategyRules(inputPlan),
queryStagePreparationRules,
Some((planChangeLogger, "AQE Preparations")))
}

@volatile private var currentPhysicalPlan = initialPlan
Expand Down Expand Up @@ -706,7 +716,7 @@ case class AdaptiveSparkPlanExec(
val optimized = optimizer.execute(logicalPlan)
val sparkPlan = context.session.sessionState.planner.plan(ReturnAnswer(optimized)).next()
val newPlan = applyPhysicalRules(
sparkPlan,
applyQueryPostPlannerStrategyRules(sparkPlan),
preprocessingRules ++ queryStagePreparationRules,
Some((planChangeLogger, "AQE Replanning")))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,8 @@ abstract class BaseSessionStateBuilder(
new AdaptiveRulesHolder(
extensions.buildQueryStagePrepRules(session),
extensions.buildRuntimeOptimizerRules(session),
extensions.buildQueryStageOptimizerRules(session))
extensions.buildQueryStageOptimizerRules(session),
extensions.buildQueryPostPlannerStrategyRules(session))
}

protected def planNormalizationRules: Seq[Rule[LogicalPlan]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,17 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIden
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{Final, Partial}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Limit, LocalRelation, LogicalPlan, Statistics, UnresolvedHint}
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.connector.write.WriterCommitMessage
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources.{FileFormat, WriteFilesExec, WriteFilesSpec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin}
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
Expand Down Expand Up @@ -516,6 +518,31 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
}
}
}

test("SPARK-46170: Support inject adaptive query post planner strategy rules in " +
"SparkSessionExtensions") {
val extensions = create { extensions =>
extensions.injectQueryPostPlannerStrategyRule(_ => MyQueryPostPlannerStrategyRule)
}
withSession(extensions) { session =>
assert(session.sessionState.adaptiveRulesHolder.queryPostPlannerStrategyRules
.contains(MyQueryPostPlannerStrategyRule))
import session.implicits._
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3",
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false") {
val input = Seq(10, 20, 10).toDF("c1")
val df = input.groupBy("c1").count()
df.collect()
assert(df.rdd.partitions.length == 1)
assert(collectFirst(df.queryExecution.executedPlan) {
case s: ShuffleExchangeExec if s.outputPartitioning == SinglePartition => true
}.isDefined)
assert(collectFirst(df.queryExecution.executedPlan) {
case _: SortExec => true
}.isDefined)
}
}
}
}

case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
Expand Down Expand Up @@ -1190,3 +1217,14 @@ object RequireAtLeaseTwoPartitions extends Rule[SparkPlan] {
}
}
}

object MyQueryPostPlannerStrategyRule extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = {
plan.transformUp {
case h: HashAggregateExec if h.aggregateExpressions.map(_.mode).contains(Partial) =>
ShuffleExchangeExec(SinglePartition, h)
case h: HashAggregateExec if h.aggregateExpressions.map(_.mode).contains(Final) =>
SortExec(h.groupingExpressions.map(k => SortOrder.apply(k, Ascending)), false, h)
}
}
}