@@ -21,13 +21,11 @@ import java.io._
2121import java .util .concurrent .ConcurrentHashMap
2222import java .util .zip .{GZIPInputStream , GZIPOutputStream }
2323
24- import scala .collection .mutable .{HashSet , HashMap , Map }
25- import scala .concurrent .Await
24+ import scala .collection .mutable .{HashSet , Map }
2625import scala .collection .JavaConversions ._
26+ import scala .reflect .ClassTag
2727
28- import akka .actor ._
29- import akka .pattern .ask
30-
28+ import org .apache .spark .rpc .{RpcEndpointRef , RpcEnv , RpcCallContext , RpcEndpoint }
3129import org .apache .spark .scheduler .MapStatus
3230import org .apache .spark .shuffle .MetadataFetchFailedException
3331import org .apache .spark .storage .BlockManagerId
@@ -38,34 +36,35 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int)
3836 extends MapOutputTrackerMessage
3937private [spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
4038
41- /** Actor class for MapOutputTrackerMaster */
42- private [spark] class MapOutputTrackerMasterActor (tracker : MapOutputTrackerMaster , conf : SparkConf )
43- extends Actor with ActorLogReceive with Logging {
39+ /** RpcEndpoint class for MapOutputTrackerMaster */
40+ private [spark] class MapOutputTrackerMasterEndpoint (
41+ override val rpcEnv : RpcEnv , tracker : MapOutputTrackerMaster , conf : SparkConf )
42+ extends RpcEndpoint with Logging {
4443 val maxAkkaFrameSize = AkkaUtils .maxFrameSizeBytes(conf)
4544
46- override def receiveWithLogging : PartialFunction [Any , Unit ] = {
45+ override def receiveAndReply ( context : RpcCallContext ) : PartialFunction [Any , Unit ] = {
4746 case GetMapOutputStatuses (shuffleId : Int ) =>
48- val hostPort = sender.path .address.hostPort
47+ val hostPort = context.sender .address.hostPort
4948 logInfo(" Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
5049 val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
5150 val serializedSize = mapOutputStatuses.size
5251 if (serializedSize > maxAkkaFrameSize) {
5352 val msg = s " Map output statuses were $serializedSize bytes which " +
5453 s " exceeds spark.akka.frameSize ( $maxAkkaFrameSize bytes). "
5554
56- /* For SPARK-1244 we'll opt for just logging an error and then throwing an exception.
57- * Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239)
58- * will ultimately remove this entire code path. */
55+ /* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender.
56+ * A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */
5957 val exception = new SparkException (msg)
6058 logError(msg, exception)
61- throw exception
59+ context.sendFailure(exception)
60+ } else {
61+ context.reply(mapOutputStatuses)
6262 }
63- sender ! mapOutputStatuses
6463
6564 case StopMapOutputTracker =>
66- logInfo(" MapOutputTrackerActor stopped!" )
67- sender ! true
68- context. stop(self )
65+ logInfo(" MapOutputTrackerMasterEndpoint stopped!" )
66+ context.reply( true )
67+ stop()
6968 }
7069}
7170
@@ -75,12 +74,9 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
7574 * (driver and executor) use different HashMap to store its metadata.
7675 */
7776private [spark] abstract class MapOutputTracker (conf : SparkConf ) extends Logging {
78- private val timeout = AkkaUtils .askTimeout(conf)
79- private val retryAttempts = AkkaUtils .numRetries(conf)
80- private val retryIntervalMs = AkkaUtils .retryWaitMs(conf)
8177
82- /** Set to the MapOutputTrackerActor living on the driver. */
83- var trackerActor : ActorRef = _
78+ /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */
79+ var trackerEndpoint : RpcEndpointRef = _
8480
8581 /**
8682 * This HashMap has different behavior for the driver and the executors.
@@ -105,22 +101,22 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
105101 private val fetching = new HashSet [Int ]
106102
107103 /**
108- * Send a message to the trackerActor and get its result within a default timeout, or
104+ * Send a message to the trackerEndpoint and get its result within a default timeout, or
109105 * throw a SparkException if this fails.
110106 */
111- protected def askTracker (message : Any ): Any = {
107+ protected def askTracker [ T : ClassTag ] (message : Any ): T = {
112108 try {
113- AkkaUtils .askWithReply(message, trackerActor, retryAttempts, retryIntervalMs, timeout )
109+ trackerEndpoint .askWithReply[ T ] (message)
114110 } catch {
115111 case e : Exception =>
116112 logError(" Error communicating with MapOutputTracker" , e)
117113 throw new SparkException (" Error communicating with MapOutputTracker" , e)
118114 }
119115 }
120116
121- /** Send a one-way message to the trackerActor , to which we expect it to reply with true. */
117+ /** Send a one-way message to the trackerEndpoint , to which we expect it to reply with true. */
122118 protected def sendTracker (message : Any ) {
123- val response = askTracker(message)
119+ val response = askTracker[ Boolean ] (message)
124120 if (response != true ) {
125121 throw new SparkException (
126122 " Error reply received from MapOutputTracker. Expecting true, got " + response.toString)
@@ -157,11 +153,10 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
157153
158154 if (fetchedStatuses == null ) {
159155 // We won the race to fetch the output locs; do so
160- logInfo(" Doing the fetch; tracker actor = " + trackerActor )
156+ logInfo(" Doing the fetch; tracker endpoint = " + trackerEndpoint )
161157 // This try-finally prevents hangs due to timeouts:
162158 try {
163- val fetchedBytes =
164- askTracker(GetMapOutputStatuses (shuffleId)).asInstanceOf [Array [Byte ]]
159+ val fetchedBytes = askTracker[Array [Byte ]](GetMapOutputStatuses (shuffleId))
165160 fetchedStatuses = MapOutputTracker .deserializeMapStatuses(fetchedBytes)
166161 logInfo(" Got the output locations" )
167162 mapStatuses.put(shuffleId, fetchedStatuses)
@@ -328,7 +323,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
328323 override def stop () {
329324 sendTracker(StopMapOutputTracker )
330325 mapStatuses.clear()
331- trackerActor = null
326+ trackerEndpoint = null
332327 metadataCleaner.cancel()
333328 cachedSerializedStatuses.clear()
334329 }
@@ -350,6 +345,8 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
350345
351346private [spark] object MapOutputTracker extends Logging {
352347
348+ val ENDPOINT_NAME = " MapOutputTracker"
349+
353350 // Serialize an array of map output locations into an efficient byte format so that we can send
354351 // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
355352 // generally be pretty compressible because many map outputs will be on the same hostname.
0 commit comments