diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 5968973132942..7acb7f8825108 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -35,13 +35,22 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage -private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster) +private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf) extends Actor with Logging { + val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) + def receive = { case GetMapOutputStatuses(shuffleId: Int) => val hostPort = sender.path.address.hostPort logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) - sender ! tracker.getSerializedMapOutputStatuses(shuffleId) + val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) + val serializedSize = mapOutputStatuses.size + if (serializedSize > maxAkkaFrameSize) { + throw new SparkException( + "spark.akka.frameSize of %d bytes exceeded! ".format(maxAkkaFrameSize) + + "Map output statuses were %d bytes".format(serializedSize)) + } + sender ! mapOutputStatuses case StopMapOutputTracker => logInfo("MapOutputTrackerActor stopped!") diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 5e43b5198422c..26c362d0fea7b 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -186,7 +186,7 @@ object SparkEnv extends Logging { } mapOutputTracker.trackerActor = registerOrLookup( "MapOutputTracker", - new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster])) + new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) val shuffleFetcher = instantiateClass[ShuffleFetcher]( "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher") diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index e69f6f72d3275..07514c63377af 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -29,7 +29,7 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler._ import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} -import org.apache.spark.util.Utils +import org.apache.spark.util.{AkkaUtils, Utils} /** * Spark executor used with Mesos, YARN, and the standalone scheduler. @@ -118,11 +118,9 @@ private[spark] class Executor( private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) Thread.currentThread.setContextClassLoader(replClassLoader) - // Akka's message frame size. If task result is bigger than this, we use the block manager - // to send the result back. - private val akkaFrameSize = { - env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size") - } + // Akka's message frame size in bytes. If task result is bigger than this, we use the block + // manager to send the result back. + private val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) // Start worker thread pool val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") @@ -239,7 +237,7 @@ private[spark] class Executor( val serializedDirectResult = ser.serialize(directResult) logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit) val serializedResult = { - if (serializedDirectResult.limit >= akkaFrameSize - 1024) { + if (serializedDirectResult.limit >= maxAkkaFrameSize - 1024) { logInfo("Storing result for " + taskId + " in local BlockManager") val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index a6c9a9aaba8eb..d9fb4eb6f49b2 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -49,7 +49,7 @@ private[spark] object AkkaUtils extends Logging { val akkaTimeout = conf.getInt("spark.akka.timeout", 100) - val akkaFrameSize = conf.getInt("spark.akka.frameSize", 10) + val akkaFrameSize = maxFrameSizeBytes(conf) val akkaLogLifecycleEvents = conf.getBoolean("spark.akka.logLifecycleEvents", false) val lifecycleEvents = if (akkaLogLifecycleEvents) "on" else "off" if (!akkaLogLifecycleEvents) { @@ -92,7 +92,7 @@ private[spark] object AkkaUtils extends Logging { |akka.remote.netty.tcp.port = $port |akka.remote.netty.tcp.tcp-nodelay = on |akka.remote.netty.tcp.connection-timeout = $akkaTimeout s - |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}MiB + |akka.remote.netty.tcp.maximum-frame-size = $akkaFrameSize b |akka.remote.netty.tcp.execution-pool-size = $akkaThreads |akka.actor.default-dispatcher.throughput = $akkaBatchSize |akka.log-config-on-start = $logAkkaConfig @@ -121,4 +121,9 @@ private[spark] object AkkaUtils extends Logging { def lookupTimeout(conf: SparkConf): FiniteDuration = { Duration.create(conf.get("spark.akka.lookupTimeout", "30").toLong, "seconds") } + + /** Returns the default max frame size for Akka messages in bytes. */ + def maxFrameSizeBytes(conf: SparkConf): Int = { + conf.getInt("spark.akka.frameSize", 10) * 1024 * 1024 + } } diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala index cd054c1f684ab..d2e303d81c4c8 100644 --- a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala @@ -45,12 +45,12 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") val badconf = new SparkConf badconf.set("spark.authenticate", "true") badconf.set("spark.authenticate.secret", "bad") - val securityManagerBad = new SecurityManager(badconf); + val securityManagerBad = new SecurityManager(badconf) assert(securityManagerBad.isAuthenticationEnabled() === true) @@ -84,7 +84,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") val badconf = new SparkConf badconf.set("spark.authenticate", "false") @@ -136,7 +136,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") val goodconf = new SparkConf goodconf.set("spark.authenticate", "true") @@ -189,7 +189,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") val badconf = new SparkConf badconf.set("spark.authenticate", "false") diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 8efa072a97911..dc70576d82419 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark import scala.concurrent.Await import akka.actor._ +import akka.testkit.TestActorRef import org.scalatest.FunSuite import org.apache.spark.scheduler.MapStatus @@ -51,14 +52,16 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("master start and stop") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) + tracker.trackerActor = + actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) tracker.stop() } test("master register and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) + tracker.trackerActor = + actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) tracker.registerShuffle(10, 2) val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L) @@ -77,7 +80,8 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("master register and unregister and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) + tracker.trackerActor = + actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) tracker.registerShuffle(10, 2) val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L) @@ -100,11 +104,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val hostname = "localhost" val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, securityManager = new SecurityManager(conf)) - System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + + // Will be cleared by LocalSparkContext + System.setProperty("spark.driver.port", boundPort.toString) val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, securityManager = new SecurityManager(conf)) @@ -126,7 +132,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0)) masterTracker.incrementEpoch() @@ -136,4 +142,30 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { // failure should be cached intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } } + + test("remote fetch exceeding akka frame size") { + val newConf = new SparkConf + newConf.set("spark.akka.frameSize", "1") + newConf.set("spark.akka.askTimeout", "1") // Fail fast + + val masterTracker = new MapOutputTrackerMaster(conf) + val actorSystem = ActorSystem("test") + val actorRef = TestActorRef[MapOutputTrackerMasterActor]( + new MapOutputTrackerMasterActor(masterTracker, newConf))(actorSystem) + val masterActor = actorRef.underlyingActor + + // Frame size should be ~123B, and no exception should be thrown + masterTracker.registerShuffle(10, 1) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("88", "mph", 1000, 0), Array.fill[Byte](10)(0))) + masterActor.receive(GetMapOutputStatuses(10)) + + // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception + masterTracker.registerShuffle(20, 100) + (0 until 100).foreach { i => + masterTracker.registerMapOutput(20, i, new MapStatus( + BlockManagerId("999", "mps", 1000, 0), Array.fill[Byte](4000000)(0))) + } + intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) } + } }