diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 6fa5203a06f7..137501868f6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -60,7 +60,9 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { // Obviously a lot to do here still... // Collect physical plan candidates. - val candidates = strategies.iterator.flatMap(_(plan)) + val candidates = strategies.iterator.flatMap({ + _(plan).map(propagate(_, plan)) + }) // The candidates may contain placeholders marked as [[planLater]], // so try to replace them by their child plans. @@ -102,4 +104,7 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { /** Prunes bad plans to prevent combinatorial explosion. */ protected def prunePlans(plans: Iterator[PhysicalPlan]): Iterator[PhysicalPlan] + + /** Propagate logicalPlan properties to PhysicalPlan */ + protected def propagate(plan: PhysicalPlan, logicalPlan: LogicalPlan): PhysicalPlan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index a89ccca99d05..2864af2c923c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext import org.codehaus.commons.compiler.CompileException import org.codehaus.janino.InternalCompilerException @@ -36,15 +35,16 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredic import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.statsEstimation.SparkPlanStats import org.apache.spark.sql.types.DataType -import org.apache.spark.util.ThreadUtils /** * The base class for physical operators. * * The naming convention is that physical operators end with "Exec" suffix, e.g. [[ProjectExec]]. */ -abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable { +abstract class SparkPlan extends QueryPlan[SparkPlan] with SparkPlanStats with Logging + with Serializable { /** * A handle to the SQL Context that was used to create this plan. Since many operators need @@ -70,7 +70,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ if (sqlContext != null) { SparkSession.setActiveSession(sqlContext.sparkSession) } - super.makeCopy(newArgs) + val sparkPlan = super.makeCopy(newArgs) + if (this.stats.isDefined) { + sparkPlan.withStats(this.stats.get) + } + sparkPlan } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 3cd02b984d33..1e38cac5f092 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -56,7 +56,7 @@ private[execution] object SparkPlanInfo { case _ => plan.children ++ plan.subqueries } val metrics = plan.metrics.toSeq.map { case (key, metric) => - new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType) + new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType, metric.stats) } // dump the file scan metadata (e.g file path) to event log diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 2a4a1c8ef343..10cc359948ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -65,6 +65,10 @@ class SparkPlanner( plans } + override protected def propagate(plan: SparkPlan, logicalPlan: LogicalPlan): SparkPlan = { + plan.withStats(logicalPlan.stats) + } + /** * Used to build table scan operators where complex projection and filtering are done using * separate physical operators. This function returns the given scan operator with Project and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index b827878f82fc..0846df86c427 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -178,8 +178,8 @@ case class InMemoryTableScanExec( relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder]) // Keeps relation's partition statistics because we don't serialize relation. - private val stats = relation.partitionStatistics - private def statsFor(a: Attribute) = stats.forAttribute(a) + private val partitionStats = relation.partitionStatistics + private def statsFor(a: Attribute) = partitionStats.forAttribute(a) // Currently, only use statistics from atomic types except binary type only. private object ExtractableLiteral { @@ -274,7 +274,7 @@ case class InMemoryTableScanExec( filter.map( BindReferences.bindReference( _, - stats.schema, + partitionStats.schema, allowFailures = true)) boundFilter.foreach(_ => @@ -297,7 +297,7 @@ case class InMemoryTableScanExec( private def filteredCachedBatches(): RDD[CachedBatch] = { // Using these variables here to avoid serialization of entire objects (if referenced directly) // within the map Partitions closure. - val schema = stats.schema + val schema = partitionStats.schema val schemaIndex = schema.zipWithIndex val buffers = relation.cacheBuilder.cachedColumnBuffers @@ -316,7 +316,7 @@ case class InMemoryTableScanExec( val value = cachedBatch.stats.get(i, a.dataType) s"${a.name}: $value" }.mkString(", ") - s"Skipping partition based on stats $statsString" + s"Skipping partition based on partitionStats $statsString" } false } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index fd4a7897c7ad..866a615ae116 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -46,8 +47,10 @@ case class BroadcastHashJoinExec( right: SparkPlan) extends BinaryExecNode with HashJoin with CodegenSupport { - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + override lazy val metrics = { + Map("numOutputRows" -> + SQLMetrics.createMetric(sparkContext, "number of output rows", this.rowCountStats.toLong)) + } override def requiredChildDistribution: Seq[Distribution] = { val mode = HashedRelationBroadcastMode(buildKeys) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala index adb81519dbc8..a2caad69ca9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetricInfo.scala @@ -27,4 +27,5 @@ import org.apache.spark.annotation.DeveloperApi class SQLMetricInfo( val name: String, val accumulatorId: Long, - val metricType: String) + val metricType: String, + val stats: Long = -1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 19809b07508d..65d68f8d51cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -33,7 +33,8 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils} * the executor side are automatically propagated and shown in the SQL UI through metrics. Updates * on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]]. */ -class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] { +class SQLMetric(val metricType: String, initValue: Long = 0L, val stats: Long = -1L) extends + AccumulatorV2[Long, Long] { // This is a workaround for SPARK-11013. // We may use -1 as initial value of the accumulator, if the accumulator is valid, we will // update it at the end of task and the value will be at least 0. Then we can filter out the -1 @@ -42,7 +43,7 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato private var _zeroValue = initValue override def copy(): SQLMetric = { - val newAcc = new SQLMetric(metricType, _value) + val newAcc = new SQLMetric(metricType, _value, stats) newAcc._zeroValue = initValue newAcc } @@ -96,8 +97,8 @@ object SQLMetrics { metric.set((v * baseForAvgMetric).toLong) } - def createMetric(sc: SparkContext, name: String): SQLMetric = { - val acc = new SQLMetric(SUM_METRIC) + def createMetric(sc: SparkContext, name: String, stats: Long = -1): SQLMetric = { + val acc = new SQLMetric(SUM_METRIC, stats = stats) acc.register(sc, name = Some(name), countFailedValues = false) acc } @@ -193,6 +194,14 @@ object SQLMetrics { } } + def stringStats(value: Long): String = { + if (value < 0) { + "" + } else { + s" est: ${stringValue(SUM_METRIC, Seq(value))}" + } + } + /** * Updates metrics based on the driver side value. This is useful for certain metrics that * are only updated on the driver, e.g. subquery execution time, or number of files. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/SparkPlanStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/SparkPlanStats.scala new file mode 100644 index 000000000000..3ac4c0668f34 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/statsEstimation/SparkPlanStats.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.statsEstimation + +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.execution.SparkPlan + +/** + * A trait to add statistics propagation to [[SparkPlan]]. + */ +trait SparkPlanStats { self: SparkPlan => + + private var _stats: Option[Statistics] = None + + def stats: Option[Statistics] = _stats + + def withStats(value: Statistics): SparkPlan = { + _stats = Option(value) + this + } + + def rowCountStats: BigInt = { + if (stats.isDefined && stats.get.rowCount.isDefined) { + stats.get.rowCount.get + } else { + -1 + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index e496de1b05e4..76c59c1f0691 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -180,17 +180,20 @@ class SQLAppStatusListener( } private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = { - val metricTypes = exec.metrics.map { m => (m.accumulatorId, m.metricType) }.toMap + val metricMap = exec.metrics.map { m => (m.accumulatorId, m) }.toMap val metrics = exec.stages.toSeq .flatMap { stageId => Option(stageMetrics.get(stageId)) } .flatMap(_.taskMetrics.values().asScala) .flatMap { metrics => metrics.ids.zip(metrics.values) } val aggregatedMetrics = (metrics ++ exec.driverAccumUpdates.toSeq) - .filter { case (id, _) => metricTypes.contains(id) } + .filter { case (id, _) => metricMap.contains(id) } .groupBy(_._1) .map { case (id, values) => - id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2)) + val metric = metricMap(id) + val value = SQLMetrics.stringValue(metric.metricType, values.map(_._2)) + val stats = SQLMetrics.stringStats(metric.stats) + id -> (value + stats) } // Check the execution again for whether the aggregated metrics data has been calculated. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala index 241001a857c8..c3083d8e41bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala @@ -142,4 +142,5 @@ class SparkPlanGraphNodeWrapper( case class SQLPlanMetric( name: String, accumulatorId: Long, - metricType: String) + metricType: String, + stats : Long = -1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index b864ad1c7108..26d6b74381a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -80,7 +80,7 @@ object SparkPlanGraph { planInfo.nodeName match { case "WholeStageCodegen" => val metrics = planInfo.metrics.map { metric => - SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType) + SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType, metric.stats) } val cluster = new SparkPlanGraphCluster( @@ -114,7 +114,7 @@ object SparkPlanGraph { edges += SparkPlanGraphEdge(node.id, parent.id) case name => val metrics = planInfo.metrics.map { metric => - SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType) + SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType, metric.stats) } val node = new SparkPlanGraphNode( nodeIdGenerator.getAndIncrement(), planInfo.nodeName, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 630f02c8e2f8..231ef8eda778 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -1416,4 +1416,25 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto assert(catalogStats.rowCount.isEmpty) } } + + test("statistics for broadcastHashJoin numOutputRows statistic") { + withTempView("t1", "t2") { + withSQLConf(SQLConf.CBO_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "40", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false") { + sql("CREATE TABLE t1 (key INT, a2 STRING, a3 DOUBLE)") + sql("INSERT INTO TABLE t1 SELECT 1, 'a', 10.0") + sql("INSERT INTO TABLE t1 SELECT 1, 'b', null") + sql("ANALYZE TABLE t1 COMPUTE STATISTICS FOR ALL COLUMNS") + + sql("CREATE TABLE t2 (key INT, b2 STRING, b3 DOUBLE)") + sql("INSERT INTO TABLE t2 SELECT 1, 'a', 10.0") + sql("ANALYZE TABLE t2 COMPUTE STATISTICS FOR ALL COLUMNS") + + val df = sql("SELECT * FROM t1 JOIN t2 ON t1.key = t2.key") + assert(df.queryExecution.sparkPlan.isInstanceOf[BroadcastHashJoinExec]) + assert(df.queryExecution.sparkPlan.rowCountStats == 2) + } + } + } }