Skip to content

Commit a9acfbf

Browse files
committed
Merge branch 'squash-pr-6294' into receiver-scheduling
2 parents 881edb9 + e530bcc commit a9acfbf

File tree

4 files changed

+110
-39
lines changed

4 files changed

+110
-39
lines changed

streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.concurrent.CountDownLatch
2222

2323
import scala.collection.mutable.ArrayBuffer
2424
import scala.concurrent._
25+
import scala.util.control.NonFatal
2526

2627
import org.apache.spark.{Logging, SparkConf}
2728
import org.apache.spark.storage.StreamBlockId
@@ -36,7 +37,7 @@ private[streaming] abstract class ReceiverSupervisor(
3637
conf: SparkConf
3738
) extends Logging {
3839

39-
/** Enumeration to identify current state of the StreamingContext */
40+
/** Enumeration to identify current state of the Receiver */
4041
object ReceiverState extends Enumeration {
4142
type CheckpointState = Value
4243
val Initialized, Started, Stopped = Value
@@ -99,8 +100,8 @@ private[streaming] abstract class ReceiverSupervisor(
99100
/** Called when supervisor is stopped */
100101
protected def onStop(message: String, error: Option[Throwable]) { }
101102

102-
/** Called when receiver is started */
103-
protected def onReceiverStart() { }
103+
/** Called when receiver is started. Return if the driver accepts us */
104+
protected def onReceiverStart(): Boolean = true
104105

105106
/** Called when receiver is stopped */
106107
protected def onReceiverStop(message: String, error: Option[Throwable]) { }
@@ -123,13 +124,17 @@ private[streaming] abstract class ReceiverSupervisor(
123124
/** Start receiver */
124125
def startReceiver(): Unit = synchronized {
125126
try {
126-
logInfo("Starting receiver")
127-
receiver.onStart()
128-
logInfo("Called receiver onStart")
129-
onReceiverStart()
130-
receiverState = Started
127+
if (onReceiverStart()) {
128+
logInfo("Starting receiver")
129+
receiverState = Started
130+
receiver.onStart()
131+
logInfo("Called receiver onStart")
132+
} else {
133+
// The driver refused us
134+
stop("Registered unsuccessfully because the driver refused" + streamId, None)
135+
}
131136
} catch {
132-
case t: Throwable =>
137+
case NonFatal(t) =>
133138
stop("Error starting receiver " + streamId, Some(t))
134139
}
135140
}
@@ -138,12 +143,19 @@ private[streaming] abstract class ReceiverSupervisor(
138143
def stopReceiver(message: String, error: Option[Throwable]): Unit = synchronized {
139144
try {
140145
logInfo("Stopping receiver with message: " + message + ": " + error.getOrElse(""))
141-
receiverState = Stopped
142-
receiver.onStop()
143-
logInfo("Called receiver onStop")
144-
onReceiverStop(message, error)
146+
receiverState match {
147+
case Initialized =>
148+
logWarning("Skip stopping receiver because it has not yet stared")
149+
case Started =>
150+
receiverState = Stopped
151+
receiver.onStop()
152+
logInfo("Called receiver onStop")
153+
onReceiverStop(message, error)
154+
case Stopped =>
155+
logWarning("Receiver has been stopped")
156+
}
145157
} catch {
146-
case t: Throwable =>
158+
case NonFatal(t) =>
147159
logError("Error stopping receiver " + streamId + t.getStackTraceString)
148160
}
149161
}
@@ -177,7 +189,7 @@ private[streaming] abstract class ReceiverSupervisor(
177189
/** Return a list of candidate executors to run the receiver */
178190
def getAllowedLocations(): Seq[String] = Seq.empty
179191

180-
/** Check if receiver has been marked for stopping */
192+
/** Check if receiver has been marked for starting */
181193
def isReceiverStarted(): Boolean = {
182194
logDebug("state = " + receiverState)
183195
receiverState == Started

streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ private[streaming] class ReceiverSupervisorImpl(
162162
env.rpcEnv.stop(endpoint)
163163
}
164164

165-
override protected def onReceiverStart() {
165+
override protected def onReceiverStart(): Boolean = {
166166
val msg = RegisterReceiver(
167167
streamId, receiver.getClass.getSimpleName, host, endpoint)
168168
trackerEndpoint.askWithRetry[Boolean](msg)

streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,26 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
8585
)
8686
private val listenerBus = ssc.scheduler.listenerBus
8787

88+
/** Enumeration to identify current state of the ReceiverTracker */
89+
object TrackerState extends Enumeration {
90+
type CheckpointState = Value
91+
val Initialized, Started, Stopping, Stopped = Value
92+
}
93+
import TrackerState._
94+
95+
/** State of the tracker. Protected by "trackerStateLock" */
96+
private var trackerState = Initialized
97+
98+
/** "trackerStateLock" is used to protect reading/writing "trackerState" */
99+
private val trackerStateLock = new AnyRef
100+
88101
// endpoint is created when generator starts.
89102
// This not being null means the tracker has been started and not stopped
90103
private var endpoint: RpcEndpointRef = null
91104

92105
private val schedulingPolicy: ReceiverSchedulingPolicy =
93106
new LoadBalanceReceiverSchedulingPolicyImpl()
94107

95-
@volatile private var stopping = false
96-
97108
/**
98109
* Track receivers' status for scheduling
99110
*/
@@ -107,9 +118,24 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
107118
/** Use a separate lock to avoid dead-lock */
108119
private val receiverTrackingInfosLock = new AnyRef
109120

121+
/** Check if tracker has been marked for starting */
122+
private def isTrackerStarted(): Boolean = trackerStateLock.synchronized {
123+
trackerState == Started
124+
}
125+
126+
/** Check if tracker has been marked for stopping */
127+
private def isTrackerStopping(): Boolean = trackerStateLock.synchronized {
128+
trackerState == Stopping
129+
}
130+
131+
/** Check if tracker has been marked for stopped */
132+
private def isTrackerStopped(): Boolean = trackerStateLock.synchronized {
133+
trackerState == Stopped
134+
}
135+
110136
/** Start the endpoint and receiver execution thread. */
111137
def start(): Unit = synchronized {
112-
if (endpoint != null) {
138+
if (isTrackerStarted) {
113139
throw new SparkException("ReceiverTracker already started")
114140
}
115141

@@ -118,20 +144,29 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
118144
"ReceiverTracker", new ReceiverTrackerEndpoint(ssc.env.rpcEnv))
119145
if (!skipReceiverLaunch) receiverExecutor.start()
120146
logInfo("ReceiverTracker started")
147+
trackerStateLock.synchronized {
148+
trackerState = Started
149+
}
121150
}
122151
}
123152

124153
/** Stop the receiver execution thread. */
125154
def stop(graceful: Boolean): Unit = synchronized {
126-
if (!receiverInputStreams.isEmpty && endpoint != null) {
155+
if (isTrackerStarted) {
127156
// First, stop the receivers
157+
trackerStateLock.synchronized {
158+
trackerState = Stopping
159+
}
128160
if (!skipReceiverLaunch) receiverExecutor.stop(graceful)
129161

130162
// Finally, stop the endpoint
131163
ssc.env.rpcEnv.stop(endpoint)
132164
endpoint = null
133165
receivedBlockTracker.stop()
134166
logInfo("ReceiverTracker stopped")
167+
trackerStateLock.synchronized {
168+
trackerState = Stopped
169+
}
135170
}
136171
}
137172

@@ -177,15 +212,24 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
177212
host: String,
178213
receiverEndpoint: RpcEndpointRef,
179214
senderAddress: RpcAddress
180-
) {
215+
): Boolean = {
181216
if (!receiverInputStreamIds.contains(streamId)) {
182217
throw new SparkException("Register received for unexpected id " + streamId)
183218
}
184-
receiverInfo(streamId) = ReceiverInfo(
185-
streamId, s"${typ}-${streamId}", receiverEndpoint, true, host)
186-
updateReceiverRunningLocation(streamId, host)
187-
listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId)))
188-
logInfo("Registered receiver for stream " + streamId + " from " + senderAddress)
219+
trackerStateLock.synchronized {
220+
if (isTrackerStopping || isTrackerStopped) {
221+
false
222+
} else {
223+
// When updating "receiverInfo", we should make sure "trackerState" won't be changed at the
224+
// same time. Therefore the following line should be in "trackerStateLock.synchronized".
225+
receiverInfo(streamId) = ReceiverInfo(
226+
streamId, s"${typ}-${streamId}", receiverEndpoint, true, host)
227+
updateReceiverRunningLocation(streamId, host)
228+
listenerBus.post(StreamingListenerReceiverStarted(receiverInfo(streamId)))
229+
logInfo("Registered receiver for stream " + streamId + " from " + senderAddress)
230+
true
231+
}
232+
}
189233
}
190234

191235
/** Deregister a receiver */
@@ -253,8 +297,9 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
253297

254298
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
255299
case RegisterReceiver(streamId, typ, host, receiverEndpoint) =>
256-
registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address)
257-
context.reply(true)
300+
val successful =
301+
registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address)
302+
context.reply(successful)
258303
case AddBlock(receivedBlockInfo) =>
259304
context.reply(addBlock(receivedBlockInfo))
260305
case DeregisterReceiver(streamId, message, error) =>
@@ -285,8 +330,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
285330
}
286331

287332
def stop(graceful: Boolean) {
288-
stopping = true
289-
290333
// Send the stop signal to all the receivers
291334
stopReceivers()
292335

@@ -389,7 +432,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
389432
for (receiver <- receivers) {
390433
submitJobThread.execute(new Runnable {
391434
override def run(): Unit = {
392-
if (stopping) {
435+
if (isTrackerStopping()) {
393436
receiverExitLatch.countDown()
394437
return
395438
}
@@ -409,17 +452,17 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
409452
ssc.sc.makeRDD(Seq(receiver -> scheduledLocations))
410453
}
411454
val future = ssc.sparkContext.submitJob[Receiver[_], Unit, Unit](
412-
receiverRDD, startReceiver, (_, _) => Unit, ())
455+
receiverRDD, startReceiver, Seq(0), (_, _) => Unit, ())
413456
future.onComplete {
414457
case Success(_) =>
415-
if (stopping) {
458+
if (isTrackerStopping()) {
416459
receiverExitLatch.countDown()
417460
} else {
418461
logInfo(s"Restarting Receiver $receiverId")
419462
submitJobThread.execute(self)
420463
}
421464
case Failure(e) =>
422-
if (stopping) {
465+
if (isTrackerStopping()) {
423466
receiverExitLatch.countDown()
424467
} else {
425468
logError("Receiver has been stopped. Try to restart it.", e)
@@ -431,11 +474,14 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
431474
}
432475
})
433476
}
434-
// Wait until all receivers exit
435-
receiverExitLatch.await()
436-
running = false
437-
logInfo("All of the receivers have been terminated")
438-
submitJobThread.shutdownNow()
477+
try {
478+
// Wait until all receivers exit
479+
receiverExitLatch.await()
480+
logInfo("All of the receivers have been terminated")
481+
} finally {
482+
running = false
483+
submitJobThread.shutdownNow()
484+
}
439485
}
440486

441487
/** Stops the receivers. */

streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,19 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
273273
}
274274
}
275275

276+
test("stop gracefully even if a receiver misses StopReceiver") {
277+
val conf = new SparkConf().setMaster(master).setAppName(appName)
278+
sc = new SparkContext(conf)
279+
ssc = new StreamingContext(sc, Milliseconds(100))
280+
val input = ssc.receiverStream(new TestReceiver)
281+
input.foreachRDD(_ => {})
282+
ssc.start()
283+
// Call `ssc.stop` at once so that it's possible that the receiver will miss "StopReceiver"
284+
failAfter(30000 millis) {
285+
ssc.stop(stopSparkContext = true, stopGracefully = true)
286+
}
287+
}
288+
276289
test("stop slow receiver gracefully") {
277290
val conf = new SparkConf().setMaster(master).setAppName(appName)
278291
conf.set("spark.streaming.gracefulStopTimeout", "20000s")

0 commit comments

Comments
 (0)