Skip to content

Commit 7eefc9d

Browse files
mateizpwendell
authored andcommitted
SPARK-1708. Add a ClassTag on Serializer and things that depend on it
This pull request contains a rebased patch from @heathermiller (heathermiller#1) to add ClassTags on Serializer and types that depend on it (Broadcast and AccumulableCollection). Putting these in the public API signatures now will allow us to use Scala Pickling for serialization down the line without breaking binary compatibility. One question remaining is whether we also want them on Accumulator -- Accumulator is passed as part of a bigger Task or TaskResult object via the closure serializer so it doesn't seem super useful to add the ClassTag there. Broadcast and AccumulableCollection in contrast were being serialized directly. CC @rxin, @pwendell, @heathermiller Author: Matei Zaharia <[email protected]> Closes #700 from mateiz/spark-1708 and squashes the following commits: 1a3d8b0 [Matei Zaharia] Use fake ClassTag in Java 3b449ed [Matei Zaharia] test fix 2209a27 [Matei Zaharia] Code style fixes 9d48830 [Matei Zaharia] Add a ClassTag on Serializer and things that depend on it
1 parent 8e94d27 commit 7eefc9d

File tree

22 files changed

+103
-72
lines changed

22 files changed

+103
-72
lines changed

core/src/main/scala/org/apache/spark/Accumulators.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.io.{ObjectInputStream, Serializable}
2121

2222
import scala.collection.generic.Growable
2323
import scala.collection.mutable.Map
24+
import scala.reflect.ClassTag
2425

2526
import org.apache.spark.serializer.JavaSerializer
2627

@@ -164,9 +165,9 @@ trait AccumulableParam[R, T] extends Serializable {
164165
def zero(initialValue: R): R
165166
}
166167

167-
private[spark]
168-
class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
169-
extends AccumulableParam[R,T] {
168+
private[spark] class
169+
GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T]
170+
extends AccumulableParam[R, T] {
170171

171172
def addAccumulator(growable: R, elem: T): R = {
172173
growable += elem

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,7 @@ class SparkContext(config: SparkConf) extends Logging {
756756
* Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by
757757
* standard mutable collections. So you can use this with mutable Map, Set, etc.
758758
*/
759-
def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
759+
def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T]
760760
(initialValue: R): Accumulable[R, T] = {
761761
val param = new GrowableAccumulableParam[R,T]
762762
new Accumulable(initialValue, param)
@@ -767,7 +767,7 @@ class SparkContext(config: SparkConf) extends Logging {
767767
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
768768
* The variable will be sent to each cluster only once.
769769
*/
770-
def broadcast[T](value: T): Broadcast[T] = {
770+
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
771771
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
772772
cleaner.foreach(_.registerBroadcastForCleanup(bc))
773773
bc

core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
447447
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
448448
* The variable will be sent to each cluster only once.
449449
*/
450-
def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value)
450+
def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value)(fakeClassTag)
451451

452452
/** Shut down the SparkContext. */
453453
def stop() {

core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import java.io.Serializable
2121

2222
import org.apache.spark.SparkException
2323

24+
import scala.reflect.ClassTag
25+
2426
/**
2527
* A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable
2628
* cached on each machine rather than shipping a copy of it with tasks. They can be used, for
@@ -50,7 +52,7 @@ import org.apache.spark.SparkException
5052
* @param id A unique identifier for the broadcast variable.
5153
* @tparam T Type of the data contained in the broadcast variable.
5254
*/
53-
abstract class Broadcast[T](val id: Long) extends Serializable {
55+
abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
5456

5557
/**
5658
* Flag signifying whether the broadcast variable is valid

core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.broadcast
1919

20+
import scala.reflect.ClassTag
21+
2022
import org.apache.spark.SecurityManager
2123
import org.apache.spark.SparkConf
2224
import org.apache.spark.annotation.DeveloperApi
@@ -31,7 +33,7 @@ import org.apache.spark.annotation.DeveloperApi
3133
@DeveloperApi
3234
trait BroadcastFactory {
3335
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
34-
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
36+
def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T]
3537
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit
3638
def stop(): Unit
3739
}

core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.broadcast
1919

2020
import java.util.concurrent.atomic.AtomicLong
2121

22+
import scala.reflect.ClassTag
23+
2224
import org.apache.spark._
2325

2426
private[spark] class BroadcastManager(
@@ -56,7 +58,7 @@ private[spark] class BroadcastManager(
5658

5759
private val nextBroadcastId = new AtomicLong(0)
5860

59-
def newBroadcast[T](value_ : T, isLocal: Boolean) = {
61+
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean) = {
6062
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
6163
}
6264

core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import java.io.{BufferedInputStream, BufferedOutputStream}
2222
import java.net.{URL, URLConnection, URI}
2323
import java.util.concurrent.TimeUnit
2424

25+
import scala.reflect.ClassTag
26+
2527
import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv}
2628
import org.apache.spark.io.CompressionCodec
2729
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
@@ -34,7 +36,8 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
3436
* (through a HTTP server running at the driver) and stored in the BlockManager of the
3537
* executor to speed up future accesses.
3638
*/
37-
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
39+
private[spark] class HttpBroadcast[T: ClassTag](
40+
@transient var value_ : T, isLocal: Boolean, id: Long)
3841
extends Broadcast[T](id) with Logging with Serializable {
3942

4043
def getValue = value_
@@ -173,7 +176,7 @@ private[spark] object HttpBroadcast extends Logging {
173176
files += file.getAbsolutePath
174177
}
175178

176-
def read[T](id: Long): T = {
179+
def read[T: ClassTag](id: Long): T = {
177180
logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id)
178181
val url = serverUri + "/" + BroadcastBlockId(id).name
179182

core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.broadcast
1919

20+
import scala.reflect.ClassTag
21+
2022
import org.apache.spark.{SecurityManager, SparkConf}
2123

2224
/**
@@ -29,7 +31,7 @@ class HttpBroadcastFactory extends BroadcastFactory {
2931
HttpBroadcast.initialize(isDriver, conf, securityMgr)
3032
}
3133

32-
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
34+
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) =
3335
new HttpBroadcast[T](value_, isLocal, id)
3436

3537
def stop() { HttpBroadcast.stop() }

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.broadcast
1919

2020
import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream}
2121

22+
import scala.reflect.ClassTag
2223
import scala.math
2324
import scala.util.Random
2425

@@ -44,7 +45,8 @@ import org.apache.spark.util.Utils
4445
* copies of the broadcast data (one per executor) as done by the
4546
* [[org.apache.spark.broadcast.HttpBroadcast]].
4647
*/
47-
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
48+
private[spark] class TorrentBroadcast[T: ClassTag](
49+
@transient var value_ : T, isLocal: Boolean, id: Long)
4850
extends Broadcast[T](id) with Logging with Serializable {
4951

5052
def getValue = value_

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.broadcast
1919

20+
import scala.reflect.ClassTag
21+
2022
import org.apache.spark.{SecurityManager, SparkConf}
2123

2224
/**
@@ -30,7 +32,7 @@ class TorrentBroadcastFactory extends BroadcastFactory {
3032
TorrentBroadcast.initialize(isDriver, conf)
3133
}
3234

33-
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
35+
def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) =
3436
new TorrentBroadcast[T](value_, isLocal, id)
3537

3638
def stop() { TorrentBroadcast.stop() }

0 commit comments

Comments
 (0)