Skip to content

Commit f15806a

Browse files
zsxwingrxin
authored andcommitted
[SPARK-6602][Core] Replace direct use of Akka with Spark RPC interface - part 1
This PR replaced the following `Actor`s to `RpcEndpoint`: 1. HeartbeatReceiver 1. ExecutorActor 1. BlockManagerMasterActor 1. BlockManagerSlaveActor 1. CoarseGrainedExecutorBackend and subclasses 1. CoarseGrainedSchedulerBackend.DriverActor This is the first PR. I will split the work of SPARK-6602 to several PRs for code review. Author: zsxwing <[email protected]> Closes apache#5268 from zsxwing/rpc-rewrite and squashes the following commits: 287e9f8 [zsxwing] Fix the code style 26c56b7 [zsxwing] Merge branch 'master' into rpc-rewrite 9cc825a [zsxwing] Rmove setupThreadSafeEndpoint and add ThreadSafeRpcEndpoint 30a9036 [zsxwing] Make self return null after stopping RpcEndpointRef; fix docs and error messages 705245d [zsxwing] Fix some bugs after rebasing the changes on the master 003cf80 [zsxwing] Update CoarseGrainedExecutorBackend and CoarseGrainedSchedulerBackend to use RpcEndpoint 7d0e6dc [zsxwing] Update BlockManagerSlaveActor to use RpcEndpoint f5d6543 [zsxwing] Update BlockManagerMaster to use RpcEndpoint 30e3f9f [zsxwing] Update ExecutorActor to use RpcEndpoint 478b443 [zsxwing] Update HeartbeatReceiver to use RpcEndpoint
1 parent 7bca62f commit f15806a

30 files changed

+616
-542
lines changed

core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@
1717

1818
package org.apache.spark
1919

20-
import scala.concurrent.duration._
21-
import scala.collection.mutable
20+
import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors}
2221

23-
import akka.actor.{Actor, Cancellable}
22+
import scala.collection.mutable
2423

2524
import org.apache.spark.executor.TaskMetrics
25+
import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext}
2626
import org.apache.spark.storage.BlockManagerId
2727
import org.apache.spark.scheduler.{SlaveLost, TaskScheduler}
28-
import org.apache.spark.util.ActorLogReceive
28+
import org.apache.spark.util.Utils
2929

3030
/**
3131
* A heartbeat from executors to the driver. This is a shared message used by several internal
@@ -51,9 +51,11 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean)
5151
* Lives in the driver to receive heartbeats from executors..
5252
*/
5353
private[spark] class HeartbeatReceiver(sc: SparkContext)
54-
extends Actor with ActorLogReceive with Logging {
54+
extends ThreadSafeRpcEndpoint with Logging {
55+
56+
override val rpcEnv: RpcEnv = sc.env.rpcEnv
5557

56-
private var scheduler: TaskScheduler = null
58+
private[spark] var scheduler: TaskScheduler = null
5759

5860
// executor ID -> timestamp of when the last heartbeat from this executor was received
5961
private val executorLastSeen = new mutable.HashMap[String, Long]
@@ -69,34 +71,44 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
6971
sc.conf.getOption("spark.network.timeoutInterval").map(_.toLong * 1000).
7072
getOrElse(sc.conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000))
7173

72-
private var timeoutCheckingTask: Cancellable = null
73-
74-
override def preStart(): Unit = {
75-
import context.dispatcher
76-
timeoutCheckingTask = context.system.scheduler.schedule(0.seconds,
77-
checkTimeoutIntervalMs.milliseconds, self, ExpireDeadHosts)
78-
super.preStart()
74+
private var timeoutCheckingTask: ScheduledFuture[_] = null
75+
76+
private val timeoutCheckingThread = Executors.newSingleThreadScheduledExecutor(
77+
Utils.namedThreadFactory("heartbeat-timeout-checking-thread"))
78+
79+
private val killExecutorThread = Executors.newSingleThreadExecutor(
80+
Utils.namedThreadFactory("kill-executor-thread"))
81+
82+
override def onStart(): Unit = {
83+
timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable {
84+
override def run(): Unit = Utils.tryLogNonFatalError {
85+
Option(self).foreach(_.send(ExpireDeadHosts))
86+
}
87+
}, 0, checkTimeoutIntervalMs, TimeUnit.MILLISECONDS)
7988
}
80-
81-
override def receiveWithLogging: PartialFunction[Any, Unit] = {
89+
90+
override def receive: PartialFunction[Any, Unit] = {
91+
case ExpireDeadHosts =>
92+
expireDeadHosts()
8293
case TaskSchedulerIsSet =>
8394
scheduler = sc.taskScheduler
95+
}
96+
97+
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
8498
case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) =>
8599
if (scheduler != null) {
86100
val unknownExecutor = !scheduler.executorHeartbeatReceived(
87101
executorId, taskMetrics, blockManagerId)
88102
val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
89103
executorLastSeen(executorId) = System.currentTimeMillis()
90-
sender ! response
104+
context.reply(response)
91105
} else {
92106
// Because Executor will sleep several seconds before sending the first "Heartbeat", this
93107
// case rarely happens. However, if it really happens, log it and ask the executor to
94108
// register itself again.
95109
logWarning(s"Dropping $heartbeat because TaskScheduler is not ready yet")
96-
sender ! HeartbeatResponse(reregisterBlockManager = true)
110+
context.reply(HeartbeatResponse(reregisterBlockManager = true))
97111
}
98-
case ExpireDeadHosts =>
99-
expireDeadHosts()
100112
}
101113

102114
private def expireDeadHosts(): Unit = {
@@ -109,17 +121,25 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
109121
scheduler.executorLost(executorId, SlaveLost("Executor heartbeat " +
110122
s"timed out after ${now - lastSeenMs} ms"))
111123
if (sc.supportDynamicAllocation) {
112-
sc.killExecutor(executorId)
124+
// Asynchronously kill the executor to avoid blocking the current thread
125+
killExecutorThread.submit(new Runnable {
126+
override def run(): Unit = sc.killExecutor(executorId)
127+
})
113128
}
114129
executorLastSeen.remove(executorId)
115130
}
116131
}
117132
}
118133

119-
override def postStop(): Unit = {
134+
override def onStop(): Unit = {
120135
if (timeoutCheckingTask != null) {
121-
timeoutCheckingTask.cancel()
136+
timeoutCheckingTask.cancel(true)
122137
}
123-
super.postStop()
138+
timeoutCheckingThread.shutdownNow()
139+
killExecutorThread.shutdownNow()
124140
}
125141
}
142+
143+
object HeartbeatReceiver {
144+
val ENDPOINT_NAME = "HeartbeatReceiver"
145+
}

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ import scala.collection.generic.Growable
3232
import scala.collection.mutable.HashMap
3333
import scala.reflect.{ClassTag, classTag}
3434

35-
import akka.actor.Props
36-
3735
import org.apache.hadoop.conf.Configuration
3836
import org.apache.hadoop.fs.Path
3937
import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable,
@@ -48,12 +46,13 @@ import org.apache.mesos.MesosNativeLibrary
4846
import org.apache.spark.annotation.{DeveloperApi, Experimental}
4947
import org.apache.spark.broadcast.Broadcast
5048
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
51-
import org.apache.spark.executor.TriggerThreadDump
49+
import org.apache.spark.executor.{ExecutorEndpoint, TriggerThreadDump}
5250
import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat,
5351
FixedLengthBinaryInputFormat}
5452
import org.apache.spark.io.CompressionCodec
5553
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
5654
import org.apache.spark.rdd._
55+
import org.apache.spark.rpc.RpcAddress
5756
import org.apache.spark.scheduler._
5857
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
5958
SparkDeploySchedulerBackend, SimrSchedulerBackend}
@@ -360,14 +359,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
360359

361360
// We need to register "HeartbeatReceiver" before "createTaskScheduler" because Executor will
362361
// retrieve "HeartbeatReceiver" in the constructor. (SPARK-6640)
363-
private val heartbeatReceiver = env.actorSystem.actorOf(
364-
Props(new HeartbeatReceiver(this)), "HeartbeatReceiver")
362+
private val heartbeatReceiver = env.rpcEnv.setupEndpoint(
363+
HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this))
365364

366365
// Create and start the scheduler
367366
private[spark] var (schedulerBackend, taskScheduler) =
368367
SparkContext.createTaskScheduler(this, master)
369368

370-
heartbeatReceiver ! TaskSchedulerIsSet
369+
heartbeatReceiver.send(TaskSchedulerIsSet)
371370

372371
@volatile private[spark] var dagScheduler: DAGScheduler = _
373372
try {
@@ -455,10 +454,12 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
455454
if (executorId == SparkContext.DRIVER_IDENTIFIER) {
456455
Some(Utils.getThreadDump())
457456
} else {
458-
val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get
459-
val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem)
460-
Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef,
461-
AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf)))
457+
val (host, port) = env.blockManager.master.getRpcHostPortForExecutor(executorId).get
458+
val endpointRef = env.rpcEnv.setupEndpointRef(
459+
SparkEnv.executorActorSystemName,
460+
RpcAddress(host, port),
461+
ExecutorEndpoint.EXECUTOR_ENDPOINT_NAME)
462+
Some(endpointRef.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump))
462463
}
463464
} catch {
464465
case e: Exception =>
@@ -1418,7 +1419,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
14181419
dagScheduler = null
14191420
listenerBus.stop()
14201421
eventLogger.foreach(_.stop())
1421-
env.actorSystem.stop(heartbeatReceiver)
1422+
env.rpcEnv.stop(heartbeatReceiver)
14221423
progressBar.foreach(_.stop())
14231424
taskScheduler = null
14241425
// TODO: Cache.stop()?

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,9 @@ object SparkEnv extends Logging {
295295
}
296296
}
297297

298-
def registerOrLookupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = {
298+
def registerOrLookupEndpoint(
299+
name: String, endpointCreator: => RpcEndpoint):
300+
RpcEndpointRef = {
299301
if (isDriver) {
300302
logInfo("Registering " + name)
301303
rpcEnv.setupEndpoint(name, endpointCreator)
@@ -334,12 +336,13 @@ object SparkEnv extends Logging {
334336
new NioBlockTransferService(conf, securityManager)
335337
}
336338

337-
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
338-
"BlockManagerMaster",
339-
new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver)
339+
val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint(
340+
BlockManagerMaster.DRIVER_ENDPOINT_NAME,
341+
new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus)),
342+
conf, isDriver)
340343

341344
// NB: blockManager is not valid until initialize() is called later.
342-
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
345+
val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster,
343346
serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager,
344347
numUsableCores)
345348

core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala

Lines changed: 43 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,39 +21,45 @@ import java.net.URL
2121
import java.nio.ByteBuffer
2222

2323
import scala.collection.mutable
24-
import scala.concurrent.Await
24+
import scala.util.{Failure, Success}
2525

26-
import akka.actor.{Actor, ActorSelection, Props}
27-
import akka.pattern.Patterns
28-
import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent}
29-
30-
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv}
26+
import org.apache.spark.rpc._
27+
import org.apache.spark._
3128
import org.apache.spark.TaskState.TaskState
3229
import org.apache.spark.deploy.SparkHadoopUtil
3330
import org.apache.spark.deploy.worker.WorkerWatcher
3431
import org.apache.spark.scheduler.TaskDescription
3532
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
36-
import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils}
33+
import org.apache.spark.util.{SignalLogger, Utils}
3734

3835
private[spark] class CoarseGrainedExecutorBackend(
36+
override val rpcEnv: RpcEnv,
3937
driverUrl: String,
4038
executorId: String,
4139
hostPort: String,
4240
cores: Int,
4341
userClassPath: Seq[URL],
4442
env: SparkEnv)
45-
extends Actor with ActorLogReceive with ExecutorBackend with Logging {
43+
extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {
4644

4745
Utils.checkHostPort(hostPort, "Expected hostport")
4846

4947
var executor: Executor = null
50-
var driver: ActorSelection = null
48+
@volatile var driver: Option[RpcEndpointRef] = None
5149

52-
override def preStart() {
50+
override def onStart() {
51+
import scala.concurrent.ExecutionContext.Implicits.global
5352
logInfo("Connecting to driver: " + driverUrl)
54-
driver = context.actorSelection(driverUrl)
55-
driver ! RegisterExecutor(executorId, hostPort, cores, extractLogUrls)
56-
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
53+
rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref =>
54+
driver = Some(ref)
55+
ref.sendWithReply[RegisteredExecutor.type](
56+
RegisterExecutor(executorId, self, hostPort, cores, extractLogUrls))
57+
} onComplete {
58+
case Success(msg) => Utils.tryLogNonFatalError {
59+
Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor
60+
}
61+
case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e)
62+
}
5763
}
5864

5965
def extractLogUrls: Map[String, String] = {
@@ -62,7 +68,7 @@ private[spark] class CoarseGrainedExecutorBackend(
6268
.map(e => (e._1.substring(prefix.length).toLowerCase, e._2))
6369
}
6470

65-
override def receiveWithLogging: PartialFunction[Any, Unit] = {
71+
override def receive: PartialFunction[Any, Unit] = {
6672
case RegisteredExecutor =>
6773
logInfo("Successfully registered with driver")
6874
val (hostname, _) = Utils.parseHostPort(hostPort)
@@ -92,23 +98,28 @@ private[spark] class CoarseGrainedExecutorBackend(
9298
executor.killTask(taskId, interruptThread)
9399
}
94100

95-
case x: DisassociatedEvent =>
96-
if (x.remoteAddress == driver.anchorPath.address) {
97-
logError(s"Driver $x disassociated! Shutting down.")
98-
System.exit(1)
99-
} else {
100-
logWarning(s"Received irrelevant DisassociatedEvent $x")
101-
}
102-
103101
case StopExecutor =>
104102
logInfo("Driver commanded a shutdown")
105103
executor.stop()
106-
context.stop(self)
107-
context.system.shutdown()
104+
stop()
105+
rpcEnv.shutdown()
106+
}
107+
108+
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
109+
if (driver.exists(_.address == remoteAddress)) {
110+
logError(s"Driver $remoteAddress disassociated! Shutting down.")
111+
System.exit(1)
112+
} else {
113+
logWarning(s"An unknown ($remoteAddress) driver disconnected.")
114+
}
108115
}
109116

110117
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
111-
driver ! StatusUpdate(executorId, taskId, state, data)
118+
val msg = StatusUpdate(executorId, taskId, state, data)
119+
driver match {
120+
case Some(driverRef) => driverRef.send(msg)
121+
case None => logWarning(s"Drop $msg because has not yet connected to driver")
122+
}
112123
}
113124
}
114125

@@ -132,16 +143,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
132143
// Bootstrap to fetch the driver's Spark properties.
133144
val executorConf = new SparkConf
134145
val port = executorConf.getInt("spark.executor.port", 0)
135-
val (fetcher, _) = AkkaUtils.createActorSystem(
146+
val fetcher = RpcEnv.create(
136147
"driverPropsFetcher",
137148
hostname,
138149
port,
139150
executorConf,
140151
new SecurityManager(executorConf))
141-
val driver = fetcher.actorSelection(driverUrl)
142-
val timeout = AkkaUtils.askTimeout(executorConf)
143-
val fut = Patterns.ask(driver, RetrieveSparkProps, timeout)
144-
val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++
152+
val driver = fetcher.setupEndpointRefByURI(driverUrl)
153+
val props = driver.askWithReply[Seq[(String, String)]](RetrieveSparkProps) ++
145154
Seq[(String, String)](("spark.app.id", appId))
146155
fetcher.shutdown()
147156

@@ -162,16 +171,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
162171
val boundPort = env.conf.getInt("spark.executor.port", 0)
163172
assert(boundPort != 0)
164173

165-
// Start the CoarseGrainedExecutorBackend actor.
174+
// Start the CoarseGrainedExecutorBackend endpoint.
166175
val sparkHostPort = hostname + ":" + boundPort
167-
env.actorSystem.actorOf(
168-
Props(classOf[CoarseGrainedExecutorBackend],
169-
driverUrl, executorId, sparkHostPort, cores, userClassPath, env),
170-
name = "Executor")
176+
env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend(
177+
env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env))
171178
workerUrl.foreach { url =>
172179
env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url))
173180
}
174-
env.actorSystem.awaitTermination()
181+
env.rpcEnv.awaitTermination()
175182
}
176183
}
177184

0 commit comments

Comments
 (0)