Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1757,8 +1757,26 @@ class SparkContext(config: SparkConf) extends Logging {
*/
def listJars(): Seq[String] = addedJars.keySet.toSeq

// Shut down the SparkContext.
def stop() {
/**
* Shut down the SparkContext.
*/
def stop(): Unit = {
if (env.rpcEnv.isInRPCThread) {
// `stop` will block until all RPC threads exit, so we cannot call stop inside a RPC thread.
// We should launch a new thread to call `stop` to avoid dead-lock.
new Thread("stop-spark-context") {
setDaemon(true)

override def run(): Unit = {
_stop()
}
}.start()
} else {
_stop()
}
}

private def _stop() {
if (LiveListenerBus.withinListenerThread.value) {
throw new SparkException(
s"Cannot stop SparkContext within listener thread of ${LiveListenerBus.name}")
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ private[spark] abstract class RpcEnv(conf: SparkConf) {
*/
def openChannel(uri: String): ReadableByteChannel

/**
* Return if the current thread is a RPC thread.
*/
def isInRPCThread: Boolean
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging {
/** Message loop used for dispatching messages. */
private class MessageLoop extends Runnable {
override def run(): Unit = {
NettyRpcEnv.rpcThreadFlag.value = true
try {
while (true) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,10 +408,13 @@ private[netty] class NettyRpcEnv(

}

override def isInRPCThread: Boolean = NettyRpcEnv.rpcThreadFlag.value
}

private[netty] object NettyRpcEnv extends Logging {

private[netty] val rpcThreadFlag = new DynamicVariable[Boolean](false)

/**
* When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]].
* Use `currentEnv` to wrap the deserialization codes. E.g.,
Expand Down
13 changes: 13 additions & 0 deletions core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,19 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll {
verify(endpoint, never()).onDisconnected(any())
verify(endpoint, never()).onNetworkError(any(), any())
}

test("isInRPCThread") {
val rpcEndpointRef = env.setupEndpoint("isInRPCThread", new RpcEndpoint {
override val rpcEnv = env

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case m => context.reply(rpcEnv.isInRPCThread)
}
})
assert(rpcEndpointRef.askWithRetry[Boolean]("hello") === true)
assert(env.isInRPCThread === false)
env.stop(rpcEndpointRef)
}
}

class UnserializableClass
Expand Down