Skip to content

Commit e530bcc

Browse files
committed
[SPARK-5681][Streaming] Use a lock to eliminate the race condition when stopping receivers and registering receivers happen at the same time #6294
1 parent c63036c commit e530bcc

File tree

4 files changed

+104
-28
lines changed

4 files changed

+104
-28
lines changed

streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.concurrent.CountDownLatch
2222

2323
import scala.collection.mutable.ArrayBuffer
2424
import scala.concurrent._
25+
import scala.util.control.NonFatal
2526

2627
import org.apache.spark.{Logging, SparkConf}
2728
import org.apache.spark.storage.StreamBlockId
@@ -36,7 +37,7 @@ private[streaming] abstract class ReceiverSupervisor(
3637
conf: SparkConf
3738
) extends Logging {
3839

39-
/** Enumeration to identify current state of the StreamingContext */
40+
/** Enumeration to identify current state of the Receiver */
4041
object ReceiverState extends Enumeration {
4142
type CheckpointState = Value
4243
val Initialized, Started, Stopped = Value
@@ -97,8 +98,8 @@ private[streaming] abstract class ReceiverSupervisor(
9798
/** Called when supervisor is stopped */
9899
protected def onStop(message: String, error: Option[Throwable]) { }
99100

100-
/** Called when receiver is started */
101-
protected def onReceiverStart() { }
101+
/** Called when receiver is started. Return if the driver accepts us */
102+
protected def onReceiverStart(): Boolean = true
102103

103104
/** Called when receiver is stopped */
104105
protected def onReceiverStop(message: String, error: Option[Throwable]) { }
@@ -121,13 +122,17 @@ private[streaming] abstract class ReceiverSupervisor(
121122
/** Start receiver */
122123
def startReceiver(): Unit = synchronized {
123124
try {
124-
logInfo("Starting receiver")
125-
receiver.onStart()
126-
logInfo("Called receiver onStart")
127-
onReceiverStart()
128-
receiverState = Started
125+
if (onReceiverStart()) {
126+
logInfo("Starting receiver")
127+
receiverState = Started
128+
receiver.onStart()
129+
logInfo("Called receiver onStart")
130+
} else {
131+
// The driver refused us
132+
stop("Registered unsuccessfully because the driver refused" + streamId, None)
133+
}
129134
} catch {
130-
case t: Throwable =>
135+
case NonFatal(t) =>
131136
stop("Error starting receiver " + streamId, Some(t))
132137
}
133138
}
@@ -136,12 +141,19 @@ private[streaming] abstract class ReceiverSupervisor(
136141
def stopReceiver(message: String, error: Option[Throwable]): Unit = synchronized {
137142
try {
138143
logInfo("Stopping receiver with message: " + message + ": " + error.getOrElse(""))
139-
receiverState = Stopped
140-
receiver.onStop()
141-
logInfo("Called receiver onStop")
142-
onReceiverStop(message, error)
144+
receiverState match {
145+
case Initialized =>
146+
logWarning("Skip stopping receiver because it has not yet stared")
147+
case Started =>
148+
receiverState = Stopped
149+
receiver.onStop()
150+
logInfo("Called receiver onStop")
151+
onReceiverStop(message, error)
152+
case Stopped =>
153+
logWarning("Receiver has been stopped")
154+
}
143155
} catch {
144-
case t: Throwable =>
156+
case NonFatal(t) =>
145157
logError("Error stopping receiver " + streamId + t.getStackTraceString)
146158
}
147159
}
@@ -167,7 +179,7 @@ private[streaming] abstract class ReceiverSupervisor(
167179
}(futureExecutionContext)
168180
}
169181

170-
/** Check if receiver has been marked for stopping */
182+
/** Check if receiver has been marked for starting */
171183
def isReceiverStarted(): Boolean = {
172184
logDebug("state = " + receiverState)
173185
receiverState == Started

streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ private[streaming] class ReceiverSupervisorImpl(
167167
env.rpcEnv.stop(endpoint)
168168
}
169169

170-
override protected def onReceiverStart() {
170+
override protected def onReceiverStart(): Boolean = {
171171
val msg = RegisterReceiver(
172172
streamId, receiver.getClass.getSimpleName, Utils.localHostName(), endpoint)
173173
trackerEndpoint.askWithRetry[Boolean](msg)

streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala

Lines changed: 63 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,41 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
6767
)
6868
private val listenerBus = ssc.scheduler.listenerBus
6969

70+
/** Enumeration to identify current state of the ReceiverTracker */
71+
object TrackerState extends Enumeration {
72+
type CheckpointState = Value
73+
val Initialized, Started, Stopping, Stopped = Value
74+
}
75+
import TrackerState._
76+
77+
/** State of the tracker. Protected by "trackerStateLock" */
78+
private var trackerState = Initialized
79+
80+
/** "trackerStateLock" is used to protect reading/writing "trackerState" */
81+
private val trackerStateLock = new AnyRef
82+
7083
// endpoint is created when generator starts.
7184
// This not being null means the tracker has been started and not stopped
7285
private var endpoint: RpcEndpointRef = null
7386

87+
/** Check if tracker has been marked for starting */
88+
private def isTrackerStarted(): Boolean = trackerStateLock.synchronized {
89+
trackerState == Started
90+
}
91+
92+
/** Check if tracker has been marked for stopping */
93+
private def isTrackerStopping(): Boolean = trackerStateLock.synchronized {
94+
trackerState == Stopping
95+
}
96+
97+
/** Check if tracker has been marked for stopped */
98+
private def isTrackerStopped(): Boolean = trackerStateLock.synchronized {
99+
trackerState == Stopped
100+
}
101+
74102
/** Start the endpoint and receiver execution thread. */
75103
def start(): Unit = synchronized {
76-
if (endpoint != null) {
104+
if (isTrackerStarted) {
77105
throw new SparkException("ReceiverTracker already started")
78106
}
79107

@@ -82,20 +110,29 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
82110
"ReceiverTracker", new ReceiverTrackerEndpoint(ssc.env.rpcEnv))
83111
if (!skipReceiverLaunch) receiverExecutor.start()
84112
logInfo("ReceiverTracker started")
113+
trackerStateLock.synchronized {
114+
trackerState = Started
115+
}
85116
}
86117
}
87118

88119
/** Stop the receiver execution thread. */
89120
def stop(graceful: Boolean): Unit = synchronized {
90-
if (!receiverInputStreams.isEmpty && endpoint != null) {
121+
if (isTrackerStarted) {
91122
// First, stop the receivers
123+
trackerStateLock.synchronized {
124+
trackerState = Stopping
125+
}
92126
if (!skipReceiverLaunch) receiverExecutor.stop(graceful)
93127

94128
// Finally, stop the endpoint
95129
ssc.env.rpcEnv.stop(endpoint)
96130
endpoint = null
97131
receivedBlockTracker.stop()
98132
logInfo("ReceiverTracker stopped")
133+
trackerStateLock.synchronized {
134+
trackerState = Stopped
135+
}
99136
}
100137
}
101138

@@ -141,14 +178,24 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
141178
host: String,
142179
receiverEndpoint: RpcEndpointRef,
143180
senderAddress: RpcAddress
144-
) {
181+
): Boolean = {
145182
if (!receiverInputStreamIds.contains(streamId)) {
146183
throw new SparkException("Register received for unexpected id " + streamId)
147184
}
148-
receiverInfo(streamId) = ReceiverInfo(
149-
streamId, s"${typ}-${streamId}", receiverEndpoint, true, host)
150-
listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId)))
151-
logInfo("Registered receiver for stream " + streamId + " from " + senderAddress)
185+
186+
trackerStateLock.synchronized {
187+
if (isTrackerStopping || isTrackerStopped) {
188+
false
189+
} else {
190+
// When updating "receiverInfo", we should make sure "trackerState" won't be changed at the
191+
// same time. Therefore the following line should be in "trackerStateLock.synchronized".
192+
receiverInfo(streamId) = ReceiverInfo(
193+
streamId, s"${typ}-${streamId}", receiverEndpoint, true, host)
194+
listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId)))
195+
logInfo("Registered receiver for stream " + streamId + " from " + senderAddress)
196+
true
197+
}
198+
}
152199
}
153200

154201
/** Deregister a receiver */
@@ -216,8 +263,9 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
216263

217264
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
218265
case RegisterReceiver(streamId, typ, host, receiverEndpoint) =>
219-
registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address)
220-
context.reply(true)
266+
val successful =
267+
registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address)
268+
context.reply(successful)
221269
case AddBlock(receivedBlockInfo) =>
222270
context.reply(addBlock(receivedBlockInfo))
223271
case DeregisterReceiver(streamId, message, error) =>
@@ -317,9 +365,12 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
317365
// Distribute the receivers and start them
318366
logInfo("Starting " + receivers.length + " receivers")
319367
running = true
320-
ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver))
321-
running = false
322-
logInfo("All of the receivers have been terminated")
368+
try {
369+
ssc.sparkContext.runJob(tempRDD, ssc.sparkContext.clean(startReceiver))
370+
logInfo("All of the receivers have been terminated")
371+
} finally {
372+
running = false
373+
}
323374
}
324375

325376
/** Stops the receivers. */

streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,19 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
255255
}
256256
}
257257

258+
test("stop gracefully even if a receiver misses StopReceiver") {
259+
val conf = new SparkConf().setMaster(master).setAppName(appName)
260+
sc = new SparkContext(conf)
261+
ssc = new StreamingContext(sc, Milliseconds(100))
262+
val input = ssc.receiverStream(new TestReceiver)
263+
input.foreachRDD(_ => {})
264+
ssc.start()
265+
// Call `ssc.stop` at once so that it's possible that the receiver will miss "StopReceiver"
266+
failAfter(30000 millis) {
267+
ssc.stop(stopSparkContext = true, stopGracefully = true)
268+
}
269+
}
270+
258271
test("stop slow receiver gracefully") {
259272
val conf = new SparkConf().setMaster(master).setAppName(appName)
260273
conf.set("spark.streaming.gracefulStopTimeout", "20000s")

0 commit comments

Comments
 (0)