Skip to content

Commit b3cd0b0

Browse files
committed
propagate the statistics from logical plan to physical plan in Strategy.apply method
1 parent 24b5471 commit b3cd0b0

File tree

10 files changed

+28
-31
lines changed

10 files changed

+28
-31
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,6 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] {
6565
// The candidates may contain placeholders marked as [[planLater]],
6666
// so try to replace them by their child plans.
6767
val plans = candidates.flatMap { candidate =>
68-
propagateProperty(candidate, plan)
69-
7068
val placeholders = collectPlaceholders(candidate)
7169

7270
if (placeholders.isEmpty) {
@@ -96,9 +94,6 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] {
9694
pruned
9795
}
9896

99-
protected def propagateProperty(candidate: PhysicalPlan, plan: LogicalPlan): Unit = {
100-
}
101-
10297
/**
10398
* Collects placeholders marked using [[GenericStrategy#planLater planLater]]
10499
* by [[strategies]].

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,6 @@ class SparkPlanner(
5353
*/
5454
def extraPlanningStrategies: Seq[Strategy] = Nil
5555

56-
override protected def propagateProperty(candidate: SparkPlan, plan: LogicalPlan): Unit = {
57-
candidate.withStats(plan.stats)
58-
}
59-
6056
override protected def collectPlaceholders(plan: SparkPlan): Seq[(SparkPlan, LogicalPlan)] = {
6157
plan.collect {
6258
case placeholder @ PlanLater(logicalPlan) => placeholder -> logicalPlan

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ import org.apache.spark.sql.types.StructType
4949
abstract class SparkStrategy extends GenericStrategy[SparkPlan] {
5050

5151
override protected def planLater(plan: LogicalPlan): SparkPlan = PlanLater(plan)
52+
53+
override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
54+
doApply(plan).map(sparkPlan => sparkPlan.withStats(plan.stats))
55+
}
56+
57+
protected def doApply(plan: LogicalPlan): Seq[SparkPlan]
5258
}
5359

5460
case class PlanLater(plan: LogicalPlan) extends LeafExecNode {
@@ -67,7 +73,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
6773
* Plans special cases of limit operators.
6874
*/
6975
object SpecialLimits extends Strategy {
70-
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
76+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
7177
case ReturnAnswer(rootPlan) => rootPlan match {
7278
case Limit(IntegerLiteral(limit), Sort(order, true, child))
7379
if limit < conf.topKSortFallbackThreshold =>
@@ -209,7 +215,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
209215
hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL))
210216
}
211217

212-
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
218+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
213219

214220
// If it is an equi-join, we first look at the join hints w.r.t. the following order:
215221
// 1. broadcast hint: pick broadcast hash join if the join type is supported. If both sides
@@ -383,7 +389,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
383389
* on-demand, only when planning in a [[org.apache.spark.sql.execution.streaming.StreamExecution]]
384390
*/
385391
object StatefulAggregationStrategy extends Strategy {
386-
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
392+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
387393
case _ if !plan.isStreaming => Nil
388394

389395
case EventTimeWatermark(columnName, delay, child) =>
@@ -423,7 +429,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
423429
* Used to plan the streaming deduplicate operator.
424430
*/
425431
object StreamingDeduplicationStrategy extends Strategy {
426-
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
432+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
427433
case Deduplicate(keys, child) if child.isStreaming =>
428434
StreamingDeduplicateExec(keys, planLater(child)) :: Nil
429435

@@ -440,7 +446,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
440446
* Limit is unsupported for streams in Update mode.
441447
*/
442448
case class StreamingGlobalLimitStrategy(outputMode: OutputMode) extends Strategy {
443-
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
449+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
444450
case ReturnAnswer(rootPlan) => rootPlan match {
445451
case Limit(IntegerLiteral(limit), child)
446452
if plan.isStreaming && outputMode == InternalOutputModes.Append =>
@@ -455,7 +461,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
455461
}
456462

457463
object StreamingJoinStrategy extends Strategy {
458-
override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
464+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = {
459465
plan match {
460466
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _)
461467
if left.isStreaming && right.isStreaming =>
@@ -476,7 +482,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
476482
* Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
477483
*/
478484
object Aggregation extends Strategy {
479-
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
485+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
480486
case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child)
481487
if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) =>
482488
val aggregateExpressions = aggExpressions.map(expr =>
@@ -538,7 +544,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
538544
}
539545

540546
object Window extends Strategy {
541-
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
547+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
542548
case PhysicalWindow(
543549
WindowFunctionType.SQL, windowExprs, partitionSpec, orderSpec, child) =>
544550
execution.window.WindowExec(
@@ -556,7 +562,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
556562
protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1)
557563

558564
object InMemoryScans extends Strategy {
559-
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
565+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
560566
case PhysicalOperation(projectList, filters, mem: InMemoryRelation) =>
561567
pruneFilterProject(
562568
projectList,
@@ -574,7 +580,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
574580
* be replaced with the real relation using the `Source` in `StreamExecution`.
575581
*/
576582
object StreamingRelationStrategy extends Strategy {
577-
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
583+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
578584
case s: StreamingRelation =>
579585
StreamingRelationExec(s.sourceName, s.output) :: Nil
580586
case s: StreamingExecutionRelation =>
@@ -590,7 +596,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
590596
* in streaming plans. Conversion for batch plans is handled by [[BasicOperators]].
591597
*/
592598
object FlatMapGroupsWithStateStrategy extends Strategy {
593-
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
599+
override def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
594600
case FlatMapGroupsWithState(
595601
func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _,
596602
timeout, child) =>
@@ -608,7 +614,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
608614
* Strategy to convert EvalPython logical operator to physical operator.
609615
*/
610616
object PythonEvals extends Strategy {
611-
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
617+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
612618
case ArrowEvalPython(udfs, output, child) =>
613619
ArrowEvalPythonExec(udfs, output, planLater(child)) :: Nil
614620
case BatchEvalPython(udfs, output, child) =>
@@ -619,7 +625,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
619625
}
620626

621627
object BasicOperators extends Strategy {
622-
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
628+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
623629
case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil
624630
case r: RunnableCommand => ExecutedCommandExec(r) :: Nil
625631

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
261261
case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with CastSupport {
262262
import DataSourceStrategy._
263263

264-
def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
264+
override protected def doApply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
265265
case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _, _)) =>
266266
pruneFilterProjectRaw(
267267
l,

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ object FileSourceStrategy extends Strategy with Logging {
136136
}
137137
}
138138

139-
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
139+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
140140
case PhysicalOperation(projects, filters,
141141
l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) =>
142142
// Filters on this relation fall into four categories based on where we can use them to avoid

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper {
102102

103103
import DataSourceV2Implicits._
104104

105-
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
105+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
106106
case PhysicalOperation(project, filters, relation: DataSourceV2Relation) =>
107107
val scanBuilder = relation.newScanBuilder()
108108

sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ case class FastOperator(output: Seq[Attribute]) extends SparkPlan {
3939
}
4040

4141
object TestStrategy extends Strategy {
42-
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
42+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
4343
case Project(Seq(attr), _) if attr.name == "a" =>
4444
FastOperator(attr.toAttribute :: Nil) :: Nil
4545
case _ => Nil

sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ case class MyCheckRule(spark: SparkSession) extends (LogicalPlan => Unit) {
218218
}
219219

220220
case class MySparkStrategy(spark: SparkSession) extends SparkStrategy {
221-
override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty
221+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty
222222
}
223223

224224
case class MyParser(spark: SparkSession, delegate: ParserInterface) extends ParserInterface {
@@ -272,7 +272,7 @@ case class MyCheckRule2(spark: SparkSession) extends (LogicalPlan => Unit) {
272272
}
273273

274274
case class MySparkStrategy2(spark: SparkSession) extends SparkStrategy {
275-
override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty
275+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty
276276
}
277277

278278
object MyExtensions2 {

sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class SparkPlannerSuite extends SharedSQLContext {
3333

3434
var planned = 0
3535
object TestStrategy extends Strategy {
36-
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
36+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
3737
case ReturnAnswer(child) =>
3838
planned += 1
3939
planLater(child) :: planLater(NeverPlanned) :: Nil

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ private[hive] trait HiveStrategies {
223223
val sparkSession: SparkSession
224224

225225
object Scripts extends Strategy {
226-
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
226+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
227227
case ScriptTransformation(input, script, output, child, ioschema) =>
228228
val hiveIoSchema = HiveScriptIOSchema(ioschema)
229229
ScriptTransformationExec(input, script, output, planLater(child), hiveIoSchema) :: Nil
@@ -236,7 +236,7 @@ private[hive] trait HiveStrategies {
236236
* applied.
237237
*/
238238
object HiveTableScans extends Strategy {
239-
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
239+
override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
240240
case PhysicalOperation(projectList, predicates, relation: HiveTableRelation) =>
241241
// Filter out all predicates that only deal with partition keys, these are given to the
242242
// hive table scan operator to be used for partition pruning.

0 commit comments

Comments
 (0)