@@ -62,9 +62,16 @@ private[streaming] case class DeregisterReceiver(streamId: Int, msg: String, err
6262private [streaming] sealed trait ReceiverTrackerLocalMessage
6363
6464/**
65- * This message will trigger ReceiverTrackerEndpoint to start a Spark job for the receiver.
65+ * This message will trigger ReceiverTrackerEndpoint to restart a Spark job for the receiver.
6666 */
67- private [streaming] case class StartReceiver (receiver : Receiver [_])
67+ private [streaming] case class RestartReceiver (receiver : Receiver [_])
68+ extends ReceiverTrackerLocalMessage
69+
70+ /**
71+ * This message is sent to ReceiverTrackerEndpoint when we start to launch Spark jobs for receivers
72+ * at the first time.
73+ */
74+ private [streaming] case class StartAllReceivers (receiver : Seq [Receiver [_]])
6875 extends ReceiverTrackerLocalMessage
6976
7077/**
@@ -307,8 +314,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
307314
308315 private def scheduleReceiver (receiverId : Int ): Seq [String ] = {
309316 val preferredLocation = receiverPreferredLocations.getOrElse(receiverId, None )
310- val scheduledLocations = schedulingPolicy.scheduleReceiver (
311- receiverId, preferredLocation, receiverTrackingInfos, getExecutors(ssc) )
317+ val scheduledLocations = schedulingPolicy.rescheduleReceiver (
318+ receiverId, preferredLocation, receiverTrackingInfos, getExecutors)
312319 updateReceiverScheduledLocations(receiverId, scheduledLocations)
313320 scheduledLocations
314321 }
@@ -337,10 +344,14 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
337344 /**
338345 * Get the list of executors excluding driver
339346 */
340- private def getExecutors (ssc : StreamingContext ): List [String ] = {
341- val executors = ssc.sparkContext.getExecutorMemoryStatus.map(_._1.split(" :" )(0 )).toList
342- val driver = ssc.sparkContext.getConf.get(" spark.driver.host" )
343- executors.diff(List (driver))
347+ private def getExecutors : List [String ] = {
348+ if (ssc.sc.isLocal) {
349+ List (" localhost" )
350+ } else {
351+ val executors = ssc.sparkContext.getExecutorMemoryStatus.map(_._1.split(" :" )(0 )).toList
352+ val driver = ssc.sparkContext.getConf.get(" spark.driver.host" )
353+ executors.diff(List (driver))
354+ }
344355 }
345356
346357 /**
@@ -355,6 +366,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
355366 if (! ssc.sparkContext.isLocal) {
356367 ssc.sparkContext.makeRDD(1 to 50 , 50 ).map(x => (x, 1 )).reduceByKey(_ + _, 20 ).collect()
357368 }
369+ assert(getExecutors.nonEmpty)
358370 }
359371
360372 /**
@@ -370,12 +382,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
370382
371383 runDummySparkJob()
372384
373- // Distribute the receivers and start them
374385 logInfo(" Starting " + receivers.length + " receivers" )
375-
376- for (receiver <- receivers) {
377- endpoint.send(StartReceiver (receiver))
378- }
386+ endpoint.send(StartAllReceivers )
379387 }
380388
381389 /** Check if tracker has been marked for starting */
@@ -396,8 +404,22 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
396404
397405 override def receive : PartialFunction [Any , Unit ] = {
398406 // Local messages
399- case StartReceiver (receiver) =>
400- startReceiver(receiver)
407+ case StartAllReceivers (receivers) =>
408+ val scheduledLocations = schedulingPolicy.scheduleReceivers(receivers, getExecutors)
409+ for (receiver <- receivers) {
410+ val locations = scheduledLocations(receiver.streamId)
411+ updateReceiverScheduledLocations(receiver.streamId, locations)
412+ receiverPreferredLocations(receiver.streamId) = receiver.preferredLocation
413+ startReceiver(receiver, locations)
414+ }
415+ case RestartReceiver (receiver) =>
416+ val scheduledLocations = schedulingPolicy.rescheduleReceiver(
417+ receiver.streamId,
418+ receiver.preferredLocation,
419+ receiverTrackingInfos,
420+ getExecutors)
421+ updateReceiverScheduledLocations(receiver.streamId, scheduledLocations)
422+ startReceiver(receiver, scheduledLocations)
401423 case c @ CleanupOldBlocks (cleanupThreshTime) =>
402424 receiverTrackingInfos.values.flatMap(_.endpoint).foreach(_.send(c))
403425 // Remote messages
@@ -425,28 +447,22 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
425447 context.reply(true )
426448 }
427449
428- private def startReceiver (receiver : Receiver [_]): Unit = {
429- val checkpointDirOption = Option (ssc.checkpointDir)
430- val serializableHadoopConf =
431- new SerializableConfiguration (ssc.sparkContext.hadoopConfiguration)
432-
433- // Function to start the receiver on the worker node
434- val startReceiverFunc = new StartReceiverFunc (checkpointDirOption, serializableHadoopConf)
435-
436- receiverPreferredLocations(receiver.streamId) = receiver.preferredLocation
450+ /**
451+ * Start a receiver along with its scheduled locations
452+ */
453+ private def startReceiver (receiver : Receiver [_], scheduledLocations : Seq [String ]): Unit = {
437454 val receiverId = receiver.streamId
438-
439455 if (! isTrackerStarted) {
440456 onReceiverJobFinish(receiverId)
441457 return
442458 }
443459
444- val scheduledLocations = schedulingPolicy.scheduleReceiver(
445- receiverId,
446- receiver.preferredLocation,
447- receiverTrackingInfos,
448- getExecutors(ssc))
449- updateReceiverScheduledLocations(receiver.streamId, scheduledLocations )
460+ val checkpointDirOption = Option (ssc.checkpointDir)
461+ val serializableHadoopConf =
462+ new SerializableConfiguration (ssc.sparkContext.hadoopConfiguration)
463+
464+ // Function to start the receiver on the worker node
465+ val startReceiverFunc = new StartReceiverFunc (checkpointDirOption, serializableHadoopConf )
450466
451467 // Create the RDD using the scheduledLocations to run the receiver in a Spark job
452468 val receiverRDD : RDD [Receiver [_]] =
@@ -465,15 +481,15 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
465481 onReceiverJobFinish(receiverId)
466482 } else {
467483 logInfo(s " Restarting Receiver $receiverId" )
468- self.send(StartReceiver (receiver))
484+ self.send(RestartReceiver (receiver))
469485 }
470486 case Failure (e) =>
471487 if (! isTrackerStarted) {
472488 onReceiverJobFinish(receiverId)
473489 } else {
474490 logError(" Receiver has been stopped. Try to restart it." , e)
475491 logInfo(s " Restarting Receiver $receiverId" )
476- self.send(StartReceiver (receiver))
492+ self.send(RestartReceiver (receiver))
477493 }
478494 }(submitJobThreadPool)
479495 logInfo(s " Receiver ${receiver.streamId} started " )
0 commit comments