Skip to content

Commit 737f071

Browse files
tedyutdas
authored andcommitted
[SPARK-11761] Prevent the call to StreamingContext#stop() in the listener bus's thread
See discussion toward the tail of #9723 From zsxwing : ``` The user should not call stop or other long-time work in a listener since it will block the listener thread, and prevent from stopping SparkContext/StreamingContext. I cannot see an approach since we need to stop the listener bus's thread before stopping SparkContext/StreamingContext totally. ``` Proposed solution is to prevent the call to StreamingContext#stop() in the listener bus's thread. Author: tedyu <[email protected]> Closes #9741 from tedyu/master. (cherry picked from commit 446738e) Signed-off-by: Tathagata Das <[email protected]>
1 parent c13f723 commit 737f071

File tree

3 files changed

+67
-19
lines changed

3 files changed

+67
-19
lines changed

core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.util
1919

2020
import java.util.concurrent._
2121
import java.util.concurrent.atomic.AtomicBoolean
22+
import scala.util.DynamicVariable
2223

2324
import org.apache.spark.SparkContext
2425

@@ -60,25 +61,27 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri
6061
private val listenerThread = new Thread(name) {
6162
setDaemon(true)
6263
override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) {
63-
while (true) {
64-
eventLock.acquire()
65-
self.synchronized {
66-
processingEvent = true
67-
}
68-
try {
69-
val event = eventQueue.poll
70-
if (event == null) {
71-
// Get out of the while loop and shutdown the daemon thread
72-
if (!stopped.get) {
73-
throw new IllegalStateException("Polling `null` from eventQueue means" +
74-
" the listener bus has been stopped. So `stopped` must be true")
75-
}
76-
return
77-
}
78-
postToAll(event)
79-
} finally {
64+
AsynchronousListenerBus.withinListenerThread.withValue(true) {
65+
while (true) {
66+
eventLock.acquire()
8067
self.synchronized {
81-
processingEvent = false
68+
processingEvent = true
69+
}
70+
try {
71+
val event = eventQueue.poll
72+
if (event == null) {
73+
// Get out of the while loop and shutdown the daemon thread
74+
if (!stopped.get) {
75+
throw new IllegalStateException("Polling `null` from eventQueue means" +
76+
" the listener bus has been stopped. So `stopped` must be true")
77+
}
78+
return
79+
}
80+
postToAll(event)
81+
} finally {
82+
self.synchronized {
83+
processingEvent = false
84+
}
8285
}
8386
}
8487
}
@@ -177,3 +180,10 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri
177180
*/
178181
def onDropEvent(event: E): Unit
179182
}
183+
184+
private[spark] object AsynchronousListenerBus {
185+
/* Allows for Context to check whether stop() call is made within listener thread
186+
*/
187+
val withinListenerThread: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false)
188+
}
189+

streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import org.apache.spark.streaming.dstream._
4444
import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver}
4545
import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener}
4646
import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab}
47-
import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils, Utils}
47+
import org.apache.spark.util.{AsynchronousListenerBus, CallSite, ShutdownHookManager, ThreadUtils, Utils}
4848

4949
/**
5050
* Main entry point for Spark Streaming functionality. It provides methods used to create
@@ -693,6 +693,10 @@ class StreamingContext private[streaming] (
693693
*/
694694
def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = {
695695
var shutdownHookRefToRemove: AnyRef = null
696+
if (AsynchronousListenerBus.withinListenerThread.value) {
697+
throw new SparkException("Cannot stop StreamingContext within listener thread of" +
698+
" AsynchronousListenerBus")
699+
}
696700
synchronized {
697701
try {
698702
state match {

streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedBuffer, Synch
2121
import scala.concurrent.Future
2222
import scala.concurrent.ExecutionContext.Implicits.global
2323

24+
import org.apache.spark.SparkException
2425
import org.apache.spark.storage.StorageLevel
2526
import org.apache.spark.streaming.dstream.DStream
2627
import org.apache.spark.streaming.receiver.Receiver
@@ -161,6 +162,14 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
161162
}
162163
}
163164

165+
test("don't call ssc.stop in listener") {
166+
ssc = new StreamingContext("local[2]", "ssc", Milliseconds(1000))
167+
val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver)
168+
inputStream.foreachRDD(_.count)
169+
170+
startStreamingContextAndCallStop(ssc)
171+
}
172+
164173
test("onBatchCompleted with successful batch") {
165174
ssc = new StreamingContext("local[2]", "test", Milliseconds(1000))
166175
val inputStream = ssc.receiverStream(new StreamingListenerSuiteReceiver)
@@ -207,6 +216,17 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers {
207216
assert(failureReasons(1).contains("This is another failed job"))
208217
}
209218

219+
private def startStreamingContextAndCallStop(_ssc: StreamingContext): Unit = {
220+
val contextStoppingCollector = new StreamingContextStoppingCollector(_ssc)
221+
_ssc.addStreamingListener(contextStoppingCollector)
222+
val batchCounter = new BatchCounter(_ssc)
223+
_ssc.start()
224+
// Make sure running at least one batch
225+
batchCounter.waitUntilBatchesCompleted(expectedNumCompletedBatches = 1, timeout = 10000)
226+
_ssc.stop()
227+
assert(contextStoppingCollector.sparkExSeen)
228+
}
229+
210230
private def startStreamingContextAndCollectFailureReasons(
211231
_ssc: StreamingContext, isFailed: Boolean = false): Map[Int, String] = {
212232
val failureReasonsCollector = new FailureReasonsCollector()
@@ -320,3 +340,17 @@ class FailureReasonsCollector extends StreamingListener {
320340
}
321341
}
322342
}
343+
/**
344+
* A StreamingListener that calls StreamingContext.stop().
345+
*/
346+
class StreamingContextStoppingCollector(val ssc: StreamingContext) extends StreamingListener {
347+
@volatile var sparkExSeen = false
348+
override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
349+
try {
350+
ssc.stop()
351+
} catch {
352+
case se: SparkException =>
353+
sparkExSeen = true
354+
}
355+
}
356+
}

0 commit comments

Comments
 (0)