1717
1818package org .apache .spark .deploy
1919
20- import scala .concurrent ._
20+ import scala .concurrent .ExecutionContext
21+ import scala .util .{Failure , Success }
2122
22- import akka .actor ._
23- import akka .pattern .ask
24- import akka .remote .{AssociationErrorEvent , DisassociatedEvent , RemotingLifecycleEvent }
2523import org .apache .log4j .{Level , Logger }
2624
25+ import org .apache .spark .rpc .{RpcEndpointRef , RpcAddress , RpcEnv , ThreadSafeRpcEndpoint }
2726import org .apache .spark .{Logging , SecurityManager , SparkConf }
2827import org .apache .spark .deploy .DeployMessages ._
2928import org .apache .spark .deploy .master .{DriverState , Master }
30- import org .apache .spark .util .{ ActorLogReceive , AkkaUtils , Utils }
29+ import org .apache .spark .util .Utils
3130
3231/**
3332 * Proxy that relays messages to the driver.
3433 */
35- private class ClientActor (driverArgs : ClientArguments , conf : SparkConf )
36- extends Actor with ActorLogReceive with Logging {
37-
38- var masterActor : ActorSelection = _
39- val timeout = AkkaUtils .askTimeout(conf)
40-
41- override def preStart (): Unit = {
42- masterActor = context.actorSelection(
43- Master .toAkkaUrl(driverArgs.master, AkkaUtils .protocol(context.system)))
44-
45- context.system.eventStream.subscribe(self, classOf [RemotingLifecycleEvent ])
34+ private class ClientEndpoint (
35+ override val rpcEnv : RpcEnv ,
36+ driverArgs : ClientArguments ,
37+ masterEndpoint : RpcEndpointRef ,
38+ conf : SparkConf )
39+ extends ThreadSafeRpcEndpoint with Logging {
40+
41+ private val forwardMessageThread = Utils .newDaemonFixedThreadPool(1 , " client-forward-message" )
42+ private implicit val forwardMessageExecutionContext =
43+ ExecutionContext .fromExecutor(forwardMessageThread,
44+ t => t match {
45+ case ie : InterruptedException => // Exit normally
46+ case e =>
47+ e.printStackTrace()
48+ System .exit(- 1 )
49+ })
4650
51+ override def onStart (): Unit = {
4752 println(s " Sending ${driverArgs.cmd} command to ${driverArgs.master}" )
4853
4954 driverArgs.cmd match {
@@ -79,22 +84,36 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
7984 driverArgs.supervise,
8085 command)
8186
82- masterActor ! RequestSubmitDriver (driverDescription)
87+ masterEndpoint.sendWithReply[SubmitDriverResponse ](RequestSubmitDriver (driverDescription)).
88+ onComplete {
89+ case Success (v) => self.send(v)
90+ case Failure (e) =>
91+ println(s " Error sending messages to master ${driverArgs.master}, exiting. " )
92+ e.printStackTrace()
93+ System .exit(- 1 )
94+ }
8395
8496 case " kill" =>
8597 val driverId = driverArgs.driverId
86- masterActor ! RequestKillDriver (driverId)
98+ masterEndpoint.sendWithReply[KillDriverResponse ](RequestKillDriver (driverId)).onComplete {
99+ case Success (v) => self.send(v)
100+ case Failure (e) =>
101+ println(s " Error sending messages to master ${driverArgs.master}, exiting. " )
102+ e.printStackTrace()
103+ System .exit(- 1 )
104+ }
87105 }
88106 }
89107
90108 /* Find out driver status then exit the JVM */
91109 def pollAndReportStatus (driverId : String ) {
110+ // Since ClientEndpoint is the only RpcEndpoint in the process, blocking the event loop thread
111+ // is fine.
92112 println(" ... waiting before polling master for driver state" )
93113 Thread .sleep(5000 )
94114 println(" ... polling master for driver state" )
95- val statusFuture = (masterActor ? RequestDriverStatus (driverId))(timeout)
96- .mapTo[DriverStatusResponse ]
97- val statusResponse = Await .result(statusFuture, timeout)
115+ val statusResponse =
116+ masterEndpoint.askWithReply[DriverStatusResponse ](RequestDriverStatus (driverId))
98117
99118 statusResponse.found match {
100119 case false =>
@@ -118,7 +137,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
118137 }
119138 }
120139
121- override def receiveWithLogging : PartialFunction [Any , Unit ] = {
140+ override def receive : PartialFunction [Any , Unit ] = {
122141
123142 case SubmitDriverResponse (success, driverId, message) =>
124143 println(message)
@@ -128,14 +147,27 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
128147 println(message)
129148 if (success) pollAndReportStatus(driverId) else System .exit(- 1 )
130149
131- case DisassociatedEvent (_, remoteAddress, _) =>
132- println(s " Error connecting to master ${driverArgs.master} ( $remoteAddress), exiting. " )
133- System .exit(- 1 )
150+ }
151+
152+ override def onDisconnected (remoteAddress : RpcAddress ): Unit = {
153+ println(s " Error connecting to master ${driverArgs.master} ( $remoteAddress), exiting. " )
154+ System .exit(- 1 )
155+ }
156+
157+ override def onNetworkError (cause : Throwable , remoteAddress : RpcAddress ): Unit = {
158+ println(s " Error connecting to master ${driverArgs.master} ( $remoteAddress), exiting. " )
159+ cause.printStackTrace()
160+ System .exit(- 1 )
161+ }
162+
163+ override def onError (cause : Throwable ): Unit = {
164+ println(s " Error processing messages, exiting. " )
165+ cause.printStackTrace()
166+ System .exit(- 1 )
167+ }
134168
135- case AssociationErrorEvent (cause, _, remoteAddress, _, _) =>
136- println(s " Error connecting to master ${driverArgs.master} ( $remoteAddress), exiting. " )
137- println(s " Cause was: $cause" )
138- System .exit(- 1 )
169+ override def onStop (): Unit = {
170+ forwardMessageThread.shutdownNow()
139171 }
140172}
141173
@@ -159,13 +191,14 @@ object Client {
159191 conf.set(" akka.loglevel" , driverArgs.logLevel.toString.replace(" WARN" , " WARNING" ))
160192 Logger .getRootLogger.setLevel(driverArgs.logLevel)
161193
162- val (actorSystem, _) = AkkaUtils .createActorSystem(
163- " driverClient" , Utils .localHostName(), 0 , conf, new SecurityManager (conf))
194+ val rpcEnv =
195+ RpcEnv .create( " driverClient" , Utils .localHostName(), 0 , conf, new SecurityManager (conf))
164196
165- // Verify driverArgs.master is a valid url so that we can use it in ClientActor safely
166- Master .toAkkaUrl(driverArgs.master, AkkaUtils .protocol(actorSystem))
167- actorSystem.actorOf(Props (classOf [ClientActor ], driverArgs, conf))
197+ val masterAddress = RpcAddress .fromSparkURL(driverArgs.master)
198+ val masterEndpoint =
199+ rpcEnv.setupEndpointRef(Master .SYSTEM_NAME , masterAddress, Master .ENDPOINT_NAME )
200+ rpcEnv.setupEndpoint(" client" , new ClientEndpoint (rpcEnv, driverArgs, masterEndpoint, conf))
168201
169- actorSystem .awaitTermination()
202+ rpcEnv .awaitTermination()
170203 }
171204}
0 commit comments