Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 29 additions & 32 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@ import java.io._
import java.util.concurrent.ConcurrentHashMap
import java.util.zip.{GZIPInputStream, GZIPOutputStream}

import scala.collection.mutable.{HashSet, HashMap, Map}
import scala.concurrent.Await
import scala.collection.mutable.{HashSet, Map}
import scala.collection.JavaConversions._
import scala.reflect.ClassTag

import akka.actor._
import akka.pattern.ask

import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.MetadataFetchFailedException
import org.apache.spark.storage.BlockManagerId
Expand All @@ -38,34 +36,35 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int)
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage

/** Actor class for MapOutputTrackerMaster */
private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf)
extends Actor with ActorLogReceive with Logging {
/** RpcEndpoint class for MapOutputTrackerMaster */
private[spark] class MapOutputTrackerMasterEndpoint(
override val rpcEnv: RpcEnv, tracker: MapOutputTrackerMaster, conf: SparkConf)
extends RpcEndpoint with Logging {
val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)

override def receiveWithLogging: PartialFunction[Any, Unit] = {
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GetMapOutputStatuses(shuffleId: Int) =>
val hostPort = sender.path.address.hostPort
val hostPort = context.sender.address.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
val serializedSize = mapOutputStatuses.size
if (serializedSize > maxAkkaFrameSize) {
val msg = s"Map output statuses were $serializedSize bytes which " +
s"exceeds spark.akka.frameSize ($maxAkkaFrameSize bytes)."

/* For SPARK-1244 we'll opt for just logging an error and then throwing an exception.
* Note that on exception the actor will just restart. A bigger refactoring (SPARK-1239)
* will ultimately remove this entire code path. */
/* For SPARK-1244 we'll opt for just logging an error and then sending it to the sender.
* A bigger refactoring (SPARK-1239) will ultimately remove this entire code path. */
val exception = new SparkException(msg)
logError(msg, exception)
throw exception
context.sendFailure(exception)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Akka always serialize exceptions? There are exceptions that cannot be serialized, and we should be careful there.

Can you add a unit test calling sendFailure with an unserializable exception to make sure things still work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also brings another question: what's the semantics in the case of exceptions? Are they ignored, logged and ignored, or propagated back to the caller?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are exceptions that cannot be serialized, and we should be careful there.

Because Throwable implements java.io.Serializable, I assume all exceptions can be serialized. Do you mean that some exception may throw an exception in writeObject(java.io.ObjectOutputStream out) to prevent it from being serialized?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In an ideal world all exceptions should be serializable, but I've seen in the past an exception that was not serializable (because it contained some field that was not serializable).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Such exception will be swallowed by Akka by default. I added ErrorMonitor to handle errors caught by Akka. Although they can not be serialized, at least we can log them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more thing - can you update the RpcEnv doc to specify what happens in the case of uncaught exceptions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

} else {
context.reply(mapOutputStatuses)
}
sender ! mapOutputStatuses

case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!")
sender ! true
context.stop(self)
logInfo("MapOutputTrackerMasterEndpoint stopped!")
context.reply(true)
stop()
}
}

Expand All @@ -75,12 +74,9 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
* (driver and executor) use different HashMap to store its metadata.
*/
private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging {
private val timeout = AkkaUtils.askTimeout(conf)
private val retryAttempts = AkkaUtils.numRetries(conf)
private val retryIntervalMs = AkkaUtils.retryWaitMs(conf)

/** Set to the MapOutputTrackerActor living on the driver. */
var trackerActor: ActorRef = _
/** Set to the MapOutputTrackerMasterEndpoint living on the driver. */
var trackerEndpoint: RpcEndpointRef = _

/**
* This HashMap has different behavior for the driver and the executors.
Expand All @@ -105,22 +101,22 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
private val fetching = new HashSet[Int]

/**
* Send a message to the trackerActor and get its result within a default timeout, or
* Send a message to the trackerEndpoint and get its result within a default timeout, or
* throw a SparkException if this fails.
*/
protected def askTracker(message: Any): Any = {
protected def askTracker[T: ClassTag](message: Any): T = {
try {
AkkaUtils.askWithReply(message, trackerActor, retryAttempts, retryIntervalMs, timeout)
trackerEndpoint.askWithReply[T](message)
} catch {
case e: Exception =>
logError("Error communicating with MapOutputTracker", e)
throw new SparkException("Error communicating with MapOutputTracker", e)
}
}

/** Send a one-way message to the trackerActor, to which we expect it to reply with true. */
/** Send a one-way message to the trackerEndpoint, to which we expect it to reply with true. */
protected def sendTracker(message: Any) {
val response = askTracker(message)
val response = askTracker[Boolean](message)
if (response != true) {
throw new SparkException(
"Error reply received from MapOutputTracker. Expecting true, got " + response.toString)
Expand Down Expand Up @@ -157,11 +153,10 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging

if (fetchedStatuses == null) {
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
// This try-finally prevents hangs due to timeouts:
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]]
val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
Expand Down Expand Up @@ -328,7 +323,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
override def stop() {
sendTracker(StopMapOutputTracker)
mapStatuses.clear()
trackerActor = null
trackerEndpoint = null
metadataCleaner.cancel()
cachedSerializedStatuses.clear()
}
Expand All @@ -350,6 +345,8 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr

private[spark] object MapOutputTracker extends Logging {

val ENDPOINT_NAME = "MapOutputTracker"

// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
Expand Down
18 changes: 4 additions & 14 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.util.Properties

import akka.actor._
import com.google.common.collect.MapMaker

import org.apache.spark.annotation.DeveloperApi
Expand All @@ -41,7 +40,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils}
import org.apache.spark.util.{RpcUtils, Utils}

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -286,15 +285,6 @@ object SparkEnv extends Logging {
val closureSerializer = instantiateClassFromConf[Serializer](
"spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")

def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
if (isDriver) {
logInfo("Registering " + name)
actorSystem.actorOf(Props(newActor), name = name)
} else {
AkkaUtils.makeDriverRef(name, conf, actorSystem)
}
}

def registerOrLookupEndpoint(
name: String, endpointCreator: => RpcEndpoint):
RpcEndpointRef = {
Expand All @@ -314,9 +304,9 @@ object SparkEnv extends Logging {

// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
mapOutputTracker.trackerEndpoint = registerOrLookupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(
rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))

// Let the user specify short names for shuffle managers
val shortShuffleMgrNames = Map(
Expand Down
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ import org.apache.spark.util.{AkkaUtils, Utils}
/**
* An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to
* receives messages. Then [[RpcEnv]] will process messages sent from [[RpcEndpointRef]] or remote
* nodes, and deliver them to corresponding [[RpcEndpoint]]s.
* nodes, and deliver them to corresponding [[RpcEndpoint]]s. For uncaught exceptions caught by
* [[RpcEnv]], [[RpcEnv]] will use [[RpcCallContext.sendFailure]] to send exceptions back to the
* sender, or logging them if no such sender or `NotSerializableException`.
*
* [[RpcEnv]] also provides some methods to retrieve [[RpcEndpointRef]]s given name or uri.
*/
Expand Down
19 changes: 17 additions & 2 deletions core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@

package org.apache.spark.rpc.akka

import java.net.URI
import java.util.concurrent.ConcurrentHashMap

import scala.concurrent.{Await, Future}
import scala.concurrent.Future
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.reflect.ClassTag
import scala.util.control.NonFatal

import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address}
import akka.event.Logging.Error
import akka.pattern.{ask => akkaAsk}
import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent}
import org.apache.spark.{SparkException, Logging, SparkConf}
Expand Down Expand Up @@ -242,10 +242,25 @@ private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {
def create(config: RpcEnvConfig): RpcEnv = {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
config.name, config.host, config.port, config.conf, config.securityManager)
actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor")
new AkkaRpcEnv(actorSystem, config.conf, boundPort)
}
}

/**
* Monitor errors reported by Akka and log them.
*/
private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging {

override def preStart(): Unit = {
context.system.eventStream.subscribe(self, classOf[Error])
}

override def receiveWithLogging: Actor.Receive = {
case Error(cause: Throwable, _, _, message: String) => logError(message, cause)
}
}

private[akka] class AkkaRpcEndpointRef(
@transient defaultAddress: RpcAddress,
@transient _actorRef: => ActorRef,
Expand Down
Loading