diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d7124616d3bfb..af7a1345ca034 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -874,7 +874,6 @@ class SparkContext( /** Shut down the SparkContext. */ def stop() { ui.stop() - eventLogger.foreach(_.stop()) // Do this only if not stopped already - best case effort. // prevent NPE if stopped more than once. val dagSchedulerCopy = dagScheduler @@ -883,13 +882,14 @@ class SparkContext( metadataCleaner.cancel() cleaner.foreach(_.stop()) dagSchedulerCopy.stop() - listenerBus.stop() taskScheduler = null // TODO: Cache.stop()? env.stop() SparkEnv.set(null) ShuffleMapTask.clearCache() ResultTask.clearCache() + listenerBus.stop() + eventLogger.foreach(_.stop()) logInfo("Successfully stopped SparkContext") } else { logInfo("SparkContext already stopped") diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 353a48661b0f7..a4901c89624d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -36,6 +36,19 @@ private[spark] class LiveListenerBus extends SparkListenerBus with Logging { private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY) private var queueFullErrorMessageLogged = false private var started = false + private val listenerThread = new Thread("SparkListenerBus") { + setDaemon(true) + override def run() { + while (true) { + val event = eventQueue.take + if (event == SparkListenerShutdown) { + // Get out of the while loop and shutdown the daemon thread + return + } + postToAll(event) + } + } + } /** * Start sending events to attached listeners. @@ -48,20 +61,8 @@ private[spark] class LiveListenerBus extends SparkListenerBus with Logging { if (started) { throw new IllegalStateException("Listener bus already started!") } + listenerThread.start() started = true - new Thread("SparkListenerBus") { - setDaemon(true) - override def run() { - while (true) { - val event = eventQueue.take - if (event == SparkListenerShutdown) { - // Get out of the while loop and shutdown the daemon thread - return - } - postToAll(event) - } - } - }.start() } def post(event: SparkListenerEvent) { @@ -97,5 +98,6 @@ private[spark] class LiveListenerBus extends SparkListenerBus with Logging { throw new IllegalStateException("Attempted to stop a listener bus that has not yet started!") } post(SparkListenerShutdown) + listenerThread.join() } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 7c843772bc2e0..acf45d07bdfa9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import java.util.concurrent.Semaphore + import scala.collection.mutable import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} @@ -72,6 +74,53 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } } + test("bus.stop() waits for the event queue to completely drain") { + @volatile var drained = false + + class BlockingListener(cond: AnyRef) extends SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd) = { + cond.synchronized { cond.wait() } + drained = true + } + } + + val bus = new LiveListenerBus + val blockingListener = new BlockingListener(bus) + val sem = new Semaphore(0) + + bus.addListener(blockingListener) + bus.post(SparkListenerJobEnd(0, JobSucceeded)) + bus.start() + // the queue should not drain immediately + assert(!drained) + + new Thread("ListenerBusStopper") { + override def run() { + // stop() would block until notify() is called below + bus.stop() + sem.release() + } + }.start() + + val startTime = System.currentTimeMillis() + val waitTime = 100 + var done = false + while (!done) { + if (System.currentTimeMillis() > startTime + waitTime) { + bus.synchronized { + bus.notify() + } + done = true + } else { + Thread.sleep(10) + // bus.stop() should wait until the event queue is drained + assert(!drained) + } + } + sem.acquire() + assert(drained) + } + test("basic creation of StageInfo") { val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index e698b9bf376e1..038afbcba80a3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -73,6 +73,6 @@ object SparkHdfsLR { } println("Final w: " + w) - System.exit(0) + sc.stop() } }