Skip to content

Commit 0b5d028

Browse files
zsxwingrxin
authored andcommitted
[SPARK-6602][Core] Update MapOutputTrackerMasterActor to MapOutputTrackerMasterEndpoint
This is the second PR for [SPARK-6602]. It updated MapOutputTrackerMasterActor and its unit tests. cc rxin Author: zsxwing <[email protected]> Closes #5371 from zsxwing/rpc-rewrite-part2 and squashes the following commits: fcf3816 [zsxwing] Fix the code style 4013a22 [zsxwing] Add doc for uncaught exceptions in RpcEnv 93c6c20 [zsxwing] Add an example of UnserializableException and add ErrorMonitor to monitor errors from Akka 134fe7b [zsxwing] Update MapOutputTrackerMasterActor to MapOutputTrackerMasterEndpoint
1 parent acffc43 commit 0b5d028

File tree

7 files changed

+221
-212
lines changed

7 files changed

+221
-212
lines changed

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,11 @@ import java.io._
2121
import java.util.concurrent.ConcurrentHashMap
2222
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
2323

24-
import scala.collection.mutable.{HashSet, HashMap, Map}
25-
import scala.concurrent.Await
24+
import scala.collection.mutable.{HashSet, Map}
2625
import scala.collection.JavaConversions._
26+
import scala.reflect.ClassTag
2727

28-
import akka.actor._
29-
import akka.pattern.ask
30-
28+
import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint}
3129
import org.apache.spark.scheduler.MapStatus
3230
import org.apache.spark.shuffle.MetadataFetchFailedException
3331
import org.apache.spark.storage.BlockManagerId
@@ -38,34 +36,35 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int)
3836
extends MapOutputTrackerMessage
3937
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
4038

41-
/** Actor class for MapOutputTrackerMaster */
42-
private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf)
43-
extends Actor with ActorLogReceive with Logging {
39+
/** RpcEndpoint class for MapOutputTrackerMaster */
40+
private[spark] class MapOutputTrackerMasterEndpoint(
41+
override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf)
42+
extends RpcEndpoint with Logging {
4443
val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
4544

46-
override def receiveWithLogging: PartialFunction[Any, Unit] = {
45+
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
4746
case GetMapOutputStatuses(shuffleId: Int) =>
48-
val hostPort = sender.path.address.hostPort
47+
val hostPort = context.sender.address.hostPort
4948
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
5049
val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
5150
val serializedSize = mapOutputStatuses.size
5251
if (serializedSize > maxAkkaFrameSize) {
5352
val msg = s"Map output statuses were $serializedSize bytes which " +
5453
s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)."
5554

56-
/* For SPARK-1244 we'll opt for just logging an error and then throwing an exception.
57-
* Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239)
58-
* will ultimately remove this entire code path. */
55+
/* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender.
56+
* A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */
5957
val exception = new SparkException(msg)
6058
logError(msg, exception)
61-
throw exception
59+
context.sendFailure(exception)
60+
} else {
61+
context.reply(mapOutputStatuses)
6262
}
63-
sender ! mapOutputStatuses
6463

6564
case StopMapOutputTracker =>
66-
logInfo("MapOutputTrackerActor stopped!")
67-
sender ! true
68-
context.stop(self)
65+
logInfo("MapOutputTrackerMasterEndpoint stopped!")
66+
context.reply(true)
67+
stop()
6968
}
7069
}
7170

@@ -75,12 +74,9 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
7574
* (driver and executor) use different HashMap to store its metadata.
7675
*/
7776
private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging {
78-
private val timeout = AkkaUtils.askTimeout(conf)
79-
private val retryAttempts = AkkaUtils.numRetries(conf)
80-
private val retryIntervalMs = AkkaUtils.retryWaitMs(conf)
8177

82-
/** Set to the MapOutputTrackerActor living on the driver. */
83-
var trackerActor: ActorRef = _
78+
/** Set to the MapOutputTrackerMasterEndpoint living on the driver. */
79+
var trackerEndpoint: RpcEndpointRef = _
8480

8581
/**
8682
* This HashMap has different behavior for the driver and the executors.
@@ -105,22 +101,22 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
105101
private val fetching = new HashSet[Int]
106102

107103
/**
108-
* Send a message to the trackerActor and get its result within a default timeout, or
104+
* Send a message to the trackerEndpoint and get its result within a default timeout, or
109105
* throw a SparkException if this fails.
110106
*/
111-
protected def askTracker(message: Any): Any = {
107+
protected def askTracker[T: ClassTag](message: Any): T = {
112108
try {
113-
AkkaUtils.askWithReply(message, trackerActor, retryAttempts, retryIntervalMs, timeout)
109+
trackerEndpoint.askWithReply[T](message)
114110
} catch {
115111
case e: Exception =>
116112
logError("Error communicating with MapOutputTracker", e)
117113
throw new SparkException("Error communicating with MapOutputTracker", e)
118114
}
119115
}
120116

121-
/** Send a one-way message to the trackerActor, to which we expect it to reply with true. */
117+
/** Send a one-way message to the trackerEndpoint, to which we expect it to reply with true. */
122118
protected def sendTracker(message: Any) {
123-
val response = askTracker(message)
119+
val response = askTracker[Boolean](message)
124120
if (response != true) {
125121
throw new SparkException(
126122
"Error reply received from MapOutputTracker. Expecting true, got " + response.toString)
@@ -157,11 +153,10 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
157153

158154
if (fetchedStatuses == null) {
159155
// We won the race to fetch the output locs; do so
160-
logInfo("Doing the fetch; tracker actor = " + trackerActor)
156+
logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
161157
// This try-finally prevents hangs due to timeouts:
162158
try {
163-
val fetchedBytes =
164-
askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]]
159+
val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
165160
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
166161
logInfo("Got the output locations")
167162
mapStatuses.put(shuffleId, fetchedStatuses)
@@ -328,7 +323,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
328323
override def stop() {
329324
sendTracker(StopMapOutputTracker)
330325
mapStatuses.clear()
331-
trackerActor = null
326+
trackerEndpoint = null
332327
metadataCleaner.cancel()
333328
cachedSerializedStatuses.clear()
334329
}
@@ -350,6 +345,8 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
350345

351346
private[spark] object MapOutputTracker extends Logging {
352347

348+
val ENDPOINT_NAME = "MapOutputTracker"
349+
353350
// Serialize an array of map output locations into an efficient byte format so that we can send
354351
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
355352
// generally be pretty compressible because many map outputs will be on the same hostname.

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import scala.collection.JavaConversions._
2424
import scala.collection.mutable
2525
import scala.util.Properties
2626

27-
import akka.actor._
2827
import com.google.common.collect.MapMaker
2928

3029
import org.apache.spark.annotation.DeveloperApi
@@ -41,7 +40,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato
4140
import org.apache.spark.serializer.Serializer
4241
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
4342
import org.apache.spark.storage._
44-
import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils}
43+
import org.apache.spark.util.{RpcUtils, Utils}
4544

4645
/**
4746
* :: DeveloperApi ::
@@ -286,15 +285,6 @@ object SparkEnv extends Logging {
286285
val closureSerializer = instantiateClassFromConf[Serializer](
287286
"spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")
288287

289-
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
290-
if (isDriver) {
291-
logInfo("Registering " + name)
292-
actorSystem.actorOf(Props(newActor), name = name)
293-
} else {
294-
AkkaUtils.makeDriverRef(name, conf, actorSystem)
295-
}
296-
}
297-
298288
def registerOrLookupEndpoint(
299289
name: String, endpointCreator: => RpcEndpoint):
300290
RpcEndpointRef = {
@@ -314,9 +304,9 @@ object SparkEnv extends Logging {
314304

315305
// Have to assign trackerActor after initialization as MapOutputTrackerActor
316306
// requires the MapOutputTracker itself
317-
mapOutputTracker.trackerActor = registerOrLookup(
318-
"MapOutputTracker",
319-
new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
307+
mapOutputTracker.trackerEndpoint = registerOrLookupEndpoint(MapOutputTracker.ENDPOINT_NAME,
308+
new MapOutputTrackerMasterEndpoint(
309+
rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
320310

321311
// Let the user specify short names for shuffle managers
322312
val shortShuffleMgrNames = Map(

core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ import org.apache.spark.util.{AkkaUtils, Utils}
3030
/**
3131
* An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to
3232
* receives messages. Then [[RpcEnv]] will process messages sent from [[RpcEndpointRef]] or remote
33-
* nodes, and deliver them to corresponding [[RpcEndpoint]]s.
33+
* nodes, and deliver them to corresponding [[RpcEndpoint]]s. For uncaught exceptions caught by
34+
* [[RpcEnv]], [[RpcEnv]] will use [[RpcCallContext.sendFailure]] to send exceptions back to the
35+
* sender, or logging them if no such sender or `NotSerializableException`.
3436
*
3537
* [[RpcEnv]] also provides some methods to retrieve [[RpcEndpointRef]]s given name or uri.
3638
*/

core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@
1717

1818
package org.apache.spark.rpc.akka
1919

20-
import java.net.URI
2120
import java.util.concurrent.ConcurrentHashMap
2221

23-
import scala.concurrent.{Await, Future}
22+
import scala.concurrent.Future
2423
import scala.concurrent.duration._
2524
import scala.language.postfixOps
2625
import scala.reflect.ClassTag
2726
import scala.util.control.NonFatal
2827

2928
import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address}
29+
import akka.event.Logging.Error
3030
import akka.pattern.{ask => akkaAsk}
3131
import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
3232
import org.apache.spark.{SparkException, Logging, SparkConf}
@@ -242,10 +242,25 @@ private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {
242242
def create(config: RpcEnvConfig): RpcEnv = {
243243
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
244244
config.name, config.host, config.port, config.conf, config.securityManager)
245+
actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor")
245246
new AkkaRpcEnv(actorSystem, config.conf, boundPort)
246247
}
247248
}
248249

250+
/**
251+
* Monitor errors reported by Akka and log them.
252+
*/
253+
private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging {
254+
255+
override def preStart(): Unit = {
256+
context.system.eventStream.subscribe(self, classOf[Error])
257+
}
258+
259+
override def receiveWithLogging: Actor.Receive = {
260+
case Error(cause: Throwable, _, _, message: String) => logError(message, cause)
261+
}
262+
}
263+
249264
private[akka] class AkkaRpcEndpointRef(
250265
@transient defaultAddress: RpcAddress,
251266
@transient _actorRef: => ActorRef,

0 commit comments

Comments
 (0)