Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ private[spark] class MapOutputTrackerMasterEndpoint(

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GetMapOutputStatuses(shuffleId: Int) =>
val hostPort = context.sender.address.hostPort
val hostPort = context.senderAddress.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
val serializedSize = mapOutputStatuses.size
Expand Down
20 changes: 15 additions & 5 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ package org.apache.spark
import java.io.File
import java.net.Socket

import akka.actor.ActorSystem

import scala.collection.mutable
import scala.util.Properties

import akka.actor.ActorSystem
import com.google.common.collect.MapMaker

import org.apache.spark.annotation.DeveloperApi
Expand All @@ -41,7 +40,7 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator}
import org.apache.spark.util.{RpcUtils, Utils}
import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils}

/**
* :: DeveloperApi ::
Expand All @@ -57,6 +56,7 @@ import org.apache.spark.util.{RpcUtils, Utils}
class SparkEnv (
val executorId: String,
private[spark] val rpcEnv: RpcEnv,
_actorSystem: ActorSystem, // TODO Remove actorSystem
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheManager: CacheManager,
Expand All @@ -76,7 +76,7 @@ class SparkEnv (

// TODO Remove actorSystem
@deprecated("Actor system is no longer supported as of 1.4.0", "1.4.0")
val actorSystem: ActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
val actorSystem: ActorSystem = _actorSystem

private[spark] var isStopped = false
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
Expand All @@ -100,6 +100,9 @@ class SparkEnv (
blockManager.master.stop()
metricsSystem.stop()
outputCommitCoordinator.stop()
if (!rpcEnv.isInstanceOf[AkkaRpcEnv]) {
actorSystem.shutdown()
}
rpcEnv.shutdown()

// Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
Expand Down Expand Up @@ -249,7 +252,13 @@ object SparkEnv extends Logging {
// Create the ActorSystem for Akka and get the port it binds to.
val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager)
val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
val actorSystem: ActorSystem =
if (rpcEnv.isInstanceOf[AkkaRpcEnv]) {
rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem
} else {
// Create a ActorSystem for legacy codes
AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)._1
}

// Figure out which port Akka actually bound to in case the original port is 0 or occupied.
if (isDriver) {
Expand Down Expand Up @@ -395,6 +404,7 @@ object SparkEnv extends Logging {
val envInstance = new SparkEnv(
executorId,
rpcEnv,
actorSystem,
serializer,
closureSerializer,
cacheManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ private[deploy] class Worker(
registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate(
new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
self.send(ReregisterWithMaster)
Option(self).foreach(_.send(ReregisterWithMaster))
}
},
INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@ import org.apache.spark.rpc._
* Actor which connects to a worker process and terminates the JVM if the connection is severed.
* Provides fate sharing between a worker and its associated child processes.
*/
private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String)
private[spark] class WorkerWatcher(
override val rpcEnv: RpcEnv, workerUrl: String, isTesting: Boolean = false)
extends RpcEndpoint with Logging {

override def onStart() {
logInfo(s"Connecting to worker $workerUrl")
if (!isTesting) {
rpcEnv.asyncSetupEndpointRefByURI(workerUrl)
}
logInfo(s"Connecting to worker $workerUrl")
if (!isTesting) {
rpcEnv.asyncSetupEndpointRefByURI(workerUrl)
}

// Used to avoid shutting down JVM during tests
Expand All @@ -40,8 +39,6 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin
// true rather than calling `System.exit`. The user can check `isShutDown` to know if
// `exitNonZero` is called.
private[deploy] var isShutDown = false
private[deploy] def setTesting(testing: Boolean) = isTesting = testing
private var isTesting = false

// Lets filter events only from the worker's rpc system
private val expectedAddress = RpcAddress.fromURIString(workerUrl)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ private[spark] trait RpcCallContext {
/**
* The sender of this message.
*/
def sender: RpcEndpointRef
def senderAddress: RpcAddress
}
51 changes: 26 additions & 25 deletions core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,6 @@ private[spark] trait RpcEnvFactory {
def create(config: RpcEnvConfig): RpcEnv
}

/**
* A trait that requires RpcEnv thread-safely sending messages to it.
*
* Thread-safety means processing of one message happens before processing of the next message by
* the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a
* [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the
* [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent.
*
* However, there is no guarantee that the same thread will be executing the same
* [[ThreadSafeRpcEndpoint]] for different messages.
*/
private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint


/**
* An end point for the RPC that defines what functions to trigger given a message.
*
Expand Down Expand Up @@ -101,38 +87,39 @@ private[spark] trait RpcEndpoint {
}

/**
* Invoked before [[RpcEndpoint]] starts to handle any message.
* Invoked when `remoteAddress` is connected to the current node.
*/
def onStart(): Unit = {
def onConnected(remoteAddress: RpcAddress): Unit = {
// By default, do nothing.
}

/**
* Invoked when [[RpcEndpoint]] is stopping.
* Invoked when `remoteAddress` is lost.
*/
def onStop(): Unit = {
def onDisconnected(remoteAddress: RpcAddress): Unit = {
// By default, do nothing.
}

/**
* Invoked when `remoteAddress` is connected to the current node.
* Invoked when some network error happens in the connection between the current node and
* `remoteAddress`.
*/
def onConnected(remoteAddress: RpcAddress): Unit = {
def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
// By default, do nothing.
}

/**
* Invoked when `remoteAddress` is lost.
* Invoked before [[RpcEndpoint]] starts to handle any message.
*/
def onDisconnected(remoteAddress: RpcAddress): Unit = {
def onStart(): Unit = {
// By default, do nothing.
}

/**
* Invoked when some network error happens in the connection between the current node and
* `remoteAddress`.
* Invoked when [[RpcEndpoint]] is stopping. `self` will be `null` in this method and you cannot
* use it to send or ask messages.
*/
def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = {
def onStop(): Unit = {
// By default, do nothing.
}

Expand All @@ -146,3 +133,17 @@ private[spark] trait RpcEndpoint {
}
}
}

/**
* A trait that requires RpcEnv thread-safely sending messages to it.
*
* Thread-safety means processing of one message happens before processing of the next message by
* the same [[ThreadSafeRpcEndpoint]]. In the other words, changes to internal fields of a
* [[ThreadSafeRpcEndpoint]] are visible when processing the next message, and fields in the
* [[ThreadSafeRpcEndpoint]] need not be volatile or equivalent.
*
* However, there is no guarantee that the same thread will be executing the same
* [[ThreadSafeRpcEndpoint]] for different messages.
*/
private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.rpc

import org.apache.spark.SparkException

private[rpc] class RpcEndpointNotFoundException(uri: String)
extends SparkException(s"Cannot find endpoint: $uri")
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ private[spark] object RpcEnv {

private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = {
// Add more RpcEnv implementations here
val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory")
val rpcEnvName = conf.get("spark.rpc", "akka")
val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory",
"netty" -> "org.apache.spark.rpc.netty.NettyRpcEnvFactory")
val rpcEnvName = conf.get("spark.rpc", "netty")
val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName)
Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ private[spark] class AkkaRpcEnv private[akka] (
_sender ! AkkaMessage(response, false)
}

// Some RpcEndpoints need to know the sender's address
override val sender: RpcEndpointRef =
new AkkaRpcEndpointRef(defaultAddress, _sender, conf)
// Use "lazy" because most of RpcEndpoints don't need "senderAddress"
override lazy val senderAddress: RpcAddress =
new AkkaRpcEndpointRef(defaultAddress, _sender, conf).address
})
} else {
endpoint.receive
Expand Down
Loading