Skip to content

Commit 26432df

Browse files
committed
[SPARK-18751][CORE] Fix deadlock when SparkContext.stop is called in Utils.tryOrStopSparkContext
## What changes were proposed in this pull request? When `SparkContext.stop` is called in `Utils.tryOrStopSparkContext` (the following three places), it will cause deadlock because the `stop` method needs to wait for the thread running `stop` to exit. - ContextCleaner.keepCleaning - LiveListenerBus.listenerThread.run - TaskSchedulerImpl.start This PR adds `SparkContext.stopInNewThread` and uses it to eliminate the potential deadlock. I also removed my changes in #15775 since they are not necessary now. ## How was this patch tested? Jenkins Author: Shixiong Zhu <[email protected]> Closes #16178 from zsxwing/fix-stop-deadlock.
1 parent c3d3a9d commit 26432df

File tree

8 files changed

+23
-42
lines changed

8 files changed

+23
-42
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,25 +1760,30 @@ class SparkContext(config: SparkConf) extends Logging {
17601760
def listJars(): Seq[String] = addedJars.keySet.toSeq
17611761

17621762
/**
1763-
* Shut down the SparkContext.
1763+
* When stopping SparkContext inside Spark components, it's easy to cause dead-lock since Spark
1764+
* may wait for some internal threads to finish. It's better to use this method to stop
1765+
* SparkContext instead.
17641766
*/
1765-
def stop(): Unit = {
1766-
if (env.rpcEnv.isInRPCThread) {
1767-
// `stop` will block until all RPC threads exit, so we cannot call stop inside a RPC thread.
1768-
// We should launch a new thread to call `stop` to avoid dead-lock.
1769-
new Thread("stop-spark-context") {
1770-
setDaemon(true)
1771-
1772-
override def run(): Unit = {
1773-
_stop()
1767+
private[spark] def stopInNewThread(): Unit = {
1768+
new Thread("stop-spark-context") {
1769+
setDaemon(true)
1770+
1771+
override def run(): Unit = {
1772+
try {
1773+
SparkContext.this.stop()
1774+
} catch {
1775+
case e: Throwable =>
1776+
logError(e.getMessage, e)
1777+
throw e
17741778
}
1775-
}.start()
1776-
} else {
1777-
_stop()
1778-
}
1779+
}
1780+
}.start()
17791781
}
17801782

1781-
private def _stop() {
1783+
/**
1784+
* Shut down the SparkContext.
1785+
*/
1786+
def stop(): Unit = {
17821787
if (LiveListenerBus.withinListenerThread.value) {
17831788
throw new SparkException(
17841789
s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}")

core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
146146
* @param uri URI with location of the file.
147147
*/
148148
def openChannel(uri: String): ReadableByteChannel
149-
150-
/**
151-
* Return if the current thread is a RPC thread.
152-
*/
153-
def isInRPCThread: Boolean
154149
}
155150

156151
/**

core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
201201
/** Message loop used for dispatching messages. */
202202
private class MessageLoop extends Runnable {
203203
override def run(): Unit = {
204-
NettyRpcEnv.rpcThreadFlag.value = true
205204
try {
206205
while (true) {
207206
try {

core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -407,14 +407,9 @@ private[netty] class NettyRpcEnv(
407407
}
408408

409409
}
410-
411-
override def isInRPCThread: Boolean = NettyRpcEnv.rpcThreadFlag.value
412410
}
413411

414412
private[netty] object NettyRpcEnv extends Logging {
415-
416-
private[netty] val rpcThreadFlag = new DynamicVariable[Boolean](false)
417-
418413
/**
419414
* When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]].
420415
* Use `currentEnv` to wrap the deserialization codes. E.g.,

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1661,7 +1661,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
16611661
} catch {
16621662
case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
16631663
}
1664-
dagScheduler.sc.stop()
1664+
dagScheduler.sc.stopInNewThread()
16651665
}
16661666

16671667
override def onStop(): Unit = {

core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ private[spark] class StandaloneSchedulerBackend(
139139
scheduler.error(reason)
140140
} finally {
141141
// Ensure the application terminates, as we can no longer run jobs.
142-
sc.stop()
142+
sc.stopInNewThread()
143143
}
144144
}
145145
}

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1249,7 +1249,7 @@ private[spark] object Utils extends Logging {
12491249
val currentThreadName = Thread.currentThread().getName
12501250
if (sc != null) {
12511251
logError(s"uncaught error in thread $currentThreadName, stopping SparkContext", t)
1252-
sc.stop()
1252+
sc.stopInNewThread()
12531253
}
12541254
if (!NonFatal(t)) {
12551255
logError(s"throw uncaught fatal error in thread $currentThreadName", t)

core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -870,19 +870,6 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
870870
verify(endpoint, never()).onDisconnected(any())
871871
verify(endpoint, never()).onNetworkError(any(), any())
872872
}
873-
874-
test("isInRPCThread") {
875-
val rpcEndpointRef = env.setupEndpoint("isInRPCThread", new RpcEndpoint {
876-
override val rpcEnv = env
877-
878-
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
879-
case m => context.reply(rpcEnv.isInRPCThread)
880-
}
881-
})
882-
assert(rpcEndpointRef.askWithRetry[Boolean]("hello") === true)
883-
assert(env.isInRPCThread === false)
884-
env.stop(rpcEndpointRef)
885-
}
886873
}
887874

888875
class UnserializableClass

0 commit comments

Comments
 (0)