@@ -21,39 +21,43 @@ import java.net.URL
2121import java .nio .ByteBuffer
2222
2323import scala .collection .mutable
24- import scala .concurrent . Await
24+ import scala .util .{ Failure , Success }
2525
26- import akka .actor .{Actor , ActorSelection , Props }
27- import akka .pattern .Patterns
28- import akka .remote .{RemotingLifecycleEvent , DisassociatedEvent }
29-
30- import org .apache .spark .{Logging , SecurityManager , SparkConf , SparkEnv }
26+ import org .apache .spark .rpc ._
27+ import org .apache .spark ._
3128import org .apache .spark .TaskState .TaskState
3229import org .apache .spark .deploy .SparkHadoopUtil
3330import org .apache .spark .deploy .worker .WorkerWatcher
3431import org .apache .spark .scheduler .TaskDescription
3532import org .apache .spark .scheduler .cluster .CoarseGrainedClusterMessages ._
36- import org .apache .spark .util .{ActorLogReceive , AkkaUtils , SignalLogger , Utils }
33+ import org .apache .spark .util .{SignalLogger , Utils }
3734
3835private [spark] class CoarseGrainedExecutorBackend (
36+ override val rpcEnv : RpcEnv ,
3937 driverUrl : String ,
4038 executorId : String ,
4139 hostPort : String ,
4240 cores : Int ,
4341 userClassPath : Seq [URL ],
4442 env : SparkEnv )
45- extends Actor with ActorLogReceive with ExecutorBackend with Logging {
43+ extends RpcEndpoint with ExecutorBackend with Logging {
4644
4745 Utils .checkHostPort(hostPort, " Expected hostport" )
4846
4947 var executor : Executor = null
50- var driver : ActorSelection = null
48+ @ volatile var driver : Option [ RpcEndpointRef ] = None
5149
52- override def preStart () {
50+ override def onStart () {
51+ import scala .concurrent .ExecutionContext .Implicits .global
5352 logInfo(" Connecting to driver: " + driverUrl)
54- driver = context.actorSelection(driverUrl)
55- driver ! RegisterExecutor (executorId, hostPort, cores, extractLogUrls)
56- context.system.eventStream.subscribe(self, classOf [RemotingLifecycleEvent ])
53+ rpcEnv.asyncSetupEndpointRefByUrl(driverUrl).flatMap { ref =>
54+ driver = Some (ref)
55+ ref.sendWithReply[RegisteredExecutor .type ](
56+ RegisterExecutor (executorId, self, hostPort, cores, extractLogUrls))
57+ } onComplete {
58+ case Success (msg) => self.send(msg)
59+ case Failure (e) => logError(s " Cannot register to driver: $driverUrl" , e)
60+ }
5761 }
5862
5963 def extractLogUrls : Map [String , String ] = {
@@ -62,7 +66,7 @@ private[spark] class CoarseGrainedExecutorBackend(
6266 .map(e => (e._1.substring(prefix.length).toLowerCase, e._2))
6367 }
6468
65- override def receiveWithLogging : PartialFunction [Any , Unit ] = {
69+ override def receive : PartialFunction [Any , Unit ] = {
6670 case RegisteredExecutor =>
6771 logInfo(" Successfully registered with driver" )
6872 val (hostname, _) = Utils .parseHostPort(hostPort)
@@ -92,23 +96,28 @@ private[spark] class CoarseGrainedExecutorBackend(
9296 executor.killTask(taskId, interruptThread)
9397 }
9498
95- case x : DisassociatedEvent =>
96- if (x.remoteAddress == driver.anchorPath.address) {
97- logError(s " Driver $x disassociated! Shutting down. " )
98- System .exit(1 )
99- } else {
100- logWarning(s " Received irrelevant DisassociatedEvent $x" )
101- }
102-
10399 case StopExecutor =>
104100 logInfo(" Driver commanded a shutdown" )
105101 executor.stop()
106- context.stop(self)
107- context.system.shutdown()
102+ stop()
103+ rpcEnv.shutdown()
104+ }
105+
106+ override def onDisconnected (remoteAddress : RpcAddress ): Unit = {
107+ if (driver.exists(_.address == remoteAddress)) {
108+ logError(s " Driver $remoteAddress disassociated! Shutting down. " )
109+ System .exit(1 )
110+ } else {
111+ logWarning(s " Received irrelevant DisassociatedEvent $remoteAddress" )
112+ }
108113 }
109114
110115 override def statusUpdate (taskId : Long , state : TaskState , data : ByteBuffer ) {
111- driver ! StatusUpdate (executorId, taskId, state, data)
116+ val msg = StatusUpdate (executorId, taskId, state, data)
117+ driver match {
118+ case Some (driverRef) => driverRef.send(msg)
119+ case None => logWarning(s " Drop $msg because has not yet connected to driver " )
120+ }
112121 }
113122}
114123
@@ -132,16 +141,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
132141 // Bootstrap to fetch the driver's Spark properties.
133142 val executorConf = new SparkConf
134143 val port = executorConf.getInt(" spark.executor.port" , 0 )
135- val ( fetcher, _) = AkkaUtils .createActorSystem (
144+ val fetcher = RpcEnv .create (
136145 " driverPropsFetcher" ,
137146 hostname,
138147 port,
139148 executorConf,
140149 new SecurityManager (executorConf))
141- val driver = fetcher.actorSelection(driverUrl)
142- val timeout = AkkaUtils .askTimeout(executorConf)
143- val fut = Patterns .ask(driver, RetrieveSparkProps , timeout)
144- val props = Await .result(fut, timeout).asInstanceOf [Seq [(String , String )]] ++
150+ val driver = fetcher.setupEndpointRefByUrl(driverUrl)
151+ val props = driver.askWithReply[Seq [(String , String )]](RetrieveSparkProps ) ++
145152 Seq [(String , String )]((" spark.app.id" , appId))
146153 fetcher.shutdown()
147154
@@ -162,16 +169,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
162169 val boundPort = env.conf.getInt(" spark.executor.port" , 0 )
163170 assert(boundPort != 0 )
164171
165- // Start the CoarseGrainedExecutorBackend actor .
172+ // Start the CoarseGrainedExecutorBackend endpoint .
166173 val sparkHostPort = hostname + " :" + boundPort
167- env.actorSystem.actorOf(
168- Props (classOf [CoarseGrainedExecutorBackend ],
169- driverUrl, executorId, sparkHostPort, cores, userClassPath, env),
170- name = " Executor" )
174+ env.rpcEnv.setupEndpoint(" Executor" , new CoarseGrainedExecutorBackend (
175+ env.rpcEnv, driverUrl, executorId, sparkHostPort, cores, userClassPath, env))
171176 workerUrl.foreach { url =>
172177 env.rpcEnv.setupEndpoint(" WorkerWatcher" , new WorkerWatcher (env.rpcEnv, url))
173178 }
174- env.actorSystem .awaitTermination()
179+ env.rpcEnv .awaitTermination()
175180 }
176181 }
177182
0 commit comments