Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] {
// Obviously a lot to do here still...

// Collect physical plan candidates.
val candidates = strategies.iterator.flatMap(_(plan))
val candidates = strategies.iterator.flatMap({
_(plan).map(propagate(_, plan))
})

// The candidates may contain placeholders marked as [[planLater]],
// so try to replace them by their child plans.
Expand Down Expand Up @@ -102,4 +104,7 @@ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] {

/** Prunes bad plans to prevent combinatorial explosion. */
protected def prunePlans(plans: Iterator[PhysicalPlan]): Iterator[PhysicalPlan]

/** Propagate logicalPlan properties to PhysicalPlan */
protected def propagate(plan: PhysicalPlan, logicalPlan: LogicalPlan): PhysicalPlan
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext

import org.codehaus.commons.compiler.CompileException
import org.codehaus.janino.InternalCompilerException
Expand All @@ -36,15 +35,16 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredic
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.statsEstimation.SparkPlanStats
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.ThreadUtils

/**
* The base class for physical operators.
*
* The naming convention is that physical operators end with "Exec" suffix, e.g. [[ProjectExec]].
*/
abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializable {
abstract class SparkPlan extends QueryPlan[SparkPlan] with SparkPlanStats with Logging
with Serializable {

/**
* A handle to the SQL Context that was used to create this plan. Since many operators need
Expand All @@ -70,7 +70,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
if (sqlContext != null) {
SparkSession.setActiveSession(sqlContext.sparkSession)
}
super.makeCopy(newArgs)
val sparkPlan = super.makeCopy(newArgs)
if (this.stats.isDefined) {
sparkPlan.withStats(this.stats.get)
}
sparkPlan
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ private[execution] object SparkPlanInfo {
case _ => plan.children ++ plan.subqueries
}
val metrics = plan.metrics.toSeq.map { case (key, metric) =>
new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType)
new SQLMetricInfo(metric.name.getOrElse(key), metric.id, metric.metricType, metric.stats)
}

// dump the file scan metadata (e.g file path) to event log
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class SparkPlanner(
plans
}

override protected def propagate(plan: SparkPlan, logicalPlan: LogicalPlan): SparkPlan = {
plan.withStats(logicalPlan.stats)
}

/**
* Used to build table scan operators where complex projection and filtering are done using
* separate physical operators. This function returns the given scan operator with Project and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ case class InMemoryTableScanExec(
relation.cachedPlan.outputOrdering.map(updateAttribute(_).asInstanceOf[SortOrder])

// Keeps relation's partition statistics because we don't serialize relation.
private val stats = relation.partitionStatistics
private def statsFor(a: Attribute) = stats.forAttribute(a)
private val partitionStats = relation.partitionStatistics
private def statsFor(a: Attribute) = partitionStats.forAttribute(a)

// Currently, only use statistics from atomic types except binary type only.
private object ExtractableLiteral {
Expand Down Expand Up @@ -274,7 +274,7 @@ case class InMemoryTableScanExec(
filter.map(
BindReferences.bindReference(
_,
stats.schema,
partitionStats.schema,
allowFailures = true))

boundFilter.foreach(_ =>
Expand All @@ -297,7 +297,7 @@ case class InMemoryTableScanExec(
private def filteredCachedBatches(): RDD[CachedBatch] = {
// Using these variables here to avoid serialization of entire objects (if referenced directly)
// within the map Partitions closure.
val schema = stats.schema
val schema = partitionStats.schema
val schemaIndex = schema.zipWithIndex
val buffers = relation.cacheBuilder.cachedColumnBuffers

Expand All @@ -316,7 +316,7 @@ case class InMemoryTableScanExec(
val value = cachedBatch.stats.get(i, a.dataType)
s"${a.name}: $value"
}.mkString(", ")
s"Skipping partition based on stats $statsString"
s"Skipping partition based on partitionStats $statsString"
}
false
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
Expand All @@ -46,8 +47,10 @@ case class BroadcastHashJoinExec(
right: SparkPlan)
extends BinaryExecNode with HashJoin with CodegenSupport {

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override lazy val metrics = {
Map("numOutputRows" ->
SQLMetrics.createMetric(sparkContext, "number of output rows", this.rowCountStats.toLong))
}

override def requiredChildDistribution: Seq[Distribution] = {
val mode = HashedRelationBroadcastMode(buildKeys)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ import org.apache.spark.annotation.DeveloperApi
class SQLMetricInfo(
val name: String,
val accumulatorId: Long,
val metricType: String)
val metricType: String,
val stats: Long = -1)
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils}
* the executor side are automatically propagated and shown in the SQL UI through metrics. Updates
* on the driver side must be explicitly posted using [[SQLMetrics.postDriverMetricUpdates()]].
*/
class SQLMetric(val metricType: String, initValue: Long = 0L) extends AccumulatorV2[Long, Long] {
class SQLMetric(val metricType: String, initValue: Long = 0L, val stats: Long = -1L) extends
AccumulatorV2[Long, Long] {
// This is a workaround for SPARK-11013.
// We may use -1 as initial value of the accumulator, if the accumulator is valid, we will
// update it at the end of task and the value will be at least 0. Then we can filter out the -1
Expand All @@ -42,7 +43,7 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato
private var _zeroValue = initValue

override def copy(): SQLMetric = {
val newAcc = new SQLMetric(metricType, _value)
val newAcc = new SQLMetric(metricType, _value, stats)
newAcc._zeroValue = initValue
newAcc
}
Expand Down Expand Up @@ -96,8 +97,8 @@ object SQLMetrics {
metric.set((v * baseForAvgMetric).toLong)
}

def createMetric(sc: SparkContext, name: String): SQLMetric = {
val acc = new SQLMetric(SUM_METRIC)
def createMetric(sc: SparkContext, name: String, stats: Long = -1): SQLMetric = {
val acc = new SQLMetric(SUM_METRIC, stats = stats)
acc.register(sc, name = Some(name), countFailedValues = false)
acc
}
Expand Down Expand Up @@ -193,6 +194,14 @@ object SQLMetrics {
}
}

def stringStats(value: Long): String = {
if (value < 0) {
""
} else {
s" est: ${stringValue(SUM_METRIC, Seq(value))}"
}
}

/**
* Updates metrics based on the driver side value. This is useful for certain metrics that
* are only updated on the driver, e.g. subquery execution time, or number of files.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.statsEstimation

import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.execution.SparkPlan

/**
* A trait to add statistics propagation to [[SparkPlan]].
*/
trait SparkPlanStats { self: SparkPlan =>

private var _stats: Option[Statistics] = None

def stats: Option[Statistics] = _stats

def withStats(value: Statistics): SparkPlan = {
_stats = Option(value)
this
}

def rowCountStats: BigInt = {
if (stats.isDefined && stats.get.rowCount.isDefined) {
stats.get.rowCount.get
} else {
-1
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -180,17 +180,20 @@ class SQLAppStatusListener(
}

private def aggregateMetrics(exec: LiveExecutionData): Map[Long, String] = {
val metricTypes = exec.metrics.map { m => (m.accumulatorId, m.metricType) }.toMap
val metricMap = exec.metrics.map { m => (m.accumulatorId, m) }.toMap
val metrics = exec.stages.toSeq
.flatMap { stageId => Option(stageMetrics.get(stageId)) }
.flatMap(_.taskMetrics.values().asScala)
.flatMap { metrics => metrics.ids.zip(metrics.values) }

val aggregatedMetrics = (metrics ++ exec.driverAccumUpdates.toSeq)
.filter { case (id, _) => metricTypes.contains(id) }
.filter { case (id, _) => metricMap.contains(id) }
.groupBy(_._1)
.map { case (id, values) =>
id -> SQLMetrics.stringValue(metricTypes(id), values.map(_._2))
val metric = metricMap(id)
val value = SQLMetrics.stringValue(metric.metricType, values.map(_._2))
val stats = SQLMetrics.stringStats(metric.stats)
id -> (value + stats)
}

// Check the execution again for whether the aggregated metrics data has been calculated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,5 @@ class SparkPlanGraphNodeWrapper(
case class SQLPlanMetric(
name: String,
accumulatorId: Long,
metricType: String)
metricType: String,
stats : Long = -1)
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ object SparkPlanGraph {
planInfo.nodeName match {
case "WholeStageCodegen" =>
val metrics = planInfo.metrics.map { metric =>
SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType, metric.stats)
}

val cluster = new SparkPlanGraphCluster(
Expand Down Expand Up @@ -114,7 +114,7 @@ object SparkPlanGraph {
edges += SparkPlanGraphEdge(node.id, parent.id)
case name =>
val metrics = planInfo.metrics.map { metric =>
SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType)
SQLPlanMetric(metric.name, metric.accumulatorId, metric.metricType, metric.stats)
}
val node = new SparkPlanGraphNode(
nodeIdGenerator.getAndIncrement(), planInfo.nodeName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1416,4 +1416,25 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
assert(catalogStats.rowCount.isEmpty)
}
}

test("statistics for broadcastHashJoin numOutputRows statistic") {
withTempView("t1", "t2") {
withSQLConf(SQLConf.CBO_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "40",
SQLConf.PREFER_SORTMERGEJOIN.key -> "false") {
sql("CREATE TABLE t1 (key INT, a2 STRING, a3 DOUBLE)")
sql("INSERT INTO TABLE t1 SELECT 1, 'a', 10.0")
sql("INSERT INTO TABLE t1 SELECT 1, 'b', null")
sql("ANALYZE TABLE t1 COMPUTE STATISTICS FOR ALL COLUMNS")

sql("CREATE TABLE t2 (key INT, b2 STRING, b3 DOUBLE)")
sql("INSERT INTO TABLE t2 SELECT 1, 'a', 10.0")
sql("ANALYZE TABLE t2 COMPUTE STATISTICS FOR ALL COLUMNS")

val df = sql("SELECT * FROM t1 JOIN t2 ON t1.key = t2.key")
assert(df.queryExecution.sparkPlan.isInstanceOf[BroadcastHashJoinExec])
assert(df.queryExecution.sparkPlan.rowCountStats == 2)
}
}
}
}