Skip to content

Commit f5d6543

Browse files
committed
Update BlockManagerMaster to use RpcEndpoint
1 parent 30e3f9f commit f5d6543

File tree

7 files changed

+102
-115
lines changed

7 files changed

+102
-115
lines changed

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -295,10 +295,16 @@ object SparkEnv extends Logging {
295295
}
296296
}
297297

298-
def registerOrLookupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = {
298+
def registerOrLookupEndpoint(
299+
name: String, endpointCreator: => RpcEndpoint, threadSafe: Boolean = false):
300+
RpcEndpointRef = {
299301
if (isDriver) {
300302
logInfo("Registering " + name)
301-
rpcEnv.setupEndpoint(name, endpointCreator)
303+
if (threadSafe) {
304+
rpcEnv.setupThreadSafeEndpoint(name, endpointCreator)
305+
} else {
306+
rpcEnv.setupEndpoint(name, endpointCreator)
307+
}
302308
} else {
303309
RpcUtils.makeDriverRef(name, conf, rpcEnv)
304310
}
@@ -334,9 +340,9 @@ object SparkEnv extends Logging {
334340
new NioBlockTransferService(conf, securityManager)
335341
}
336342

337-
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
338-
"BlockManagerMaster",
339-
new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver)
343+
val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint(
344+
BlockManagerMaster.DRIVER_AKKA_ACTOR_NAME,
345+
new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus), true), conf, isDriver)
340346

341347
// NB: blockManager is not valid until initialize() is called later.
342348
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,10 @@ import java.util.concurrent.{TimeUnit, Executors}
2323
import java.util.concurrent.atomic.AtomicInteger
2424

2525
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
26-
import scala.concurrent.Await
2726
import scala.concurrent.duration._
2827
import scala.language.postfixOps
2928
import scala.util.control.NonFatal
3029

31-
import akka.pattern.ask
32-
import akka.util.Timeout
33-
3430
import org.apache.spark._
3531
import org.apache.spark.broadcast.Broadcast
3632
import org.apache.spark.executor.TaskMetrics
@@ -165,11 +161,8 @@ class DAGScheduler(
165161
taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics)
166162
blockManagerId: BlockManagerId): Boolean = {
167163
listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics))
168-
implicit val timeout = Timeout(600 seconds)
169-
170-
Await.result(
171-
blockManagerMaster.driverActor ? BlockManagerHeartbeat(blockManagerId),
172-
timeout.duration).asInstanceOf[Boolean]
164+
blockManagerMaster.driverEndpoint.askWithReply[Boolean](
165+
BlockManagerHeartbeat(blockManagerId), 600 seconds)
173166
}
174167

175168
// Called by TaskScheduler when an executor fails.

core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,17 @@ import scala.concurrent.ExecutionContext.Implicits.global
2222

2323
import akka.actor._
2424

25+
import org.apache.spark.rpc.RpcEndpointRef
2526
import org.apache.spark.{Logging, SparkConf, SparkException}
2627
import org.apache.spark.storage.BlockManagerMessages._
2728
import org.apache.spark.util.AkkaUtils
2829

2930
private[spark]
3031
class BlockManagerMaster(
31-
var driverActor: ActorRef,
32+
var driverEndpoint: RpcEndpointRef,
3233
conf: SparkConf,
3334
isDriver: Boolean)
3435
extends Logging {
35-
private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf)
36-
private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf)
37-
38-
val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster"
3936

4037
val timeout = AkkaUtils.askTimeout(conf)
4138

@@ -59,20 +56,20 @@ class BlockManagerMaster(
5956
memSize: Long,
6057
diskSize: Long,
6158
tachyonSize: Long): Boolean = {
62-
val res = askDriverWithReply[Boolean](
59+
val res = driverEndpoint.askWithReply[Boolean](
6360
UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize))
6461
logDebug(s"Updated info of block $blockId")
6562
res
6663
}
6764

6865
/** Get locations of the blockId from the driver */
6966
def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
70-
askDriverWithReply[Seq[BlockManagerId]](GetLocations(blockId))
67+
driverEndpoint.askWithReply[Seq[BlockManagerId]](GetLocations(blockId))
7168
}
7269

7370
/** Get locations of multiple blockIds from the driver */
7471
def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = {
75-
askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
72+
driverEndpoint.askWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds))
7673
}
7774

7875
/**
@@ -85,24 +82,24 @@ class BlockManagerMaster(
8582

8683
/** Get ids of other nodes in the cluster from the driver */
8784
def getPeers(blockManagerId: BlockManagerId): Seq[BlockManagerId] = {
88-
askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId))
85+
driverEndpoint.askWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId))
8986
}
9087

9188
def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
92-
askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId))
89+
driverEndpoint.askWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId))
9390
}
9491

9592
/**
9693
* Remove a block from the slaves that have it. This can only be used to remove
9794
* blocks that the driver knows about.
9895
*/
9996
def removeBlock(blockId: BlockId) {
100-
askDriverWithReply(RemoveBlock(blockId))
97+
driverEndpoint.askWithReply[Boolean](RemoveBlock(blockId))
10198
}
10299

103100
/** Remove all blocks belonging to the given RDD. */
104101
def removeRdd(rddId: Int, blocking: Boolean) {
105-
val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId))
102+
val future = driverEndpoint.askWithReply[Future[Seq[Int]]](RemoveRdd(rddId))
106103
future.onFailure {
107104
case e: Exception =>
108105
logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}")
@@ -114,7 +111,7 @@ class BlockManagerMaster(
114111

115112
/** Remove all blocks belonging to the given shuffle. */
116113
def removeShuffle(shuffleId: Int, blocking: Boolean) {
117-
val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
114+
val future = driverEndpoint.askWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId))
118115
future.onFailure {
119116
case e: Exception =>
120117
logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}")
@@ -126,7 +123,7 @@ class BlockManagerMaster(
126123

127124
/** Remove all blocks belonging to the given broadcast. */
128125
def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) {
129-
val future = askDriverWithReply[Future[Seq[Int]]](
126+
val future = driverEndpoint.askWithReply[Future[Seq[Int]]](
130127
RemoveBroadcast(broadcastId, removeFromMaster))
131128
future.onFailure {
132129
case e: Exception =>
@@ -145,11 +142,11 @@ class BlockManagerMaster(
145142
* amount of remaining memory.
146143
*/
147144
def getMemoryStatus: Map[BlockManagerId, (Long, Long)] = {
148-
askDriverWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
145+
driverEndpoint.askWithReply[Map[BlockManagerId, (Long, Long)]](GetMemoryStatus)
149146
}
150147

151148
def getStorageStatus: Array[StorageStatus] = {
152-
askDriverWithReply[Array[StorageStatus]](GetStorageStatus)
149+
driverEndpoint.askWithReply[Array[StorageStatus]](GetStorageStatus)
153150
}
154151

155152
/**
@@ -169,7 +166,7 @@ class BlockManagerMaster(
169166
* should not block on waiting for a block manager, which can in turn be waiting for the
170167
* master actor for a response to a prior message.
171168
*/
172-
val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
169+
val response = driverEndpoint.askWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
173170
val (blockManagerIds, futures) = response.unzip
174171
val result = Await.result(Future.sequence(futures), timeout)
175172
if (result == null) {
@@ -193,33 +190,28 @@ class BlockManagerMaster(
193190
filter: BlockId => Boolean,
194191
askSlaves: Boolean): Seq[BlockId] = {
195192
val msg = GetMatchingBlockIds(filter, askSlaves)
196-
val future = askDriverWithReply[Future[Seq[BlockId]]](msg)
193+
val future = driverEndpoint.askWithReply[Future[Seq[BlockId]]](msg)
197194
Await.result(future, timeout)
198195
}
199196

200197
/** Stop the driver actor, called only on the Spark driver node */
201198
def stop() {
202-
if (driverActor != null && isDriver) {
199+
if (driverEndpoint != null && isDriver) {
203200
tell(StopBlockManagerMaster)
204-
driverActor = null
201+
driverEndpoint = null
205202
logInfo("BlockManagerMaster stopped")
206203
}
207204
}
208205

209206
/** Send a one-way message to the master actor, to which we expect it to reply with true. */
210207
private def tell(message: Any) {
211-
if (!askDriverWithReply[Boolean](message)) {
208+
if (!driverEndpoint.askWithReply[Boolean](message)) {
212209
throw new SparkException("BlockManagerMasterActor returned false, expected true.")
213210
}
214211
}
215212

216-
/**
217-
* Send a message to the driver actor and get its result within a default timeout, or
218-
* throw a SparkException if this fails.
219-
*/
220-
private def askDriverWithReply[T](message: Any): T = {
221-
AkkaUtils.askWithReply(message, driverActor, AKKA_RETRY_ATTEMPTS, AKKA_RETRY_INTERVAL_MS,
222-
timeout)
223-
}
213+
}
224214

215+
private[spark] object BlockManagerMaster {
216+
val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster"
225217
}

core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala renamed to core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,29 @@ import java.util.{HashMap => JHashMap}
2121

2222
import scala.collection.mutable
2323
import scala.collection.JavaConversions._
24-
import scala.concurrent.Future
25-
import scala.concurrent.duration._
24+
import scala.concurrent.{ExecutionContext, Future}
2625

27-
import akka.actor.{Actor, ActorRef}
26+
import akka.actor.ActorRef
2827
import akka.pattern.ask
2928

30-
import org.apache.spark.{Logging, SparkConf, SparkException}
29+
import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint}
30+
import org.apache.spark.{Logging, SparkConf}
3131
import org.apache.spark.annotation.DeveloperApi
3232
import org.apache.spark.scheduler._
3333
import org.apache.spark.storage.BlockManagerMessages._
34-
import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils}
34+
import org.apache.spark.util.{AkkaUtils, Utils}
3535

3636
/**
3737
* BlockManagerMasterActor is an actor on the master node to track statuses of
3838
* all slaves' block managers.
3939
*/
4040
private[spark]
41-
class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus)
42-
extends Actor with ActorLogReceive with Logging {
41+
class BlockManagerMasterEndpoint(
42+
override val rpcEnv: RpcEnv,
43+
val isLocal: Boolean,
44+
conf: SparkConf,
45+
listenerBus: LiveListenerBus)
46+
extends RpcEndpoint with Logging {
4347

4448
// Mapping from block manager id to the block manager's information.
4549
private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo]
@@ -52,66 +56,67 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
5256

5357
private val akkaTimeout = AkkaUtils.askTimeout(conf)
5458

55-
override def receiveWithLogging: PartialFunction[Any, Unit] = {
59+
private val askThreadPool = Utils.newDaemonCachedThreadPool("block-manager-ask-thread-pool")
60+
private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool)
61+
62+
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
5663
case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) =>
5764
register(blockManagerId, maxMemSize, slaveActor)
58-
sender ! true
65+
context.reply(true)
5966

6067
case UpdateBlockInfo(
6168
blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) =>
62-
sender ! updateBlockInfo(
63-
blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize)
69+
context.reply(updateBlockInfo(
70+
blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize))
6471

6572
case GetLocations(blockId) =>
66-
sender ! getLocations(blockId)
73+
context.reply(getLocations(blockId))
6774

6875
case GetLocationsMultipleBlockIds(blockIds) =>
69-
sender ! getLocationsMultipleBlockIds(blockIds)
76+
context.reply(getLocationsMultipleBlockIds(blockIds))
7077

7178
case GetPeers(blockManagerId) =>
72-
sender ! getPeers(blockManagerId)
79+
context.reply(getPeers(blockManagerId))
7380

7481
case GetActorSystemHostPortForExecutor(executorId) =>
75-
sender ! getActorSystemHostPortForExecutor(executorId)
82+
context.reply(getActorSystemHostPortForExecutor(executorId))
7683

7784
case GetMemoryStatus =>
78-
sender ! memoryStatus
85+
context.reply(memoryStatus)
7986

8087
case GetStorageStatus =>
81-
sender ! storageStatus
88+
context.reply(storageStatus)
8289

8390
case GetBlockStatus(blockId, askSlaves) =>
84-
sender ! blockStatus(blockId, askSlaves)
91+
context.reply(blockStatus(blockId, askSlaves))
8592

8693
case GetMatchingBlockIds(filter, askSlaves) =>
87-
sender ! getMatchingBlockIds(filter, askSlaves)
94+
context.reply(getMatchingBlockIds(filter, askSlaves))
8895

8996
case RemoveRdd(rddId) =>
90-
sender ! removeRdd(rddId)
97+
context.reply(removeRdd(rddId))
9198

9299
case RemoveShuffle(shuffleId) =>
93-
sender ! removeShuffle(shuffleId)
100+
context.reply(removeShuffle(shuffleId))
94101

95102
case RemoveBroadcast(broadcastId, removeFromDriver) =>
96-
sender ! removeBroadcast(broadcastId, removeFromDriver)
103+
context.reply(removeBroadcast(broadcastId, removeFromDriver))
97104

98105
case RemoveBlock(blockId) =>
99106
removeBlockFromWorkers(blockId)
100-
sender ! true
107+
context.reply(true)
101108

102109
case RemoveExecutor(execId) =>
103110
removeExecutor(execId)
104-
sender ! true
111+
context.reply(true)
105112

106113
case StopBlockManagerMaster =>
107-
sender ! true
108-
context.stop(self)
114+
context.reply(true)
115+
stop()
109116

110117
case BlockManagerHeartbeat(blockManagerId) =>
111-
sender ! heartbeatReceived(blockManagerId)
118+
context.reply(heartbeatReceived(blockManagerId))
112119

113-
case other =>
114-
logWarning("Got unknown message: " + other)
115120
}
116121

117122
private def removeRdd(rddId: Int): Future[Seq[Int]] = {
@@ -129,7 +134,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
129134

130135
// Ask the slaves to remove the RDD, and put the result in a sequence of Futures.
131136
// The dispatcher is used as an implicit argument into the Future sequence construction.
132-
import context.dispatcher
133137
val removeMsg = RemoveRdd(rddId)
134138
Future.sequence(
135139
blockManagerInfo.values.map { bm =>
@@ -140,7 +144,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
140144

141145
private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = {
142146
// Nothing to do in the BlockManagerMasterActor data structures
143-
import context.dispatcher
144147
val removeMsg = RemoveShuffle(shuffleId)
145148
Future.sequence(
146149
blockManagerInfo.values.map { bm =>
@@ -155,7 +158,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
155158
* from the executors, but not from the driver.
156159
*/
157160
private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = {
158-
import context.dispatcher
159161
val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver)
160162
val requiredBlockManagers = blockManagerInfo.values.filter { info =>
161163
removeFromDriver || !info.blockManagerId.isDriver
@@ -247,7 +249,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
247249
private def blockStatus(
248250
blockId: BlockId,
249251
askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = {
250-
import context.dispatcher
251252
val getBlockStatus = GetBlockStatus(blockId)
252253
/*
253254
* Rather than blocking on the block status query, master actor should simply return
@@ -276,7 +277,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
276277
private def getMatchingBlockIds(
277278
filter: BlockId => Boolean,
278279
askSlaves: Boolean): Future[Seq[BlockId]] = {
279-
import context.dispatcher
280280
val getMatchingBlockIds = GetMatchingBlockIds(filter)
281281
Future.sequence(
282282
blockManagerInfo.values.map { info =>

0 commit comments

Comments
 (0)