Skip to content

Commit fff63f9

Browse files
committed
Use a lock to eliminate the race condition when stopping receivers and registering receivers happen at the same time.
1 parent e0ef72a commit fff63f9

File tree

7 files changed

+53
-82
lines changed

7 files changed

+53
-82
lines changed

streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,7 @@ private[streaming] abstract class ReceiverSupervisor(
9393

9494
/** Called when supervisor is stopped */
9595
protected def onStop(message: String, error: Option[Throwable]) { }
96-
97-
/** Called when receiver is registered */
98-
protected def onReceiverRegister() { }
99-
96+
10097
/** Called when receiver is started */
10198
protected def onReceiverStart() { }
10299

@@ -120,7 +117,6 @@ private[streaming] abstract class ReceiverSupervisor(
120117
/** Start receiver */
121118
def startReceiver(): Unit = synchronized {
122119
try {
123-
onReceiverRegister()
124120
logInfo("Starting receiver")
125121
receiver.onStart()
126122
logInfo("Called receiver onStart")

streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -166,23 +166,13 @@ private[streaming] class ReceiverSupervisorImpl(
166166
blockGenerator.stop()
167167
env.rpcEnv.stop(endpoint)
168168
}
169-
170-
override protected def onReceiverRegister() {
171-
val msg = RegisterReceiver(
172-
streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint)
173-
val ret = trackerEndpoint.askWithRetry[Boolean](msg)
174-
if (!ret) {
175-
throw new SparkException("ReceiverTracker is stopping and doesn't accept registeration " +
176-
"from receivers.")
177-
}
178-
}
179-
169+
180170
override protected def onReceiverStart() {
181-
val msg = ReceiverStarted(
171+
val msg = RegisterReceiver(
182172
streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint)
183-
val ret = trackerEndpoint.askWithRetry[Boolean](msg)
184-
if (!ret) {
185-
throw new SparkException("ReceiverTracker is stopping and doesn't accept receiver started.")
173+
val successful = trackerEndpoint.askWithRetry[Boolean](msg)
174+
if (!successful) {
175+
stop("Registered unsuccessfully", None)
186176
}
187177
}
188178

streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,6 @@ private[streaming] case class RegisterReceiver(
3737
host: String,
3838
receiverEndpoint: RpcEndpointRef
3939
) extends ReceiverTrackerMessage
40-
private[streaming] case class ReceiverStarted(
41-
streamId: Int,
42-
typ: String,
43-
host: String,
44-
receiverEndpoint: RpcEndpointRef
45-
) extends ReceiverTrackerMessage
4640
private[streaming] case class AddBlock(receivedBlockInfo: ReceivedBlockInfo)
4741
extends ReceiverTrackerMessage
4842
private[streaming] case class ReportError(streamId: Int, message: String, error: String)
@@ -83,6 +77,12 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
8377
/** State of the tracker */
8478
@volatile private var trackerState = Initialized
8579

80+
/**
81+
* There is a race condition when stopping receivers and registering receivers happen at the same
82+
* time. This lock is used to eliminate the race condition. See SPARK-5681.
83+
*/
84+
private val stoppingTrackerLock = new AnyRef
85+
8686
// endpoint is created when generator starts.
8787
// This not being null means the tracker has been started and not stopped
8888
private var endpoint: RpcEndpointRef = null
@@ -120,8 +120,16 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
120120
/** Stop the receiver execution thread. */
121121
def stop(graceful: Boolean): Unit = synchronized {
122122
if (isTrackerStarted) {
123-
trackerState = Stopping
124123
// First, stop the receivers
124+
// acquire "stoppingTrackerLock" so that setting trackerState to "Stopping" and registering
125+
// receivers won't happen at the same time
126+
stoppingTrackerLock.synchronized {
127+
trackerState = Stopping
128+
if (!skipReceiverLaunch) {
129+
// Send the stop signal to all the receivers
130+
receiverExecutor.stopReceivers()
131+
}
132+
}
125133
if (!skipReceiverLaunch) receiverExecutor.stop(graceful)
126134

127135
// Finally, stop the endpoint
@@ -175,33 +183,25 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
175183
host: String,
176184
receiverEndpoint: RpcEndpointRef,
177185
senderAddress: RpcAddress
178-
) {
186+
): Boolean = {
179187
if (!receiverInputStreamIds.contains(streamId)) {
180188
throw new SparkException("Register received for unexpected id " + streamId)
181189
}
182-
receiverInfo(streamId) = ReceiverInfo(
183-
streamId, s"${typ}-${streamId}", receiverEndpoint, true, host)
184-
listenerBus.post(StreamingListenerReceiverRegistered(receiverInfo(streamId)))
185-
logInfo("Registered receiver for stream " + streamId + " from " + senderAddress)
186-
}
187-
188-
/** Receiver started */
189-
private def receiverStarted(
190-
streamId: Int,
191-
typ: String,
192-
host: String,
193-
receiverEndpoint: RpcEndpointRef,
194-
senderAddress: RpcAddress
195-
) {
196-
if (!receiverInputStreamIds.contains(streamId)) {
197-
throw new SparkException("Start received for unexpected id " + streamId)
190+
// acquire "stoppingTrackerLock" so that setting trackerState to "Stopping" and registering
191+
// receivers won't happen at the same time
192+
stoppingTrackerLock.synchronized {
193+
if (isTrackerStopping || isTrackerStopped) {
194+
false
195+
} else {
196+
receiverInfo(streamId) = ReceiverInfo(
197+
streamId, s"${typ}-${streamId}", receiverEndpoint, true, host)
198+
listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId)))
199+
logInfo("Registered receiver for stream " + streamId + " from " + senderAddress)
200+
true
201+
}
198202
}
199-
receiverInfo(streamId) = ReceiverInfo(
200-
streamId, s"${typ}-${streamId}", receiverEndpoint, true, host)
201-
listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId)))
202-
logInfo("Receiver started for stream " + streamId + " from " + senderAddress)
203203
}
204-
204+
205205
/** Deregister a receiver */
206206
private def deregisterReceiver(streamId: Int, message: String, error: String) {
207207
val newReceiverInfo = receiverInfo.get(streamId) match {
@@ -267,19 +267,9 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
267267

268268
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
269269
case RegisterReceiver(streamId, typ, host, receiverEndpoint) =>
270-
if (!isTrackerStopping) {
270+
val successful =
271271
registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address)
272-
context.reply(true)
273-
} else {
274-
context.reply(false)
275-
}
276-
case ReceiverStarted(streamId, typ, host, receiverEndpoint) =>
277-
if (!isTrackerStopping) {
278-
receiverStarted(streamId, typ, host, receiverEndpoint, context.sender.address)
279-
context.reply(true)
280-
} else {
281-
context.reply(false)
282-
}
272+
context.reply(successful)
283273
case AddBlock(receivedBlockInfo) =>
284274
context.reply(addBlock(receivedBlockInfo))
285275
case DeregisterReceiver(streamId, message, error) =>
@@ -308,9 +298,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
308298
}
309299

310300
def stop(graceful: Boolean) {
311-
// Send the stop signal to all the receivers
312-
stopReceivers()
313-
314301
// Wait for the Spark job that runs the receivers to be over
315302
// That is, for the receivers to quit gracefully.
316303
thread.join(10000)
@@ -385,7 +372,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
385372
}
386373

387374
/** Stops the receivers. */
388-
private def stopReceivers() {
375+
def stopReceivers() {
389376
// Signal the receivers to stop
390377
receiverInfo.values.flatMap { info => Option(info.endpoint)}
391378
.foreach { _.send(StopReceiver) }

streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,7 @@ case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends Streami
3737

3838
@DeveloperApi
3939
case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent
40-
41-
@DeveloperApi
42-
case class StreamingListenerReceiverRegistered(receiverInfo: ReceiverInfo)
43-
extends StreamingListenerEvent
44-
40+
4541
@DeveloperApi
4642
case class StreamingListenerReceiverStarted(receiverInfo: ReceiverInfo)
4743
extends StreamingListenerEvent
@@ -61,10 +57,7 @@ case class StreamingListenerReceiverStopped(receiverInfo: ReceiverInfo)
6157
*/
6258
@DeveloperApi
6359
trait StreamingListener {
64-
65-
/** Called when a receiver has been registered */
66-
def onReceiverRegistered(receiverRegistered: StreamingListenerReceiverRegistered) { }
67-
60+
6861
/** Called when a receiver has been started */
6962
def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { }
7063

streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ private[spark] class StreamingListenerBus
3131

3232
override def onPostEvent(listener: StreamingListener, event: StreamingListenerEvent): Unit = {
3333
event match {
34-
case receiverRegistered: StreamingListenerReceiverRegistered =>
35-
listener.onReceiverRegistered(receiverRegistered)
3634
case receiverStarted: StreamingListenerReceiverStarted =>
3735
listener.onReceiverStarted(receiverStarted)
3836
case receiverError: StreamingListenerReceiverError =>

streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,19 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
224224
}
225225
}
226226

227+
test("stop gracefully even if a receiver misses StopReceiver") {
228+
val conf = new SparkConf().setMaster(master).setAppName(appName)
229+
sc = new SparkContext(conf)
230+
ssc = new StreamingContext(sc, Milliseconds(100))
231+
val input = ssc.receiverStream(new TestReceiver)
232+
input.foreachRDD(_ => {})
233+
ssc.start()
234+
// Call `ssc.stop` at once so that it's possible that the receiver will miss "StopReceiver"
235+
failAfter(30000 millis) {
236+
ssc.stop(stopSparkContext = true, stopGracefully = true)
237+
}
238+
}
239+
227240
test("stop slow receiver gracefully") {
228241
val conf = new SparkConf().setMaster(master).setAppName(appName)
229242
conf.set("spark.streaming.gracefulStopTimeout", "20000s")

streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
117117
ssc.start()
118118
try {
119119
eventually(timeout(2000 millis), interval(20 millis)) {
120-
collector.registeredReceiverStreamIds.size should equal (1)
121120
collector.startedReceiverStreamIds.size should equal (1)
122121
collector.startedReceiverStreamIds(0) should equal (0)
123122
collector.stoppedReceiverStreamIds should have size 1
@@ -162,16 +161,11 @@ class BatchInfoCollector extends StreamingListener {
162161

163162
/** Listener that collects information on processed batches */
164163
class ReceiverInfoCollector extends StreamingListener {
165-
val registeredReceiverStreamIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int]
166164
val startedReceiverStreamIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int]
167165
val stoppedReceiverStreamIds = new ArrayBuffer[Int] with SynchronizedBuffer[Int]
168166
val receiverErrors =
169167
new ArrayBuffer[(Int, String, String)] with SynchronizedBuffer[(Int, String, String)]
170168

171-
override def onReceiverRegistered(receiverRegistered: StreamingListenerReceiverRegistered) {
172-
registeredReceiverStreamIds += receiverRegistered.receiverInfo.streamId
173-
}
174-
175169
override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) {
176170
startedReceiverStreamIds += receiverStarted.receiverInfo.streamId
177171
}

0 commit comments

Comments
 (0)