@@ -130,7 +130,6 @@ private[spark] class ApplicationMaster(
130130 private var nextAllocationInterval = initialAllocationInterval
131131
132132 private var rpcEnv : RpcEnv = null
133- private var amEndpoint : RpcEndpointRef = _
134133
135134 // In cluster mode, used to tell the AM when the user's SparkContext has been initialized.
136135 private val sparkContextPromise = Promise [SparkContext ]()
@@ -405,32 +404,26 @@ private[spark] class ApplicationMaster(
405404 securityMgr,
406405 localResources)
407406
407+ // Initialize the AM endpoint *after* the allocator has been initialized. This ensures
408+ // that when the driver sends an initial executor request (e.g. after an AM restart),
409+ // the allocator is ready to service requests.
410+ rpcEnv.setupEndpoint(" YarnAM" , new AMEndpoint (rpcEnv, driverRef))
411+
408412 allocator.allocateResources()
409413 reporterThread = launchReporterThread()
410414 }
411415
412416 /**
413- * Create an [[RpcEndpoint ]] that communicates with the driver.
414- *
415- * In cluster mode, the AM and the driver belong to same process
416- * so the AMEndpoint need not monitor lifecycle of the driver.
417- *
418- * @return A reference to the driver's RPC endpoint.
417+ * @return An [[RpcEndpoint ]] that communicates with the driver's scheduler backend.
419418 */
420- private def runAMEndpoint (
421- host : String ,
422- port : String ,
423- isClusterMode : Boolean ): RpcEndpointRef = {
424- val driverEndpoint = rpcEnv.setupEndpointRef(
419+ private def createSchedulerRef (host : String , port : String ): RpcEndpointRef = {
420+ rpcEnv.setupEndpointRef(
425421 RpcAddress (host, port.toInt),
426422 YarnSchedulerBackend .ENDPOINT_NAME )
427- amEndpoint =
428- rpcEnv.setupEndpoint(" YarnAM" , new AMEndpoint (rpcEnv, driverEndpoint, isClusterMode))
429- driverEndpoint
430423 }
431424
432425 private def runDriver (securityMgr : SecurityManager ): Unit = {
433- addAmIpFilter()
426+ addAmIpFilter(None )
434427 userClassThread = startUserApplication()
435428
436429 // This a bit hacky, but we need to wait until the spark.driver.port property has
@@ -442,10 +435,9 @@ private[spark] class ApplicationMaster(
442435 Duration (totalWaitTime, TimeUnit .MILLISECONDS ))
443436 if (sc != null ) {
444437 rpcEnv = sc.env.rpcEnv
445- val driverRef = runAMEndpoint (
438+ val driverRef = createSchedulerRef (
446439 sc.getConf.get(" spark.driver.host" ),
447- sc.getConf.get(" spark.driver.port" ),
448- isClusterMode = true )
440+ sc.getConf.get(" spark.driver.port" ))
449441 registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl), securityMgr)
450442 registered = true
451443 } else {
@@ -471,7 +463,7 @@ private[spark] class ApplicationMaster(
471463 rpcEnv = RpcEnv .create(" sparkYarnAM" , hostname, hostname, - 1 , sparkConf, securityMgr,
472464 amCores, true )
473465 val driverRef = waitForSparkDriver()
474- addAmIpFilter()
466+ addAmIpFilter(Some (driverRef) )
475467 registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption(" spark.driver.appUIAddress" ),
476468 securityMgr)
477469 registered = true
@@ -620,20 +612,21 @@ private[spark] class ApplicationMaster(
620612
621613 sparkConf.set(" spark.driver.host" , driverHost)
622614 sparkConf.set(" spark.driver.port" , driverPort.toString)
623-
624- runAMEndpoint(driverHost, driverPort.toString, isClusterMode = false )
615+ createSchedulerRef(driverHost, driverPort.toString)
625616 }
626617
627618 /** Add the Yarn IP filter that is required for properly securing the UI. */
628- private def addAmIpFilter () = {
619+ private def addAmIpFilter (driver : Option [ RpcEndpointRef ] ) = {
629620 val proxyBase = System .getenv(ApplicationConstants .APPLICATION_WEB_PROXY_BASE_ENV )
630621 val amFilter = " org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter"
631622 val params = client.getAmIpFilterParams(yarnConf, proxyBase)
632- if (isClusterMode) {
633- System .setProperty(" spark.ui.filters" , amFilter)
634- params.foreach { case (k, v) => System .setProperty(s " spark. $amFilter.param. $k" , v) }
635- } else {
636- amEndpoint.send(AddWebUIFilter (amFilter, params.toMap, proxyBase))
623+ driver match {
624+ case Some (d) =>
625+ d.send(AddWebUIFilter (amFilter, params.toMap, proxyBase))
626+
627+ case None =>
628+ System .setProperty(" spark.ui.filters" , amFilter)
629+ params.foreach { case (k, v) => System .setProperty(s " spark. $amFilter.param. $k" , v) }
637630 }
638631 }
639632
@@ -704,20 +697,13 @@ private[spark] class ApplicationMaster(
704697 /**
705698 * An [[RpcEndpoint ]] that communicates with the driver's scheduler backend.
706699 */
707- private class AMEndpoint (
708- override val rpcEnv : RpcEnv , driver : RpcEndpointRef , isClusterMode : Boolean )
700+ private class AMEndpoint (override val rpcEnv : RpcEnv , driver : RpcEndpointRef )
709701 extends RpcEndpoint with Logging {
710702
711703 override def onStart (): Unit = {
712704 driver.send(RegisterClusterManager (self))
713705 }
714706
715- override def receive : PartialFunction [Any , Unit ] = {
716- case x : AddWebUIFilter =>
717- logInfo(s " Add WebUI Filter. $x" )
718- driver.send(x)
719- }
720-
721707 override def receiveAndReply (context : RpcCallContext ): PartialFunction [Any , Unit ] = {
722708 case r : RequestExecutors =>
723709 Option (allocator) match {
0 commit comments