Skip to content

Commit ba80eaf

Browse files
committed
[SPARK-18280][CORE] Fix potential deadlock in StandaloneSchedulerBackend.dead
## What changes were proposed in this pull request? "StandaloneSchedulerBackend.dead" is called in a RPC thread, so it should not call "SparkContext.stop" in the same thread. "SparkContext.stop" will block until all RPC threads exit, if it's called inside a RPC thread, it will be dead-lock. This PR add a thread local flag inside RPC threads. `SparkContext.stop` uses it to decide if launching a new thread to stop the SparkContext. ## How was this patch tested? Jenkins Author: Shixiong Zhu <[email protected]> Closes #15775 from zsxwing/SPARK-18280.
1 parent 21bbf94 commit ba80eaf

File tree

5 files changed

+41
-2
lines changed

5 files changed

+41
-2
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,8 +1757,26 @@ class SparkContext(config: SparkConf) extends Logging {
17571757
*/
17581758
def listJars(): Seq[String] = addedJars.keySet.toSeq
17591759

1760-
// Shut down the SparkContext.
1761-
def stop() {
1760+
/**
1761+
* Shut down the SparkContext.
1762+
*/
1763+
def stop(): Unit = {
1764+
if (env.rpcEnv.isInRPCThread) {
1765+
// `stop` will block until all RPC threads exit, so we cannot call stop inside a RPC thread.
1766+
// We should launch a new thread to call `stop` to avoid dead-lock.
1767+
new Thread("stop-spark-context") {
1768+
setDaemon(true)
1769+
1770+
override def run(): Unit = {
1771+
_stop()
1772+
}
1773+
}.start()
1774+
} else {
1775+
_stop()
1776+
}
1777+
}
1778+
1779+
private def _stop() {
17621780
if (LiveListenerBus.withinListenerThread.value) {
17631781
throw new SparkException(
17641782
s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}")

core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,10 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
147147
*/
148148
def openChannel(uri: String): ReadableByteChannel
149149

150+
/**
151+
* Return if the current thread is a RPC thread.
152+
*/
153+
def isInRPCThread: Boolean
150154
}
151155

152156
/**

core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
201201
/** Message loop used for dispatching messages. */
202202
private class MessageLoop extends Runnable {
203203
override def run(): Unit = {
204+
NettyRpcEnv.rpcThreadFlag.value = true
204205
try {
205206
while (true) {
206207
try {

core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,10 +408,13 @@ private[netty] class NettyRpcEnv(
408408

409409
}
410410

411+
override def isInRPCThread: Boolean = NettyRpcEnv.rpcThreadFlag.value
411412
}
412413

413414
private[netty] object NettyRpcEnv extends Logging {
414415

416+
private[netty] val rpcThreadFlag = new DynamicVariable[Boolean](false)
417+
415418
/**
416419
* When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]].
417420
* Use `currentEnv` to wrap the deserialization codes. E.g.,

core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,19 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
870870
verify(endpoint, never()).onDisconnected(any())
871871
verify(endpoint, never()).onNetworkError(any(), any())
872872
}
873+
874+
test("isInRPCThread") {
875+
val rpcEndpointRef = env.setupEndpoint("isInRPCThread", new RpcEndpoint {
876+
override val rpcEnv = env
877+
878+
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
879+
case m => context.reply(rpcEnv.isInRPCThread)
880+
}
881+
})
882+
assert(rpcEndpointRef.askWithRetry[Boolean]("hello") === true)
883+
assert(env.isInRPCThread === false)
884+
env.stop(rpcEndpointRef)
885+
}
873886
}
874887

875888
class UnserializableClass

0 commit comments

Comments
 (0)