From 38acc159e450f3fb86f2e43f040fe77a43e68f38 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 30 Apr 2015 12:53:02 -0700 Subject: [PATCH 01/30] A new RPC implemetation based on the network module --- .../org/apache/spark/MapOutputTracker.scala | 2 +- .../scala/org/apache/spark/SparkEnv.scala | 19 +- .../org/apache/spark/rpc/RpcCallContext.scala | 2 +- .../rpc/RpcEndpointNotFoundException.scala | 22 + .../scala/org/apache/spark/rpc/RpcEnv.scala | 7 +- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 6 +- .../apache/spark/rpc/netty/Dispatcher.scala | 187 ++++++++ .../apache/spark/rpc/netty/IDVerifier.scala | 39 ++ .../org/apache/spark/rpc/netty/Inbox.scala | 171 +++++++ .../spark/rpc/netty/NettyRpcAddress.scala | 56 +++ .../spark/rpc/netty/NettyRpcCallContext.scala | 83 ++++ .../apache/spark/rpc/netty/NettyRpcEnv.scala | 447 ++++++++++++++++++ .../storage/BlockManagerSlaveEndpoint.scala | 2 +- .../apache/spark/MapOutputTrackerSuite.scala | 10 +- .../org/apache/spark/SSLSampleConfigs.scala | 2 + .../org/apache/spark/rpc/RpcEnvSuite.scala | 74 ++- .../apache/spark/rpc/TestRpcEndpoint.scala | 123 +++++ .../apache/spark/rpc/netty/InboxSuite.scala | 144 ++++++ .../rpc/netty/NettyRpcAddressSuite.scala | 29 ++ .../spark/rpc/netty/NettyRpcEnvSuit.scala | 38 ++ .../rpc/netty/NettyRpcHandlerSuite.scala | 67 +++ .../spark/network/client/TransportClient.java | 4 + .../spark/network/server/RpcHandler.java | 2 + .../server/TransportRequestHandler.java | 1 + .../streaming/scheduler/ReceiverTracker.scala | 2 +- 25 files changed, 1517 insertions(+), 22 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala create mode 100644 core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala create mode 100644 core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuit.scala create mode 100644 core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 018422827e1c8..d96b6c5b97d4f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -44,7 +44,7 @@ private[spark] class MapOutputTrackerMasterEndpoint( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GetMapOutputStatuses(shuffleId: Int) => - val hostPort = context.sender.address.hostPort + val hostPort = context.senderAddress.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) val serializedSize = mapOutputStatuses.size diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index a185954089528..5c2f6876ecb23 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.util.Properties +import akka.actor.ActorSystem import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi @@ -41,7 +42,7 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator} -import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} /** * :: DeveloperApi :: @@ -57,6 +58,7 @@ import org.apache.spark.util.{RpcUtils, Utils} class SparkEnv ( val executorId: String, private[spark] val rpcEnv: RpcEnv, + val actorSystem: ActorSystem, // TODO Remove actorSystem val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, @@ -74,9 +76,6 @@ class SparkEnv ( val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { - // TODO Remove actorSystem - val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem - private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -97,6 +96,9 @@ class SparkEnv ( blockManager.master.stop() metricsSystem.stop() outputCommitCoordinator.stop() + if (!rpcEnv.isInstanceOf[AkkaRpcEnv]) { + actorSystem.shutdown() + } rpcEnv.shutdown() // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut @@ -245,7 +247,13 @@ object SparkEnv extends Logging { // Create the ActorSystem for Akka and get the port it binds to. val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager) - val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + val actorSystem: ActorSystem = + if (rpcEnv.isInstanceOf[AkkaRpcEnv]) { + rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + } else { + // Create a ActorSystem for legacy codes + AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)._1 + } // Figure out which port Akka actually bound to in case the original port is 0 or occupied. if (isDriver) { @@ -397,6 +405,7 @@ object SparkEnv extends Logging { val envInstance = new SparkEnv( executorId, rpcEnv, + actorSystem, serializer, closureSerializer, cacheManager, diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala index 3e5b64265e919..f527ec86ab7b2 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala @@ -37,5 +37,5 @@ private[spark] trait RpcCallContext { /** * The sender of this message. */ - def sender: RpcEndpointRef + def senderAddress: RpcAddress } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala new file mode 100644 index 0000000000000..d177881fb3053 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointNotFoundException.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc + +import org.apache.spark.SparkException + +private[rpc] class RpcEndpointNotFoundException(uri: String) + extends SparkException(s"Cannot find endpoint: $uri") diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 12b6b28d4d7ec..fbb23816af163 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -34,8 +34,11 @@ private[spark] object RpcEnv { private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { // Add more RpcEnv implementations here - val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") - val rpcEnvName = conf.get("spark.rpc", "akka") + val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory", + "netty" -> "org.apache.spark.rpc.netty.NettyRpcEnvFactory") + // Use "netty" by default so that Jenkins can run all tests using NettyRpcEnv. + // Will change it back to "akka" before merging the new implementation. + val rpcEnvName = conf.get("spark.rpc", "netty") val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader). newInstance().asInstanceOf[RpcEnvFactory] diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 0161962cde073..937d065ac1556 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -167,9 +167,9 @@ private[spark] class AkkaRpcEnv private[akka] ( _sender ! AkkaMessage(response, false) } - // Some RpcEndpoints need to know the sender's address - override val sender: RpcEndpointRef = - new AkkaRpcEndpointRef(defaultAddress, _sender, conf) + // Use "lazy" because most of RpcEndpoints don't need "senderAddress" + override lazy val senderAddress: RpcAddress = + new AkkaRpcEndpointRef(defaultAddress, _sender, conf).address }) } else { endpoint.receive diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala new file mode 100644 index 0000000000000..7ddf2866c1de8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.util.concurrent.{TimeUnit, Executors, LinkedBlockingQueue, ConcurrentHashMap} + +import org.apache.spark.network.client.RpcResponseCallback + +import scala.concurrent.Promise +import scala.util.control.NonFatal + +import org.apache.spark.{SparkException, Logging} +import org.apache.spark.rpc.{RpcCallContext, RpcAddress, RpcEndpointRef, RpcEndpoint} +import org.apache.spark.util.ThreadUtils + +private class RpcEndpointPair(val endpoint: RpcEndpoint, val endpointRef: NettyRpcEndpointRef) + +private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { + + // the inboxes that are not being used + private val idleInboxes = new ConcurrentHashMap[RpcEndpoint, Inbox]() + + private val endpointToInbox = new ConcurrentHashMap[RpcEndpoint, Inbox]() + + // need a name to RpcEndpoint mapping so that we can delivery the messages + private val nameToEndpoint = new ConcurrentHashMap[String, RpcEndpointPair]() + + private val endpointToEndpointRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() + + // Track the receivers whose inboxes may contain messages. + private val receivers = new LinkedBlockingQueue[RpcEndpoint]() + + @volatile private var stopped = false + + def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = { + val addr = new NettyRpcAddress(nettyEnv.address.host, nettyEnv.address.port, name) + val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv) + nameToEndpoint.put(name, new RpcEndpointPair(endpoint,endpointRef)) + endpointToEndpointRef.put(endpoint, endpointRef) + val inbox = new Inbox(endpointRef, endpoint) + endpointToInbox.put(endpoint, inbox) + idleInboxes.put(endpoint, inbox) + afterUpdateInbox(inbox) + endpointRef + } + + def getRpcEndpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToEndpointRef.get(endpoint) + + def getRpcEndpointRef(name: String): RpcEndpointRef = nameToEndpoint.get(name).endpointRef + + // Should be idempotent + def unregisterRpcEndpoint(name: String): Unit = { + val endpointPair = nameToEndpoint.remove(name) + if (endpointPair != null) { + val inbox = endpointToInbox.remove(endpointPair.endpoint) + if (inbox != null) { + inbox.stop() + afterUpdateInbox(inbox) + } + endpointToEndpointRef.remove(endpointPair.endpoint) + } + } + + def stop(rpcEndpointRef: RpcEndpointRef): Unit = { + unregisterRpcEndpoint(rpcEndpointRef.name) + } + + /** + * Send a message to all registered [[RpcEndpoint]]s. + * @param message + */ + def broadcastMessage(message: BroadcastMessage): Unit = { + val iter = endpointToInbox.values().iterator() + while(iter.hasNext) { + val inbox = iter.next() + postMessageToInbox(inbox, message) + } + } + + def postMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { + val receiver = nameToEndpoint.get(message.receiver.name) + if (receiver != null) { + val inbox = endpointToInbox.get(receiver.endpoint) + if (inbox != null) { + val rpcCallContext = + new RemoteNettyRpcCallContext( + nettyEnv, inbox.endpointRef, callback, message.senderAddress, message.needReply) + postMessageToInbox(inbox, + ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)) + } + } + } + + def postMessage(message: RequestMessage, p: Promise[Any]): Unit = { + val receiver = nameToEndpoint.get(message.receiver.name) + if (receiver != null) { + val inbox = endpointToInbox.get(receiver.endpoint) + if (inbox != null) { + val rpcCallContext = + new LocalNettyRpcCallContext( + inbox.endpointRef, message.senderAddress, message.needReply, p) + postMessageToInbox(inbox, + ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)) + } + } + } + + private def postMessageToInbox(inbox: Inbox, message: InboxMessage): Unit = { + inbox.post(message) + afterUpdateInbox(inbox) + } + + private def afterUpdateInbox(inbox: Inbox): Unit = { + // Do some work to trigger processing messages in the inbox + val endpoint = inbox.endpoint + // Replacing unsuccessfully means someone is processing it + idleInboxes.replace(endpoint, inbox, inbox) + receivers.put(endpoint) + } + + class MessageLoop extends Runnable { + override def run(): Unit = { + try { + while (!stopped) { + try { + val endpoint = receivers.take() + val inbox = idleInboxes.remove(endpoint) + if (inbox != null) { + val inboxStopped = inbox.process(Dispatcher.this) + if (!inboxStopped) { + idleInboxes.put(endpoint, inbox) + if (!inbox.isEmpty) { + receivers.add(endpoint) + } + } + } else { + // other thread is processing endpoint's Inbox + } + } catch { + case NonFatal(e) => logError(e.getMessage, e) + } + } + } catch { + case ie: InterruptedException => // exit + } + } + } + + private val parallelism = Runtime.getRuntime.availableProcessors() + + private val executor = ThreadUtils.newDaemonFixedThreadPool(parallelism, "dispatcher-event-loop") + (0 until parallelism) foreach { _ => + executor.execute(new MessageLoop) + } + + def stop(): Unit = { + stopped = true + executor.shutdownNow() + } + + def awaitTermination(): Unit = { + executor.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS) + } + + /** + * Return if the endpoint exists + */ + def verify(name: String): Boolean = { + nameToEndpoint.containsKey(name) + } + +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala new file mode 100644 index 0000000000000..1c19484283f73 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc.netty + +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, RpcEndpoint} + +/** + * A message used to ask the remote [[IDVerifier]] if an [[RpcEndpoint]] exists + */ +private[netty] case class ID(name: String) + +/** + * An [[RpcEndpoint]] for remote [[RpcEnv]]s to query if a [[RpcEndpoint]] exists in this [[RpcEnv]] + */ +private[netty] class IDVerifier( + override val rpcEnv: RpcEnv, dispatcher: Dispatcher) extends RpcEndpoint { + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case ID(name) => context.reply(dispatcher.verify(name)) + } +} + +private[netty] object IDVerifier { + val NAME = "id-verifier" +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala new file mode 100644 index 0000000000000..0868c75eb0931 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.util.concurrent.ConcurrentLinkedQueue + +import scala.util.control.NonFatal + +import org.apache.spark.{SparkException, Logging} +import org.apache.spark.rpc.{RpcAddress, RpcEndpoint} + +private[netty] sealed trait InboxMessage + +private[netty] case class ContentMessage( + senderAddress: RpcAddress, + content: Any, + needReply: Boolean, + context: NettyRpcCallContext) extends InboxMessage + +/** + * A message type that will be posted to all registered [[RpcEndpoint]] + */ +private[netty] sealed trait BroadcastMessage extends InboxMessage + +private[netty] case object OnStart extends InboxMessage + +private[netty] case object OnStop extends InboxMessage + +private[netty] case class Associated(remoteAddress: RpcAddress) extends BroadcastMessage + +/** + * A broadcast message that indicates + */ +private[netty] case class Disassociated(remoteAddress: RpcAddress) extends BroadcastMessage + +/** + * A broadcast message that indicates a network error + */ +private[netty] case class AssociationError(cause: Throwable, remoteAddress: RpcAddress) + extends BroadcastMessage + +/** + * A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely. + * @param endpointRef + * @param endpoint + */ +private[netty] class Inbox( + val endpointRef: NettyRpcEndpointRef, val endpoint: RpcEndpoint) extends Logging { + + private val messages = new ConcurrentLinkedQueue[InboxMessage]() + + // protected by "this" + private var stopped = false + + // OnStart should be the first message to process + messages.add(OnStart) + + /** + * Process stored messages. Return `true` if the `Inbox` is already stopped, and the caller will + * release all resources used by the `Inbox`. + */ + def process(dispatcher: Dispatcher): Boolean = { + var exit = false + var message = messages.poll() + while (message != null) { + safelyCall(endpoint) { + message match { + case ContentMessage(_sender, content, needReply, context) => + val pf: PartialFunction[Any, Unit] = + if (needReply) { + endpoint.receiveAndReply(context) + } else { + endpoint.receive + } + try { + pf.applyOrElse[Any, Unit](content, { msg => + throw new SparkException(s"Unmatched message $message from ${_sender}") + }) + if (!needReply) { + context.finish() + } + } catch { + case NonFatal(e) => + if (needReply) { + // If the sender asks a reply, we should send the error back to the sender + context.sendFailure(e) + } else { + context.finish() + throw e + } + } + + case OnStart => endpoint.onStart() + case OnStop => + dispatcher.unregisterRpcEndpoint(endpointRef.name) + endpoint.onStop() + assert(isEmpty, "OnStop should be the last message") + exit = true + case Associated(remoteAddress) => + endpoint.onConnected(remoteAddress) + case Disassociated(remoteAddress) => + endpoint.onDisconnected(remoteAddress) + case AssociationError(cause, remoteAddress) => + endpoint.onNetworkError(cause, remoteAddress) + } + } + message = messages.poll() + } + exit + } + + def post(message: InboxMessage): Unit = { + val dropped = + synchronized { + if (stopped) { + // We already put "OnStop" into "messages", so we should drop further messages + true + } else { + messages.add(message) + false + } + } + if (dropped) { + onDrop() + } + } + + def stop(): Unit = synchronized { + // The following codes should be in `synchronized` so that we can make sure "OnStop" is the last + // message + if (!stopped) { + stopped = true + messages.add(OnStop) + } + } + + def isEmpty: Boolean = messages.isEmpty + + protected def onDrop(message: Any): Unit = { + logWarning(s"Drop ${message} because $endpointRef is stopped") + } + + private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { + try { + action + } catch { + case NonFatal(e) => { + try { + endpoint.onError(e) + } catch { + case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e) + } + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala new file mode 100644 index 0000000000000..2f38b1c00f291 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.net.URI + +import org.apache.spark.SparkException +import org.apache.spark.rpc.RpcAddress + +private[netty] case class NettyRpcAddress(host: String, port: Int, name: String) { + + def toRpcAddress: RpcAddress = RpcAddress(host, port) + + override val toString = s"spark://$name@$host:$port" +} + +private[netty] object NettyRpcAddress { + + def apply(sparkUrl: String): NettyRpcAddress = { + try { + val uri = new URI(sparkUrl) + val host = uri.getHost + val port = uri.getPort + val name = uri.getUserInfo + if (uri.getScheme != "spark" || + host == null || + port < 0 || + name == null || + (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null + uri.getFragment != null || + uri.getQuery != null) { + throw new SparkException("Invalid master URL: " + sparkUrl) + } + NettyRpcAddress(host, port, name) + } catch { + case e: java.net.URISyntaxException => + throw new SparkException("Invalid master URL: " + sparkUrl, e) + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala new file mode 100644 index 0000000000000..e885c1ecc9153 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import org.apache.spark.network.client.RpcResponseCallback +import org.apache.spark.rpc.{RpcAddress, RpcCallContext} + +import scala.concurrent.{Promise, Future} + +private[netty] abstract class NettyRpcCallContext( + endpointRef: NettyRpcEndpointRef, + override val senderAddress: RpcAddress,needReply: Boolean) extends RpcCallContext{ + + protected def send(message: Any): Unit + + override def reply(response: Any): Unit = { + if (needReply) { + send(AskResponse(endpointRef, response)) + } else { + throw new IllegalStateException( + s"Cannot send $response to the sender because the sender won't handle it") + } + } + + override def sendFailure(e: Throwable): Unit = { + if (needReply) { + send(AskResponse(endpointRef, RpcFailure(e))) + } else { + throw new IllegalStateException( + "Cannot send reply to the sender because the sender won't handle it", e) + } + } + + def finish(): Unit = { + if (!needReply) { + send(SendAck(endpointRef)) + } + } +} + +/** + * If the sender and the receiver are in the same process, the reply can be sent back via `Promise`. + */ +private[netty] class LocalNettyRpcCallContext( + endpointRef: NettyRpcEndpointRef, + senderAddress: RpcAddress, + needReply: Boolean, + p: Promise[Any]) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + override protected def send(message: Any): Unit = { + p.success(message) + } +} + +/** + * A [[RpcCallContext]] that will call [[RpcResponseCallback]] to send the reply back. + */ +private[netty] class RemoteNettyRpcCallContext( + nettyEnv: NettyRpcEnv, + endpointRef: NettyRpcEndpointRef, + callback: RpcResponseCallback, + senderAddress: RpcAddress, + needReply: Boolean) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + + override protected def send(message: Any): Unit = { + val reply = nettyEnv.serialize(message) + callback.onSuccess(reply) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala new file mode 100644 index 0000000000000..acb59718d6973 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc.netty + +import java.io._ +import java.net.{InetSocketAddress, URI} +import java.nio.ByteBuffer +import java.{util => ju} +import java.util.concurrent._ + +import scala.collection.mutable +import scala.concurrent.duration._ +import scala.concurrent.{Future, Promise} +import scala.reflect.ClassTag +import scala.util.{Failure, Success} +import scala.util.control.NonFatal + +import org.apache.spark.network.TransportContext +import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClient} +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} +import org.apache.spark.network.server._ +import org.apache.spark.rpc._ +import org.apache.spark.serializer.{JavaSerializer, Serializer} +import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.{Logging, SecurityManager, SparkConf} + +private[netty] class NettyRpcEnv( + val conf: SparkConf, serializer: Serializer, host: String, securityManager: SecurityManager) + extends RpcEnv(conf) with Logging { + + private val transportConf = + SparkTransportConf.fromSparkConf(conf, conf.getInt("spark.rpc.io.threads", 0)) + + private val dispatcher: Dispatcher = new Dispatcher(this) + + private val transportContext = + new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this)) + + private val clientFactory = { + val bootstraps: ju.List[TransportClientBootstrap] = + if (securityManager.isAuthenticationEnabled()) { + ju.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, true)) + } else { + ju.Collections.emptyList[TransportClientBootstrap]() + } + transportContext.createClientFactory(bootstraps) + } + + val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") + + @volatile private var server: TransportServer = _ + + def start(port: Int): Unit = { + val bootstraps: ju.List[TransportServerBootstrap] = + if (securityManager.isAuthenticationEnabled()) { + ju.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager)) + } else { + ju.Collections.emptyList[TransportServerBootstrap]() + } + server = transportContext.createServer(port, bootstraps) + dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher)) + } + + override lazy val address: RpcAddress = { + require(server != null, "NettyRpcEnv has not yet started") + RpcAddress(host, server.getPort()) + } + + override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { + dispatcher.registerRpcEndpoint(name, endpoint) + } + + def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { + val addr = NettyRpcAddress(uri) + val endpointRef = new NettyRpcEndpointRef(conf, addr, this) + val idVerifierRef = + new NettyRpcEndpointRef(conf, NettyRpcAddress(addr.host, addr.port, IDVerifier.NAME), this) + idVerifierRef.ask[Boolean](ID(endpointRef.name)).flatMap(find => + if (find) { + Future.successful(endpointRef) + } else { + Future.failed(new RpcEndpointNotFoundException(uri)) + } + )(ThreadUtils.sameThread) + } + + override def stop(endpointRef: RpcEndpointRef): Unit = { + require(endpointRef.isInstanceOf[NettyRpcEndpointRef]) + dispatcher.stop(endpointRef) + } + + private[netty] def send(message: RequestMessage): Unit = { + val remoteAddr = message.receiver.address + if (remoteAddr == address) { + val promise = Promise[Any]() + dispatcher.postMessage(message, promise) + promise.future.onComplete { + case Success(response) => + val ack = response.asInstanceOf[SendAck] + logDebug(s"Receive ack from ${ack.sender}") + case Failure(e) => + logError(s"Exception when sending $message", e) + }(ThreadUtils.sameThread) + } else { + val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port) + client.sendRpc(serialize(message), new RpcResponseCallback { + + override def onFailure(e: Throwable): Unit = { + logError(s"Exception when sending $message", e) + } + + override def onSuccess(response: Array[Byte]): Unit = { + val ack = deserialize[SendAck](response) + logDebug(s"Receive ack from ${ack.sender}") + } + }) + } + } + + private[netty] def ask(message: RequestMessage): Future[Any] = { + val promise = Promise[Any]() + val remoteAddr = message.receiver.address + if (remoteAddr == address) { + val p = Promise[Any]() + dispatcher.postMessage(message, p) + p.future.onComplete { + case Success(response) => + val reply = response.asInstanceOf[AskResponse] + if (reply.reply != null && reply.reply.isInstanceOf[RpcFailure]) { + if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { + logWarning(s"Ignore failure + ${reply.reply}") + } + } + else if (!promise.trySuccess(reply.reply)) { + logWarning(s"Ignore message + ${reply}") + } + case Failure(e) => + if (!promise.tryFailure(e)) { + logWarning("Ignore Exception", e) + } + }(ThreadUtils.sameThread) + } else { + val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port) + client.sendRpc(serialize(message), new RpcResponseCallback { + + override def onFailure(e: Throwable): Unit = { + if (!promise.tryFailure(e)) { + logWarning("Ignore Exception", e) + } + } + + override def onSuccess(response: Array[Byte]): Unit = { + val reply = deserialize[AskResponse](response) + if (reply.reply != null && reply.reply.isInstanceOf[RpcFailure]) { + if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { + logWarning(s"Ignore failure + ${reply.reply}") + } + } + else if (!promise.trySuccess(reply.reply)) { + logWarning(s"Ignore message + ${reply}") + } + } + }) + } + promise.future + } + + private[netty] def serialize(content: Any): Array[Byte] = { + val buffer = serializer.newInstance().serialize(content) + ju.Arrays.copyOfRange( + buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit) + } + + private[netty] def deserialize[T: ClassTag](bytes: Array[Byte]): T = { + val prevEnv = NettyRpcEnv.currentEnv + NettyRpcEnv.setCurrentEnv(this) + try { + serializer.newInstance().deserialize[T](ByteBuffer.wrap(bytes)) + } finally { + NettyRpcEnv.setCurrentEnv(prevEnv) + } + } + + override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = { + dispatcher.getRpcEndpointRef(endpoint) + } + + override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = + new NettyRpcAddress(address.host, address.port, endpointName).toString + + override def shutdown(): Unit = { + cleanup() + } + + override def awaitTermination(): Unit = { + dispatcher.awaitTermination() + } + + private def cleanup(): Unit = { + if (timeoutScheduler != null) { + timeoutScheduler.shutdownNow() + } + if (server != null) { + server.close() + } + if (clientFactory != null) { + clientFactory.close() + } + if (dispatcher != null) { + dispatcher.stop() + } + } +} + +private[netty] object NettyRpcEnv extends Logging { + + /** + * When deserializing the [[NettyRpcEndpointRef]], it needs a reference to [[NettyRpcEnv]]. + * [[NettyRpcEnv]] will call `setCurrentEnv` before deserializing messages so that + * [[NettyRpcEndpointRef]] can get it via `currentEnv`. + */ + private val _env = new ThreadLocal[NettyRpcEnv] + + private[netty] def setCurrentEnv(env: NettyRpcEnv): Unit = { + _env.set(env) + } + + private[netty] def currentEnv: NettyRpcEnv = _env.get +} + +class NettyRpcEnvFactory extends RpcEnvFactory with Logging { + + def create(config: RpcEnvConfig): RpcEnv = { + val sparkConf = config.conf + val serializer = new JavaSerializer(sparkConf) + val nettyEnv = new NettyRpcEnv(sparkConf, serializer, config.host, config.securityManager) + val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => + nettyEnv.start(actualPort) + (nettyEnv, actualPort) + } + try { + Utils.startServiceOnPort(config.port, startNettyRpcEnv, sparkConf, "NettyRpcEnv")._1 + } catch { + case NonFatal(e) => + nettyEnv.shutdown() + throw e + } + } + + def instantiateClass[T](conf: SparkConf, className: String): T = { + val cls = Class.forName(className, true, Utils.getContextOrSparkClassLoader) + try { + cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T] + } catch { + case _: NoSuchMethodException => + cls.getConstructor().newInstance().asInstanceOf[T] + } + } +} + +class NettyRpcEndpointRef(@transient conf: SparkConf) + extends RpcEndpointRef(conf) with Serializable with Logging { + + @transient @volatile private var nettyEnv: NettyRpcEnv = _ + + @transient @volatile private var _address: NettyRpcAddress = _ + + def this(conf: SparkConf, _address: NettyRpcAddress, nettyEnv: NettyRpcEnv) { + this(conf) + this._address = _address + this.nettyEnv = nettyEnv + } + + override def address: RpcAddress = _address.toRpcAddress + + private def readObject(in: ObjectInputStream): Unit = { + in.defaultReadObject() + _address = in.readObject().asInstanceOf[NettyRpcAddress] + nettyEnv = NettyRpcEnv.currentEnv + } + + private def writeObject(out: ObjectOutputStream): Unit = { + out.defaultWriteObject() + out.writeObject(_address) + } + + override def name: String = _address.name + + + override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { + val promise = Promise[Any]() + val timeoutCancelable = nettyEnv.timeoutScheduler.schedule(new Runnable { + override def run(): Unit = { + promise.tryFailure(new TimeoutException("Cannot receive any reply in " + timeout)) + } + }, timeout.toNanos, TimeUnit.NANOSECONDS) + val f = nettyEnv.ask(RequestMessage(nettyEnv.address, this, message, true)) + f.onComplete(v => { + timeoutCancelable.cancel(true) + if (!promise.tryComplete(v)) { + logWarning(s"Ignore message $v") + } + })(ThreadUtils.sameThread) + f.mapTo[T] + } + + override def send(message: Any): Unit = { + require(message != null, "Message is null") + nettyEnv.send(RequestMessage(nettyEnv.address, this, message, false)) + } + + override def toString: String = s"NettyRpcEndpointRef(${_address})" + + def toURI: URI = new URI(s"spark://${_address}") +} + +/** + * The message that is sent from the sender to the receiver. + */ +private[netty] case class RequestMessage( + senderAddress: RpcAddress, receiver: NettyRpcEndpointRef, content: Any, needReply: Boolean) + +/** + * The base trait for all messages that are sent back from the receiver to the sender. + */ +private[netty] trait ResponseMessage + +/** + * The reply for `ask` from the receiver side. + */ +private[netty] case class AskResponse(sender: NettyRpcEndpointRef, reply: Any) + extends ResponseMessage + +/** + * A message to send back to the receiver side. It's necessary because [[TransportClient]] only + * clean the resources when it receives a reply. + */ +private[netty] case class SendAck(sender: NettyRpcEndpointRef) extends ResponseMessage + +/** + * A response that indicates some failure happens in the receiver side. + */ +private[netty] case class RpcFailure(e: Throwable) + +/** + * Maintain the mapping relations between client addresses and [[RpcEnv]] addresses, broadcast + * network events and forward messages to [[Dispatcher]]. + */ +class NettyRpcHandler( + dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging { + + private type ClientAddress = RpcAddress + private type RemoteEnvAddress = RpcAddress + + // Store all client addresses and their NettyRpcEnv addresses. Protected by "this". + private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]() + // Store the connections from other NettyRpcEnv addresses. Protected by "this". + private val remoteConnectionCount = new mutable.HashMap[RemoteEnvAddress, Int]() + + override def receive( + client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = { + val requestMessage = nettyEnv.deserialize[RequestMessage](message) + val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + assert(addr != null) + val remoteEnvAddress = requestMessage.senderAddress + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + val broadcastMessage = + synchronized { + // If the first connection to a remote RpcEnv is found, we should broadcast "Associated" + if (remoteAddresses.put(clientAddr, remoteEnvAddress).isEmpty) { + // clientAddr connects at the first time + val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0) + // Increase the connection number of remoteEnvAddress + remoteConnectionCount.put(remoteEnvAddress, count + 1) + if (count == 0) { + // This is the first connection, so fire "Associated" + Some(Associated(remoteEnvAddress)) + } else { + None + } + } else { + None + } + } + broadcastMessage.foreach(dispatcher.broadcastMessage) + dispatcher.postMessage(requestMessage, callback) + } + + override def getStreamManager: StreamManager = new OneForOneStreamManager + + override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { + val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + assert(addr != null) + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + val broadcastMessage = + synchronized { + remoteAddresses.get(clientAddr).map(AssociationError(cause, _)) + } + if (broadcastMessage.isEmpty) { + logError(cause.getMessage, cause) + } else { + dispatcher.broadcastMessage(broadcastMessage.get) + } + } + + override def connectionTerminated(client: TransportClient): Unit = { + val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + assert(addr != null) + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + val broadcastMessage = + synchronized { + // If the last connection to a remote RpcEnv is terminated, we should broadcast + // "Disassociated" + remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress => + remoteAddresses -= clientAddr + val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0) + assert(count != 0, "remoteAddresses and remoteConnectionCount are not consistent") + if (count - 1 == 0) { + // We lost all clients, so clean up and fire "Disassociated" + remoteConnectionCount.remove(remoteEnvAddress) + Some(Disassociated(remoteEnvAddress)) + } else { + // Decrease the connection number of remoteEnvAddress + remoteConnectionCount.put(remoteEnvAddress, count - 1) + None + } + } + } + broadcastMessage.foreach(dispatcher.broadcastMessage) + } + +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index 7478ab0fc2f7a..e7999e1f75e39 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -80,7 +80,7 @@ class BlockManagerSlaveEndpoint( future.onSuccess { case response => logDebug("Done " + actionMessage + ", response is " + response) context.reply(response) - logDebug("Sent response: " + response + " to " + context.sender) + logDebug("Sent response: " + response + " to " + context.senderAddress) } future.onFailure { case t: Throwable => logError("Error in " + actionMessage, t) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 1fab69678d040..9c9704c5da330 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -164,10 +164,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) - val sender = mock(classOf[RpcEndpointRef]) - when(sender.address).thenReturn(RpcAddress("localhost", 12345)) + val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) - when(rpcCallContext.sender).thenReturn(sender) + when(rpcCallContext.senderAddress).thenReturn(senderAddress) masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(10)) verify(rpcCallContext).reply(any()) verify(rpcCallContext, never()).sendFailure(any()) @@ -194,10 +193,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerMapOutput(20, i, new CompressedMapStatus( BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) } - val sender = mock(classOf[RpcEndpointRef]) - when(sender.address).thenReturn(RpcAddress("localhost", 12345)) + val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) - when(rpcCallContext.sender).thenReturn(sender) + when(rpcCallContext.senderAddress).thenReturn(senderAddress) masterEndpoint.receiveAndReply(rpcCallContext)(GetMapOutputStatuses(20)) verify(rpcCallContext, never()).reply(any()) verify(rpcCallContext).sendFailure(isA(classOf[SparkException])) diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index 1a099da2c6c8e..272f64a048838 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -27,6 +27,7 @@ object SSLSampleConfigs { def sparkSSLConfig(): SparkConf = { val conf = new SparkConf(loadDefaults = false) + conf.set("spark.rpc", "akka") conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -41,6 +42,7 @@ object SSLSampleConfigs { def sparkSSLConfigUntrusted(): SparkConf = { val conf = new SparkConf(loadDefaults = false) + conf.set("spark.rpc", "akka") conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", untrustedKeyStorePath) conf.set("spark.ssl.keyStorePassword", "password") diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 1f0aa759b08da..ba3c60807adc2 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.rpc +import java.io.NotSerializableException import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} import scala.collection.mutable @@ -99,7 +100,6 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint) - val newRpcEndpointRef = rpcEndpointRef.askWithRetry[RpcEndpointRef]("Hello") val reply = newRpcEndpointRef.askWithRetry[String]("Echo") assert("Echo" === reply) @@ -511,6 +511,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { assert(events === List( ("onConnected", remoteAddress), ("onNetworkError", remoteAddress), + ("onDisconnected", remoteAddress)) || + events === List( + ("onConnected", remoteAddress), ("onDisconnected", remoteAddress))) } } @@ -530,15 +533,82 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { "local", env.address, "sendWithReply-unserializable-error") try { val f = rpcEndpointRef.ask[String]("hello") - intercept[TimeoutException] { + val e = intercept[Exception] { Await.result(f, 1 seconds) } + assert(e.isInstanceOf[TimeoutException] || // For Akka + e.isInstanceOf[NotSerializableException] // For Netty + ) } finally { anotherEnv.shutdown() anotherEnv.awaitTermination() } } + test("port conflict") { + val anotherEnv = createRpcEnv(new SparkConf(), "remote", env.address.port) + assert(anotherEnv.address.port != env.address.port) + } + + test("send with ssl") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + + val localEnv = createRpcEnv(conf, "ssl-local", 13345) + val remoteEnv = createRpcEnv(conf, "ssl-remote", 14345) + + try { + @volatile var message: String = null + localEnv.setupEndpoint("send-ssl", new RpcEndpoint { + override val rpcEnv = localEnv + + override def receive: PartialFunction[Any, Unit] = { + case msg: String => message = msg + } + }) + val rpcEndpointRef = remoteEnv.setupEndpointRef("ssl-local", localEnv.address, "send-ssl") + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(10 millis)) { + assert("hello" === message) + } + } finally { + localEnv.shutdown() + localEnv.awaitTermination() + remoteEnv.shutdown() + remoteEnv.awaitTermination() + } + } + + test("ask with ssl") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + + val localEnv = createRpcEnv(conf, "ssl-local", 13345) + val remoteEnv = createRpcEnv(conf, "ssl-remote", 14345) + + try { + localEnv.setupEndpoint("ask-ssl", new RpcEndpoint { + override val rpcEnv = localEnv + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case msg: String => { + context.reply(msg) + } + } + }) + val rpcEndpointRef = remoteEnv.setupEndpointRef("ssl-local", localEnv.address, "ask-ssl") + val reply = rpcEndpointRef.askWithRetry[String]("hello") + assert("hello" === reply) + } finally { + localEnv.shutdown() + localEnv.awaitTermination() + remoteEnv.shutdown() + remoteEnv.awaitTermination() + } + } + } class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala b/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala new file mode 100644 index 0000000000000..3268c89d8296d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import scala.collection.mutable.ArrayBuffer + +import org.scalactic.TripleEquals + +class TestRpcEndpoint extends RpcEndpoint with TripleEquals { + + override val rpcEnv: RpcEnv = null + + @volatile private var receiveMessages = ArrayBuffer[Any]() + + @volatile private var receiveAndReplyMessages = ArrayBuffer[Any]() + + @volatile private var onConnectedMessages = ArrayBuffer[RpcAddress]() + + @volatile private var onDisconnectedMessages = ArrayBuffer[RpcAddress]() + + @volatile private var onNetworkErrorMessages = ArrayBuffer[(Throwable, RpcAddress)]() + + @volatile private var started = false + + @volatile private var stopped = false + + override def receive: PartialFunction[Any, Unit] = { + case message: Any => receiveMessages += message + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case message: Any => receiveAndReplyMessages += message + } + + override def onConnected(remoteAddress: RpcAddress): Unit = { + onConnectedMessages += remoteAddress + } + + /** + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. + */ + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + onNetworkErrorMessages += cause -> remoteAddress + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + onDisconnectedMessages += remoteAddress + } + + def numReceiveMessages: Int = receiveMessages.size + + override def onStart(): Unit = { + started = true + } + + override def onStop(): Unit = { + stopped = true + } + + def verifyStarted(): Unit = { + assert(started, "RpcEndpoint is not started") + } + + def verifyStopped(): Unit = { + assert(stopped, "RpcEndpoint is not stopped") + } + + def verifyReceiveMessages(expected: Seq[Any]): Unit = { + assert(receiveMessages === expected) + } + + def verifySingleReceiveMessage(message: Any): Unit = { + verifyReceiveMessages(List(message)) + } + + def verifyReceiveAndReplyMessages(expected: Seq[Any]): Unit = { + assert(receiveAndReplyMessages === expected) + } + + def verifySingleReceiveAndReplyMessage(message: Any): Unit = { + verifyReceiveAndReplyMessages(List(message)) + } + + def verifySingleOnConnectedMessage(remoteAddress: RpcAddress): Unit = { + verifyOnConnectedMessages(List(remoteAddress)) + } + + def verifyOnConnectedMessages(expected: Seq[RpcAddress]): Unit = { + assert(onConnectedMessages === expected) + } + + def verifySingleOnDisconnectedMessage(remoteAddress: RpcAddress): Unit = { + verifyOnDisconnectedMessages(List(remoteAddress)) + } + + def verifyOnDisconnectedMessages(expected: Seq[RpcAddress]): Unit = { + assert(onDisconnectedMessages === expected) + } + + def verifySingleOnNetworkErrorMessage(cause: Throwable, remoteAddress: RpcAddress): Unit = { + verifyOnNetworkErrorMessages(List(cause -> remoteAddress)) + } + + def verifyOnNetworkErrorMessages(expected: Seq[(Throwable, RpcAddress)]): Unit = { + assert(onNetworkErrorMessages === expected) + } +} diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala new file mode 100644 index 0000000000000..ebee763eec56b --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.util.concurrent.{TimeUnit, CountDownLatch} + +import org.mockito.Mockito._ +import org.scalatest.FunSuite + +import org.apache.spark.rpc.{TestRpcEndpoint, RpcAddress} + +class InboxSuite extends FunSuite { + + test("post") { + val endpoint = new TestRpcEndpoint + val endpointRef = mock(classOf[NettyRpcEndpointRef]) + when(endpointRef.name).thenReturn("hello") + + val dispatcher = mock(classOf[Dispatcher]) + + val inbox = new Inbox(endpointRef, endpoint) + val message = ContentMessage(null, "hi", false, null) + inbox.post(message) + assert(inbox.process(dispatcher) === false) + + endpoint.verifySingleReceiveMessage("hi") + + inbox.stop() + assert(inbox.process(dispatcher) === true) + endpoint.verifyStarted() + endpoint.verifyStopped() + verify(dispatcher).unregisterRpcEndpoint("hello") + } + + test("post: with reply") { + val endpoint = new TestRpcEndpoint + val endpointRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + + val inbox = new Inbox(endpointRef, endpoint) + val message = ContentMessage(null, "hi", true, null) + inbox.post(message) + assert(inbox.process(dispatcher) === false) + + endpoint.verifySingleReceiveAndReplyMessage("hi") + } + + test("post: multiple threads") { + val endpoint = new TestRpcEndpoint + val endpointRef = mock(classOf[NettyRpcEndpointRef]) + when(endpointRef.name).thenReturn("hello") + + val dispatcher = mock(classOf[Dispatcher]) + + @volatile var numDroppedMessages = 0 + val inbox = new Inbox(endpointRef, endpoint) { + override def onDrop(message: Any): Unit = { + numDroppedMessages += 1 + } + } + + val exitLatch = new CountDownLatch(10) + + for(_ <- 0 until 10) { + new Thread { + override def run(): Unit = { + for(_ <- 0 until 100) { + val message = ContentMessage(null, "hi", false, null) + inbox.post(message) + } + exitLatch.countDown() + } + }.start() + } + assert(inbox.process(dispatcher) === false) + inbox.stop() + assert(inbox.process(dispatcher) === true) + + exitLatch.await(30, TimeUnit.SECONDS) + + assert(1000 === endpoint.numReceiveMessages + numDroppedMessages) + endpoint.verifyStarted() + endpoint.verifyStopped() + verify(dispatcher).unregisterRpcEndpoint("hello") + } + + test("post: Associated") { + val endpoint = new TestRpcEndpoint + val endpointRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + + val remoteAddress = RpcAddress("localhost", 11111) + + val inbox = new Inbox(endpointRef, endpoint) + inbox.post(Associated(remoteAddress)) + inbox.process(dispatcher) + + endpoint.verifySingleOnConnectedMessage(remoteAddress) + } + + test("post: Disassociated") { + val endpoint = new TestRpcEndpoint + val endpointRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + + val remoteAddress = RpcAddress("localhost", 11111) + + val inbox = new Inbox(endpointRef, endpoint) + inbox.post(Disassociated(remoteAddress)) + inbox.process(dispatcher) + + endpoint.verifySingleOnDisconnectedMessage(remoteAddress) + } + + test("post: AssociationError") { + val endpoint = new TestRpcEndpoint + val endpointRef = mock(classOf[NettyRpcEndpointRef]) + val dispatcher = mock(classOf[Dispatcher]) + + val remoteAddress = RpcAddress("localhost", 11111) + val cause = new RuntimeException("Oops") + + val inbox = new Inbox(endpointRef, endpoint) + inbox.post(AssociationError(cause, remoteAddress)) + inbox.process(dispatcher) + + endpoint.verifySingleOnNetworkErrorMessage(cause, remoteAddress) + } +} diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala new file mode 100644 index 0000000000000..6ffd65047eac4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import org.scalatest.FunSuite + +class NettyRpcAddressSuite extends FunSuite { + + test("toString") { + val addr = NettyRpcAddress("localhost", 12345, "test") + assert(addr.toString === "spark://test@localhost:12345") + } + +} diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuit.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuit.scala new file mode 100644 index 0000000000000..4025372492456 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuit.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import org.apache.spark.rpc._ +import org.apache.spark.{SecurityManager, SparkConf} + +class NettyRpcEnvSuite extends RpcEnvSuite { + + override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = { + val config = RpcEnvConfig(conf, "test", "localhost", port, new SecurityManager(conf)) + new NettyRpcEnvFactory().create(config) + } + + test("nonexist-endpoint") { + val uri = env.uriOf("test", env.address, "nonexist-endpoint") + val e = intercept[RpcEndpointNotFoundException] { + env.setupEndpointRef("test", env.address, "nonexist-endpoint") + } + assert(e.getMessage.contains(uri)) + } + +} diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala new file mode 100644 index 0000000000000..57dcff586bd82 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.net.InetSocketAddress + +import io.netty.channel.Channel +import org.apache.spark.network.client.{TransportResponseHandler, TransportClient} +import org.mockito.Mockito._ +import org.mockito.Matchers._ +import org.scalatest.FunSuite + +import org.apache.spark.rpc._ + +class NettyRpcHandlerSuite extends FunSuite { + + val env = mock(classOf[NettyRpcEnv]) + when(env.deserialize(any())(any())). + thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) + + test("receive") { + val dispatcher = mock(classOf[Dispatcher]) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + + val channel = mock(classOf[Channel]) + val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) + when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) + nettyRpcHandler.receive(client, null, null) + + when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40001)) + nettyRpcHandler.receive(client, null, null) + + verify(dispatcher, times(1)).broadcastMessage(Associated(RpcAddress("localhost", 12345))) + } + + test("connectionTerminated") { + val dispatcher = mock(classOf[Dispatcher]) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + + val channel = mock(classOf[Channel]) + val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) + when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) + nettyRpcHandler.receive(client, null, null) + + when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) + nettyRpcHandler.connectionTerminated(client) + + verify(dispatcher, times(1)).broadcastMessage(Associated(RpcAddress("localhost", 12345))) + verify(dispatcher, times(1)).broadcastMessage(Disassociated(RpcAddress("localhost", 12345))) + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index 37f2e34ceb24d..078df6c9227d4 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -75,6 +75,10 @@ public TransportClient(Channel channel, TransportResponseHandler handler) { this.handler = Preconditions.checkNotNull(handler); } + public Channel getChannel() { + return channel; + } + public boolean isActive() { return channel.isOpen() || channel.isActive(); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index 2ba92a40f8b0a..dbb7f95f55bc0 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -52,4 +52,6 @@ public abstract void receive( * No further requests will come from this client. */ public void connectionTerminated(TransportClient client) { } + + public void exceptionCaught(Throwable cause, TransportClient client) { } } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index e5159ab56d0d4..c72c43f3dea96 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -71,6 +71,7 @@ public TransportRequestHandler( @Override public void exceptionCaught(Throwable cause) { + rpcHandler.exceptionCaught(cause, reverseClient); } @Override diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index f1504b09c9873..b734a0151040c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -216,7 +216,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RegisterReceiver(streamId, typ, host, receiverEndpoint) => - registerReceiver(streamId, typ, host, receiverEndpoint, context.sender.address) + registerReceiver(streamId, typ, host, receiverEndpoint, context.senderAddress) context.reply(true) case AddBlock(receivedBlockInfo) => context.reply(addBlock(receivedBlockInfo)) From caff122f2896199bdc9a891c18820004b52d9282 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 29 May 2015 21:44:50 +0800 Subject: [PATCH 02/30] Fix InboxSuite --- .../test/scala/org/apache/spark/rpc/netty/InboxSuite.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala index ebee763eec56b..26e4bc13efe3c 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.rpc.netty import java.util.concurrent.{TimeUnit, CountDownLatch} +import java.util.concurrent.atomic.AtomicInteger import org.mockito.Mockito._ import org.scalatest.FunSuite @@ -67,10 +68,10 @@ class InboxSuite extends FunSuite { val dispatcher = mock(classOf[Dispatcher]) - @volatile var numDroppedMessages = 0 + val numDroppedMessages = new AtomicInteger(0) val inbox = new Inbox(endpointRef, endpoint) { override def onDrop(message: Any): Unit = { - numDroppedMessages += 1 + numDroppedMessages.incrementAndGet() } } @@ -93,7 +94,7 @@ class InboxSuite extends FunSuite { exitLatch.await(30, TimeUnit.SECONDS) - assert(1000 === endpoint.numReceiveMessages + numDroppedMessages) + assert(1000 === endpoint.numReceiveMessages + numDroppedMessages.get) endpoint.verifyStarted() endpoint.verifyStopped() verify(dispatcher).unregisterRpcEndpoint("hello") From 02fbca0682e3b805beff78199738ba82d83ae70f Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 30 May 2015 07:53:57 +0800 Subject: [PATCH 03/30] Fix the code style --- .../scala/org/apache/spark/rpc/netty/Dispatcher.scala | 2 +- .../org/apache/spark/rpc/netty/NettyRpcCallContext.scala | 3 ++- .../scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala | 2 +- .../test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala | 2 +- .../scala/org/apache/spark/rpc/netty/InboxSuite.scala | 8 ++++---- .../org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala | 6 +++--- .../org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala | 4 ++-- 7 files changed, 14 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 7ddf2866c1de8..2b2a009e98005 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -50,7 +50,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = { val addr = new NettyRpcAddress(nettyEnv.address.host, nettyEnv.address.port, name) val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv) - nameToEndpoint.put(name, new RpcEndpointPair(endpoint,endpointRef)) + nameToEndpoint.put(name, new RpcEndpointPair(endpoint, endpointRef)) endpointToEndpointRef.put(endpoint, endpointRef) val inbox = new Inbox(endpointRef, endpoint) endpointToInbox.put(endpoint, inbox) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala index e885c1ecc9153..d2802e59e9e9d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala @@ -24,7 +24,8 @@ import scala.concurrent.{Promise, Future} private[netty] abstract class NettyRpcCallContext( endpointRef: NettyRpcEndpointRef, - override val senderAddress: RpcAddress,needReply: Boolean) extends RpcCallContext{ + override val senderAddress: RpcAddress, + needReply: Boolean) extends RpcCallContext{ protected def send(message: Any): Unit diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index acb59718d6973..bceecf9b870c2 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -370,7 +370,7 @@ class NettyRpcHandler( // Store all client addresses and their NettyRpcEnv addresses. Protected by "this". private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]() // Store the connections from other NettyRpcEnv addresses. Protected by "this". - private val remoteConnectionCount = new mutable.HashMap[RemoteEnvAddress, Int]() + private val remoteConnectionCount = new mutable.HashMap[RemoteEnvAddress, Int]() override def receive( client: TransportClient, message: Array[Byte], callback: RpcResponseCallback): Unit = { diff --git a/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala b/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala index 3268c89d8296d..698e302534cc4 100644 --- a/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala +++ b/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala @@ -55,7 +55,7 @@ class TestRpcEndpoint extends RpcEndpoint with TripleEquals { * Invoked when some network error happens in the connection between the current node and * `remoteAddress`. */ - override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { onNetworkErrorMessages += cause -> remoteAddress } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala index 26e4bc13efe3c..205ba1262f88f 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -21,11 +21,11 @@ import java.util.concurrent.{TimeUnit, CountDownLatch} import java.util.concurrent.atomic.AtomicInteger import org.mockito.Mockito._ -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.rpc.{TestRpcEndpoint, RpcAddress} -class InboxSuite extends FunSuite { +class InboxSuite extends SparkFunSuite { test("post") { val endpoint = new TestRpcEndpoint @@ -77,10 +77,10 @@ class InboxSuite extends FunSuite { val exitLatch = new CountDownLatch(10) - for(_ <- 0 until 10) { + for (_ <- 0 until 10) { new Thread { override def run(): Unit = { - for(_ <- 0 until 100) { + for (_ <- 0 until 100) { val message = ContentMessage(null, "hi", false, null) inbox.post(message) } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala index 6ffd65047eac4..a5d43d3704e37 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.rpc.netty -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite -class NettyRpcAddressSuite extends FunSuite { +class NettyRpcAddressSuite extends SparkFunSuite { test("toString") { - val addr = NettyRpcAddress("localhost", 12345, "test") + val addr = NettyRpcAddress("localhost", 12345, "test") assert(addr.toString === "spark://test@localhost:12345") } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index 57dcff586bd82..f31f0ab83390d 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -23,11 +23,11 @@ import io.netty.channel.Channel import org.apache.spark.network.client.{TransportResponseHandler, TransportClient} import org.mockito.Mockito._ import org.mockito.Matchers._ -import org.scalatest.FunSuite +import org.apache.spark.SparkFunSuite import org.apache.spark.rpc._ -class NettyRpcHandlerSuite extends FunSuite { +class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) when(env.deserialize(any())(any())). From 019334a4d15903e731aa4b4c2f3c56c8872fe246 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 1 Jun 2015 16:44:05 +0800 Subject: [PATCH 04/30] ThreadSafeInbox and ConcurrentInbox --- .../spark/deploy/worker/WorkerWatcher.scala | 11 +- .../org/apache/spark/rpc/RpcEndpoint.scala | 30 ++-- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 8 +- .../apache/spark/rpc/netty/Dispatcher.scala | 18 ++- .../org/apache/spark/rpc/netty/Inbox.scala | 139 ++++++++++++++---- .../storage/BlockManagerSlaveEndpoint.scala | 4 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 12 +- .../apache/spark/rpc/TestRpcEndpoint.scala | 2 +- .../apache/spark/rpc/netty/InboxSuite.scala | 12 +- .../spark/deploy/yarn/ApplicationMaster.scala | 5 +- 10 files changed, 170 insertions(+), 71 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 83fb991891a41..c192b86366a6b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -28,12 +28,11 @@ import org.apache.spark.rpc._ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String) extends RpcEndpoint with Logging { - override def onStart() { - logInfo(s"Connecting to worker $workerUrl") - if (!isTesting) { - rpcEnv.asyncSetupEndpointRefByURI(workerUrl) - } - } + logInfo(s"Connecting to worker $workerUrl") + // workerUrl is wrong now. https://github.com/apache/spark/pull/5392 will fix it. + // if (!isTesting) { + // rpcEnv.asyncSetupEndpointRefByURI(workerUrl) + // } // Used to avoid shutting down JVM during tests // In the normal case, exitNonZero will call `System.exit(-1)` to shutdown the JVM. In the unit diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index d2b2baef1d8c4..94dea0d3cc8e5 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -39,7 +39,21 @@ private[spark] trait RpcEnvFactory { * However, there is no guarantee that the same thread will be executing the same * [[ThreadSafeRpcEndpoint]] for different messages. */ -private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint +private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint { + /** + * Invoked before [[RpcEndpoint]] starts to handle any message. + */ + def onStart(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when [[RpcEndpoint]] is stopping. + */ + def onStop(): Unit = { + // By default, do nothing. + } +} /** @@ -100,20 +114,6 @@ private[spark] trait RpcEndpoint { throw cause } - /** - * Invoked before [[RpcEndpoint]] starts to handle any message. - */ - def onStart(): Unit = { - // By default, do nothing. - } - - /** - * Invoked when [[RpcEndpoint]] is stopping. - */ - def onStop(): Unit = { - // By default, do nothing. - } - /** * Invoked when `remoteAddress` is connected to the current node. */ diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 937d065ac1556..e0eefb65ed13c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -98,7 +98,9 @@ private[spark] class AkkaRpcEnv private[akka] ( // Listen for remote client network events context.system.eventStream.subscribe(self, classOf[AssociationEvent]) safelyCall(endpoint) { - endpoint.onStart() + if (endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { + endpoint.asInstanceOf[ThreadSafeRpcEndpoint].onStart() + } } } @@ -141,7 +143,9 @@ private[spark] class AkkaRpcEnv private[akka] ( override def postStop(): Unit = { unregisterEndpoint(endpoint.self) safelyCall(endpoint) { - endpoint.onStop() + if (endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { + endpoint.asInstanceOf[ThreadSafeRpcEndpoint].onStop() + } } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 2b2a009e98005..7f846796192a8 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -25,7 +25,7 @@ import scala.concurrent.Promise import scala.util.control.NonFatal import org.apache.spark.{SparkException, Logging} -import org.apache.spark.rpc.{RpcCallContext, RpcAddress, RpcEndpointRef, RpcEndpoint} +import org.apache.spark.rpc._ import org.apache.spark.util.ThreadUtils private class RpcEndpointPair(val endpoint: RpcEndpoint, val endpointRef: NettyRpcEndpointRef) @@ -52,7 +52,12 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv) nameToEndpoint.put(name, new RpcEndpointPair(endpoint, endpointRef)) endpointToEndpointRef.put(endpoint, endpointRef) - val inbox = new Inbox(endpointRef, endpoint) + val inbox = + if (endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { + new ThreadSafeInbox(endpointRef, endpoint.asInstanceOf[ThreadSafeRpcEndpoint]) + } else { + new ConcurrentInbox(endpointRef, endpoint) + } endpointToInbox.put(endpoint, inbox) idleInboxes.put(endpoint, inbox) afterUpdateInbox(inbox) @@ -139,7 +144,12 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { while (!stopped) { try { val endpoint = receivers.take() - val inbox = idleInboxes.remove(endpoint) + val inbox = + if (endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { + idleInboxes.remove(endpoint) + } else { + idleInboxes.get(endpoint) + } if (inbox != null) { val inboxStopped = inbox.process(Dispatcher.this) if (!inboxStopped) { @@ -147,6 +157,8 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { if (!inbox.isEmpty) { receivers.add(endpoint) } + } else { + idleInboxes.remove(endpoint) } } else { // other thread is processing endpoint's Inbox diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index 0868c75eb0931..4e62b07cf79ee 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -22,7 +22,7 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.util.control.NonFatal import org.apache.spark.{SparkException, Logging} -import org.apache.spark.rpc.{RpcAddress, RpcEndpoint} +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcAddress, RpcEndpoint} private[netty] sealed trait InboxMessage @@ -54,15 +54,53 @@ private[netty] case class Disassociated(remoteAddress: RpcAddress) extends Broad private[netty] case class AssociationError(cause: Throwable, remoteAddress: RpcAddress) extends BroadcastMessage +private[netty] abstract class Inbox( + val endpointRef: NettyRpcEndpointRef, val endpoint: RpcEndpoint) extends Logging { + + protected val messages = new ConcurrentLinkedQueue[InboxMessage]() + + /** + * Process stored messages. Return `true` if the `Inbox` is already stopped, and the caller will + * release all resources used by the `Inbox`. + */ + def process(dispatcher: Dispatcher): Boolean + + def post(message: InboxMessage): Unit + + protected def onDrop(message: Any): Unit = { + logWarning(s"Drop ${message} because $endpointRef is stopped") + } + + def stop(): Unit + + protected def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { + try { + action + } catch { + case NonFatal(e) => { + try { + endpoint.onError(e) + } catch { + case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e) + } + } + } + } + + def isEmpty: Boolean = messages.isEmpty + +} + /** * A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely. * @param endpointRef * @param endpoint */ -private[netty] class Inbox( - val endpointRef: NettyRpcEndpointRef, val endpoint: RpcEndpoint) extends Logging { +private[netty] class ThreadSafeInbox( + endpointRef: NettyRpcEndpointRef, + override val endpoint: ThreadSafeRpcEndpoint) extends Inbox(endpointRef, endpoint) { - private val messages = new ConcurrentLinkedQueue[InboxMessage]() + private val _endpoint = endpoint.asInstanceOf[ThreadSafeRpcEndpoint] // protected by "this" private var stopped = false @@ -70,11 +108,7 @@ private[netty] class Inbox( // OnStart should be the first message to process messages.add(OnStart) - /** - * Process stored messages. Return `true` if the `Inbox` is already stopped, and the caller will - * release all resources used by the `Inbox`. - */ - def process(dispatcher: Dispatcher): Boolean = { + override def process(dispatcher: Dispatcher): Boolean = { var exit = false var message = messages.poll() while (message != null) { @@ -105,10 +139,12 @@ private[netty] class Inbox( } } - case OnStart => endpoint.onStart() + case OnStart => { + _endpoint.asInstanceOf[ThreadSafeRpcEndpoint].onStart() + } case OnStop => dispatcher.unregisterRpcEndpoint(endpointRef.name) - endpoint.onStop() + _endpoint.onStop() assert(isEmpty, "OnStop should be the last message") exit = true case Associated(remoteAddress) => @@ -124,7 +160,7 @@ private[netty] class Inbox( exit } - def post(message: InboxMessage): Unit = { + override def post(message: InboxMessage): Unit = { val dropped = synchronized { if (stopped) { @@ -140,7 +176,7 @@ private[netty] class Inbox( } } - def stop(): Unit = synchronized { + override def stop(): Unit = synchronized { // The following codes should be in `synchronized` so that we can make sure "OnStop" is the last // message if (!stopped) { @@ -148,24 +184,75 @@ private[netty] class Inbox( messages.add(OnStop) } } +} - def isEmpty: Boolean = messages.isEmpty +/** + * A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it concurrently. + * @param endpointRef + * @param endpoint + */ +private[netty] class ConcurrentInbox( + endpointRef: NettyRpcEndpointRef, + endpoint: RpcEndpoint) extends Inbox(endpointRef, endpoint) { - protected def onDrop(message: Any): Unit = { - logWarning(s"Drop ${message} because $endpointRef is stopped") - } + @volatile private var stopped = false - private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { - try { - action - } catch { - case NonFatal(e) => { - try { - endpoint.onError(e) - } catch { - case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e) + /** + * Process stored messages. Return `true` if the `Inbox` is already stopped, and the caller will + * release all resources used by the `Inbox`. + */ + override def process(dispatcher: Dispatcher): Boolean = { + var message = messages.poll() + while (!stopped && message != null) { + safelyCall(endpoint) { + message match { + case ContentMessage(_sender, content, needReply, context) => + val pf: PartialFunction[Any, Unit] = + if (needReply) { + endpoint.receiveAndReply(context) + } else { + endpoint.receive + } + try { + pf.applyOrElse[Any, Unit](content, { msg => + throw new SparkException(s"Unmatched message $message from ${_sender}") + }) + if (!needReply) { + context.finish() + } + } catch { + case NonFatal(e) => + if (needReply) { + // If the sender asks a reply, we should send the error back to the sender + context.sendFailure(e) + } else { + context.finish() + throw e + } + } + case Associated(remoteAddress) => + endpoint.onConnected(remoteAddress) + case Disassociated(remoteAddress) => + endpoint.onDisconnected(remoteAddress) + case AssociationError(cause, remoteAddress) => + endpoint.onNetworkError(cause, remoteAddress) } } + message = messages.poll() + } + stopped + } + + override def post(message: InboxMessage): Unit = { + if (stopped) { + onDrop() + } else { + messages.add(message) } } + + override def stop(): Unit = { + stopped = true + } + } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index e7999e1f75e39..e749631bf6f19 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.rpc.{RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext, RpcEndpoint} import org.apache.spark.util.ThreadUtils import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} import org.apache.spark.storage.BlockManagerMessages._ @@ -33,7 +33,7 @@ class BlockManagerSlaveEndpoint( override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) - extends RpcEndpoint with Logging { + extends ThreadSafeRpcEndpoint with Logging { private val asyncThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool") diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index ba3c60807adc2..204a8ad3fd268 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -175,7 +175,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val stopLatch = new CountDownLatch(1) val calledMethods = mutable.ArrayBuffer[String]() - val endpoint = new RpcEndpoint { + val endpoint = new ThreadSafeRpcEndpoint { override val rpcEnv = env override def onStart(): Unit = { @@ -199,7 +199,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { test("onError: error in onStart") { @volatile var e: Throwable = null - env.setupEndpoint("onError-onStart", new RpcEndpoint { + env.setupEndpoint("onError-onStart", new ThreadSafeRpcEndpoint { override val rpcEnv = env override def onStart(): Unit = { @@ -222,7 +222,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { test("onError: error in onStop") { @volatile var e: Throwable = null - val endpointRef = env.setupEndpoint("onError-onStop", new RpcEndpoint { + val endpointRef = env.setupEndpoint("onError-onStop", new ThreadSafeRpcEndpoint { override val rpcEnv = env override def receive: PartialFunction[Any, Unit] = { @@ -269,7 +269,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { test("self: call in onStart") { @volatile var callSelfSuccessfully = false - env.setupEndpoint("self-onStart", new RpcEndpoint { + env.setupEndpoint("self-onStart", new ThreadSafeRpcEndpoint { override val rpcEnv = env override def onStart(): Unit = { @@ -313,7 +313,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { test("self: call in onStop") { @volatile var selfOption: Option[RpcEndpointRef] = null - val endpointRef = env.setupEndpoint("self-onStop", new RpcEndpoint { + val endpointRef = env.setupEndpoint("self-onStop", new ThreadSafeRpcEndpoint { override val rpcEnv = env override def receive: PartialFunction[Any, Unit] = { @@ -369,7 +369,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { test("stop(RpcEndpointRef) reentrant") { @volatile var onStopCount = 0 - val endpointRef = env.setupEndpoint("stop-reentrant", new RpcEndpoint { + val endpointRef = env.setupEndpoint("stop-reentrant", new ThreadSafeRpcEndpoint { override val rpcEnv = env override def receive: PartialFunction[Any, Unit] = { diff --git a/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala b/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala index 698e302534cc4..5e8da3e205ab0 100644 --- a/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala +++ b/core/src/test/scala/org/apache/spark/rpc/TestRpcEndpoint.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.scalactic.TripleEquals -class TestRpcEndpoint extends RpcEndpoint with TripleEquals { +class TestRpcEndpoint extends ThreadSafeRpcEndpoint with TripleEquals { override val rpcEnv: RpcEnv = null diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala index 205ba1262f88f..b8f82acdaa6b4 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -34,7 +34,7 @@ class InboxSuite extends SparkFunSuite { val dispatcher = mock(classOf[Dispatcher]) - val inbox = new Inbox(endpointRef, endpoint) + val inbox = new ThreadSafeInbox(endpointRef, endpoint) val message = ContentMessage(null, "hi", false, null) inbox.post(message) assert(inbox.process(dispatcher) === false) @@ -53,7 +53,7 @@ class InboxSuite extends SparkFunSuite { val endpointRef = mock(classOf[NettyRpcEndpointRef]) val dispatcher = mock(classOf[Dispatcher]) - val inbox = new Inbox(endpointRef, endpoint) + val inbox = new ThreadSafeInbox(endpointRef, endpoint) val message = ContentMessage(null, "hi", true, null) inbox.post(message) assert(inbox.process(dispatcher) === false) @@ -69,7 +69,7 @@ class InboxSuite extends SparkFunSuite { val dispatcher = mock(classOf[Dispatcher]) val numDroppedMessages = new AtomicInteger(0) - val inbox = new Inbox(endpointRef, endpoint) { + val inbox = new ThreadSafeInbox(endpointRef, endpoint) { override def onDrop(message: Any): Unit = { numDroppedMessages.incrementAndGet() } @@ -107,7 +107,7 @@ class InboxSuite extends SparkFunSuite { val remoteAddress = RpcAddress("localhost", 11111) - val inbox = new Inbox(endpointRef, endpoint) + val inbox = new ThreadSafeInbox(endpointRef, endpoint) inbox.post(Associated(remoteAddress)) inbox.process(dispatcher) @@ -121,7 +121,7 @@ class InboxSuite extends SparkFunSuite { val remoteAddress = RpcAddress("localhost", 11111) - val inbox = new Inbox(endpointRef, endpoint) + val inbox = new ThreadSafeInbox(endpointRef, endpoint) inbox.post(Disassociated(remoteAddress)) inbox.process(dispatcher) @@ -136,7 +136,7 @@ class InboxSuite extends SparkFunSuite { val remoteAddress = RpcAddress("localhost", 11111) val cause = new RuntimeException("Oops") - val inbox = new Inbox(endpointRef, endpoint) + val inbox = new ThreadSafeInbox(endpointRef, endpoint) inbox.post(AssociationError(cause, remoteAddress)) inbox.process(dispatcher) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 002d7b6eaf498..802c729dd7c09 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -535,10 +535,7 @@ private[spark] class ApplicationMaster( override val rpcEnv: RpcEnv, driver: RpcEndpointRef, isClusterMode: Boolean) extends RpcEndpoint with Logging { - override def onStart(): Unit = { - driver.send(RegisterClusterManager(self)) - - } + driver.send(RegisterClusterManager(self)) override def receive: PartialFunction[Any, Unit] = { case x: AddWebUIFilter => From c8acb8ad15c8e860cb467ac2493b3a14a1c49a1d Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 24 Jun 2015 09:49:11 +0800 Subject: [PATCH 05/30] Make new classes private --- .../main/scala/org/apache/spark/rpc/netty/Dispatcher.scala | 2 +- .../main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 7f846796192a8..bff0dbee81b93 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -138,7 +138,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { receivers.put(endpoint) } - class MessageLoop extends Runnable { + private[netty] class MessageLoop extends Runnable { override def run(): Unit = { try { while (!stopped) { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index bceecf9b870c2..e1636186c5597 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -243,7 +243,7 @@ private[netty] object NettyRpcEnv extends Logging { private[netty] def currentEnv: NettyRpcEnv = _env.get } -class NettyRpcEnvFactory extends RpcEnvFactory with Logging { +private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { def create(config: RpcEnvConfig): RpcEnv = { val sparkConf = config.conf @@ -273,7 +273,7 @@ class NettyRpcEnvFactory extends RpcEnvFactory with Logging { } } -class NettyRpcEndpointRef(@transient conf: SparkConf) +private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf) extends RpcEndpointRef(conf) with Serializable with Logging { @transient @volatile private var nettyEnv: NettyRpcEnv = _ @@ -361,7 +361,7 @@ private[netty] case class RpcFailure(e: Throwable) * Maintain the mapping relations between client addresses and [[RpcEnv]] addresses, broadcast * network events and forward messages to [[Dispatcher]]. */ -class NettyRpcHandler( +private[netty] class NettyRpcHandler( dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging { private type ClientAddress = RpcAddress From a4fd9d2718479528a01bd6720a46f2bddcf5e01b Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 24 Jul 2015 01:13:34 +0800 Subject: [PATCH 06/30] Fix WorkerWatcherSuite --- .../org/apache/spark/deploy/worker/WorkerWatcher.scala | 5 ++--- .../org/apache/spark/deploy/worker/WorkerWatcherSuite.scala | 6 ++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 778c11acf6c96..ef42f32a13b31 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -24,7 +24,8 @@ import org.apache.spark.rpc._ * Actor which connects to a worker process and terminates the JVM if the connection is severed. * Provides fate sharing between a worker and its associated child processes. */ -private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String) +private[spark] class WorkerWatcher( + override val rpcEnv: RpcEnv, workerUrl: String, isTesting: Boolean = false) extends RpcEndpoint with Logging { logInfo(s"Connecting to worker $workerUrl") @@ -38,8 +39,6 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin // true rather than calling `System.exit`. The user can check `isShutDown` to know if // `exitNonZero` is called. private[deploy] var isShutDown = false - private[deploy] def setTesting(testing: Boolean) = isTesting = testing - private var isTesting = false // Lets us filter events only from the worker's actor system private val expectedAddress = RpcAddress.fromURIString(workerUrl) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index cd24d79423316..dbde5ab317e2f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -26,8 +26,7 @@ class WorkerWatcherSuite extends SparkFunSuite { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") - val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) - workerWatcher.setTesting(testing = true) + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234)) assert(workerWatcher.isShutDown) @@ -40,8 +39,7 @@ class WorkerWatcherSuite extends SparkFunSuite { val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") val otherAddress = "akka://test@4.3.2.1:1234/user/OtherActor" val otherAkkaAddress = RpcAddress("4.3.2.1", 1234) - val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) - workerWatcher.setTesting(testing = true) + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) workerWatcher.onDisconnected(otherAkkaAddress) assert(!workerWatcher.isShutDown) From 0c0862148efba41ad8c2f429493ec94dd1eab4f2 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 24 Jul 2015 22:31:37 +0800 Subject: [PATCH 07/30] Check if a name is used; Fix the code style --- .../apache/spark/rpc/netty/Dispatcher.scala | 6 ++-- .../apache/spark/rpc/netty/IDVerifier.scala | 2 +- .../org/apache/spark/rpc/netty/Inbox.scala | 6 ++-- .../spark/rpc/netty/NettyRpcCallContext.scala | 4 +-- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 32 +++++++++---------- .../apache/spark/rpc/netty/InboxSuite.scala | 4 +-- .../spark/rpc/netty/NettyRpcEnvSuit.scala | 2 +- .../rpc/netty/NettyRpcHandlerSuite.scala | 2 +- 8 files changed, 29 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index bff0dbee81b93..b7dad3e543db5 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -50,7 +50,9 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = { val addr = new NettyRpcAddress(nettyEnv.address.host, nettyEnv.address.port, name) val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv) - nameToEndpoint.put(name, new RpcEndpointPair(endpoint, endpointRef)) + if (nameToEndpoint.putIfAbsent(name, new RpcEndpointPair(endpoint, endpointRef)) != null) { + throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name") + } endpointToEndpointRef.put(endpoint, endpointRef) val inbox = if (endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { @@ -91,7 +93,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { */ def broadcastMessage(message: BroadcastMessage): Unit = { val iter = endpointToInbox.values().iterator() - while(iter.hasNext) { + while (iter.hasNext) { val inbox = iter.next() postMessageToInbox(inbox, message) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala index 1c19484283f73..6061c9b8de944 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/IDVerifier.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.rpc.netty -import org.apache.spark.rpc.{RpcCallContext, RpcEnv, RpcEndpoint} +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv} /** * A message used to ask the remote [[IDVerifier]] if an [[RpcEndpoint]] exists diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index 4e62b07cf79ee..60791669ac8d7 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -21,8 +21,8 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.util.control.NonFatal -import org.apache.spark.{SparkException, Logging} -import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcAddress, RpcEndpoint} +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, ThreadSafeRpcEndpoint} private[netty] sealed trait InboxMessage @@ -81,7 +81,7 @@ private[netty] abstract class Inbox( try { endpoint.onError(e) } catch { - case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e) + case NonFatal(e) => logWarning(s"Ignore error", e) } } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala index d2802e59e9e9d..9bbd3a00ea2e4 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala @@ -17,11 +17,11 @@ package org.apache.spark.rpc.netty +import scala.concurrent.Promise + import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc.{RpcAddress, RpcCallContext} -import scala.concurrent.{Promise, Future} - private[netty] abstract class NettyRpcCallContext( endpointRef: NettyRpcEndpointRef, override val senderAddress: RpcAddress, diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index e7b0b44de2622..6733b344a8f75 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -19,7 +19,7 @@ package org.apache.spark.rpc.netty import java.io._ import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer -import java.{util => ju} +import java.util.{Arrays, Collections, List => JList} import java.util.concurrent._ import scala.collection.mutable @@ -28,6 +28,7 @@ import scala.reflect.ClassTag import scala.util.{DynamicVariable, Failure, Success} import scala.util.control.NonFatal +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.network.TransportContext import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClient} import org.apache.spark.network.netty.SparkTransportConf @@ -36,7 +37,6 @@ import org.apache.spark.network.server._ import org.apache.spark.rpc._ import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.util.{ThreadUtils, Utils} -import org.apache.spark.{Logging, SecurityManager, SparkConf} private[netty] class NettyRpcEnv( val conf: SparkConf, serializer: Serializer, host: String, securityManager: SecurityManager) @@ -51,11 +51,11 @@ private[netty] class NettyRpcEnv( new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this)) private val clientFactory = { - val bootstraps: ju.List[TransportClientBootstrap] = + val bootstraps: JList[TransportClientBootstrap] = if (securityManager.isAuthenticationEnabled()) { - ju.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, true)) + Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, true)) } else { - ju.Collections.emptyList[TransportClientBootstrap]() + Collections.emptyList[TransportClientBootstrap]() } transportContext.createClientFactory(bootstraps) } @@ -65,11 +65,11 @@ private[netty] class NettyRpcEnv( @volatile private var server: TransportServer = _ def start(port: Int): Unit = { - val bootstraps: ju.List[TransportServerBootstrap] = + val bootstraps: JList[TransportServerBootstrap] = if (securityManager.isAuthenticationEnabled()) { - ju.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager)) + Arrays.asList(new SaslServerBootstrap(transportConf, securityManager)) } else { - ju.Collections.emptyList[TransportServerBootstrap]() + Collections.emptyList[TransportServerBootstrap]() } server = transportContext.createServer(port, bootstraps) dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher)) @@ -142,11 +142,10 @@ private[netty] class NettyRpcEnv( val reply = response.asInstanceOf[AskResponse] if (reply.reply != null && reply.reply.isInstanceOf[RpcFailure]) { if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { - logWarning(s"Ignore failure + ${reply.reply}") + logWarning(s"Ignore failure: ${reply.reply}") } - } - else if (!promise.trySuccess(reply.reply)) { - logWarning(s"Ignore message + ${reply}") + } else if (!promise.trySuccess(reply.reply)) { + logWarning(s"Ignore message: ${reply}") } case Failure(e) => if (!promise.tryFailure(e)) { @@ -167,11 +166,10 @@ private[netty] class NettyRpcEnv( val reply = deserialize[AskResponse](response) if (reply.reply != null && reply.reply.isInstanceOf[RpcFailure]) { if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { - logWarning(s"Ignore failure + ${reply.reply}") + logWarning(s"Ignore failure: ${reply.reply}") } - } - else if (!promise.trySuccess(reply.reply)) { - logWarning(s"Ignore message + ${reply}") + } else if (!promise.trySuccess(reply.reply)) { + logWarning(s"Ignore message: ${reply}") } } }) @@ -181,7 +179,7 @@ private[netty] class NettyRpcEnv( private[netty] def serialize(content: Any): Array[Byte] = { val buffer = serializer.newInstance().serialize(content) - ju.Arrays.copyOfRange( + Arrays.copyOfRange( buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit) } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala index b8f82acdaa6b4..178a2f7ccebb2 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.rpc.netty -import java.util.concurrent.{TimeUnit, CountDownLatch} +import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicInteger import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite -import org.apache.spark.rpc.{TestRpcEndpoint, RpcAddress} +import org.apache.spark.rpc.{RpcAddress, TestRpcEndpoint} class InboxSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuit.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuit.scala index 4025372492456..76050106f756c 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuit.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuit.scala @@ -17,8 +17,8 @@ package org.apache.spark.rpc.netty -import org.apache.spark.rpc._ import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.rpc._ class NettyRpcEnvSuite extends RpcEnvSuite { diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index 364d12a2a4dda..06ca035d199e8 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.rpc.netty import java.net.InetSocketAddress import io.netty.channel.Channel -import org.apache.spark.network.client.{TransportResponseHandler, TransportClient} import org.mockito.Mockito._ import org.mockito.Matchers._ import org.apache.spark.SparkFunSuite +import org.apache.spark.network.client.{TransportResponseHandler, TransportClient} import org.apache.spark.rpc._ class NettyRpcHandlerSuite extends SparkFunSuite { From c20d205ae400ebf0bece7062dd680a5aa2076717 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 24 Jul 2015 22:34:49 +0800 Subject: [PATCH 08/30] Add spark.rpc.netty.dispatcher.parallelism to control the thread number --- .../src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index b7dad3e543db5..063c763a2a263 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -175,7 +175,8 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { } } - private val parallelism = Runtime.getRuntime.availableProcessors() + private val parallelism = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.parallelism", + Runtime.getRuntime.availableProcessors()) private val executor = ThreadUtils.newDaemonFixedThreadPool(parallelism, "dispatcher-event-loop") (0 until parallelism) foreach { _ => From ea84f6c5649793910167d2ad69af4e43a54f86d7 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 24 Jul 2015 22:55:11 +0800 Subject: [PATCH 09/30] Use Scala collections and JavaConversions; Use securityManager.isSaslEncryptionEnabled() --- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 6733b344a8f75..666e58f278101 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -19,9 +19,10 @@ package org.apache.spark.rpc.netty import java.io._ import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer -import java.util.{Arrays, Collections, List => JList} +import java.util.Arrays import java.util.concurrent._ +import scala.collection.JavaConversions._ import scala.collection.mutable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag @@ -30,7 +31,7 @@ import scala.util.control.NonFatal import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.network.TransportContext -import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClient} +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ @@ -51,11 +52,12 @@ private[netty] class NettyRpcEnv( new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this)) private val clientFactory = { - val bootstraps: JList[TransportClientBootstrap] = + val bootstraps = if (securityManager.isAuthenticationEnabled()) { - Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, true)) + Seq(new SaslClientBootstrap(transportConf, "", securityManager, + securityManager.isSaslEncryptionEnabled())) } else { - Collections.emptyList[TransportClientBootstrap]() + Seq.empty } transportContext.createClientFactory(bootstraps) } @@ -65,11 +67,11 @@ private[netty] class NettyRpcEnv( @volatile private var server: TransportServer = _ def start(port: Int): Unit = { - val bootstraps: JList[TransportServerBootstrap] = + val bootstraps = if (securityManager.isAuthenticationEnabled()) { - Arrays.asList(new SaslServerBootstrap(transportConf, securityManager)) + Seq(new SaslServerBootstrap(transportConf, securityManager)) } else { - Collections.emptyList[TransportServerBootstrap]() + Seq.empty } server = transportContext.createServer(port, bootstraps) dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher)) From 0de34c8af28cc9d5825d02481dce1c3455e5f2f9 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 24 Jul 2015 23:01:13 +0800 Subject: [PATCH 10/30] Fix a wrong file name --- .../rpc/netty/{NettyRpcEnvSuit.scala => NettyRpcEnvSuite.scala} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename core/src/test/scala/org/apache/spark/rpc/netty/{NettyRpcEnvSuit.scala => NettyRpcEnvSuite.scala} (100%) diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuit.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala similarity index 100% rename from core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuit.scala rename to core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala From ac69cb594608ee030b68be3e3671d2d67df80bd4 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 24 Jul 2015 23:22:28 +0800 Subject: [PATCH 11/30] Fix the compiler warning --- core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index 60791669ac8d7..8c14ba58eba86 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -236,6 +236,10 @@ private[netty] class ConcurrentInbox( endpoint.onDisconnected(remoteAddress) case AssociationError(cause, remoteAddress) => endpoint.onNetworkError(cause, remoteAddress) + case OnStart => + throw new IllegalStateException("Non-thread-safe RpcEndpoint doesn't support OnStart") + case OnStop => + throw new IllegalStateException("Non-thread-safe RpcEndpoint doesn't support OnStop") } } message = messages.poll() From 8343274e081d1ec422dfe41f3d57c2aa4eb87373 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 9 Sep 2015 23:28:51 +0800 Subject: [PATCH 12/30] A single Inbox implementation --- .../apache/spark/deploy/worker/Worker.scala | 2 +- .../org/apache/spark/rpc/RpcEndpoint.scala | 56 ++-- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 8 +- .../apache/spark/rpc/netty/Dispatcher.scala | 45 +-- .../org/apache/spark/rpc/netty/Inbox.scala | 273 ++++++++---------- .../spark/rpc/netty/NettyRpcCallContext.scala | 7 +- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 120 +++++--- .../org/apache/spark/rpc/RpcEnvSuite.scala | 12 +- .../apache/spark/rpc/netty/InboxSuite.scala | 14 +- 9 files changed, 259 insertions(+), 278 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 770927c80f7a4..93a1b3f310422 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -329,7 +329,7 @@ private[deploy] class Worker( registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate( new Runnable { override def run(): Unit = Utils.tryLogNonFatalError { - self.send(ReregisterWithMaster) + Option(self).foreach(_.send(ReregisterWithMaster)) } }, INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index fb752753f4a55..b9af77ecc26e2 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -28,34 +28,6 @@ private[spark] trait RpcEnvFactory { def create(config: RpcEnvConfig): RpcEnv } -/** - * A trait that requires RpcEnv thread-safely sending messages to it. - * - * Thread-safety means processing of one message happens before processing of the next message by - * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a - * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the - * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. - * - * However, there is no guarantee that the same thread will be executing the same - * [[ThreadSafeRpcEndpoint]] for different messages. - */ -private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint { - /** - * Invoked before [[RpcEndpoint]] starts to handle any message. - */ - def onStart(): Unit = { - // By default, do nothing. - } - - /** - * Invoked when [[RpcEndpoint]] is stopping. - */ - def onStop(): Unit = { - // By default, do nothing. - } -} - - /** * An end point for the RPC that defines what functions to trigger given a message. * @@ -136,6 +108,20 @@ private[spark] trait RpcEndpoint { // By default, do nothing. } + /** + * Invoked before [[RpcEndpoint]] starts to handle any message. + */ + def onStart(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when [[RpcEndpoint]] is stopping. + */ + def onStop(): Unit = { + // By default, do nothing. + } + /** * A convenient method to stop [[RpcEndpoint]]. */ @@ -146,3 +132,17 @@ private[spark] trait RpcEndpoint { } } } + +/** + * A trait that requires RpcEnv thread-safely sending messages to it. + * + * Thread-safety means processing of one message happens before processing of the next message by + * the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a + * [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the + * [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent. + * + * However, there is no guarantee that the same thread will be executing the same + * [[ThreadSafeRpcEndpoint]] for different messages. + */ +private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint { +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index ae4bfe9e9dc96..89854a6fcbbce 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -97,9 +97,7 @@ private[spark] class AkkaRpcEnv private[akka] ( // Listen for remote client network events context.system.eventStream.subscribe(self, classOf[AssociationEvent]) safelyCall(endpoint) { - if (endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { - endpoint.asInstanceOf[ThreadSafeRpcEndpoint].onStart() - } + endpoint.onStart() } } @@ -142,9 +140,7 @@ private[spark] class AkkaRpcEnv private[akka] ( override def postStop(): Unit = { unregisterEndpoint(endpoint.self) safelyCall(endpoint) { - if (endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { - endpoint.asInstanceOf[ThreadSafeRpcEndpoint].onStop() - } + endpoint.onStop() } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 063c763a2a263..b1a99629459c7 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -17,7 +17,7 @@ package org.apache.spark.rpc.netty -import java.util.concurrent.{TimeUnit, Executors, LinkedBlockingQueue, ConcurrentHashMap} +import java.util.concurrent.{TimeUnit, LinkedBlockingQueue, ConcurrentHashMap} import org.apache.spark.network.client.RpcResponseCallback @@ -32,9 +32,6 @@ private class RpcEndpointPair(val endpoint: RpcEndpoint, val endpointRef: NettyR private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { - // the inboxes that are not being used - private val idleInboxes = new ConcurrentHashMap[RpcEndpoint, Inbox]() - private val endpointToInbox = new ConcurrentHashMap[RpcEndpoint, Inbox]() // need a name to RpcEndpoint mapping so that we can delivery the messages @@ -54,14 +51,8 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name") } endpointToEndpointRef.put(endpoint, endpointRef) - val inbox = - if (endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { - new ThreadSafeInbox(endpointRef, endpoint.asInstanceOf[ThreadSafeRpcEndpoint]) - } else { - new ConcurrentInbox(endpointRef, endpoint) - } + val inbox = new Inbox(endpointRef, endpoint) endpointToInbox.put(endpoint, inbox) - idleInboxes.put(endpoint, inbox) afterUpdateInbox(inbox) endpointRef } @@ -74,7 +65,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { def unregisterRpcEndpoint(name: String): Unit = { val endpointPair = nameToEndpoint.remove(name) if (endpointPair != null) { - val inbox = endpointToInbox.remove(endpointPair.endpoint) + val inbox = endpointToInbox.get(endpointPair.endpoint) if (inbox != null) { inbox.stop() afterUpdateInbox(inbox) @@ -109,8 +100,11 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { nettyEnv, inbox.endpointRef, callback, message.senderAddress, message.needReply) postMessageToInbox(inbox, ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)) + return } } + callback.onFailure( + new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) } def postMessage(message: RequestMessage, p: Promise[Any]): Unit = { @@ -123,8 +117,11 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { inbox.endpointRef, message.senderAddress, message.needReply, p) postMessageToInbox(inbox, ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)) + return } } + p.tryFailure( + new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) } private def postMessageToInbox(inbox: Inbox, message: InboxMessage): Unit = { @@ -134,10 +131,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { private def afterUpdateInbox(inbox: Inbox): Unit = { // Do some work to trigger processing messages in the inbox - val endpoint = inbox.endpoint - // Replacing unsuccessfully means someone is processing it - idleInboxes.replace(endpoint, inbox, inbox) - receivers.put(endpoint) + receivers.put(inbox.endpoint) } private[netty] class MessageLoop extends Runnable { @@ -146,24 +140,14 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { while (!stopped) { try { val endpoint = receivers.take() - val inbox = - if (endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { - idleInboxes.remove(endpoint) - } else { - idleInboxes.get(endpoint) - } + val inbox = endpointToInbox.get(endpoint) if (inbox != null) { val inboxStopped = inbox.process(Dispatcher.this) - if (!inboxStopped) { - idleInboxes.put(endpoint, inbox) - if (!inbox.isEmpty) { - receivers.add(endpoint) - } - } else { - idleInboxes.remove(endpoint) + if (inboxStopped) { + endpointToInbox.remove(endpoint) } } else { - // other thread is processing endpoint's Inbox + // The endpoint has been stopped } } catch { case NonFatal(e) => logError(e.getMessage, e) @@ -179,6 +163,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { Runtime.getRuntime.availableProcessors()) private val executor = ThreadUtils.newDaemonFixedThreadPool(parallelism, "dispatcher-event-loop") + (0 until parallelism) foreach { _ => executor.execute(new MessageLoop) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index 8c14ba58eba86..56efcefe15d1b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -17,7 +17,8 @@ package org.apache.spark.rpc.netty -import java.util.concurrent.ConcurrentLinkedQueue +import java.util.LinkedList +import javax.annotation.concurrent.GuardedBy import scala.util.control.NonFatal @@ -54,113 +55,128 @@ private[netty] case class Disassociated(remoteAddress: RpcAddress) extends Broad private[netty] case class AssociationError(cause: Throwable, remoteAddress: RpcAddress) extends BroadcastMessage -private[netty] abstract class Inbox( - val endpointRef: NettyRpcEndpointRef, val endpoint: RpcEndpoint) extends Logging { - - protected val messages = new ConcurrentLinkedQueue[InboxMessage]() - - /** - * Process stored messages. Return `true` if the `Inbox` is already stopped, and the caller will - * release all resources used by the `Inbox`. - */ - def process(dispatcher: Dispatcher): Boolean - - def post(message: InboxMessage): Unit - - protected def onDrop(message: Any): Unit = { - logWarning(s"Drop ${message} because $endpointRef is stopped") - } - - def stop(): Unit - - protected def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { - try { - action - } catch { - case NonFatal(e) => { - try { - endpoint.onError(e) - } catch { - case NonFatal(e) => logWarning(s"Ignore error", e) - } - } - } - } - - def isEmpty: Boolean = messages.isEmpty - -} - /** * A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely. * @param endpointRef * @param endpoint */ -private[netty] class ThreadSafeInbox( - endpointRef: NettyRpcEndpointRef, - override val endpoint: ThreadSafeRpcEndpoint) extends Inbox(endpointRef, endpoint) { +private[netty] class Inbox( + val endpointRef: NettyRpcEndpointRef, + val endpoint: RpcEndpoint) extends Logging { + + private val supportConcurrent = !endpoint.isInstanceOf[ThreadSafeRpcEndpoint] - private val _endpoint = endpoint.asInstanceOf[ThreadSafeRpcEndpoint] + @GuardedBy("this") + protected val messages = new LinkedList[InboxMessage]() - // protected by "this" + @GuardedBy("this") private var stopped = false + @GuardedBy("this") + private var enableConcurrent = false + + @GuardedBy("this") + private var workerCount = 0 + // OnStart should be the first message to process - messages.add(OnStart) + synchronized { + messages.add(OnStart) + } - override def process(dispatcher: Dispatcher): Boolean = { - var exit = false - var message = messages.poll() - while (message != null) { - safelyCall(endpoint) { - message match { - case ContentMessage(_sender, content, needReply, context) => - val pf: PartialFunction[Any, Unit] = - if (needReply) { - endpoint.receiveAndReply(context) - } else { - endpoint.receive - } - try { - pf.applyOrElse[Any, Unit](content, { msg => - throw new SparkException(s"Unmatched message $message from ${_sender}") - }) - if (!needReply) { - context.finish() - } - } catch { - case NonFatal(e) => + def process(dispatcher: Dispatcher): Boolean = { + var message: InboxMessage = null + synchronized { + if (!enableConcurrent && workerCount != 0) { + return false + } + message = messages.poll() + if (message != null) { + workerCount += 1 + } else { + return false + } + } + var skipFinally = false + try { + while (true) { + safelyCall(endpoint) { + message match { + case ContentMessage(_sender, content, needReply, context) => + val pf: PartialFunction[Any, Unit] = if (needReply) { - // If the sender asks a reply, we should send the error back to the sender - context.sendFailure(e) + endpoint.receiveAndReply(context) } else { + endpoint.receive + } + try { + pf.applyOrElse[Any, Unit](content, { msg => + throw new SparkException(s"Unmatched message $message from ${_sender}") + }) + if (!needReply) { context.finish() - throw e } + } catch { + case NonFatal(e) => + if (needReply) { + // If the sender asks a reply, we should send the error back to the sender + context.sendFailure(e) + } else { + context.finish() + throw e + } + } + + case OnStart => { + endpoint.onStart() + if (supportConcurrent) { + synchronized { + enableConcurrent = true + } + } } + case OnStop => + dispatcher.unregisterRpcEndpoint(endpointRef.name) + endpoint.onStop() + assert(isEmpty, "OnStop should be the last message") + return true + case Associated(remoteAddress) => + endpoint.onConnected(remoteAddress) + case Disassociated(remoteAddress) => + endpoint.onDisconnected(remoteAddress) + case AssociationError(cause, remoteAddress) => + endpoint.onNetworkError(cause, remoteAddress) + } + } - case OnStart => { - _endpoint.asInstanceOf[ThreadSafeRpcEndpoint].onStart() + synchronized { + // "enableConcurrent" will be set to false after `onStop` is called, so we should check it + // every time. + if (!enableConcurrent && workerCount != 1) { + // If we are not the only one worker, exit + skipFinally = true + workerCount -= 1 + return false + } + message = messages.poll() + if (message == null) { + skipFinally = true + workerCount -= 1 + return false } - case OnStop => - dispatcher.unregisterRpcEndpoint(endpointRef.name) - _endpoint.onStop() - assert(isEmpty, "OnStop should be the last message") - exit = true - case Associated(remoteAddress) => - endpoint.onConnected(remoteAddress) - case Disassociated(remoteAddress) => - endpoint.onDisconnected(remoteAddress) - case AssociationError(cause, remoteAddress) => - endpoint.onNetworkError(cause, remoteAddress) } } - message = messages.poll() + return false + } finally { + if (!skipFinally) { + // Reset `workerCount` if some exception is thrown. + synchronized { + workerCount -= 1 + } + } } - exit } - override def post(message: InboxMessage): Unit = { + def post(message: InboxMessage): Unit = { val dropped = synchronized { if (stopped) { @@ -176,87 +192,38 @@ private[netty] class ThreadSafeInbox( } } - override def stop(): Unit = synchronized { + def stop(): Unit = synchronized { // The following codes should be in `synchronized` so that we can make sure "OnStop" is the last // message if (!stopped) { + // We should disable concurrent here. Then when RpcEndpoint.onStop is called, it's the only + // thread that is processing messages. So `RpcEndpoint.onStop` can release its resources + // safely. + enableConcurrent = false stopped = true messages.add(OnStop) + // Note: The concurrent events in messages will be processed one by one. } } -} -/** - * A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it concurrently. - * @param endpointRef - * @param endpoint - */ -private[netty] class ConcurrentInbox( - endpointRef: NettyRpcEndpointRef, - endpoint: RpcEndpoint) extends Inbox(endpointRef, endpoint) { + protected def onDrop(message: Any): Unit = { + logWarning(s"Drop ${message} because $endpointRef is stopped") + } - @volatile private var stopped = false + def isEmpty: Boolean = messages.isEmpty - /** - * Process stored messages. Return `true` if the `Inbox` is already stopped, and the caller will - * release all resources used by the `Inbox`. - */ - override def process(dispatcher: Dispatcher): Boolean = { - var message = messages.poll() - while (!stopped && message != null) { - safelyCall(endpoint) { - message match { - case ContentMessage(_sender, content, needReply, context) => - val pf: PartialFunction[Any, Unit] = - if (needReply) { - endpoint.receiveAndReply(context) - } else { - endpoint.receive - } - try { - pf.applyOrElse[Any, Unit](content, { msg => - throw new SparkException(s"Unmatched message $message from ${_sender}") - }) - if (!needReply) { - context.finish() - } - } catch { - case NonFatal(e) => - if (needReply) { - // If the sender asks a reply, we should send the error back to the sender - context.sendFailure(e) - } else { - context.finish() - throw e - } - } - case Associated(remoteAddress) => - endpoint.onConnected(remoteAddress) - case Disassociated(remoteAddress) => - endpoint.onDisconnected(remoteAddress) - case AssociationError(cause, remoteAddress) => - endpoint.onNetworkError(cause, remoteAddress) - case OnStart => - throw new IllegalStateException("Non-thread-safe RpcEndpoint doesn't support OnStart") - case OnStop => - throw new IllegalStateException("Non-thread-safe RpcEndpoint doesn't support OnStop") + protected def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { + try { + action + } catch { + case NonFatal(e) => { + try { + endpoint.onError(e) + } catch { + case NonFatal(e) => logWarning(s"Ignore error", e) } } - message = messages.poll() - } - stopped - } - - override def post(message: InboxMessage): Unit = { - if (stopped) { - onDrop() - } else { - messages.add(message) } } - override def stop(): Unit = { - stopped = true - } - } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala index 9bbd3a00ea2e4..e9ff66ae7f367 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala @@ -19,13 +19,14 @@ package org.apache.spark.rpc.netty import scala.concurrent.Promise +import org.apache.spark.Logging import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc.{RpcAddress, RpcCallContext} private[netty] abstract class NettyRpcCallContext( endpointRef: NettyRpcEndpointRef, override val senderAddress: RpcAddress, - needReply: Boolean) extends RpcCallContext{ + needReply: Boolean) extends RpcCallContext with Logging { protected def send(message: Any): Unit @@ -42,8 +43,9 @@ private[netty] abstract class NettyRpcCallContext( if (needReply) { send(AskResponse(endpointRef, RpcFailure(e))) } else { + logError(e.getMessage, e) throw new IllegalStateException( - "Cannot send reply to the sender because the sender won't handle it", e) + "Cannot send reply to the sender because the sender won't handle it") } } @@ -62,6 +64,7 @@ private[netty] class LocalNettyRpcCallContext( senderAddress: RpcAddress, needReply: Boolean, p: Promise[Any]) extends NettyRpcCallContext(endpointRef, senderAddress, needReply) { + override protected def send(message: Any): Unit = { p.success(message) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 666e58f278101..6874c377e372a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import java.util.Arrays import java.util.concurrent._ -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag @@ -31,7 +31,7 @@ import scala.util.control.NonFatal import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.network.TransportContext -import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.network.client._ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ @@ -52,28 +52,32 @@ private[netty] class NettyRpcEnv( new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this)) private val clientFactory = { - val bootstraps = + val bootstraps: Seq[TransportClientBootstrap] = if (securityManager.isAuthenticationEnabled()) { Seq(new SaslClientBootstrap(transportConf, "", securityManager, securityManager.isSaslEncryptionEnabled())) } else { Seq.empty } - transportContext.createClientFactory(bootstraps) + transportContext.createClientFactory(bootstraps.asJava) } val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") + private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool( + "netty-rpc-connection", + conf.getInt("spark.rpc.connect.threads", 256)) + @volatile private var server: TransportServer = _ def start(port: Int): Unit = { - val bootstraps = + val bootstraps: Seq[TransportServerBootstrap] = if (securityManager.isAuthenticationEnabled()) { Seq(new SaslServerBootstrap(transportConf, securityManager)) } else { Seq.empty } - server = transportContext.createServer(port, bootstraps) + server = transportContext.createServer(port, bootstraps.asJava) dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher)) } @@ -118,18 +122,30 @@ private[netty] class NettyRpcEnv( logError(s"Exception when sending $message", e) }(ThreadUtils.sameThread) } else { - val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port) - client.sendRpc(serialize(message), new RpcResponseCallback { - - override def onFailure(e: Throwable): Unit = { - logError(s"Exception when sending $message", e) - } - - override def onSuccess(response: Array[Byte]): Unit = { - val ack = deserialize[SendAck](response) - logDebug(s"Receive ack from ${ack.sender}") + try { + // `createClient` will block if it cannot find a known connection, so we should run it in + // clientConnectionExecutor + clientConnectionExecutor.execute(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port) + client.sendRpc(serialize(message), new RpcResponseCallback { + + override def onFailure(e: Throwable): Unit = { + logError(s"Exception when sending $message", e) + } + + override def onSuccess(response: Array[Byte]): Unit = { + val ack = deserialize[SendAck](response) + logDebug(s"Receive ack from ${ack.sender}") + } + }) + } + }) + } catch { + case e: RejectedExecutionException => { + // `send` after shutting clientConnectionExecutor down, ignore it } - }) + } } } @@ -219,6 +235,9 @@ private[netty] class NettyRpcEnv( if (dispatcher != null) { dispatcher.stop() } + if (clientConnectionExecutor != null) { + clientConnectionExecutor.shutdownNow() + } } override def deserialize[T](deserializationAction: () => T): T = { @@ -395,43 +414,54 @@ private[netty] class NettyRpcHandler( override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] - assert(addr != null) - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - val broadcastMessage = - synchronized { - remoteAddresses.get(clientAddr).map(AssociationError(cause, _)) + if (addr != null) { + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + val broadcastMessage = + synchronized { + remoteAddresses.get(clientAddr).map(AssociationError(cause, _)) + } + if (broadcastMessage.isEmpty) { + logError(cause.getMessage, cause) + } else { + dispatcher.broadcastMessage(broadcastMessage.get) } - if (broadcastMessage.isEmpty) { - logError(cause.getMessage, cause) } else { - dispatcher.broadcastMessage(broadcastMessage.get) + // If the channel is closed before connecting, its remoteAddress will be null. + // See java.net.Socket.getRemoteSocketAddress + // Because we cannot get a RpcAddress, just log it + logError("Exception before connecting to the client", cause) } } override def connectionTerminated(client: TransportClient): Unit = { val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] - assert(addr != null) - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - val broadcastMessage = - synchronized { - // If the last connection to a remote RpcEnv is terminated, we should broadcast - // "Disassociated" - remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress => - remoteAddresses -= clientAddr - val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0) - assert(count != 0, "remoteAddresses and remoteConnectionCount are not consistent") - if (count - 1 == 0) { - // We lost all clients, so clean up and fire "Disassociated" - remoteConnectionCount.remove(remoteEnvAddress) - Some(Disassociated(remoteEnvAddress)) - } else { - // Decrease the connection number of remoteEnvAddress - remoteConnectionCount.put(remoteEnvAddress, count - 1) - None + if (addr != null) { + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + val broadcastMessage = + synchronized { + // If the last connection to a remote RpcEnv is terminated, we should broadcast + // "Disassociated" + remoteAddresses.get(clientAddr).flatMap { remoteEnvAddress => + remoteAddresses -= clientAddr + val count = remoteConnectionCount.getOrElse(remoteEnvAddress, 0) + assert(count != 0, "remoteAddresses and remoteConnectionCount are not consistent") + if (count - 1 == 0) { + // We lost all clients, so clean up and fire "Disassociated" + remoteConnectionCount.remove(remoteEnvAddress) + Some(Disassociated(remoteEnvAddress)) + } else { + // Decrease the connection number of remoteEnvAddress + remoteConnectionCount.put(remoteEnvAddress, count - 1) + None + } } } - } - broadcastMessage.foreach(dispatcher.broadcastMessage) + broadcastMessage.foreach(dispatcher.broadcastMessage) + } else { + // If the channel is closed before connecting, its remoteAddress will be null. In this case, + // we can ignore it since we don't fire "Associated". + // See java.net.Socket.getRemoteSocketAddress + } } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index fe8df0dfa7435..4b298af8d5d30 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -180,7 +180,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val stopLatch = new CountDownLatch(1) val calledMethods = mutable.ArrayBuffer[String]() - val endpoint = new ThreadSafeRpcEndpoint { + val endpoint = new RpcEndpoint { override val rpcEnv = env override def onStart(): Unit = { @@ -204,7 +204,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { test("onError: error in onStart") { @volatile var e: Throwable = null - env.setupEndpoint("onError-onStart", new ThreadSafeRpcEndpoint { + env.setupEndpoint("onError-onStart", new RpcEndpoint { override val rpcEnv = env override def onStart(): Unit = { @@ -227,7 +227,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { test("onError: error in onStop") { @volatile var e: Throwable = null - val endpointRef = env.setupEndpoint("onError-onStop", new ThreadSafeRpcEndpoint { + val endpointRef = env.setupEndpoint("onError-onStop", new RpcEndpoint { override val rpcEnv = env override def receive: PartialFunction[Any, Unit] = { @@ -274,7 +274,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { test("self: call in onStart") { @volatile var callSelfSuccessfully = false - env.setupEndpoint("self-onStart", new ThreadSafeRpcEndpoint { + env.setupEndpoint("self-onStart", new RpcEndpoint { override val rpcEnv = env override def onStart(): Unit = { @@ -318,7 +318,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { test("self: call in onStop") { @volatile var selfOption: Option[RpcEndpointRef] = null - val endpointRef = env.setupEndpoint("self-onStop", new ThreadSafeRpcEndpoint { + val endpointRef = env.setupEndpoint("self-onStop", new RpcEndpoint { override val rpcEnv = env override def receive: PartialFunction[Any, Unit] = { @@ -374,7 +374,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { test("stop(RpcEndpointRef) reentrant") { @volatile var onStopCount = 0 - val endpointRef = env.setupEndpoint("stop-reentrant", new ThreadSafeRpcEndpoint { + val endpointRef = env.setupEndpoint("stop-reentrant", new RpcEndpoint { override val rpcEnv = env override def receive: PartialFunction[Any, Unit] = { diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala index 178a2f7ccebb2..0d545d520becd 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite -import org.apache.spark.rpc.{RpcAddress, TestRpcEndpoint} +import org.apache.spark.rpc.{RpcEnv, RpcEndpoint, RpcAddress, TestRpcEndpoint} class InboxSuite extends SparkFunSuite { @@ -34,7 +34,7 @@ class InboxSuite extends SparkFunSuite { val dispatcher = mock(classOf[Dispatcher]) - val inbox = new ThreadSafeInbox(endpointRef, endpoint) + val inbox = new Inbox(endpointRef, endpoint) val message = ContentMessage(null, "hi", false, null) inbox.post(message) assert(inbox.process(dispatcher) === false) @@ -53,7 +53,7 @@ class InboxSuite extends SparkFunSuite { val endpointRef = mock(classOf[NettyRpcEndpointRef]) val dispatcher = mock(classOf[Dispatcher]) - val inbox = new ThreadSafeInbox(endpointRef, endpoint) + val inbox = new Inbox(endpointRef, endpoint) val message = ContentMessage(null, "hi", true, null) inbox.post(message) assert(inbox.process(dispatcher) === false) @@ -69,7 +69,7 @@ class InboxSuite extends SparkFunSuite { val dispatcher = mock(classOf[Dispatcher]) val numDroppedMessages = new AtomicInteger(0) - val inbox = new ThreadSafeInbox(endpointRef, endpoint) { + val inbox = new Inbox(endpointRef, endpoint) { override def onDrop(message: Any): Unit = { numDroppedMessages.incrementAndGet() } @@ -107,7 +107,7 @@ class InboxSuite extends SparkFunSuite { val remoteAddress = RpcAddress("localhost", 11111) - val inbox = new ThreadSafeInbox(endpointRef, endpoint) + val inbox = new Inbox(endpointRef, endpoint) inbox.post(Associated(remoteAddress)) inbox.process(dispatcher) @@ -121,7 +121,7 @@ class InboxSuite extends SparkFunSuite { val remoteAddress = RpcAddress("localhost", 11111) - val inbox = new ThreadSafeInbox(endpointRef, endpoint) + val inbox = new Inbox(endpointRef, endpoint) inbox.post(Disassociated(remoteAddress)) inbox.process(dispatcher) @@ -136,7 +136,7 @@ class InboxSuite extends SparkFunSuite { val remoteAddress = RpcAddress("localhost", 11111) val cause = new RuntimeException("Oops") - val inbox = new ThreadSafeInbox(endpointRef, endpoint) + val inbox = new Inbox(endpointRef, endpoint) inbox.post(AssociationError(cause, remoteAddress)) inbox.process(dispatcher) From 9ae0873ba9a3436a5ee323ab5806fe1288d799fc Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 15 Sep 2015 17:55:12 +0800 Subject: [PATCH 13/30] make sure calling RpcEndpoint.onStop when stopping RpcEnv to release all resources --- .../apache/spark/rpc/netty/Dispatcher.scala | 82 ++++++++++++++----- .../org/apache/spark/rpc/netty/Inbox.scala | 1 - .../apache/spark/rpc/netty/NettyRpcEnv.scala | 7 ++ .../org/apache/spark/util/ThreadUtils.scala | 6 +- .../apache/spark/rpc/netty/InboxSuite.scala | 2 - 5 files changed, 72 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index b1a99629459c7..fff164d8430bf 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -17,14 +17,15 @@ package org.apache.spark.rpc.netty -import java.util.concurrent.{TimeUnit, LinkedBlockingQueue, ConcurrentHashMap} - -import org.apache.spark.network.client.RpcResponseCallback +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} +import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.concurrent.Promise import scala.util.control.NonFatal import org.apache.spark.{SparkException, Logging} +import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc._ import org.apache.spark.util.ThreadUtils @@ -42,18 +43,24 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { // Track the receivers whose inboxes may contain messages. private val receivers = new LinkedBlockingQueue[RpcEndpoint]() - @volatile private var stopped = false + @GuardedBy("this") + private var stopped = false def registerRpcEndpoint(name: String, endpoint: RpcEndpoint): NettyRpcEndpointRef = { val addr = new NettyRpcAddress(nettyEnv.address.host, nettyEnv.address.port, name) val endpointRef = new NettyRpcEndpointRef(nettyEnv.conf, addr, nettyEnv) - if (nameToEndpoint.putIfAbsent(name, new RpcEndpointPair(endpoint, endpointRef)) != null) { - throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name") + synchronized { + if (stopped) { + throw new IllegalStateException("RpcEnv has been stopped") + } + if (nameToEndpoint.putIfAbsent(name, new RpcEndpointPair(endpoint, endpointRef)) != null) { + throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name") + } + endpointToEndpointRef.put(endpoint, endpointRef) + val inbox = new Inbox(endpointRef, endpoint) + endpointToInbox.put(endpoint, inbox) + receivers.put(inbox.endpoint) } - endpointToEndpointRef.put(endpoint, endpointRef) - val inbox = new Inbox(endpointRef, endpoint) - endpointToInbox.put(endpoint, inbox) - afterUpdateInbox(inbox) endpointRef } @@ -62,20 +69,26 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { def getRpcEndpointRef(name: String): RpcEndpointRef = nameToEndpoint.get(name).endpointRef // Should be idempotent - def unregisterRpcEndpoint(name: String): Unit = { + private def unregisterRpcEndpoint(name: String): Unit = { val endpointPair = nameToEndpoint.remove(name) if (endpointPair != null) { val inbox = endpointToInbox.get(endpointPair.endpoint) if (inbox != null) { inbox.stop() - afterUpdateInbox(inbox) + receivers.put(inbox.endpoint) } endpointToEndpointRef.remove(endpointPair.endpoint) } } def stop(rpcEndpointRef: RpcEndpointRef): Unit = { - unregisterRpcEndpoint(rpcEndpointRef.name) + synchronized { + if (stopped) { + // This endpoint will be stopped by Distpatcher.stop() method. + return + } + unregisterRpcEndpoint(rpcEndpointRef.name) + } } /** @@ -124,22 +137,26 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) } - private def postMessageToInbox(inbox: Inbox, message: InboxMessage): Unit = { + private def postMessageToInbox(inbox: Inbox, message: InboxMessage): Unit = synchronized { + if (stopped) { + logWarning(s"Drop ${message} because RpcEnv has been stopped") + return + } inbox.post(message) - afterUpdateInbox(inbox) - } - - private def afterUpdateInbox(inbox: Inbox): Unit = { - // Do some work to trigger processing messages in the inbox receivers.put(inbox.endpoint) } private[netty] class MessageLoop extends Runnable { override def run(): Unit = { try { - while (!stopped) { + while (true) { try { val endpoint = receivers.take() + if (endpoint == DummyEndpoint) { + // Put DummyEndpoint back so that other MessageLoops can see it. + receivers.put(DummyEndpoint) + return + } val inbox = endpointToInbox.get(endpoint) if (inbox != null) { val inboxStopped = inbox.process(Dispatcher.this) @@ -169,8 +186,22 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { } def stop(): Unit = { - stopped = true - executor.shutdownNow() + synchronized { + if (stopped) { + return + } + stopped = true + } + // When we reach here, other threads won't update `nameToEndpoint`. So we can guarantee all + // registered endpoints will be stopped correctly. + for (name <- nameToEndpoint.keySet().asScala) { + unregisterRpcEndpoint(name) + } + // When we reach here, the new items put into receivers will always be DummyEndpoint, others + // will be rejected. So that we can make sure we will process all messages that have already in + // the Inboxes. + receivers.put(DummyEndpoint) + executor.shutdown() } def awaitTermination(): Unit = { @@ -185,3 +216,10 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { } } + +/** + * A dummy endpoint that indicates MessageLoop should exit its loop. + */ +private[netty] object DummyEndpoint extends RpcEndpoint { + override val rpcEnv: RpcEnv = null +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index 56efcefe15d1b..975dfcb4be391 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -135,7 +135,6 @@ private[netty] class Inbox( } } case OnStop => - dispatcher.unregisterRpcEndpoint(endpointRef.name) endpoint.onStop() assert(isEmpty, "OnStop should be the last message") return true diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 6874c377e372a..01aaf23d8c6b4 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -336,6 +336,13 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf) override def toString: String = s"NettyRpcEndpointRef(${_address})" def toURI: URI = new URI(s"spark://${_address}") + + final override def equals(that: Any): Boolean = that match { + case other: NettyRpcEndpointRef => _address == other._address + case _ => false + } + + final override def hashCode(): Int = if (_address == null) 0 else _address.hashCode() } /** diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index ca5624a3d8b3d..93ee32dc1f6ca 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -84,6 +84,10 @@ private[spark] object ThreadUtils { */ def newDaemonSingleThreadScheduledExecutor(threadName: String): ScheduledExecutorService = { val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() - Executors.newSingleThreadScheduledExecutor(threadFactory) + val executor = new ScheduledThreadPoolExecutor(1, threadFactory) + // By default, a cancelled task is not automatically removed from the work queue until its delay + // elapses. We have to enable it manually. + executor.setRemoveOnCancelPolicy(true) + executor } } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala index 0d545d520becd..f68ac8659f018 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -45,7 +45,6 @@ class InboxSuite extends SparkFunSuite { assert(inbox.process(dispatcher) === true) endpoint.verifyStarted() endpoint.verifyStopped() - verify(dispatcher).unregisterRpcEndpoint("hello") } test("post: with reply") { @@ -97,7 +96,6 @@ class InboxSuite extends SparkFunSuite { assert(1000 === endpoint.numReceiveMessages + numDroppedMessages.get) endpoint.verifyStarted() endpoint.verifyStopped() - verify(dispatcher).unregisterRpcEndpoint("hello") } test("post: Associated") { From e8ecab8c20e496b961b4ce51dac1e33d840dc2d4 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 16 Sep 2015 08:59:53 +0800 Subject: [PATCH 14/30] Fix a race condition --- .../src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala | 3 ++- core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index fff164d8430bf..203c37510788d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -66,6 +66,8 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { def getRpcEndpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToEndpointRef.get(endpoint) + def removeRpcEndpointRef(endpoint: RpcEndpoint): Unit = endpointToEndpointRef.remove(endpoint) + def getRpcEndpointRef(name: String): RpcEndpointRef = nameToEndpoint.get(name).endpointRef // Should be idempotent @@ -77,7 +79,6 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { inbox.stop() receivers.put(inbox.endpoint) } - endpointToEndpointRef.remove(endpointPair.endpoint) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index 975dfcb4be391..dfe6c24332135 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -135,6 +135,7 @@ private[netty] class Inbox( } } case OnStop => + dispatcher.removeRpcEndpointRef(endpoint) endpoint.onStop() assert(isEmpty, "OnStop should be the last message") return true From 8589699405976afbc0b9cf8712af6e1ec718a00e Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 17 Sep 2015 22:03:34 +0800 Subject: [PATCH 15/30] DummyEndpoint -> PoisonEndpoint --- .../apache/spark/rpc/netty/Dispatcher.scala | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 203c37510788d..7e3e4cc9b3d0a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -153,9 +153,9 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { while (true) { try { val endpoint = receivers.take() - if (endpoint == DummyEndpoint) { - // Put DummyEndpoint back so that other MessageLoops can see it. - receivers.put(DummyEndpoint) + if (endpoint == PoisonEndpoint) { + // Put PoisonEndpoint back so that other MessageLoops can see it. + receivers.put(PoisonEndpoint) return } val inbox = endpointToInbox.get(endpoint) @@ -198,10 +198,10 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { for (name <- nameToEndpoint.keySet().asScala) { unregisterRpcEndpoint(name) } - // When we reach here, the new items put into receivers will always be DummyEndpoint, others + // When we reach here, the new items put into receivers will always be `PoisonEndpoint`, others // will be rejected. So that we can make sure we will process all messages that have already in // the Inboxes. - receivers.put(DummyEndpoint) + receivers.put(PoisonEndpoint) executor.shutdown() } @@ -216,11 +216,10 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { nameToEndpoint.containsKey(name) } -} - -/** - * A dummy endpoint that indicates MessageLoop should exit its loop. - */ -private[netty] object DummyEndpoint extends RpcEndpoint { - override val rpcEnv: RpcEnv = null + /** + * A poison endpoint that indicates MessageLoop should exit its loop. + */ + private object PoisonEndpoint extends RpcEndpoint { + override val rpcEnv: RpcEnv = null + } } From 89c92c9c84506b877c6dee61a1bf6ef532d52314 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 17 Sep 2015 22:07:05 +0800 Subject: [PATCH 16/30] Move MessageLoop after all fields and methods --- .../apache/spark/rpc/netty/Dispatcher.scala | 60 +++++++++---------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 7e3e4cc9b3d0a..13021389deed0 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -147,36 +147,6 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { receivers.put(inbox.endpoint) } - private[netty] class MessageLoop extends Runnable { - override def run(): Unit = { - try { - while (true) { - try { - val endpoint = receivers.take() - if (endpoint == PoisonEndpoint) { - // Put PoisonEndpoint back so that other MessageLoops can see it. - receivers.put(PoisonEndpoint) - return - } - val inbox = endpointToInbox.get(endpoint) - if (inbox != null) { - val inboxStopped = inbox.process(Dispatcher.this) - if (inboxStopped) { - endpointToInbox.remove(endpoint) - } - } else { - // The endpoint has been stopped - } - } catch { - case NonFatal(e) => logError(e.getMessage, e) - } - } - } catch { - case ie: InterruptedException => // exit - } - } - } - private val parallelism = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.parallelism", Runtime.getRuntime.availableProcessors()) @@ -216,6 +186,36 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { nameToEndpoint.containsKey(name) } + private class MessageLoop extends Runnable { + override def run(): Unit = { + try { + while (true) { + try { + val endpoint = receivers.take() + if (endpoint == PoisonEndpoint) { + // Put PoisonEndpoint back so that other MessageLoops can see it. + receivers.put(PoisonEndpoint) + return + } + val inbox = endpointToInbox.get(endpoint) + if (inbox != null) { + val inboxStopped = inbox.process(Dispatcher.this) + if (inboxStopped) { + endpointToInbox.remove(endpoint) + } + } else { + // The endpoint has been stopped + } + } catch { + case NonFatal(e) => logError(e.getMessage, e) + } + } + } catch { + case ie: InterruptedException => // exit + } + } + } + /** * A poison endpoint that indicates MessageLoop should exit its loop. */ From 0e08d8b7b38e048549a1c5a5b3b5671afb76dabc Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 17 Sep 2015 22:19:36 +0800 Subject: [PATCH 17/30] Add comments and code style fix --- .../scala/org/apache/spark/rpc/netty/Inbox.scala | 14 ++++++++++---- .../apache/spark/rpc/netty/NettyRpcAddress.scala | 2 +- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index dfe6c24332135..00a78eca072a0 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -42,10 +42,13 @@ private[netty] case object OnStart extends InboxMessage private[netty] case object OnStop extends InboxMessage +/** + * A broadcast message that indicates connecting to a remote node. + */ private[netty] case class Associated(remoteAddress: RpcAddress) extends BroadcastMessage /** - * A broadcast message that indicates + * A broadcast message that indicates a remote connection is lost. */ private[netty] case class Disassociated(remoteAddress: RpcAddress) extends BroadcastMessage @@ -64,8 +67,6 @@ private[netty] class Inbox( val endpointRef: NettyRpcEndpointRef, val endpoint: RpcEndpoint) extends Logging { - private val supportConcurrent = !endpoint.isInstanceOf[ThreadSafeRpcEndpoint] - @GuardedBy("this") protected val messages = new LinkedList[InboxMessage]() @@ -83,6 +84,10 @@ private[netty] class Inbox( messages.add(OnStart) } + /** + * Process stored messages. Return `true` if the `Inbox` is already stopped, and the caller will + * release all resources used by the `Inbox`. + */ def process(dispatcher: Dispatcher): Boolean = { var message: InboxMessage = null synchronized { @@ -128,12 +133,13 @@ private[netty] class Inbox( case OnStart => { endpoint.onStart() - if (supportConcurrent) { + if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { synchronized { enableConcurrent = true } } } + case OnStop => dispatcher.removeRpcEndpointRef(endpoint) endpoint.onStop() diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala index 2f38b1c00f291..3142cc4fb58e2 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala @@ -49,7 +49,7 @@ private[netty] object NettyRpcAddress { NettyRpcAddress(host, port, name) } catch { case e: java.net.URISyntaxException => - throw new SparkException("Invalid master URL: " + sparkUrl, e) + throw new SparkException("Invalid Spark URL: " + sparkUrl, e) } } From e130a4c65e81652337f4c36719d8dc5323d14f44 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 17 Sep 2015 22:25:38 +0800 Subject: [PATCH 18/30] Remove skipFinally --- .../org/apache/spark/rpc/netty/Inbox.scala | 116 ++++++++---------- 1 file changed, 53 insertions(+), 63 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index 00a78eca072a0..a27f1def94597 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -101,85 +101,75 @@ private[netty] class Inbox( return false } } - var skipFinally = false - try { - while (true) { - safelyCall(endpoint) { - message match { - case ContentMessage(_sender, content, needReply, context) => - val pf: PartialFunction[Any, Unit] = + while (true) { + safelyCall(endpoint) { + message match { + case ContentMessage(_sender, content, needReply, context) => + val pf: PartialFunction[Any, Unit] = + if (needReply) { + endpoint.receiveAndReply(context) + } else { + endpoint.receive + } + try { + pf.applyOrElse[Any, Unit](content, { msg => + throw new SparkException(s"Unmatched message $message from ${_sender}") + }) + if (!needReply) { + context.finish() + } + } catch { + case NonFatal(e) => if (needReply) { - endpoint.receiveAndReply(context) + // If the sender asks a reply, we should send the error back to the sender + context.sendFailure(e) } else { - endpoint.receive - } - try { - pf.applyOrElse[Any, Unit](content, { msg => - throw new SparkException(s"Unmatched message $message from ${_sender}") - }) - if (!needReply) { context.finish() + throw e } - } catch { - case NonFatal(e) => - if (needReply) { - // If the sender asks a reply, we should send the error back to the sender - context.sendFailure(e) - } else { - context.finish() - throw e - } - } + } - case OnStart => { - endpoint.onStart() - if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { - synchronized { - enableConcurrent = true - } + case OnStart => { + endpoint.onStart() + if (!endpoint.isInstanceOf[ThreadSafeRpcEndpoint]) { + synchronized { + enableConcurrent = true } } - - case OnStop => - dispatcher.removeRpcEndpointRef(endpoint) - endpoint.onStop() - assert(isEmpty, "OnStop should be the last message") - return true - case Associated(remoteAddress) => - endpoint.onConnected(remoteAddress) - case Disassociated(remoteAddress) => - endpoint.onDisconnected(remoteAddress) - case AssociationError(cause, remoteAddress) => - endpoint.onNetworkError(cause, remoteAddress) } - } - synchronized { - // "enableConcurrent" will be set to false after `onStop` is called, so we should check it - // every time. - if (!enableConcurrent && workerCount != 1) { - // If we are not the only one worker, exit - skipFinally = true - workerCount -= 1 - return false - } - message = messages.poll() - if (message == null) { - skipFinally = true + case OnStop => + dispatcher.removeRpcEndpointRef(endpoint) + endpoint.onStop() + assert(isEmpty, "OnStop should be the last message") workerCount -= 1 - return false - } + return true + case Associated(remoteAddress) => + endpoint.onConnected(remoteAddress) + case Disassociated(remoteAddress) => + endpoint.onDisconnected(remoteAddress) + case AssociationError(cause, remoteAddress) => + endpoint.onNetworkError(cause, remoteAddress) } } - return false - } finally { - if (!skipFinally) { - // Reset `workerCount` if some exception is thrown. - synchronized { + + synchronized { + // "enableConcurrent" will be set to false after `onStop` is called, so we should check it + // every time. + if (!enableConcurrent && workerCount != 1) { + // If we are not the only one worker, exit + workerCount -= 1 + return false + } + message = messages.poll() + if (message == null) { workerCount -= 1 + return false } } } + // We won't reach here. Just make the compiler happy. + return false } def post(message: InboxMessage): Unit = { From 792491ecaf873250fed73d56eda7501a35427523 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 17 Sep 2015 22:38:42 +0800 Subject: [PATCH 19/30] Fix an exception mesage --- .../main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala index 3142cc4fb58e2..1876b25592086 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcAddress.scala @@ -44,7 +44,7 @@ private[netty] object NettyRpcAddress { (uri.getPath != null && !uri.getPath.isEmpty) || // uri.getPath returns "" instead of null uri.getFragment != null || uri.getQuery != null) { - throw new SparkException("Invalid master URL: " + sparkUrl) + throw new SparkException("Invalid Spark URL: " + sparkUrl) } NettyRpcAddress(host, port, name) } catch { From ce21e1ae657283ac442dcccac48ea995fa517f22 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 17 Sep 2015 22:59:59 +0800 Subject: [PATCH 20/30] Add getInbox and update postMessageToInbox to make sure we can send the exception back if Dispatcher is stopped --- .../apache/spark/rpc/netty/Dispatcher.scala | 88 ++++++++++++------- 1 file changed, 56 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 13021389deed0..f54ac288c7128 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -100,51 +100,75 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { val iter = endpointToInbox.values().iterator() while (iter.hasNext) { val inbox = iter.next() - postMessageToInbox(inbox, message) + postMessageToInbox(inbox, message, () => { + logWarning(s"Drop ${message} because RpcEnv has been stopped") + }) } } def postMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { - val receiver = nameToEndpoint.get(message.receiver.name) + def onDispatcherStop(): Unit = { + callback.onFailure( + new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) + } + + val inbox = getInbox(message.receiver.name) + if (inbox != null) { + val rpcCallContext = + new RemoteNettyRpcCallContext( + nettyEnv, inbox.endpointRef, callback, message.senderAddress, message.needReply) + postMessageToInbox( + inbox, + ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext), + onDispatcherStop) + } else { + onDispatcherStop() + } + } + + private def getInbox(endpointName: String): Inbox = { + val receiver = nameToEndpoint.get(endpointName) if (receiver != null) { - val inbox = endpointToInbox.get(receiver.endpoint) - if (inbox != null) { - val rpcCallContext = - new RemoteNettyRpcCallContext( - nettyEnv, inbox.endpointRef, callback, message.senderAddress, message.needReply) - postMessageToInbox(inbox, - ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)) - return - } + endpointToInbox.get(receiver.endpoint) + } else { + null } - callback.onFailure( - new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) } def postMessage(message: RequestMessage, p: Promise[Any]): Unit = { - val receiver = nameToEndpoint.get(message.receiver.name) - if (receiver != null) { - val inbox = endpointToInbox.get(receiver.endpoint) - if (inbox != null) { - val rpcCallContext = - new LocalNettyRpcCallContext( - inbox.endpointRef, message.senderAddress, message.needReply, p) - postMessageToInbox(inbox, - ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext)) - return - } + def onDispatcherStop(): Unit = { + p.tryFailure( + new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) + } + + val inbox = getInbox(message.receiver.name) + if (inbox != null) { + val rpcCallContext = + new LocalNettyRpcCallContext( + inbox.endpointRef, message.senderAddress, message.needReply, p) + postMessageToInbox( + inbox, + ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext), + onDispatcherStop) + } else { + onDispatcherStop() } - p.tryFailure( - new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) } - private def postMessageToInbox(inbox: Inbox, message: InboxMessage): Unit = synchronized { - if (stopped) { - logWarning(s"Drop ${message} because RpcEnv has been stopped") - return + private def postMessageToInbox(inbox: Inbox, message: InboxMessage, onStop: () => Unit): Unit = { + var shouldCallOnStop = false + synchronized { + if (stopped) { + shouldCallOnStop = true + } else { + inbox.post(message) + receivers.put(inbox.endpoint) + } + } + if (shouldCallOnStop) { + // We don't need to call `onStop` in the `synchronized` block + onStop() } - inbox.post(message) - receivers.put(inbox.endpoint) } private val parallelism = nettyEnv.conf.getInt("spark.rpc.netty.dispatcher.parallelism", From c904a9d77727e13a58724bbcd8341a4606868242 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 17 Sep 2015 23:05:16 +0800 Subject: [PATCH 21/30] Remove BroadcastMessage --- .../scala/org/apache/spark/rpc/netty/Dispatcher.scala | 4 ++-- .../main/scala/org/apache/spark/rpc/netty/Inbox.scala | 11 +++-------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index f54ac288c7128..d7ffe6f9b7767 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -85,7 +85,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { def stop(rpcEndpointRef: RpcEndpointRef): Unit = { synchronized { if (stopped) { - // This endpoint will be stopped by Distpatcher.stop() method. + // This endpoint will be stopped by Dispatcher.stop() method. return } unregisterRpcEndpoint(rpcEndpointRef.name) @@ -96,7 +96,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { * Send a message to all registered [[RpcEndpoint]]s. * @param message */ - def broadcastMessage(message: BroadcastMessage): Unit = { + def broadcastMessage(message: InboxMessage): Unit = { val iter = endpointToInbox.values().iterator() while (iter.hasNext) { val inbox = iter.next() diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index a27f1def94597..32c0ddf561d7b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -33,11 +33,6 @@ private[netty] case class ContentMessage( needReply: Boolean, context: NettyRpcCallContext) extends InboxMessage -/** - * A message type that will be posted to all registered [[RpcEndpoint]] - */ -private[netty] sealed trait BroadcastMessage extends InboxMessage - private[netty] case object OnStart extends InboxMessage private[netty] case object OnStop extends InboxMessage @@ -45,18 +40,18 @@ private[netty] case object OnStop extends InboxMessage /** * A broadcast message that indicates connecting to a remote node. */ -private[netty] case class Associated(remoteAddress: RpcAddress) extends BroadcastMessage +private[netty] case class Associated(remoteAddress: RpcAddress) extends InboxMessage /** * A broadcast message that indicates a remote connection is lost. */ -private[netty] case class Disassociated(remoteAddress: RpcAddress) extends BroadcastMessage +private[netty] case class Disassociated(remoteAddress: RpcAddress) extends InboxMessage /** * A broadcast message that indicates a network error */ private[netty] case class AssociationError(cause: Throwable, remoteAddress: RpcAddress) - extends BroadcastMessage + extends InboxMessage /** * A inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely. From 81173de3056b6f958b033c9cc83525d1051f04a5 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 19 Sep 2015 22:26:50 +0800 Subject: [PATCH 22/30] Add missing synchronized and other minor fixes --- .../main/scala/org/apache/spark/rpc/netty/Inbox.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index 32c0ddf561d7b..e65a7a3c78025 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -137,7 +137,7 @@ private[netty] class Inbox( dispatcher.removeRpcEndpointRef(endpoint) endpoint.onStop() assert(isEmpty, "OnStop should be the last message") - workerCount -= 1 + synchronized { workerCount -= 1 } return true case Associated(remoteAddress) => endpoint.onConnected(remoteAddress) @@ -164,7 +164,7 @@ private[netty] class Inbox( } } // We won't reach here. Just make the compiler happy. - return false + throw new IllegalStateException("ShouldNotReachHere") } def post(message: InboxMessage): Unit = { @@ -197,13 +197,14 @@ private[netty] class Inbox( } } + // Visible for testing. protected def onDrop(message: Any): Unit = { logWarning(s"Drop ${message} because $endpointRef is stopped") } - def isEmpty: Boolean = messages.isEmpty + def isEmpty: Boolean = synchronized { messages.isEmpty } - protected def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { + private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { try { action } catch { From d540a80e2598d425353772d936383644b18b81a0 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 19 Sep 2015 23:42:02 +0800 Subject: [PATCH 23/30] Reorganize dispatcher to be easier to read. Use less data structures to keep data; "endpointRefs" is not strictly necessary, but was left as a performance optimization, since RpcEndpoint.self calls getRpcEndpointRef indirectly, and this extra map makes that call O(1) instead of O(n). Stole @vanzin's idea from https://github.com/vanzin/spark/commit/e6674673518ef95cc20c5f0d0f113ce6711b7917 --- .../apache/spark/rpc/netty/Dispatcher.scala | 159 +++++++----------- .../org/apache/spark/rpc/netty/Inbox.scala | 18 +- .../apache/spark/rpc/netty/InboxSuite.scala | 15 +- 3 files changed, 83 insertions(+), 109 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index d7ffe6f9b7767..b3362cb53b2b8 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -33,15 +33,18 @@ private class RpcEndpointPair(val endpoint: RpcEndpoint, val endpointRef: NettyR private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { - private val endpointToInbox = new ConcurrentHashMap[RpcEndpoint, Inbox]() - - // need a name to RpcEndpoint mapping so that we can delivery the messages - private val nameToEndpoint = new ConcurrentHashMap[String, RpcEndpointPair]() + private class EndpointData( + val name: String, + val endpoint: RpcEndpoint, + val ref: NettyRpcEndpointRef) { + val inbox = new Inbox(ref, endpoint) + } - private val endpointToEndpointRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() + private val endpoints = new ConcurrentHashMap[String, EndpointData]() + private val endpointRefs = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() // Track the receivers whose inboxes may contain messages. - private val receivers = new LinkedBlockingQueue[RpcEndpoint]() + private val receivers = new LinkedBlockingQueue[EndpointData]() @GuardedBy("this") private var stopped = false @@ -53,33 +56,30 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { if (stopped) { throw new IllegalStateException("RpcEnv has been stopped") } - if (nameToEndpoint.putIfAbsent(name, new RpcEndpointPair(endpoint, endpointRef)) != null) { + if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) { throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name") } - endpointToEndpointRef.put(endpoint, endpointRef) - val inbox = new Inbox(endpointRef, endpoint) - endpointToInbox.put(endpoint, inbox) - receivers.put(inbox.endpoint) + val data = endpoints.get(name) + endpointRefs.put(data.endpoint, data.ref) + receivers.put(data) } endpointRef } - def getRpcEndpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToEndpointRef.get(endpoint) - - def removeRpcEndpointRef(endpoint: RpcEndpoint): Unit = endpointToEndpointRef.remove(endpoint) + def getRpcEndpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointRefs.get(endpoint) - def getRpcEndpointRef(name: String): RpcEndpointRef = nameToEndpoint.get(name).endpointRef + def removeRpcEndpointRef(endpoint: RpcEndpoint): Unit = endpointRefs.remove(endpoint) // Should be idempotent private def unregisterRpcEndpoint(name: String): Unit = { - val endpointPair = nameToEndpoint.remove(name) - if (endpointPair != null) { - val inbox = endpointToInbox.get(endpointPair.endpoint) - if (inbox != null) { - inbox.stop() - receivers.put(inbox.endpoint) - } + val data = endpoints.remove(name) + if (data != null) { + data.inbox.stop() + receivers.put(data) } + // Don't clean `endpointRefs` here because it's possible that some messages are being processed + // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via + // `removeRpcEndpointRef`. } def stop(rpcEndpointRef: RpcEndpointRef): Unit = { @@ -97,77 +97,63 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { * @param message */ def broadcastMessage(message: InboxMessage): Unit = { - val iter = endpointToInbox.values().iterator() + val iter = endpoints.keySet().iterator() while (iter.hasNext) { - val inbox = iter.next() - postMessageToInbox(inbox, message, () => { - logWarning(s"Drop ${message} because RpcEnv has been stopped") - }) + val name = iter.next + postMessageToInbox(name, (_) => message, + () => { logWarning(s"Drop ${message} because ${name} has been stopped") }) } } def postMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { - def onDispatcherStop(): Unit = { - callback.onFailure( - new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) - } - - val inbox = getInbox(message.receiver.name) - if (inbox != null) { + def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { val rpcCallContext = new RemoteNettyRpcCallContext( - nettyEnv, inbox.endpointRef, callback, message.senderAddress, message.needReply) - postMessageToInbox( - inbox, - ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext), - onDispatcherStop) - } else { - onDispatcherStop() + nettyEnv, sender, callback, message.senderAddress, message.needReply) + ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext) } - } - private def getInbox(endpointName: String): Inbox = { - val receiver = nameToEndpoint.get(endpointName) - if (receiver != null) { - endpointToInbox.get(receiver.endpoint) - } else { - null + def onEndpointStopped(): Unit = { + callback.onFailure( + new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) } + + postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped) } def postMessage(message: RequestMessage, p: Promise[Any]): Unit = { - def onDispatcherStop(): Unit = { + def createMessage(sender: NettyRpcEndpointRef): InboxMessage = { + val rpcCallContext = + new LocalNettyRpcCallContext(sender, message.senderAddress, message.needReply, p) + ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext) + } + + def onEndpointStopped(): Unit = { p.tryFailure( new SparkException(s"Could not find ${message.receiver.name} or it has been stopped")) } - val inbox = getInbox(message.receiver.name) - if (inbox != null) { - val rpcCallContext = - new LocalNettyRpcCallContext( - inbox.endpointRef, message.senderAddress, message.needReply, p) - postMessageToInbox( - inbox, - ContentMessage(message.senderAddress, message.content, message.needReply, rpcCallContext), - onDispatcherStop) - } else { - onDispatcherStop() - } + postMessageToInbox(message.receiver.name, createMessage, onEndpointStopped) } - private def postMessageToInbox(inbox: Inbox, message: InboxMessage, onStop: () => Unit): Unit = { - var shouldCallOnStop = false - synchronized { - if (stopped) { - shouldCallOnStop = true - } else { - inbox.post(message) - receivers.put(inbox.endpoint) + private def postMessageToInbox( + endpointName: String, + createMessageFn: NettyRpcEndpointRef => InboxMessage, + onStopped: () => Unit): Unit = { + val shouldCallOnStop = + synchronized { + val data = endpoints.get(endpointName) + if (stopped || data == null) { + true + } else { + data.inbox.post(createMessageFn(data.ref)) + receivers.put(data) + false + } } - } if (shouldCallOnStop) { // We don't need to call `onStop` in the `synchronized` block - onStop() + onStopped() } } @@ -187,14 +173,9 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { } stopped = true } - // When we reach here, other threads won't update `nameToEndpoint`. So we can guarantee all - // registered endpoints will be stopped correctly. - for (name <- nameToEndpoint.keySet().asScala) { - unregisterRpcEndpoint(name) - } - // When we reach here, the new items put into receivers will always be `PoisonEndpoint`, others - // will be rejected. So that we can make sure we will process all messages that have already in - // the Inboxes. + // Stop all endpoints. This will queue all endpoints for processing by the message loops. + endpoints.keySet().asScala.foreach(unregisterRpcEndpoint) + // Enqueue a message that tells the message loops to stop. receivers.put(PoisonEndpoint) executor.shutdown() } @@ -207,7 +188,7 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { * Return if the endpoint exists */ def verify(name: String): Boolean = { - nameToEndpoint.containsKey(name) + endpoints.containsKey(name) } private class MessageLoop extends Runnable { @@ -215,21 +196,13 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { try { while (true) { try { - val endpoint = receivers.take() - if (endpoint == PoisonEndpoint) { + val data = receivers.take() + if (data == PoisonEndpoint) { // Put PoisonEndpoint back so that other MessageLoops can see it. receivers.put(PoisonEndpoint) return } - val inbox = endpointToInbox.get(endpoint) - if (inbox != null) { - val inboxStopped = inbox.process(Dispatcher.this) - if (inboxStopped) { - endpointToInbox.remove(endpoint) - } - } else { - // The endpoint has been stopped - } + data.inbox.process(Dispatcher.this) } catch { case NonFatal(e) => logError(e.getMessage, e) } @@ -243,7 +216,5 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { /** * A poison endpoint that indicates MessageLoop should exit its loop. */ - private object PoisonEndpoint extends RpcEndpoint { - override val rpcEnv: RpcEnv = null - } + private val PoisonEndpoint = new EndpointData(null, null, null) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index e65a7a3c78025..4a30fa1a73300 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -80,20 +80,19 @@ private[netty] class Inbox( } /** - * Process stored messages. Return `true` if the `Inbox` is already stopped, and the caller will - * release all resources used by the `Inbox`. + * Process stored messages. */ - def process(dispatcher: Dispatcher): Boolean = { + def process(dispatcher: Dispatcher): Unit = { var message: InboxMessage = null synchronized { if (!enableConcurrent && workerCount != 0) { - return false + return } message = messages.poll() if (message != null) { workerCount += 1 } else { - return false + return } } while (true) { @@ -134,11 +133,12 @@ private[netty] class Inbox( } case OnStop => + assert(workerCount == 1) dispatcher.removeRpcEndpointRef(endpoint) endpoint.onStop() assert(isEmpty, "OnStop should be the last message") synchronized { workerCount -= 1 } - return true + return case Associated(remoteAddress) => endpoint.onConnected(remoteAddress) case Disassociated(remoteAddress) => @@ -154,17 +154,15 @@ private[netty] class Inbox( if (!enableConcurrent && workerCount != 1) { // If we are not the only one worker, exit workerCount -= 1 - return false + return } message = messages.poll() if (message == null) { workerCount -= 1 - return false + return } } } - // We won't reach here. Just make the compiler happy. - throw new IllegalStateException("ShouldNotReachHere") } def post(message: InboxMessage): Unit = { diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala index f68ac8659f018..a6f9267a4953c 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -37,12 +37,14 @@ class InboxSuite extends SparkFunSuite { val inbox = new Inbox(endpointRef, endpoint) val message = ContentMessage(null, "hi", false, null) inbox.post(message) - assert(inbox.process(dispatcher) === false) + inbox.process(dispatcher) + assert(inbox.isEmpty) endpoint.verifySingleReceiveMessage("hi") inbox.stop() - assert(inbox.process(dispatcher) === true) + inbox.process(dispatcher) + assert(inbox.isEmpty) endpoint.verifyStarted() endpoint.verifyStopped() } @@ -55,7 +57,8 @@ class InboxSuite extends SparkFunSuite { val inbox = new Inbox(endpointRef, endpoint) val message = ContentMessage(null, "hi", true, null) inbox.post(message) - assert(inbox.process(dispatcher) === false) + inbox.process(dispatcher) + assert(inbox.isEmpty) endpoint.verifySingleReceiveAndReplyMessage("hi") } @@ -87,9 +90,11 @@ class InboxSuite extends SparkFunSuite { } }.start() } - assert(inbox.process(dispatcher) === false) + inbox.process(dispatcher) + assert(inbox.isEmpty) inbox.stop() - assert(inbox.process(dispatcher) === true) + inbox.process(dispatcher) + assert(inbox.isEmpty) exitLatch.await(30, TimeUnit.SECONDS) From af6df386877fcef22256890c5e2a1c63bc19246a Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 20 Sep 2015 12:06:51 +0800 Subject: [PATCH 24/30] Fix the style to addess vanzin's comments --- .../apache/spark/rpc/netty/Dispatcher.scala | 2 -- .../org/apache/spark/rpc/netty/Inbox.scala | 7 ++-- .../spark/rpc/netty/NettyRpcCallContext.scala | 2 +- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 33 +++++++++++-------- .../org/apache/spark/rpc/RpcEnvSuite.scala | 4 +-- .../spark/rpc/netty/NettyRpcEnvSuite.scala | 2 +- 6 files changed, 28 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index b3362cb53b2b8..d71e6f01dbb29 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -29,8 +29,6 @@ import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc._ import org.apache.spark.util.ThreadUtils -private class RpcEndpointPair(val endpoint: RpcEndpoint, val endpointRef: NettyRpcEndpointRef) - private[netty] class Dispatcher(nettyEnv: NettyRpcEnv) extends Logging { private class EndpointData( diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index 4a30fa1a73300..f185f0949bf20 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -133,16 +133,17 @@ private[netty] class Inbox( } case OnStop => - assert(workerCount == 1) + assert(synchronized { workerCount } == 1) dispatcher.removeRpcEndpointRef(endpoint) endpoint.onStop() assert(isEmpty, "OnStop should be the last message") - synchronized { workerCount -= 1 } - return + case Associated(remoteAddress) => endpoint.onConnected(remoteAddress) + case Disassociated(remoteAddress) => endpoint.onDisconnected(remoteAddress) + case AssociationError(cause, remoteAddress) => endpoint.onNetworkError(cause, remoteAddress) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala index e9ff66ae7f367..75dcc02a0c5a9 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala @@ -51,7 +51,7 @@ private[netty] abstract class NettyRpcCallContext( def finish(): Unit = { if (!needReply) { - send(SendAck(endpointRef)) + send(Ack(endpointRef)) } } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 01aaf23d8c6b4..3d3827900ea17 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -21,6 +21,7 @@ import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer import java.util.Arrays import java.util.concurrent._ +import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.collection.mutable @@ -57,7 +58,7 @@ private[netty] class NettyRpcEnv( Seq(new SaslClientBootstrap(transportConf, "", securityManager, securityManager.isSaslEncryptionEnabled())) } else { - Seq.empty + Nil } transportContext.createClientFactory(bootstraps.asJava) } @@ -75,7 +76,7 @@ private[netty] class NettyRpcEnv( if (securityManager.isAuthenticationEnabled()) { Seq(new SaslServerBootstrap(transportConf, securityManager)) } else { - Seq.empty + Nil } server = transportContext.createServer(port, bootstraps.asJava) dispatcher.registerRpcEndpoint(IDVerifier.NAME, new IDVerifier(this, dispatcher)) @@ -95,13 +96,13 @@ private[netty] class NettyRpcEnv( val endpointRef = new NettyRpcEndpointRef(conf, addr, this) val idVerifierRef = new NettyRpcEndpointRef(conf, NettyRpcAddress(addr.host, addr.port, IDVerifier.NAME), this) - idVerifierRef.ask[Boolean](ID(endpointRef.name)).flatMap(find => + idVerifierRef.ask[Boolean](ID(endpointRef.name)).flatMap { find => if (find) { Future.successful(endpointRef) } else { Future.failed(new RpcEndpointNotFoundException(uri)) } - )(ThreadUtils.sameThread) + }(ThreadUtils.sameThread) } override def stop(endpointRef: RpcEndpointRef): Unit = { @@ -116,7 +117,7 @@ private[netty] class NettyRpcEnv( dispatcher.postMessage(message, promise) promise.future.onComplete { case Success(response) => - val ack = response.asInstanceOf[SendAck] + val ack = response.asInstanceOf[Ack] logDebug(s"Receive ack from ${ack.sender}") case Failure(e) => logError(s"Exception when sending $message", e) @@ -135,7 +136,7 @@ private[netty] class NettyRpcEnv( } override def onSuccess(response: Array[Byte]): Unit = { - val ack = deserialize[SendAck](response) + val ack = deserialize[Ack](response) logDebug(s"Receive ack from ${ack.sender}") } }) @@ -144,6 +145,7 @@ private[netty] class NettyRpcEnv( } catch { case e: RejectedExecutionException => { // `send` after shutting clientConnectionExecutor down, ignore it + logWarning(s"Cannot send ${message} because RpcEnv is stopped") } } } @@ -158,7 +160,7 @@ private[netty] class NettyRpcEnv( p.future.onComplete { case Success(response) => val reply = response.asInstanceOf[AskResponse] - if (reply.reply != null && reply.reply.isInstanceOf[RpcFailure]) { + if (reply.reply.isInstanceOf[RpcFailure]) { if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { logWarning(s"Ignore failure: ${reply.reply}") } @@ -310,7 +312,6 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf) override def name: String = _address.name - override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { val promise = Promise[Any]() val timeoutCancelable = nettyEnv.timeoutScheduler.schedule(new Runnable { @@ -319,12 +320,12 @@ private[netty] class NettyRpcEndpointRef(@transient conf: SparkConf) } }, timeout.duration.toNanos, TimeUnit.NANOSECONDS) val f = nettyEnv.ask(RequestMessage(nettyEnv.address, this, message, true)) - f.onComplete(v => { + f.onComplete { v => timeoutCancelable.cancel(true) if (!promise.tryComplete(v)) { logWarning(s"Ignore message $v") } - })(ThreadUtils.sameThread) + }(ThreadUtils.sameThread) promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) } @@ -366,7 +367,7 @@ private[netty] case class AskResponse(sender: NettyRpcEndpointRef, reply: Any) * A message to send back to the receiver side. It's necessary because [[TransportClient]] only * clean the resources when it receives a reply. */ -private[netty] case class SendAck(sender: NettyRpcEndpointRef) extends ResponseMessage +private[netty] case class Ack(sender: NettyRpcEndpointRef) extends ResponseMessage /** * A response that indicates some failure happens in the receiver side. @@ -383,9 +384,15 @@ private[netty] class NettyRpcHandler( private type ClientAddress = RpcAddress private type RemoteEnvAddress = RpcAddress - // Store all client addresses and their NettyRpcEnv addresses. Protected by "this". + // Store all client addresses and their NettyRpcEnv addresses. + @GuardedBy("this") private val remoteAddresses = new mutable.HashMap[ClientAddress, RemoteEnvAddress]() - // Store the connections from other NettyRpcEnv addresses. Protected by "this". + + // Store the connections from other NettyRpcEnv addresses. We need to keep track of the connection + // count because `TransportClientFactory.createClient` will create multiple connections + // (at most `spark.shuffle.io.numConnectionsPerPeer` connections) and randomly select a connection + // to send the message. See `TransportClientFactory.createClient` for more details. + @GuardedBy("this") private val remoteConnectionCount = new mutable.HashMap[RemoteEnvAddress, Int]() override def receive( diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 4b298af8d5d30..ce8afb7546d88 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -555,7 +555,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { assert(anotherEnv.address.port != env.address.port) } - test("send with ssl") { + test("send with authentication") { val conf = new SparkConf conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") @@ -585,7 +585,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } - test("ask with ssl") { + test("ask with authentication") { val conf = new SparkConf conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala index 76050106f756c..be19668e17c04 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -27,7 +27,7 @@ class NettyRpcEnvSuite extends RpcEnvSuite { new NettyRpcEnvFactory().create(config) } - test("nonexist-endpoint") { + test("non-existent endpoint") { val uri = env.uriOf("test", env.address, "nonexist-endpoint") val e = intercept[RpcEndpointNotFoundException] { env.setupEndpointRef("test", env.address, "nonexist-endpoint") From 3262dd1b047db28d0db0bf98bdc1fb599913f435 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sun, 20 Sep 2015 12:19:50 +0800 Subject: [PATCH 25/30] Reuse JavaSerializerInstance --- .../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 3d3827900ea17..9b0135838b78d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -37,13 +37,17 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.rpc._ -import org.apache.spark.serializer.{JavaSerializer, Serializer} +import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.{ThreadUtils, Utils} private[netty] class NettyRpcEnv( - val conf: SparkConf, serializer: Serializer, host: String, securityManager: SecurityManager) + val conf: SparkConf, serializer: JavaSerializer, host: String, securityManager: SecurityManager) extends RpcEnv(conf) with Logging { + // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support + // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance + private val javaSerializerInstance = serializer.newInstance() + private val transportConf = SparkTransportConf.fromSparkConf(conf, conf.getInt("spark.rpc.io.threads", 0)) @@ -198,14 +202,14 @@ private[netty] class NettyRpcEnv( } private[netty] def serialize(content: Any): Array[Byte] = { - val buffer = serializer.newInstance().serialize(content) + val buffer = javaSerializerInstance.serialize(content) Arrays.copyOfRange( buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit) } private[netty] def deserialize[T: ClassTag](bytes: Array[Byte]): T = { deserialize { () => - serializer.newInstance().deserialize[T](ByteBuffer.wrap(bytes)) + javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes)) } } From 90770681778d961874a7c6e6ac7fa2a98e223bb0 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 22 Sep 2015 09:23:17 +0800 Subject: [PATCH 26/30] Use JavaSerializerInstance in the constructor insead of JavaSerializer --- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 9b0135838b78d..04ed3040cbde0 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -37,16 +37,14 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ import org.apache.spark.rpc._ -import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.serializer.{JavaSerializer, JavaSerializerInstance} import org.apache.spark.util.{ThreadUtils, Utils} private[netty] class NettyRpcEnv( - val conf: SparkConf, serializer: JavaSerializer, host: String, securityManager: SecurityManager) - extends RpcEnv(conf) with Logging { - - // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support - // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance - private val javaSerializerInstance = serializer.newInstance() + val conf: SparkConf, + javaSerializerInstance: JavaSerializerInstance, + host: String, + securityManager: SecurityManager) extends RpcEnv(conf) with Logging { private val transportConf = SparkTransportConf.fromSparkConf(conf, conf.getInt("spark.rpc.io.threads", 0)) @@ -272,8 +270,12 @@ private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { def create(config: RpcEnvConfig): RpcEnv = { val sparkConf = config.conf - val serializer = new JavaSerializer(sparkConf) - val nettyEnv = new NettyRpcEnv(sparkConf, serializer, config.host, config.securityManager) + // Use JavaSerializerInstance in multiple threads is safe. However, if we plan to support + // KryoSerializer in future, we have to use ThreadLocal to store SerializerInstance + val javaSerializerInstance = + new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance] + val nettyEnv = + new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager) val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort => nettyEnv.start(actualPort) (nettyEnv, actualPort) From 610d1550c66ec15827f04e202367f58fc8f79e9d Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 22 Sep 2015 09:52:45 +0800 Subject: [PATCH 27/30] Use clientConnectionExecutor to run the blocking 'createClient' for 'ask' --- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 49 +++++++++++++------ 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 04ed3040cbde0..8efd67579942a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -67,6 +67,9 @@ private[netty] class NettyRpcEnv( val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") + // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool + // to implement non-blocking send/ask. + // TODO: a non-blocking TransportClientFactory.createClient in future private val clientConnectionExecutor = ThreadUtils.newDaemonCachedThreadPool( "netty-rpc-connection", conf.getInt("spark.rpc.connect.threads", 256)) @@ -175,26 +178,40 @@ private[netty] class NettyRpcEnv( } }(ThreadUtils.sameThread) } else { - val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port) - client.sendRpc(serialize(message), new RpcResponseCallback { + try { + // `createClient` will block if it cannot find a known connection, so we should run it in + // clientConnectionExecutor + clientConnectionExecutor.execute(new Runnable { + override def run(): Unit = { + val client = clientFactory.createClient(remoteAddr.host, remoteAddr.port) + client.sendRpc(serialize(message), new RpcResponseCallback { - override def onFailure(e: Throwable): Unit = { - if (!promise.tryFailure(e)) { - logWarning("Ignore Exception", e) - } - } + override def onFailure(e: Throwable): Unit = { + if (!promise.tryFailure(e)) { + logWarning("Ignore Exception", e) + } + } - override def onSuccess(response: Array[Byte]): Unit = { - val reply = deserialize[AskResponse](response) - if (reply.reply != null && reply.reply.isInstanceOf[RpcFailure]) { - if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { - logWarning(s"Ignore failure: ${reply.reply}") - } - } else if (!promise.trySuccess(reply.reply)) { - logWarning(s"Ignore message: ${reply}") + override def onSuccess(response: Array[Byte]): Unit = { + val reply = deserialize[AskResponse](response) + if (reply.reply != null && reply.reply.isInstanceOf[RpcFailure]) { + if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { + logWarning(s"Ignore failure: ${reply.reply}") + } + } else if (!promise.trySuccess(reply.reply)) { + logWarning(s"Ignore message: ${reply}") + } + } + }) + } + }) + } catch { + case e: RejectedExecutionException => { + if (!promise.tryFailure(e)) { + logWarning(s"Ignore failure", e) } } - }) + } } promise.future } From b998ec19587fd1786da578ff1f837f3db990a1c4 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 22 Sep 2015 09:56:14 +0800 Subject: [PATCH 28/30] ssl -> authentication --- .../org/apache/spark/rpc/RpcEnvSuite.scala | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index ce8afb7546d88..e836946a59431 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -560,19 +560,20 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val localEnv = createRpcEnv(conf, "ssl-local", 13345) - val remoteEnv = createRpcEnv(conf, "ssl-remote", 14345) + val localEnv = createRpcEnv(conf, "authentication-local", 13345) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345) try { @volatile var message: String = null - localEnv.setupEndpoint("send-ssl", new RpcEndpoint { + localEnv.setupEndpoint("send-authentication", new RpcEndpoint { override val rpcEnv = localEnv override def receive: PartialFunction[Any, Unit] = { case msg: String => message = msg } }) - val rpcEndpointRef = remoteEnv.setupEndpointRef("ssl-local", localEnv.address, "send-ssl") + val rpcEndpointRef = + remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "send-authentication") rpcEndpointRef.send("hello") eventually(timeout(5 seconds), interval(10 millis)) { assert("hello" === message) @@ -590,11 +591,11 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val localEnv = createRpcEnv(conf, "ssl-local", 13345) - val remoteEnv = createRpcEnv(conf, "ssl-remote", 14345) + val localEnv = createRpcEnv(conf, "authentication-local", 13345) + val remoteEnv = createRpcEnv(conf, "authentication-remote", 14345) try { - localEnv.setupEndpoint("ask-ssl", new RpcEndpoint { + localEnv.setupEndpoint("ask-authentication", new RpcEndpoint { override val rpcEnv = localEnv override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { @@ -603,7 +604,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } }) - val rpcEndpointRef = remoteEnv.setupEndpointRef("ssl-local", localEnv.address, "ask-ssl") + val rpcEndpointRef = + remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "ask-authentication") val reply = rpcEndpointRef.askWithRetry[String]("hello") assert("hello" === reply) } finally { From 1c1ec99f6211a51da7f54e4c6912acf2c4b6c544 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 22 Sep 2015 09:59:51 +0800 Subject: [PATCH 29/30] Add more comments to RpcEndpoint.onStop and fix onDrop --- core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala | 3 ++- core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala | 4 ++-- .../test/scala/org/apache/spark/rpc/netty/InboxSuite.scala | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index b9af77ecc26e2..f1ddc6d2cd438 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -116,7 +116,8 @@ private[spark] trait RpcEndpoint { } /** - * Invoked when [[RpcEndpoint]] is stopping. + * Invoked when [[RpcEndpoint]] is stopping. `self` will be `null` in this method and you cannot + * use it to send or ask messages. */ def onStop(): Unit = { // By default, do nothing. diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index f185f0949bf20..4803548365aba 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -178,7 +178,7 @@ private[netty] class Inbox( } } if (dropped) { - onDrop() + onDrop(message) } } @@ -197,7 +197,7 @@ private[netty] class Inbox( } // Visible for testing. - protected def onDrop(message: Any): Unit = { + protected def onDrop(message: InboxMessage): Unit = { logWarning(s"Drop ${message} because $endpointRef is stopped") } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala index a6f9267a4953c..ff83ab9b32cb9 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -72,7 +72,7 @@ class InboxSuite extends SparkFunSuite { val numDroppedMessages = new AtomicInteger(0) val inbox = new Inbox(endpointRef, endpoint) { - override def onDrop(message: Any): Unit = { + override def onDrop(message: InboxMessage): Unit = { numDroppedMessages.incrementAndGet() } } From 90de09594195b9b23a4994b062eaaef10a6a61e5 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 23 Sep 2015 07:55:48 +0800 Subject: [PATCH 30/30] Remove null check --- .../src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 8efd67579942a..5522b40782d9e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -194,7 +194,7 @@ private[netty] class NettyRpcEnv( override def onSuccess(response: Array[Byte]): Unit = { val reply = deserialize[AskResponse](response) - if (reply.reply != null && reply.reply.isInstanceOf[RpcFailure]) { + if (reply.reply.isInstanceOf[RpcFailure]) { if (!promise.tryFailure(reply.reply.asInstanceOf[RpcFailure].e)) { logWarning(s"Ignore failure: ${reply.reply}") }