Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ private[streaming] abstract class ReceiverSupervisor(
conf: SparkConf
) extends Logging {

/** Enumeration to identify current state of the StreamingContext */
/** Enumeration to identify current state of the Receiver */
object ReceiverState extends Enumeration {
type CheckpointState = Value
val Initialized, Started, Stopped = Value
Expand Down Expand Up @@ -93,7 +93,10 @@ private[streaming] abstract class ReceiverSupervisor(

/** Called when supervisor is stopped */
protected def onStop(message: String, error: Option[Throwable]) { }


/** Called when receiver is registered */
protected def onReceiverRegister() { }

/** Called when receiver is started */
protected def onReceiverStart() { }

Expand All @@ -117,6 +120,7 @@ private[streaming] abstract class ReceiverSupervisor(
/** Start receiver */
def startReceiver(): Unit = synchronized {
try {
onReceiverRegister()
logInfo("Starting receiver")
receiver.onStart()
logInfo("Called receiver onStart")
Expand Down Expand Up @@ -161,7 +165,7 @@ private[streaming] abstract class ReceiverSupervisor(
}
}

/** Check if receiver has been marked for stopping */
/** Check if receiver has been marked for starting */
def isReceiverStarted(): Boolean = {
logDebug("state = " + receiverState)
receiverState == Started
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,24 @@ private[streaming] class ReceiverSupervisorImpl(
blockGenerator.stop()
env.rpcEnv.stop(endpoint)
}

override protected def onReceiverStart() {
override protected def onReceiverRegister() {
val msg = RegisterReceiver(
streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint)
trackerEndpoint.askWithRetry[Boolean](msg)
val ret = trackerEndpoint.askWithRetry[Boolean](msg)
if (!ret) {
throw new SparkException("ReceiverTracker is stopping and doesn't accept registeration " +
"from receivers.")
}
}

override protected def onReceiverStart() {
val msg = ReceiverStarted(
streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint)
val ret = trackerEndpoint.askWithRetry[Boolean](msg)
if (!ret) {
throw new SparkException("ReceiverTracker is stopping and doesn't accept receiver started.")
}
}

override protected def onReceiverStop(message: String, error: Option[Throwable]) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ private[streaming] case class RegisterReceiver(
host: String,
receiverEndpoint: RpcEndpointRef
) extends ReceiverTrackerMessage
private[streaming] case class ReceiverStarted(
streamId: Int,
typ: String,
host: String,
receiverEndpoint: RpcEndpointRef
) extends ReceiverTrackerMessage
private[streaming] case class AddBlock(receivedBlockInfo: ReceivedBlockInfo)
extends ReceiverTrackerMessage
private[streaming] case class ReportError(streamId: Int, message: String, error: String)
Expand Down Expand Up @@ -67,13 +73,38 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
)
private val listenerBus = ssc.scheduler.listenerBus

/** Enumeration to identify current state of the ReceiverTracker */
object TrackerState extends Enumeration {
type CheckpointState = Value
val Initialized, Started, Stopping, Stopped = Value
}
import TrackerState._

/** State of the tracker */
@volatile private var trackerState = Initialized

// endpoint is created when generator starts.
// This not being null means the tracker has been started and not stopped
private var endpoint: RpcEndpointRef = null

/** Check if tracker has been marked for starting */
private def isTrackerStarted(): Boolean = {
trackerState == Started
}

/** Check if tracker has been marked for stopping */
private def isTrackerStopping(): Boolean = {
trackerState == Stopping
}

/** Check if tracker has been marked for stopped */
private def isTrackerStopped(): Boolean = {
trackerState == Stopped
}

/** Start the endpoint and receiver execution thread. */
def start(): Unit = synchronized {
if (endpoint != null) {
if (isTrackerStarted) {
throw new SparkException("ReceiverTracker already started")
}

Expand All @@ -82,12 +113,14 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
"ReceiverTracker", new ReceiverTrackerEndpoint(ssc.env.rpcEnv))
if (!skipReceiverLaunch) receiverExecutor.start()
logInfo("ReceiverTracker started")
trackerState = Started
}
}

/** Stop the receiver execution thread. */
def stop(graceful: Boolean): Unit = synchronized {
if (!receiverInputStreams.isEmpty && endpoint != null) {
if (isTrackerStarted) {
trackerState = Stopping
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changing trackerState to Stopping and calling stopReceivers() should be atomic, or it's still possible that StopReceiver won't be sent because the registerReceiver method runs in another thread.

And you need to use another lock rather than adding synchronized to registerReceiver because it may block ReceiverTrackerEndpoint, which is actually a Akka thread. If it blocks the Akka thread for a long time, some dead-lock may happen.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changing trackerState to Stopping and calling stopReceivers() should be atomic, or it's still possible that StopReceiver won't be sent because the registerReceiver method runs in another thread.

Oh, I see. ReceiverStarted will be rejected and force the receiver stop. It's really tricky. Do you think we can solve this issue without ReceiverStarted?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to prevent the receiver to step forward if the tracker is entering stopping status. So we don't have the risk to lose data. Without ReceiverStarted, we may have to modify the current receiver initialization.

// First, stop the receivers
if (!skipReceiverLaunch) receiverExecutor.stop(graceful)

Expand All @@ -96,6 +129,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
endpoint = null
receivedBlockTracker.stop()
logInfo("ReceiverTracker stopped")
trackerState = Stopped
}
}

Expand Down Expand Up @@ -147,10 +181,27 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
}
receiverInfo(streamId) = ReceiverInfo(
streamId, s"${typ}-${streamId}", receiverEndpoint, true, host)
listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId)))
listenerBus.post(StreamingListenerReceiverRegistered(receiverInfo(streamId)))
logInfo("Registered receiver for stream " + streamId + " from " + senderAddress)
}


/** Receiver started */
private def receiverStarted(
streamId: Int,
typ: String,
host: String,
receiverEndpoint: RpcEndpointRef,
senderAddress: RpcAddress
) {
if (!receiverInputStreamIds.contains(streamId)) {
throw new SparkException("Start received for unexpected id " + streamId)
}
receiverInfo(streamId) = ReceiverInfo(
streamId, s"${typ}-${streamId}", receiverEndpoint, true, host)
listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId)))
logInfo("Receiver started for stream " + streamId + " from " + senderAddress)
}

/** Deregister a receiver */
private def deregisterReceiver(streamId: Int, message: String, error: String) {
val newReceiverInfo = receiverInfo.get(streamId) match {
Expand Down Expand Up @@ -216,8 +267,19 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case RegisterReceiver(streamId, typ, host, receiverEndpoint) =>
registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address)
context.reply(true)
if (!isTrackerStopping) {
registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address)
context.reply(true)
} else {
context.reply(false)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to send something back, or trackerActor.ask(msg)(askTimeout) will wait until timeout.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I propose that returning false if isTrackerStopping == true. And if onReceiverRegister receives false, it just throws an exception.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was originally intended to let it timeout and throw exception. Returning false and throw exception is good too. I will update it.

case ReceiverStarted(streamId, typ, host, receiverEndpoint) =>
if (!isTrackerStopping) {
receiverStarted(streamId, typ, host, receiverEndpoint, context.sender.address)
context.reply(true)
} else {
context.reply(false)
}
case AddBlock(receivedBlockInfo) =>
context.reply(addBlock(receivedBlockInfo))
case DeregisterReceiver(streamId, message, error) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends Streami

@DeveloperApi
case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent


@DeveloperApi
case class StreamingListenerReceiverRegistered(receiverInfo: ReceiverInfo)
extends StreamingListenerEvent

@DeveloperApi
case class StreamingListenerReceiverStarted(receiverInfo: ReceiverInfo)
extends StreamingListenerEvent
Expand All @@ -57,7 +61,10 @@ case class StreamingListenerReceiverStopped(receiverInfo: ReceiverInfo)
*/
@DeveloperApi
trait StreamingListener {


/** Called when a receiver has been registered */
def onReceiverRegistered(receiverRegistered: StreamingListenerReceiverRegistered) { }

/** Called when a receiver has been started */
def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ private[spark] class StreamingListenerBus

override def onPostEvent(listener: StreamingListener, event: StreamingListenerEvent): Unit = {
event match {
case receiverRegistered: StreamingListenerReceiverRegistered =>
listener.onReceiverRegistered(receiverRegistered)
case receiverStarted: StreamingListenerReceiverStarted =>
listener.onReceiverStarted(receiverStarted)
case receiverError: StreamingListenerReceiverError =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
ssc.start()
try {
eventually(timeout(2000 millis), interval(20 millis)) {
collector.registeredReceiverStreamIds.size should equal (1)
collector.startedReceiverStreamIds.size should equal (1)
collector.startedReceiverStreamIds(0) should equal (0)
collector.stoppedReceiverStreamIds should have size 1
Expand Down Expand Up @@ -161,11 +162,16 @@ class BatchInfoCollector extends StreamingListener {

/** Listener that collects information on processed batches */
class ReceiverInfoCollector extends StreamingListener {
val registeredReceiverStreamIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int]
val startedReceiverStreamIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int]
val stoppedReceiverStreamIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int]
val receiverErrors =
new ArrayBuffer[(Int, String, String)] with SynchronizedBuffer[(Int, String, String)]

override def onReceiverRegistered(receiverRegistered: StreamingListenerReceiverRegistered) {
registeredReceiverStreamIds += receiverRegistered.receiverInfo.streamId
}

override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) {
startedReceiverStreamIds += receiverStarted.receiverInfo.streamId
}
Expand Down