diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala index b3b54af972cb4..e753b6683fec2 100644 --- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -56,22 +56,30 @@ private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: Stri // A counter that represents the number of events produced and consumed in the queue private val eventLock = new Semaphore(0) + // limit on the number of events to process before exiting. -1 means no limit + private var eventLimit = -1 private val listenerThread = new Thread(name) { setDaemon(true) override def run(): Unit = Utils.tryOrStopSparkContext(sparkContext) { - while (true) { + while (eventLimit != 0) { eventLock.acquire() self.synchronized { processingEvent = true } try { if (stopped.get()) { - // Get out of the while loop and shutdown the daemon thread - return + eventLimit = eventQueue.size + if (eventLimit == 0) { + // Get out of the while loop and shutdown the daemon thread + return + } } val event = eventQueue.poll assert(event != null, "event queue was empty but the listener bus was not stopped") + if (eventLimit > 0) { + eventLimit-=1 + } postToAll(event) } finally { self.synchronized {