Skip to content

Commit ebfd91c

Browse files
zsxwingrxin
authored andcommitted
[SPARK-9467][SQL]Add SQLMetric to specialize accumulators to avoid boxing
This PR adds SQLMetric/SQLMetricParam/SQLMetricValue to specialize accumulators to avoid boxing. All SQL metrics should use these classes rather than `Accumulator`. Author: zsxwing <[email protected]> Closes apache#7996 from zsxwing/sql-accu and squashes the following commits: 14a5f0a [zsxwing] Address comments 367ca23 [zsxwing] Use localValue directly to avoid changing Accumulable 42f50c3 [zsxwing] Add SQLMetric to specialize accumulators to avoid boxing
1 parent e57d6b5 commit ebfd91c

File tree

9 files changed

+338
-47
lines changed

9 files changed

+338
-47
lines changed

core/src/main/scala/org/apache/spark/Accumulators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa
257257
*/
258258
class Accumulator[T] private[spark] (
259259
@transient private[spark] val initialValue: T,
260-
private[spark] val param: AccumulatorParam[T],
260+
param: AccumulatorParam[T],
261261
name: Option[String],
262262
internal: Boolean)
263263
extends Accumulable[T, T](initialValue, param, name, internal) {

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,21 +1238,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
12381238
acc
12391239
}
12401240

1241-
/**
1242-
* Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display
1243-
* in the Spark UI. Tasks can "add" values to the accumulator using the `+=` method. Only the
1244-
* driver can access the accumulator's `value`. The latest local value of such accumulator will be
1245-
* sent back to the driver via heartbeats.
1246-
*
1247-
* @tparam T type that can be added to the accumulator, must be thread safe
1248-
*/
1249-
private[spark] def internalAccumulator[T](initialValue: T, name: String)(
1250-
implicit param: AccumulatorParam[T]): Accumulator[T] = {
1251-
val acc = new Accumulator(initialValue, param, Some(name), internal = true)
1252-
cleaner.foreach(_.registerAccumulatorForCleanup(acc))
1253-
acc
1254-
}
1255-
12561241
/**
12571242
* Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values
12581243
* with `+=`. Only the driver can access the accumuable's `value`.

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicBoolean
2121

2222
import scala.collection.mutable.ArrayBuffer
2323

24-
import org.apache.spark.{Accumulator, Logging}
24+
import org.apache.spark.Logging
2525
import org.apache.spark.annotation.DeveloperApi
2626
import org.apache.spark.rdd.{RDD, RDDOperationScope}
2727
import org.apache.spark.sql.SQLContext
@@ -32,6 +32,7 @@ import org.apache.spark.sql.Row
3232
import org.apache.spark.sql.catalyst.expressions.codegen._
3333
import org.apache.spark.sql.catalyst.plans.QueryPlan
3434
import org.apache.spark.sql.catalyst.plans.physical._
35+
import org.apache.spark.sql.metric.{IntSQLMetric, LongSQLMetric, SQLMetric, SQLMetrics}
3536
import org.apache.spark.sql.types.DataType
3637

3738
object SparkPlan {
@@ -84,22 +85,30 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
8485
*/
8586
protected[sql] def trackNumOfRowsEnabled: Boolean = false
8687

87-
private lazy val numOfRowsAccumulator = sparkContext.internalAccumulator(0L, "number of rows")
88+
private lazy val defaultMetrics: Map[String, SQLMetric[_, _]] =
89+
if (trackNumOfRowsEnabled) {
90+
Map("numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows"))
91+
}
92+
else {
93+
Map.empty
94+
}
8895

8996
/**
90-
* Return all accumulators containing metrics of this SparkPlan.
97+
* Return all metrics containing metrics of this SparkPlan.
9198
*/
92-
private[sql] def accumulators: Map[String, Accumulator[_]] = if (trackNumOfRowsEnabled) {
93-
Map("numRows" -> numOfRowsAccumulator)
94-
} else {
95-
Map.empty
96-
}
99+
private[sql] def metrics: Map[String, SQLMetric[_, _]] = defaultMetrics
100+
101+
/**
102+
* Return a IntSQLMetric according to the name.
103+
*/
104+
private[sql] def intMetric(name: String): IntSQLMetric =
105+
metrics(name).asInstanceOf[IntSQLMetric]
97106

98107
/**
99-
* Return the accumulator according to the name.
108+
* Return a LongSQLMetric according to the name.
100109
*/
101-
private[sql] def accumulator[T](name: String): Accumulator[T] =
102-
accumulators(name).asInstanceOf[Accumulator[T]]
110+
private[sql] def longMetric(name: String): LongSQLMetric =
111+
metrics(name).asInstanceOf[LongSQLMetric]
103112

104113
// TODO: Move to `DistributedPlan`
105114
/** Specifies how data is partitioned across different nodes in the cluster. */
@@ -148,7 +157,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
148157
RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
149158
prepare()
150159
if (trackNumOfRowsEnabled) {
151-
val numRows = accumulator[Long]("numRows")
160+
val numRows = longMetric("numRows")
152161
doExecute().map { row =>
153162
numRows += 1
154163
row

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters
2626
import org.apache.spark.sql.catalyst.errors._
2727
import org.apache.spark.sql.catalyst.expressions._
2828
import org.apache.spark.sql.catalyst.plans.physical._
29+
import org.apache.spark.sql.metric.SQLMetrics
2930
import org.apache.spark.sql.types.StructType
3031
import org.apache.spark.util.collection.ExternalSorter
3132
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
@@ -81,13 +82,13 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan)
8182
case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
8283
override def output: Seq[Attribute] = child.output
8384

84-
private[sql] override lazy val accumulators = Map(
85-
"numInputRows" -> sparkContext.internalAccumulator(0L, "number of input rows"),
86-
"numOutputRows" -> sparkContext.internalAccumulator(0L, "number of output rows"))
85+
private[sql] override lazy val metrics = Map(
86+
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
87+
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
8788

8889
protected override def doExecute(): RDD[InternalRow] = {
89-
val numInputRows = accumulator[Long]("numInputRows")
90-
val numOutputRows = accumulator[Long]("numOutputRows")
90+
val numInputRows = longMetric("numInputRows")
91+
val numOutputRows = longMetric("numOutputRows")
9192
child.execute().mapPartitions { iter =>
9293
val predicate = newPredicate(condition, child.output)
9394
iter.filter { row =>
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.metric
19+
20+
import org.apache.spark.{Accumulable, AccumulableParam, SparkContext}
21+
22+
/**
23+
* Create a layer for specialized metric. We cannot add `@specialized` to
24+
* `Accumulable/AccumulableParam` because it will break Java source compatibility.
25+
*
26+
* An implementation of SQLMetric should override `+=` and `add` to avoid boxing.
27+
*/
28+
private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T](
29+
name: String, val param: SQLMetricParam[R, T])
30+
extends Accumulable[R, T](param.zero, param, Some(name), true)
31+
32+
/**
33+
* Create a layer for specialized metric. We cannot add `@specialized` to
34+
* `Accumulable/AccumulableParam` because it will break Java source compatibility.
35+
*/
36+
private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] {
37+
38+
def zero: R
39+
}
40+
41+
/**
42+
* Create a layer for specialized metric. We cannot add `@specialized` to
43+
* `Accumulable/AccumulableParam` because it will break Java source compatibility.
44+
*/
45+
private[sql] trait SQLMetricValue[T] extends Serializable {
46+
47+
def value: T
48+
49+
override def toString: String = value.toString
50+
}
51+
52+
/**
53+
* A wrapper of Long to avoid boxing and unboxing when using Accumulator
54+
*/
55+
private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] {
56+
57+
def add(incr: Long): LongSQLMetricValue = {
58+
_value += incr
59+
this
60+
}
61+
62+
// Although there is a boxing here, it's fine because it's only called in SQLListener
63+
override def value: Long = _value
64+
}
65+
66+
/**
67+
* A wrapper of Int to avoid boxing and unboxing when using Accumulator
68+
*/
69+
private[sql] class IntSQLMetricValue(private var _value: Int) extends SQLMetricValue[Int] {
70+
71+
def add(term: Int): IntSQLMetricValue = {
72+
_value += term
73+
this
74+
}
75+
76+
// Although there is a boxing here, it's fine because it's only called in SQLListener
77+
override def value: Int = _value
78+
}
79+
80+
/**
81+
* A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's
82+
* `+=` and `add`.
83+
*/
84+
private[sql] class LongSQLMetric private[metric](name: String)
85+
extends SQLMetric[LongSQLMetricValue, Long](name, LongSQLMetricParam) {
86+
87+
override def +=(term: Long): Unit = {
88+
localValue.add(term)
89+
}
90+
91+
override def add(term: Long): Unit = {
92+
localValue.add(term)
93+
}
94+
}
95+
96+
/**
97+
* A specialized int Accumulable to avoid boxing and unboxing when using Accumulator's
98+
* `+=` and `add`.
99+
*/
100+
private[sql] class IntSQLMetric private[metric](name: String)
101+
extends SQLMetric[IntSQLMetricValue, Int](name, IntSQLMetricParam) {
102+
103+
override def +=(term: Int): Unit = {
104+
localValue.add(term)
105+
}
106+
107+
override def add(term: Int): Unit = {
108+
localValue.add(term)
109+
}
110+
}
111+
112+
private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Long] {
113+
114+
override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t)
115+
116+
override def addInPlace(r1: LongSQLMetricValue, r2: LongSQLMetricValue): LongSQLMetricValue =
117+
r1.add(r2.value)
118+
119+
override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero
120+
121+
override def zero: LongSQLMetricValue = new LongSQLMetricValue(0L)
122+
}
123+
124+
private object IntSQLMetricParam extends SQLMetricParam[IntSQLMetricValue, Int] {
125+
126+
override def addAccumulator(r: IntSQLMetricValue, t: Int): IntSQLMetricValue = r.add(t)
127+
128+
override def addInPlace(r1: IntSQLMetricValue, r2: IntSQLMetricValue): IntSQLMetricValue =
129+
r1.add(r2.value)
130+
131+
override def zero(initialValue: IntSQLMetricValue): IntSQLMetricValue = zero
132+
133+
override def zero: IntSQLMetricValue = new IntSQLMetricValue(0)
134+
}
135+
136+
private[sql] object SQLMetrics {
137+
138+
def createIntMetric(sc: SparkContext, name: String): IntSQLMetric = {
139+
val acc = new IntSQLMetric(name)
140+
sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc))
141+
acc
142+
}
143+
144+
def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = {
145+
val acc = new LongSQLMetric(name)
146+
sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc))
147+
acc
148+
}
149+
}

sql/core/src/main/scala/org/apache/spark/sql/ui/SQLListener.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ import scala.collection.mutable
2121

2222
import com.google.common.annotations.VisibleForTesting
2323

24-
import org.apache.spark.{AccumulatorParam, JobExecutionStatus, Logging}
24+
import org.apache.spark.{JobExecutionStatus, Logging}
2525
import org.apache.spark.executor.TaskMetrics
2626
import org.apache.spark.scheduler._
2727
import org.apache.spark.sql.SQLContext
2828
import org.apache.spark.sql.execution.SQLExecution
29+
import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue}
2930

3031
private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener with Logging {
3132

@@ -36,8 +37,6 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit
3637

3738
// Old data in the following fields must be removed in "trimExecutionsIfNecessary".
3839
// If adding new fields, make sure "trimExecutionsIfNecessary" can clean up old data
39-
40-
// VisibleForTesting
4140
private val _executionIdToData = mutable.HashMap[Long, SQLExecutionUIData]()
4241

4342
/**
@@ -270,9 +269,10 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit
270269
accumulatorUpdate <- taskMetrics.accumulatorUpdates.toSeq) yield {
271270
accumulatorUpdate
272271
}
273-
}.filter { case (id, _) => executionUIData.accumulatorMetrics.keySet(id) }
272+
}.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) }
274273
mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId =>
275-
executionUIData.accumulatorMetrics(accumulatorId).accumulatorParam)
274+
executionUIData.accumulatorMetrics(accumulatorId).metricParam).
275+
mapValues(_.asInstanceOf[SQLMetricValue[_]].value)
276276
case None =>
277277
// This execution has been dropped
278278
Map.empty
@@ -281,10 +281,11 @@ private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener wit
281281

282282
private def mergeAccumulatorUpdates(
283283
accumulatorUpdates: Seq[(Long, Any)],
284-
paramFunc: Long => AccumulatorParam[Any]): Map[Long, Any] = {
284+
paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, Any] = {
285285
accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) =>
286286
val param = paramFunc(accumulatorId)
287-
(accumulatorId, values.map(_._2).reduceLeft(param.addInPlace))
287+
(accumulatorId,
288+
values.map(_._2.asInstanceOf[SQLMetricValue[Any]]).foldLeft(param.zero)(param.addInPlace))
288289
}
289290
}
290291

@@ -336,7 +337,7 @@ private[ui] class SQLExecutionUIData(
336337
private[ui] case class SQLPlanMetric(
337338
name: String,
338339
accumulatorId: Long,
339-
accumulatorParam: AccumulatorParam[Any])
340+
metricParam: SQLMetricParam[SQLMetricValue[Any], Any])
340341

341342
/**
342343
* Store all accumulatorUpdates for all tasks in a Spark stage.

sql/core/src/main/scala/org/apache/spark/sql/ui/SparkPlanGraph.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong
2121

2222
import scala.collection.mutable
2323

24-
import org.apache.spark.AccumulatorParam
2524
import org.apache.spark.sql.execution.SparkPlan
25+
import org.apache.spark.sql.metric.{SQLMetricParam, SQLMetricValue}
2626

2727
/**
2828
* A graph used for storing information of an executionPlan of DataFrame.
@@ -61,9 +61,9 @@ private[sql] object SparkPlanGraph {
6161
nodeIdGenerator: AtomicLong,
6262
nodes: mutable.ArrayBuffer[SparkPlanGraphNode],
6363
edges: mutable.ArrayBuffer[SparkPlanGraphEdge]): SparkPlanGraphNode = {
64-
val metrics = plan.accumulators.toSeq.map { case (key, accumulator) =>
65-
SQLPlanMetric(accumulator.name.getOrElse(key), accumulator.id,
66-
accumulator.param.asInstanceOf[AccumulatorParam[Any]])
64+
val metrics = plan.metrics.toSeq.map { case (key, metric) =>
65+
SQLPlanMetric(metric.name.getOrElse(key), metric.id,
66+
metric.param.asInstanceOf[SQLMetricParam[SQLMetricValue[Any], Any]])
6767
}
6868
val node = SparkPlanGraphNode(
6969
nodeIdGenerator.getAndIncrement(), plan.nodeName, plan.simpleString, metrics)

0 commit comments

Comments
 (0)