@@ -49,6 +49,12 @@ import org.apache.spark.sql.types.StructType
4949abstract class SparkStrategy extends GenericStrategy [SparkPlan ] {
5050
5151 override protected def planLater (plan : LogicalPlan ): SparkPlan = PlanLater (plan)
52+
53+ override def apply (plan : LogicalPlan ): Seq [SparkPlan ] = {
54+ doApply(plan).map(sparkPlan => sparkPlan.withStats(plan.stats))
55+ }
56+
57+ protected def doApply (plan : LogicalPlan ): Seq [SparkPlan ]
5258}
5359
5460case class PlanLater (plan : LogicalPlan ) extends LeafExecNode {
@@ -67,7 +73,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
6773 * Plans special cases of limit operators.
6874 */
6975 object SpecialLimits extends Strategy {
70- override def apply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
76+ override protected def doApply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
7177 case ReturnAnswer (rootPlan) => rootPlan match {
7278 case Limit (IntegerLiteral (limit), Sort (order, true , child))
7379 if limit < conf.topKSortFallbackThreshold =>
@@ -209,7 +215,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
209215 hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL ))
210216 }
211217
212- def apply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
218+ override protected def doApply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
213219
214220 // If it is an equi-join, we first look at the join hints w.r.t. the following order:
215221 // 1. broadcast hint: pick broadcast hash join if the join type is supported. If both sides
@@ -383,7 +389,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
383389 * on-demand, only when planning in a [[org.apache.spark.sql.execution.streaming.StreamExecution ]]
384390 */
385391 object StatefulAggregationStrategy extends Strategy {
386- override def apply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
392+ override protected def doApply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
387393 case _ if ! plan.isStreaming => Nil
388394
389395 case EventTimeWatermark (columnName, delay, child) =>
@@ -423,7 +429,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
423429 * Used to plan the streaming deduplicate operator.
424430 */
425431 object StreamingDeduplicationStrategy extends Strategy {
426- override def apply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
432+ override protected def doApply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
427433 case Deduplicate (keys, child) if child.isStreaming =>
428434 StreamingDeduplicateExec (keys, planLater(child)) :: Nil
429435
@@ -440,7 +446,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
440446 * Limit is unsupported for streams in Update mode.
441447 */
442448 case class StreamingGlobalLimitStrategy (outputMode : OutputMode ) extends Strategy {
443- override def apply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
449+ override protected def doApply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
444450 case ReturnAnswer (rootPlan) => rootPlan match {
445451 case Limit (IntegerLiteral (limit), child)
446452 if plan.isStreaming && outputMode == InternalOutputModes .Append =>
@@ -455,7 +461,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
455461 }
456462
457463 object StreamingJoinStrategy extends Strategy {
458- override def apply (plan : LogicalPlan ): Seq [SparkPlan ] = {
464+ override protected def doApply (plan : LogicalPlan ): Seq [SparkPlan ] = {
459465 plan match {
460466 case ExtractEquiJoinKeys (joinType, leftKeys, rightKeys, condition, left, right, _)
461467 if left.isStreaming && right.isStreaming =>
@@ -476,7 +482,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
476482 * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface.
477483 */
478484 object Aggregation extends Strategy {
479- def apply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
485+ override protected def doApply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
480486 case PhysicalAggregation (groupingExpressions, aggExpressions, resultExpressions, child)
481487 if aggExpressions.forall(expr => expr.isInstanceOf [AggregateExpression ]) =>
482488 val aggregateExpressions = aggExpressions.map(expr =>
@@ -538,7 +544,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
538544 }
539545
540546 object Window extends Strategy {
541- def apply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
547+ override protected def doApply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
542548 case PhysicalWindow (
543549 WindowFunctionType .SQL , windowExprs, partitionSpec, orderSpec, child) =>
544550 execution.window.WindowExec (
@@ -556,7 +562,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
556562 protected lazy val singleRowRdd = sparkContext.parallelize(Seq (InternalRow ()), 1 )
557563
558564 object InMemoryScans extends Strategy {
559- def apply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
565+ override protected def doApply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
560566 case PhysicalOperation (projectList, filters, mem : InMemoryRelation ) =>
561567 pruneFilterProject(
562568 projectList,
@@ -574,7 +580,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
574580 * be replaced with the real relation using the `Source` in `StreamExecution`.
575581 */
576582 object StreamingRelationStrategy extends Strategy {
577- def apply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
583+ override protected def doApply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
578584 case s : StreamingRelation =>
579585 StreamingRelationExec (s.sourceName, s.output) :: Nil
580586 case s : StreamingExecutionRelation =>
@@ -590,7 +596,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
590596 * in streaming plans. Conversion for batch plans is handled by [[BasicOperators ]].
591597 */
592598 object FlatMapGroupsWithStateStrategy extends Strategy {
593- override def apply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
599+ override def doApply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
594600 case FlatMapGroupsWithState (
595601 func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _,
596602 timeout, child) =>
@@ -608,7 +614,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
608614 * Strategy to convert EvalPython logical operator to physical operator.
609615 */
610616 object PythonEvals extends Strategy {
611- override def apply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
617+ override protected def doApply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
612618 case ArrowEvalPython (udfs, output, child) =>
613619 ArrowEvalPythonExec (udfs, output, planLater(child)) :: Nil
614620 case BatchEvalPython (udfs, output, child) =>
@@ -619,7 +625,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
619625 }
620626
621627 object BasicOperators extends Strategy {
622- def apply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
628+ override protected def doApply (plan : LogicalPlan ): Seq [SparkPlan ] = plan match {
623629 case d : DataWritingCommand => DataWritingCommandExec (d, planLater(d.query)) :: Nil
624630 case r : RunnableCommand => ExecutedCommandExec (r) :: Nil
625631
0 commit comments