diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 4943f29395d12..f724e71d1f7db 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -36,7 +36,7 @@ private[streaming] abstract class ReceiverSupervisor( conf: SparkConf ) extends Logging { - /** Enumeration to identify current state of the StreamingContext */ + /** Enumeration to identify current state of the Receiver */ object ReceiverState extends Enumeration { type CheckpointState = Value val Initialized, Started, Stopped = Value @@ -93,7 +93,10 @@ private[streaming] abstract class ReceiverSupervisor( /** Called when supervisor is stopped */ protected def onStop(message: String, error: Option[Throwable]) { } - + + /** Called when receiver is registered */ + protected def onReceiverRegister() { } + /** Called when receiver is started */ protected def onReceiverStart() { } @@ -117,6 +120,7 @@ private[streaming] abstract class ReceiverSupervisor( /** Start receiver */ def startReceiver(): Unit = synchronized { try { + onReceiverRegister() logInfo("Starting receiver") receiver.onStart() logInfo("Called receiver onStart") @@ -161,7 +165,7 @@ private[streaming] abstract class ReceiverSupervisor( } } - /** Check if receiver has been marked for stopping */ + /** Check if receiver has been marked for starting */ def isReceiverStarted(): Boolean = { logDebug("state = " + receiverState) receiverState == Started diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 92938379b9c17..1965c81824a9a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -166,11 +166,24 @@ private[streaming] class ReceiverSupervisorImpl( blockGenerator.stop() env.rpcEnv.stop(endpoint) } - - override protected def onReceiverStart() { + + override protected def onReceiverRegister() { val msg = RegisterReceiver( streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint) - trackerEndpoint.askWithRetry[Boolean](msg) + val ret = trackerEndpoint.askWithRetry[Boolean](msg) + if (!ret) { + throw new SparkException("ReceiverTracker is stopping and doesn't accept registeration " + + "from receivers.") + } + } + + override protected def onReceiverStart() { + val msg = ReceiverStarted( + streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint) + val ret = trackerEndpoint.askWithRetry[Boolean](msg) + if (!ret) { + throw new SparkException("ReceiverTracker is stopping and doesn't accept receiver started.") + } } override protected def onReceiverStop(message: String, error: Option[Throwable]) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index f73f7e705ee0d..805965e82d83b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -37,6 +37,12 @@ private[streaming] case class RegisterReceiver( host: String, receiverEndpoint: RpcEndpointRef ) extends ReceiverTrackerMessage +private[streaming] case class ReceiverStarted( + streamId: Int, + typ: String, + host: String, + receiverEndpoint: RpcEndpointRef + ) extends ReceiverTrackerMessage private[streaming] case class AddBlock(receivedBlockInfo: ReceivedBlockInfo) extends ReceiverTrackerMessage private[streaming] case class ReportError(streamId: Int, message: String, error: String) @@ -67,13 +73,38 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false ) private val listenerBus = ssc.scheduler.listenerBus + /** Enumeration to identify current state of the ReceiverTracker */ + object TrackerState extends Enumeration { + type CheckpointState = Value + val Initialized, Started, Stopping, Stopped = Value + } + import TrackerState._ + + /** State of the tracker */ + @volatile private var trackerState = Initialized + // endpoint is created when generator starts. // This not being null means the tracker has been started and not stopped private var endpoint: RpcEndpointRef = null + /** Check if tracker has been marked for starting */ + private def isTrackerStarted(): Boolean = { + trackerState == Started + } + + /** Check if tracker has been marked for stopping */ + private def isTrackerStopping(): Boolean = { + trackerState == Stopping + } + + /** Check if tracker has been marked for stopped */ + private def isTrackerStopped(): Boolean = { + trackerState == Stopped + } + /** Start the endpoint and receiver execution thread. */ def start(): Unit = synchronized { - if (endpoint != null) { + if (isTrackerStarted) { throw new SparkException("ReceiverTracker already started") } @@ -82,12 +113,14 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false "ReceiverTracker", new ReceiverTrackerEndpoint(ssc.env.rpcEnv)) if (!skipReceiverLaunch) receiverExecutor.start() logInfo("ReceiverTracker started") + trackerState = Started } } /** Stop the receiver execution thread. */ def stop(graceful: Boolean): Unit = synchronized { - if (!receiverInputStreams.isEmpty && endpoint != null) { + if (isTrackerStarted) { + trackerState = Stopping // First, stop the receivers if (!skipReceiverLaunch) receiverExecutor.stop(graceful) @@ -96,6 +129,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false endpoint = null receivedBlockTracker.stop() logInfo("ReceiverTracker stopped") + trackerState = Stopped } } @@ -147,10 +181,27 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } receiverInfo(streamId) = ReceiverInfo( streamId, s"${typ}-${streamId}", receiverEndpoint, true, host) - listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) + listenerBus.post(StreamingListenerReceiverRegistered(receiverInfo(streamId))) logInfo("Registered receiver for stream " + streamId + " from " + senderAddress) } - + + /** Receiver started */ + private def receiverStarted( + streamId: Int, + typ: String, + host: String, + receiverEndpoint: RpcEndpointRef, + senderAddress: RpcAddress + ) { + if (!receiverInputStreamIds.contains(streamId)) { + throw new SparkException("Start received for unexpected id " + streamId) + } + receiverInfo(streamId) = ReceiverInfo( + streamId, s"${typ}-${streamId}", receiverEndpoint, true, host) + listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId))) + logInfo("Receiver started for stream " + streamId + " from " + senderAddress) + } + /** Deregister a receiver */ private def deregisterReceiver(streamId: Int, message: String, error: String) { val newReceiverInfo = receiverInfo.get(streamId) match { @@ -216,8 +267,19 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RegisterReceiver(streamId, typ, host, receiverEndpoint) => - registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address) - context.reply(true) + if (!isTrackerStopping) { + registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address) + context.reply(true) + } else { + context.reply(false) + } + case ReceiverStarted(streamId, typ, host, receiverEndpoint) => + if (!isTrackerStopping) { + receiverStarted(streamId, typ, host, receiverEndpoint, context.sender.address) + context.reply(true) + } else { + context.reply(false) + } case AddBlock(receivedBlockInfo) => context.reply(addBlock(receivedBlockInfo)) case DeregisterReceiver(streamId, message, error) => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala index 74dbba453f026..0838dbb6ec39a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala @@ -37,7 +37,11 @@ case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends Streami @DeveloperApi case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent - + +@DeveloperApi +case class StreamingListenerReceiverRegistered(receiverInfo: ReceiverInfo) + extends StreamingListenerEvent + @DeveloperApi case class StreamingListenerReceiverStarted(receiverInfo: ReceiverInfo) extends StreamingListenerEvent @@ -57,7 +61,10 @@ case class StreamingListenerReceiverStopped(receiverInfo: ReceiverInfo) */ @DeveloperApi trait StreamingListener { - + + /** Called when a receiver has been registered */ + def onReceiverRegistered(receiverRegistered: StreamingListenerReceiverRegistered) { } + /** Called when a receiver has been started */ def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala index b07d6cf347ca7..9febd7f446293 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -31,6 +31,8 @@ private[spark] class StreamingListenerBus override def onPostEvent(listener: StreamingListener, event: StreamingListenerEvent): Unit = { event match { + case receiverRegistered: StreamingListenerReceiverRegistered => + listener.onReceiverRegistered(receiverRegistered) case receiverStarted: StreamingListenerReceiverStarted => listener.onReceiverStarted(receiverStarted) case receiverError: StreamingListenerReceiverError => diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 312cce408cfe7..e8e20edc08bef 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -117,6 +117,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { ssc.start() try { eventually(timeout(2000 millis), interval(20 millis)) { + collector.registeredReceiverStreamIds.size should equal (1) collector.startedReceiverStreamIds.size should equal (1) collector.startedReceiverStreamIds(0) should equal (0) collector.stoppedReceiverStreamIds should have size 1 @@ -161,11 +162,16 @@ class BatchInfoCollector extends StreamingListener { /** Listener that collects information on processed batches */ class ReceiverInfoCollector extends StreamingListener { + val registeredReceiverStreamIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int] val startedReceiverStreamIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int] val stoppedReceiverStreamIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int] val receiverErrors = new ArrayBuffer[(Int, String, String)] with SynchronizedBuffer[(Int, String, String)] + override def onReceiverRegistered(receiverRegistered: StreamingListenerReceiverRegistered) { + registeredReceiverStreamIds += receiverRegistered.receiverInfo.streamId + } + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { startedReceiverStreamIds += receiverStarted.receiverInfo.streamId }