From 4055bc87424a4dfaac7694a8664e1765ef78515e Mon Sep 17 00:00:00 2001 From: pengbo Date: Wed, 17 Apr 2019 13:32:27 +0800 Subject: [PATCH 1/5] SPARK-27482: Show BroadcastHashJoinExec numOutputRows statistics info on SparkSQL UI page --- .../spark/sql/execution/SparkPlan.scala | 6 +-- .../spark/sql/execution/SparkPlanInfo.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 3 +- .../columnar/InMemoryTableScanExec.scala | 10 ++--- .../joins/BroadcastHashJoinExec.scala | 10 +++-- .../sql/execution/metric/SQLMetricInfo.scala | 3 +- .../sql/execution/metric/SQLMetrics.scala | 17 +++++++-- .../statsEstimation/SparkPlanStats.scala | 37 +++++++++++++++++++ .../execution/ui/SQLAppStatusListener.scala | 9 +++-- .../sql/execution/ui/SQLAppStatusStore.scala | 3 +- .../sql/execution/ui/SparkPlanGraph.scala | 4 +- .../execution/WholeStageCodegenSuite.scala | 4 +- .../spark/sql/hive/StatisticsSuite.scala | 21 +++++++++++ 13 files changed, 103 insertions(+), 26 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/SparkPlanStats.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index a89ccca99d05..e5a3ef59c064 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext import org.codehaus.commons.compiler.CompileException import org.codehaus.janino.InternalCompilerException @@ -36,15 +35,16 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredic import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.statsEstimation.SparkPlanStats import org.apache.spark.sql.types.DataType -import org.apache.spark.util.ThreadUtils /** * The base class for physical operators. * * The naming convention is that physical operators end with "Exec" suffix, e.g. [[ProjectExec]]. */ -abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable { +abstract class SparkPlan extends QueryPlan[SparkPlan] with SparkPlanStats with Logging + with Serializable { /** * A handle to the SQL Context that was used to create this plan. Since many operators need diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 3cd02b984d33..1e38cac5f092 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -56,7 +56,7 @@ private[execution] object SparkPlanInfo { case _ => plan.children ++ plan.subqueries } val metrics = plan.metrics.toSeq.map { case (key, metric) => - new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType) + new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType, metric.stats) } // dump the file scan metadata (e.g file path) to event log diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index efd05a3e2b3e..c053c244456b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -241,7 +241,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { buildSide, condition, planLater(left), - planLater(right))) + planLater(right), + Option(plan.stats))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index b827878f82fc..0846df86c427 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -178,8 +178,8 @@ case class InMemoryTableScanExec( relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) // Keeps relation's partition statistics because we don't serialize relation. - private val stats = relation.partitionStatistics - private def statsFor(a: Attribute) = stats.forAttribute(a) + private val partitionStats = relation.partitionStatistics + private def statsFor(a: Attribute) = partitionStats.forAttribute(a) // Currently, only use statistics from atomic types except binary type only. private object ExtractableLiteral { @@ -274,7 +274,7 @@ case class InMemoryTableScanExec( filter.map( BindReferences.bindReference( _, - stats.schema, + partitionStats.schema, allowFailures = true)) boundFilter.foreach(_ => @@ -297,7 +297,7 @@ case class InMemoryTableScanExec( private def filteredCachedBatches(): RDD[CachedBatch] = { // Using these variables here to avoid serialization of entire objects (if referenced directly) // within the map Partitions closure. - val schema = stats.schema + val schema = partitionStats.schema val schemaIndex = schema.zipWithIndex val buffers = relation.cacheBuilder.cachedColumnBuffers @@ -316,7 +316,7 @@ case class InMemoryTableScanExec( val value = cachedBatch.stats.get(i, a.dataType) s"${a.name}: $value" }.mkString(", ") - s"Skipping partition based on stats $statsString" + s"Skipping partition based on partitionStats $statsString" } false } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index fd4a7897c7ad..4c170d6eafc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -43,11 +44,14 @@ case class BroadcastHashJoinExec( buildSide: BuildSide, condition: Option[Expression], left: SparkPlan, - right: SparkPlan) + right: SparkPlan, + override val stats: Option[Statistics] = None) extends BinaryExecNode with HashJoin with CodegenSupport { - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + override lazy val metrics = { + Map("numOutputRows" -> + SQLMetrics.createMetric(sparkContext, "number of output rows", this.rowCountStats.toLong)) + } override def requiredChildDistribution: Seq[Distribution] = { val mode = HashedRelationBroadcastMode(buildKeys) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala index adb81519dbc8..a2caad69ca9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala @@ -27,4 +27,5 @@ import org.apache.spark.annotation.DeveloperApi class SQLMetricInfo( val name: String, val accumulatorId: Long, - val metricType: String) + val metricType: String, + val stats: Long = -1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 19809b07508d..65d68f8d51cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -33,7 +33,8 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} * the executor side are automatically propagated and shown in the SQL UI through metrics. Updates * on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]]. */ -class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] { +class SQLMetric(val metricType: String, initValue: Long = 0L, val stats: Long = -1L) extends + AccumulatorV2[Long, Long] { // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will // update it at the end of task and the value will be at least 0. Then we can filter out the -1 @@ -42,7 +43,7 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato private var _zeroValue = initValue override def copy(): SQLMetric = { - val newAcc = new SQLMetric(metricType, _value) + val newAcc = new SQLMetric(metricType, _value, stats) newAcc._zeroValue = initValue newAcc } @@ -96,8 +97,8 @@ object SQLMetrics { metric.set((v * baseForAvgMetric).toLong) } - def createMetric(sc: SparkContext, name: String): SQLMetric = { - val acc = new SQLMetric(SUM_METRIC) + def createMetric(sc: SparkContext, name: String, stats: Long = -1): SQLMetric = { + val acc = new SQLMetric(SUM_METRIC, stats = stats) acc.register(sc, name = Some(name), countFailedValues = false) acc } @@ -193,6 +194,14 @@ object SQLMetrics { } } + def stringStats(value: Long): String = { + if (value < 0) { + "" + } else { + s" est: ${stringValue(SUM_METRIC, Seq(value))}" + } + } + /** * Updates metrics based on the driver side value. This is useful for certain metrics that * are only updated on the driver, e.g. subquery execution time, or number of files. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/SparkPlanStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/SparkPlanStats.scala new file mode 100644 index 000000000000..9b107e4c8944 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/SparkPlanStats.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.statsEstimation + +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.execution.SparkPlan + +/** + * A trait to add statistics propagation to [[SparkPlan]]. + */ +trait SparkPlanStats { self: SparkPlan => + + def stats: Option[Statistics] = None + + def rowCountStats: BigInt = { + if (stats.isDefined && stats.get.rowCount.isDefined) { + stats.get.rowCount.get + } else { + -1 + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index e496de1b05e4..76c59c1f0691 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -180,17 +180,20 @@ class SQLAppStatusListener( } private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = { - val metricTypes = exec.metrics.map { m => (m.accumulatorId, m.metricType) }.toMap + val metricMap = exec.metrics.map { m => (m.accumulatorId, m) }.toMap val metrics = exec.stages.toSeq .flatMap { stageId => Option(stageMetrics.get(stageId)) } .flatMap(_.taskMetrics.values().asScala) .flatMap { metrics => metrics.ids.zip(metrics.values) } val aggregatedMetrics = (metrics ++ exec.driverAccumUpdates.toSeq) - .filter { case (id, _) => metricTypes.contains(id) } + .filter { case (id, _) => metricMap.contains(id) } .groupBy(_._1) .map { case (id, values) => - id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2)) + val metric = metricMap(id) + val value = SQLMetrics.stringValue(metric.metricType, values.map(_._2)) + val stats = SQLMetrics.stringStats(metric.stats) + id -> (value + stats) } // Check the execution again for whether the aggregated metrics data has been calculated. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala index 241001a857c8..c3083d8e41bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala @@ -142,4 +142,5 @@ class SparkPlanGraphNodeWrapper( case class SQLPlanMetric( name: String, accumulatorId: Long, - metricType: String) + metricType: String, + stats : Long = -1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index b864ad1c7108..26d6b74381a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -80,7 +80,7 @@ object SparkPlanGraph { planInfo.nodeName match { case "WholeStageCodegen" => val metrics = planInfo.metrics.map { metric => - SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType) + SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType, metric.stats) } val cluster = new SparkPlanGraphCluster( @@ -114,7 +114,7 @@ object SparkPlanGraph { edges += SparkPlanGraphEdge(node.id, parent.id) case name => val metrics = planInfo.metrics.map { metric => - SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType) + SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType, metric.stats) } val node = new SparkPlanGraphNode( nodeIdGenerator.getAndIncrement(), planInfo.nodeName, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 3c9a0908147a..7ebd7c76b2c6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -351,7 +351,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { .join(baseTable, "idx") assert(distinctWithId.queryExecution.executedPlan.collectFirst { case WholeStageCodegenExec( - ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _))) => true + ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true }.isDefined) checkAnswer(distinctWithId, Seq(Row(1, 0), Row(1, 0))) @@ -362,7 +362,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { .join(baseTable, "idx") assert(groupByWithId.queryExecution.executedPlan.collectFirst { case WholeStageCodegenExec( - ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _))) => true + ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true }.isDefined) checkAnswer(groupByWithId, Seq(Row(1, 2, 0), Row(1, 2, 0))) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 630f02c8e2f8..231ef8eda778 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -1416,4 +1416,25 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto assert(catalogStats.rowCount.isEmpty) } } + + test("statistics for broadcastHashJoin numOutputRows statistic") { + withTempView("t1", "t2") { + withSQLConf(SQLConf.CBO_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "40", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false") { + sql("CREATE TABLE t1 (key INT, a2 STRING, a3 DOUBLE)") + sql("INSERT INTO TABLE t1 SELECT 1, 'a', 10.0") + sql("INSERT INTO TABLE t1 SELECT 1, 'b', null") + sql("ANALYZE TABLE t1 COMPUTE STATISTICS FOR ALL COLUMNS") + + sql("CREATE TABLE t2 (key INT, b2 STRING, b3 DOUBLE)") + sql("INSERT INTO TABLE t2 SELECT 1, 'a', 10.0") + sql("ANALYZE TABLE t2 COMPUTE STATISTICS FOR ALL COLUMNS") + + val df = sql("SELECT * FROM t1 JOIN t2 ON t1.key = t2.key") + assert(df.queryExecution.sparkPlan.isInstanceOf[BroadcastHashJoinExec]) + assert(df.queryExecution.sparkPlan.rowCountStats == 2) + } + } + } } From a2e23b5b8315726855d13d2e5859b2b49110c232 Mon Sep 17 00:00:00 2001 From: pengbo Date: Thu, 18 Apr 2019 14:57:16 +0800 Subject: [PATCH 2/5] add setPlanProperty in QueryPlanner.plan to propagate the statistics from logical plan to physical plan --- .../spark/sql/catalyst/planning/QueryPlanner.scala | 5 +++++ .../scala/org/apache/spark/sql/execution/SparkPlan.scala | 6 +++++- .../org/apache/spark/sql/execution/SparkPlanner.scala | 4 ++++ .../org/apache/spark/sql/execution/SparkStrategies.scala | 3 +-- .../sql/execution/joins/BroadcastHashJoinExec.scala | 3 +-- .../sql/execution/statsEstimation/SparkPlanStats.scala | 9 ++++++++- .../spark/sql/execution/WholeStageCodegenSuite.scala | 4 ++-- 7 files changed, 26 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 6fa5203a06f7..0181adf0bf87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -65,6 +65,8 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { // The candidates may contain placeholders marked as [[planLater]], // so try to replace them by their child plans. val plans = candidates.flatMap { candidate => + setPlanProperty(candidate, plan) + val placeholders = collectPlaceholders(candidate) if (placeholders.isEmpty) { @@ -94,6 +96,9 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { pruned } + protected def setPlanProperty(candidate: PhysicalPlan, plan: LogicalPlan): Unit = { + } + /** * Collects placeholders marked using [[GenericStrategy#planLater planLater]] * by [[strategies]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index e5a3ef59c064..2864af2c923c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -70,7 +70,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with SparkPlanStats with L if (sqlContext != null) { SparkSession.setActiveSession(sqlContext.sparkSession) } - super.makeCopy(newArgs) + val sparkPlan = super.makeCopy(newArgs) + if (this.stats.isDefined) { + sparkPlan.withStats(this.stats.get) + } + sparkPlan } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 2a4a1c8ef343..8a22c8a00561 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -53,6 +53,10 @@ class SparkPlanner( */ def extraPlanningStrategies: Seq[Strategy] = Nil + override protected def setPlanProperty(candidate: SparkPlan, plan: LogicalPlan): Unit = { + candidate.withStats(plan.stats) + } + override protected def collectPlaceholders(plan: SparkPlan): Seq[(SparkPlan, LogicalPlan)] = { plan.collect { case placeholder @ PlanLater(logicalPlan) => placeholder -> logicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c053c244456b..efd05a3e2b3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -241,8 +241,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { buildSide, condition, planLater(left), - planLater(right), - Option(plan.stats))) + planLater(right))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index 4c170d6eafc3..866a615ae116 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -44,8 +44,7 @@ case class BroadcastHashJoinExec( buildSide: BuildSide, condition: Option[Expression], left: SparkPlan, - right: SparkPlan, - override val stats: Option[Statistics] = None) + right: SparkPlan) extends BinaryExecNode with HashJoin with CodegenSupport { override lazy val metrics = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/SparkPlanStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/SparkPlanStats.scala index 9b107e4c8944..3ac4c0668f34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/SparkPlanStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/SparkPlanStats.scala @@ -25,7 +25,14 @@ import org.apache.spark.sql.execution.SparkPlan */ trait SparkPlanStats { self: SparkPlan => - def stats: Option[Statistics] = None + private var _stats: Option[Statistics] = None + + def stats: Option[Statistics] = _stats + + def withStats(value: Statistics): SparkPlan = { + _stats = Option(value) + this + } def rowCountStats: BigInt = { if (stats.isDefined && stats.get.rowCount.isDefined) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 7ebd7c76b2c6..3c9a0908147a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -351,7 +351,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { .join(baseTable, "idx") assert(distinctWithId.queryExecution.executedPlan.collectFirst { case WholeStageCodegenExec( - ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true + ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _))) => true }.isDefined) checkAnswer(distinctWithId, Seq(Row(1, 0), Row(1, 0))) @@ -362,7 +362,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { .join(baseTable, "idx") assert(groupByWithId.queryExecution.executedPlan.collectFirst { case WholeStageCodegenExec( - ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _, _))) => true + ProjectExec(_, BroadcastHashJoinExec(_, _, _, _, _, _: HashAggregateExec, _))) => true }.isDefined) checkAnswer(groupByWithId, Seq(Row(1, 2, 0), Row(1, 2, 0))) } From 24b54718834cc5cd5a2b311d3476bc11edf5e39c Mon Sep 17 00:00:00 2001 From: pengbo Date: Fri, 19 Apr 2019 00:05:21 +0800 Subject: [PATCH 3/5] renaming method to propagateProperty --- .../org/apache/spark/sql/catalyst/planning/QueryPlanner.scala | 4 ++-- .../scala/org/apache/spark/sql/execution/SparkPlanner.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 0181adf0bf87..246873d8c3b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -65,7 +65,7 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { // The candidates may contain placeholders marked as [[planLater]], // so try to replace them by their child plans. val plans = candidates.flatMap { candidate => - setPlanProperty(candidate, plan) + propagateProperty(candidate, plan) val placeholders = collectPlaceholders(candidate) @@ -96,7 +96,7 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { pruned } - protected def setPlanProperty(candidate: PhysicalPlan, plan: LogicalPlan): Unit = { + protected def propagateProperty(candidate: PhysicalPlan, plan: LogicalPlan): Unit = { } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 8a22c8a00561..47039e50d23e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -53,7 +53,7 @@ class SparkPlanner( */ def extraPlanningStrategies: Seq[Strategy] = Nil - override protected def setPlanProperty(candidate: SparkPlan, plan: LogicalPlan): Unit = { + override protected def propagateProperty(candidate: SparkPlan, plan: LogicalPlan): Unit = { candidate.withStats(plan.stats) } From b3cd0b0e6593a892a927113a3c8db8ac5014564f Mon Sep 17 00:00:00 2001 From: pengbo Date: Fri, 26 Apr 2019 11:50:07 +0800 Subject: [PATCH 4/5] propagate the statistics from logical plan to physical plan in Strategy.apply method --- .../sql/catalyst/planning/QueryPlanner.scala | 5 --- .../spark/sql/execution/SparkPlanner.scala | 4 --- .../spark/sql/execution/SparkStrategies.scala | 32 +++++++++++-------- .../datasources/DataSourceStrategy.scala | 2 +- .../datasources/FileSourceStrategy.scala | 2 +- .../datasources/v2/DataSourceV2Strategy.scala | 2 +- .../spark/sql/ExtraStrategiesSuite.scala | 2 +- .../sql/SparkSessionExtensionSuite.scala | 4 +-- .../sql/execution/SparkPlannerSuite.scala | 2 +- .../spark/sql/hive/HiveStrategies.scala | 4 +-- 10 files changed, 28 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 246873d8c3b5..6fa5203a06f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -65,8 +65,6 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { // The candidates may contain placeholders marked as [[planLater]], // so try to replace them by their child plans. val plans = candidates.flatMap { candidate => - propagateProperty(candidate, plan) - val placeholders = collectPlaceholders(candidate) if (placeholders.isEmpty) { @@ -96,9 +94,6 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { pruned } - protected def propagateProperty(candidate: PhysicalPlan, plan: LogicalPlan): Unit = { - } - /** * Collects placeholders marked using [[GenericStrategy#planLater planLater]] * by [[strategies]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 47039e50d23e..2a4a1c8ef343 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -53,10 +53,6 @@ class SparkPlanner( */ def extraPlanningStrategies: Seq[Strategy] = Nil - override protected def propagateProperty(candidate: SparkPlan, plan: LogicalPlan): Unit = { - candidate.withStats(plan.stats) - } - override protected def collectPlaceholders(plan: SparkPlan): Seq[(SparkPlan, LogicalPlan)] = { plan.collect { case placeholder @ PlanLater(logicalPlan) => placeholder -> logicalPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c0a28fad9d39..01c054430361 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -49,6 +49,12 @@ import org.apache.spark.sql.types.StructType abstract class SparkStrategy extends GenericStrategy[SparkPlan] { override protected def planLater(plan: LogicalPlan): SparkPlan = PlanLater(plan) + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = { + doApply(plan).map(sparkPlan => sparkPlan.withStats(plan.stats)) + } + + protected def doApply(plan: LogicalPlan): Seq[SparkPlan] } case class PlanLater(plan: LogicalPlan) extends LeafExecNode { @@ -67,7 +73,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Plans special cases of limit operators. */ object SpecialLimits extends Strategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ReturnAnswer(rootPlan) => rootPlan match { case Limit(IntegerLiteral(limit), Sort(order, true, child)) if limit < conf.topKSortFallbackThreshold => @@ -209,7 +215,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) } - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // If it is an equi-join, we first look at the join hints w.r.t. the following order: // 1. broadcast hint: pick broadcast hash join if the join type is supported. If both sides @@ -383,7 +389,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * on-demand, only when planning in a [[org.apache.spark.sql.execution.streaming.StreamExecution]] */ object StatefulAggregationStrategy extends Strategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case _ if !plan.isStreaming => Nil case EventTimeWatermark(columnName, delay, child) => @@ -423,7 +429,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Used to plan the streaming deduplicate operator. */ object StreamingDeduplicationStrategy extends Strategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case Deduplicate(keys, child) if child.isStreaming => StreamingDeduplicateExec(keys, planLater(child)) :: Nil @@ -440,7 +446,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Limit is unsupported for streams in Update mode. */ case class StreamingGlobalLimitStrategy(outputMode: OutputMode) extends Strategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ReturnAnswer(rootPlan) => rootPlan match { case Limit(IntegerLiteral(limit), child) if plan.isStreaming && outputMode == InternalOutputModes.Append => @@ -455,7 +461,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } object StreamingJoinStrategy extends Strategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = { + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = { plan match { case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if left.isStreaming && right.isStreaming => @@ -476,7 +482,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. */ object Aggregation extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) => val aggregateExpressions = aggExpressions.map(expr => @@ -538,7 +544,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } object Window extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalWindow( WindowFunctionType.SQL, windowExprs, partitionSpec, orderSpec, child) => execution.window.WindowExec( @@ -556,7 +562,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) object InMemoryScans extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projectList, filters, mem: InMemoryRelation) => pruneFilterProject( projectList, @@ -574,7 +580,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * be replaced with the real relation using the `Source` in `StreamExecution`. */ object StreamingRelationStrategy extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case s: StreamingRelation => StreamingRelationExec(s.sourceName, s.output) :: Nil case s: StreamingExecutionRelation => @@ -590,7 +596,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. */ object FlatMapGroupsWithStateStrategy extends Strategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case FlatMapGroupsWithState( func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _, timeout, child) => @@ -608,7 +614,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Strategy to convert EvalPython logical operator to physical operator. */ object PythonEvals extends Strategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ArrowEvalPython(udfs, output, child) => ArrowEvalPythonExec(udfs, output, planLater(child)) :: Nil case BatchEvalPython(udfs, output, child) => @@ -619,7 +625,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } object BasicOperators extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil case r: RunnableCommand => ExecutedCommandExec(r) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index c907ac21af38..982de85f614d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -261,7 +261,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with CastSupport { import DataSourceStrategy._ - def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { + override protected def doApply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _, _)) => pruneFilterProjectRaw( l, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index c8a42f043f15..df2a73a5c6fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -136,7 +136,7 @@ object FileSourceStrategy extends Strategy with Logging { } } - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) => // Filters on this relation fall into four categories based on where we can use them to avoid diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 7681dc8dfb37..4b5be502bdd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -102,7 +102,7 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { import DataSourceV2Implicits._ - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => val scanBuilder = relation.newScanBuilder() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index a41b46554862..f645ef9d144d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -39,7 +39,7 @@ case class FastOperator(output: Seq[Attribute]) extends SparkPlan { } object TestStrategy extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case Project(Seq(attr), _) if attr.name == "a" => FastOperator(attr.toAttribute :: Nil) :: Nil case _ => Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 881268440ccd..fe416a989d99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -218,7 +218,7 @@ case class MyCheckRule(spark: SparkSession) extends (LogicalPlan => Unit) { } case class MySparkStrategy(spark: SparkSession) extends SparkStrategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty } case class MyParser(spark: SparkSession, delegate: ParserInterface) extends ParserInterface { @@ -272,7 +272,7 @@ case class MyCheckRule2(spark: SparkSession) extends (LogicalPlan => Unit) { } case class MySparkStrategy2(spark: SparkSession) extends SparkStrategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty } object MyExtensions2 { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala index 5828f9783da4..e3fe1771128a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala @@ -33,7 +33,7 @@ class SparkPlannerSuite extends SharedSQLContext { var planned = 0 object TestStrategy extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ReturnAnswer(child) => planned += 1 planLater(child) :: planLater(NeverPlanned) :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 58b711006e47..0fb7ca41e577 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -223,7 +223,7 @@ private[hive] trait HiveStrategies { val sparkSession: SparkSession object Scripts extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ScriptTransformation(input, script, output, child, ioschema) => val hiveIoSchema = HiveScriptIOSchema(ioschema) ScriptTransformationExec(input, script, output, planLater(child), hiveIoSchema) :: Nil @@ -236,7 +236,7 @@ private[hive] trait HiveStrategies { * applied. */ object HiveTableScans extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projectList, predicates, relation: HiveTableRelation) => // Filter out all predicates that only deal with partition keys, these are given to the // hive table scan operator to be used for partition pruning. From 83483b17d95d63eb9ffcd9f61b0957c9081bc76e Mon Sep 17 00:00:00 2001 From: "mingbo.pb" Date: Thu, 9 May 2019 17:46:49 +0800 Subject: [PATCH 5/5] back to use propagate template method in QueryPlanner.plan for extension strategy compability --- .../sql/catalyst/planning/QueryPlanner.scala | 7 +++- .../spark/sql/execution/SparkPlanner.scala | 4 +++ .../spark/sql/execution/SparkStrategies.scala | 32 ++++++++----------- .../datasources/DataSourceStrategy.scala | 2 +- .../datasources/FileSourceStrategy.scala | 2 +- .../datasources/v2/DataSourceV2Strategy.scala | 2 +- .../spark/sql/ExtraStrategiesSuite.scala | 2 +- .../sql/SparkSessionExtensionSuite.scala | 4 +-- .../sql/execution/SparkPlannerSuite.scala | 2 +- .../spark/sql/hive/HiveStrategies.scala | 4 +-- 10 files changed, 32 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 6fa5203a06f7..137501868f6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -60,7 +60,9 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { // Obviously a lot to do here still... // Collect physical plan candidates. - val candidates = strategies.iterator.flatMap(_(plan)) + val candidates = strategies.iterator.flatMap({ + _(plan).map(propagate(_, plan)) + }) // The candidates may contain placeholders marked as [[planLater]], // so try to replace them by their child plans. @@ -102,4 +104,7 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { /** Prunes bad plans to prevent combinatorial explosion. */ protected def prunePlans(plans: Iterator[PhysicalPlan]): Iterator[PhysicalPlan] + + /** Propagate logicalPlan properties to PhysicalPlan */ + protected def propagate(plan: PhysicalPlan, logicalPlan: LogicalPlan): PhysicalPlan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 2a4a1c8ef343..10cc359948ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -65,6 +65,10 @@ class SparkPlanner( plans } + override protected def propagate(plan: SparkPlan, logicalPlan: LogicalPlan): SparkPlan = { + plan.withStats(logicalPlan.stats) + } + /** * Used to build table scan operators where complex projection and filtering are done using * separate physical operators. This function returns the given scan operator with Project and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 01c054430361..c0a28fad9d39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -49,12 +49,6 @@ import org.apache.spark.sql.types.StructType abstract class SparkStrategy extends GenericStrategy[SparkPlan] { override protected def planLater(plan: LogicalPlan): SparkPlan = PlanLater(plan) - - override def apply(plan: LogicalPlan): Seq[SparkPlan] = { - doApply(plan).map(sparkPlan => sparkPlan.withStats(plan.stats)) - } - - protected def doApply(plan: LogicalPlan): Seq[SparkPlan] } case class PlanLater(plan: LogicalPlan) extends LeafExecNode { @@ -73,7 +67,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Plans special cases of limit operators. */ object SpecialLimits extends Strategy { - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ReturnAnswer(rootPlan) => rootPlan match { case Limit(IntegerLiteral(limit), Sort(order, true, child)) if limit < conf.topKSortFallbackThreshold => @@ -215,7 +209,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL)) } - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // If it is an equi-join, we first look at the join hints w.r.t. the following order: // 1. broadcast hint: pick broadcast hash join if the join type is supported. If both sides @@ -389,7 +383,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * on-demand, only when planning in a [[org.apache.spark.sql.execution.streaming.StreamExecution]] */ object StatefulAggregationStrategy extends Strategy { - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case _ if !plan.isStreaming => Nil case EventTimeWatermark(columnName, delay, child) => @@ -429,7 +423,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Used to plan the streaming deduplicate operator. */ object StreamingDeduplicationStrategy extends Strategy { - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case Deduplicate(keys, child) if child.isStreaming => StreamingDeduplicateExec(keys, planLater(child)) :: Nil @@ -446,7 +440,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Limit is unsupported for streams in Update mode. */ case class StreamingGlobalLimitStrategy(outputMode: OutputMode) extends Strategy { - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ReturnAnswer(rootPlan) => rootPlan match { case Limit(IntegerLiteral(limit), child) if plan.isStreaming && outputMode == InternalOutputModes.Append => @@ -461,7 +455,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } object StreamingJoinStrategy extends Strategy { - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = { plan match { case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if left.isStreaming && right.isStreaming => @@ -482,7 +476,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. */ object Aggregation extends Strategy { - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) => val aggregateExpressions = aggExpressions.map(expr => @@ -544,7 +538,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } object Window extends Strategy { - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalWindow( WindowFunctionType.SQL, windowExprs, partitionSpec, orderSpec, child) => execution.window.WindowExec( @@ -562,7 +556,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) object InMemoryScans extends Strategy { - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projectList, filters, mem: InMemoryRelation) => pruneFilterProject( projectList, @@ -580,7 +574,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * be replaced with the real relation using the `Source` in `StreamExecution`. */ object StreamingRelationStrategy extends Strategy { - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case s: StreamingRelation => StreamingRelationExec(s.sourceName, s.output) :: Nil case s: StreamingExecutionRelation => @@ -596,7 +590,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. */ object FlatMapGroupsWithStateStrategy extends Strategy { - override def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case FlatMapGroupsWithState( func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateEnc, outputMode, _, timeout, child) => @@ -614,7 +608,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * Strategy to convert EvalPython logical operator to physical operator. */ object PythonEvals extends Strategy { - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ArrowEvalPython(udfs, output, child) => ArrowEvalPythonExec(udfs, output, planLater(child)) :: Nil case BatchEvalPython(udfs, output, child) => @@ -625,7 +619,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } object BasicOperators extends Strategy { - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case d: DataWritingCommand => DataWritingCommandExec(d, planLater(d.query)) :: Nil case r: RunnableCommand => ExecutedCommandExec(r) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 982de85f614d..c907ac21af38 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -261,7 +261,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with CastSupport { import DataSourceStrategy._ - override protected def doApply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { + def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _, _)) => pruneFilterProjectRaw( l, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index df2a73a5c6fe..c8a42f043f15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -136,7 +136,7 @@ object FileSourceStrategy extends Strategy with Logging { } } - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) => // Filters on this relation fall into four categories based on where we can use them to avoid diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 4b5be502bdd2..7681dc8dfb37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -102,7 +102,7 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { import DataSourceV2Implicits._ - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => val scanBuilder = relation.newScanBuilder() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index f645ef9d144d..a41b46554862 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -39,7 +39,7 @@ case class FastOperator(output: Seq[Attribute]) extends SparkPlan { } object TestStrategy extends Strategy { - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case Project(Seq(attr), _) if attr.name == "a" => FastOperator(attr.toAttribute :: Nil) :: Nil case _ => Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index fe416a989d99..881268440ccd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -218,7 +218,7 @@ case class MyCheckRule(spark: SparkSession) extends (LogicalPlan => Unit) { } case class MySparkStrategy(spark: SparkSession) extends SparkStrategy { - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty + override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty } case class MyParser(spark: SparkSession, delegate: ParserInterface) extends ParserInterface { @@ -272,7 +272,7 @@ case class MyCheckRule2(spark: SparkSession) extends (LogicalPlan => Unit) { } case class MySparkStrategy2(spark: SparkSession) extends SparkStrategy { - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty + override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty } object MyExtensions2 { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala index e3fe1771128a..5828f9783da4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala @@ -33,7 +33,7 @@ class SparkPlannerSuite extends SharedSQLContext { var planned = 0 object TestStrategy extends Strategy { - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ReturnAnswer(child) => planned += 1 planLater(child) :: planLater(NeverPlanned) :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 0fb7ca41e577..58b711006e47 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -223,7 +223,7 @@ private[hive] trait HiveStrategies { val sparkSession: SparkSession object Scripts extends Strategy { - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ScriptTransformation(input, script, output, child, ioschema) => val hiveIoSchema = HiveScriptIOSchema(ioschema) ScriptTransformationExec(input, script, output, planLater(child), hiveIoSchema) :: Nil @@ -236,7 +236,7 @@ private[hive] trait HiveStrategies { * applied. */ object HiveTableScans extends Strategy { - override protected def doApply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(projectList, predicates, relation: HiveTableRelation) => // Filter out all predicates that only deal with partition keys, these are given to the // hive table scan operator to be used for partition pruning.