Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,22 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int)
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage

private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf)
extends Actor with Logging {
val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)

def receive = {
case GetMapOutputStatuses(shuffleId: Int) =>
val hostPort = sender.path.address.hostPort
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
sender ! tracker.getSerializedMapOutputStatuses(shuffleId)
val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId)
val serializedSize = mapOutputStatuses.size
if (serializedSize > maxAkkaFrameSize) {
throw new SparkException(
"spark.akka.frameSize of %d bytes exceeded! ".format(maxAkkaFrameSize) +
"Map output statuses were %d bytes".format(serializedSize))
}
sender ! mapOutputStatuses

case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!")
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ object SparkEnv extends Logging {
}
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]))
new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))

val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
Expand Down
12 changes: 5 additions & 7 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler._
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util.Utils
import org.apache.spark.util.{AkkaUtils, Utils}

/**
* Spark executor used with Mesos, YARN, and the standalone scheduler.
Expand Down Expand Up @@ -118,11 +118,9 @@ private[spark] class Executor(
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
Thread.currentThread.setContextClassLoader(replClassLoader)

// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
private val akkaFrameSize = {
env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size")
}
// Akka's message frame size in bytes. If task result is bigger than this, we use the block
// manager to send the result back.
private val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)

// Start worker thread pool
val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
Expand Down Expand Up @@ -239,7 +237,7 @@ private[spark] class Executor(
val serializedDirectResult = ser.serialize(directResult)
logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit)
val serializedResult = {
if (serializedDirectResult.limit >= akkaFrameSize - 1024) {
if (serializedDirectResult.limit >= maxAkkaFrameSize - 1024) {
logInfo("Storing result for " + taskId + " in local BlockManager")
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
Expand Down
9 changes: 7 additions & 2 deletions core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ private[spark] object AkkaUtils extends Logging {

val akkaTimeout = conf.getInt("spark.akka.timeout", 100)

val akkaFrameSize = conf.getInt("spark.akka.frameSize", 10)
val akkaFrameSize = maxFrameSizeBytes(conf)
val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false)
val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off"
if (!akkaLogLifecycleEvents) {
Expand Down Expand Up @@ -92,7 +92,7 @@ private[spark] object AkkaUtils extends Logging {
|akka.remote.netty.tcp.port = $port
|akka.remote.netty.tcp.tcp-nodelay = on
|akka.remote.netty.tcp.connection-timeout = $akkaTimeout s
|akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}MiB
|akka.remote.netty.tcp.maximum-frame-size = $akkaFrameSize b
|akka.remote.netty.tcp.execution-pool-size = $akkaThreads
|akka.actor.default-dispatcher.throughput = $akkaBatchSize
|akka.log-config-on-start = $logAkkaConfig
Expand Down Expand Up @@ -121,4 +121,9 @@ private[spark] object AkkaUtils extends Logging {
def lookupTimeout(conf: SparkConf): FiniteDuration = {
Duration.create(conf.get("spark.akka.lookupTimeout", "30").toLong, "seconds")
}

/** Returns the default max frame size for Akka messages in bytes. */
def maxFrameSizeBytes(conf: SparkConf): Int = {
conf.getInt("spark.akka.frameSize", 10) * 1024 * 1024
}
}
10 changes: 5 additions & 5 deletions core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {

val masterTracker = new MapOutputTrackerMaster(conf)
masterTracker.trackerActor = actorSystem.actorOf(
Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")

val badconf = new SparkConf
badconf.set("spark.authenticate", "true")
badconf.set("spark.authenticate.secret", "bad")
val securityManagerBad = new SecurityManager(badconf);
val securityManagerBad = new SecurityManager(badconf)

assert(securityManagerBad.isAuthenticationEnabled() === true)

Expand Down Expand Up @@ -84,7 +84,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {

val masterTracker = new MapOutputTrackerMaster(conf)
masterTracker.trackerActor = actorSystem.actorOf(
Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")

val badconf = new SparkConf
badconf.set("spark.authenticate", "false")
Expand Down Expand Up @@ -136,7 +136,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {

val masterTracker = new MapOutputTrackerMaster(conf)
masterTracker.trackerActor = actorSystem.actorOf(
Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")

val goodconf = new SparkConf
goodconf.set("spark.authenticate", "true")
Expand Down Expand Up @@ -189,7 +189,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext {

val masterTracker = new MapOutputTrackerMaster(conf)
masterTracker.trackerActor = actorSystem.actorOf(
Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")

val badconf = new SparkConf
badconf.set("spark.authenticate", "false")
Expand Down
44 changes: 38 additions & 6 deletions core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark
import scala.concurrent.Await

import akka.actor._
import akka.testkit.TestActorRef
import org.scalatest.FunSuite

import org.apache.spark.scheduler.MapStatus
Expand Down Expand Up @@ -51,14 +52,16 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master start and stop") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster(conf)
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.trackerActor =
actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
tracker.stop()
}

test("master register and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster(conf)
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.trackerActor =
actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
Expand All @@ -77,7 +80,8 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master register and unregister and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster(conf)
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.trackerActor =
actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
Expand All @@ -100,11 +104,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
val hostname = "localhost"
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf,
securityManager = new SecurityManager(conf))
System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext

// Will be cleared by LocalSparkContext
System.setProperty("spark.driver.port", boundPort.toString)

val masterTracker = new MapOutputTrackerMaster(conf)
masterTracker.trackerActor = actorSystem.actorOf(
Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker")

val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf,
securityManager = new SecurityManager(conf))
Expand All @@ -126,7 +132,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
masterTracker.incrementEpoch()
slaveTracker.updateEpoch(masterTracker.getEpoch)
assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))
Seq((BlockManagerId("a", "hostA", 1000, 0), size1000)))

masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
masterTracker.incrementEpoch()
Expand All @@ -136,4 +142,30 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
// failure should be cached
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
}

test("remote fetch exceeding akka frame size") {
val newConf = new SparkConf
newConf.set("spark.akka.frameSize", "1")
newConf.set("spark.akka.askTimeout", "1") // Fail fast
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this conf needed anymore? is it relevant with the local actor ref?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we need it to lower the frame size so we can actually do this within a "short" amount of time.


val masterTracker = new MapOutputTrackerMaster(conf)
val actorSystem = ActorSystem("test")
val actorRef = TestActorRef[MapOutputTrackerMasterActor](
new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem)
val masterActor = actorRef.underlyingActor

// Frame size should be ~123B, and no exception should be thrown
masterTracker.registerShuffle(10, 1)
masterTracker.registerMapOutput(10, 0, new MapStatus(
BlockManagerId("88", "mph", 1000, 0), Array.fill[Byte](10)(0)))
masterActor.receive(GetMapOutputStatuses(10))

// Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception
masterTracker.registerShuffle(20, 100)
(0 until 100).foreach { i =>
masterTracker.registerMapOutput(20, i, new MapStatus(
BlockManagerId("999", "mps", 1000, 0), Array.fill[Byte](4000000)(0)))
}
intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) }
}
}