@@ -20,7 +20,8 @@ package org.apache.spark
2020import java .io .{ObjectInputStream , Serializable }
2121
2222import scala .collection .generic .Growable
23- import scala .collection .mutable .Map
23+ import scala .collection .Map
24+ import scala .collection .mutable
2425import scala .ref .WeakReference
2526import scala .reflect .ClassTag
2627
@@ -39,25 +40,44 @@ import org.apache.spark.util.Utils
3940 * @param initialValue initial value of accumulator
4041 * @param param helper object defining how to add elements of type `R` and `T`
4142 * @param name human-readable name for use in Spark's web UI
43+ * @param internal if this [[Accumulable ]] is internal. Internal [[Accumulable ]]s will be reported
44+ * to the driver via heartbeats. For internal [[Accumulable ]]s, `R` must be
45+ * thread safe so that they can be reported correctly.
4246 * @tparam R the full accumulated data (result type)
4347 * @tparam T partial data that can be added in
4448 */
45- class Accumulable [R , T ] (
49+ class Accumulable [R , T ] private [spark] (
4650 @ transient initialValue : R ,
4751 param : AccumulableParam [R , T ],
48- val name : Option [String ])
52+ val name : Option [String ],
53+ internal : Boolean )
4954 extends Serializable {
5055
56+ private [spark] def this (
57+ @ transient initialValue : R , param : AccumulableParam [R , T ], internal : Boolean ) = {
58+ this (initialValue, param, None , internal)
59+ }
60+
61+ def this (@ transient initialValue : R , param : AccumulableParam [R , T ], name : Option [String ]) =
62+ this (initialValue, param, name, false )
63+
5164 def this (@ transient initialValue : R , param : AccumulableParam [R , T ]) =
5265 this (initialValue, param, None )
5366
5467 val id : Long = Accumulators .newId
5568
56- @ transient private var value_ = initialValue // Current value on master
69+ @ volatile @ transient private var value_ : R = initialValue // Current value on master
5770 val zero = param.zero(initialValue) // Zero value to be passed to workers
5871 private var deserialized = false
5972
60- Accumulators .register(this , true )
73+ Accumulators .register(this )
74+
75+ /**
76+ * If this [[Accumulable ]] is internal. Internal [[Accumulable ]]s will be reported to the driver
77+ * via heartbeats. For internal [[Accumulable ]]s, `R` must be thread safe so that they can be
78+ * reported correctly.
79+ */
80+ private [spark] def isInternal : Boolean = internal
6181
6282 /**
6383 * Add more data to this accumulator / accumulable
@@ -132,7 +152,8 @@ class Accumulable[R, T] (
132152 in.defaultReadObject()
133153 value_ = zero
134154 deserialized = true
135- Accumulators .register(this , false )
155+ val taskContext = TaskContext .get()
156+ taskContext.registerAccumulator(this )
136157 }
137158
138159 override def toString : String = if (value_ == null ) " null" else value_.toString
@@ -284,16 +305,7 @@ private[spark] object Accumulators extends Logging {
284305 * It keeps weak references to these objects so that accumulators can be garbage-collected
285306 * once the RDDs and user-code that reference them are cleaned up.
286307 */
287- val originals = Map [Long , WeakReference [Accumulable [_, _]]]()
288-
289- /**
290- * This thread-local map holds per-task copies of accumulators; it is used to collect the set
291- * of accumulator updates to send back to the driver when tasks complete. After tasks complete,
292- * this map is cleared by `Accumulators.clear()` (see Executor.scala).
293- */
294- private val localAccums = new ThreadLocal [Map [Long , Accumulable [_, _]]]() {
295- override protected def initialValue () = Map [Long , Accumulable [_, _]]()
296- }
308+ val originals = mutable.Map [Long , WeakReference [Accumulable [_, _]]]()
297309
298310 private var lastId : Long = 0
299311
@@ -302,19 +314,8 @@ private[spark] object Accumulators extends Logging {
302314 lastId
303315 }
304316
305- def register (a : Accumulable [_, _], original : Boolean ): Unit = synchronized {
306- if (original) {
307- originals(a.id) = new WeakReference [Accumulable [_, _]](a)
308- } else {
309- localAccums.get()(a.id) = a
310- }
311- }
312-
313- // Clear the local (non-original) accumulators for the current thread
314- def clear () {
315- synchronized {
316- localAccums.get.clear()
317- }
317+ def register (a : Accumulable [_, _]): Unit = synchronized {
318+ originals(a.id) = new WeakReference [Accumulable [_, _]](a)
318319 }
319320
320321 def remove (accId : Long ) {
@@ -323,15 +324,6 @@ private[spark] object Accumulators extends Logging {
323324 }
324325 }
325326
326- // Get the values of the local accumulators for the current thread (by ID)
327- def values : Map [Long , Any ] = synchronized {
328- val ret = Map [Long , Any ]()
329- for ((id, accum) <- localAccums.get) {
330- ret(id) = accum.localValue
331- }
332- return ret
333- }
334-
335327 // Add values to the original accumulators with some given IDs
336328 def add (values : Map [Long , Any ]): Unit = synchronized {
337329 for ((id, value) <- values) {
0 commit comments