From 9ef279ccfc6b8fedf4e3eb24b9cd7b5bc4ce7424 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 28 Aug 2014 14:20:12 -0700 Subject: [PATCH 01/10] Initial refactoring to move ConnectionManager to use the BlockTransferService. --- .../scala/org/apache/spark/SparkEnv.scala | 14 +- ...eiverTest.scala => BlockDataManager.scala} | 29 +- .../spark/network/BlockFetchingListener.scala | 37 ++ .../spark/network/BlockTransferService.scala | 81 +++++ .../spark/network/ConnectionManagerTest.scala | 103 ------ .../apache/spark/network/ManagedBuffer.scala | 95 +++++ .../org/apache/spark/network/SenderTest.scala | 76 ---- .../cm}/BlockMessage.scala | 7 +- .../cm}/BlockMessageArray.scala | 8 +- .../network/{ => cm}/BufferMessage.scala | 6 +- .../network/cm/CMBlockTransferService.scala | 190 ++++++++++ .../spark/network/{ => cm}/Connection.scala | 6 +- .../spark/network/{ => cm}/ConnectionId.scala | 2 +- .../network/{ => cm}/ConnectionManager.scala | 24 +- .../{ => cm}/ConnectionManagerId.scala | 2 +- .../spark/network/{ => cm}/Message.scala | 2 +- .../spark/network/{ => cm}/MessageChunk.scala | 4 +- .../network/{ => cm}/MessageChunkHeader.scala | 5 +- .../network/{ => cm}/SecurityMessage.scala | 8 +- .../spark/serializer/KryoSerializer.scala | 2 +- .../hash/BlockStoreShuffleFetcher.scala | 14 +- .../shuffle/hash/HashShuffleReader.scala | 4 +- .../spark/storage/BlockFetcherIterator.scala | 328 ------------------ .../apache/spark/storage/BlockManager.scala | 116 ++----- .../apache/spark/storage/BlockManagerId.scala | 21 +- .../spark/storage/BlockManagerWorker.scala | 147 -------- .../storage/ShuffleBlockFetcherIterator.scala | 266 ++++++++++++++ .../apache/spark/storage/ThreadingTest.scala | 2 +- .../org/apache/spark/util/JsonProtocol.scala | 6 +- .../org/apache/spark/DistributedSuite.scala | 4 +- .../{ => cm}/ConnectionManagerSuite.scala | 10 +- .../storage/BlockFetcherIteratorSuite.scala | 9 +- .../spark/storage/BlockManagerSuite.scala | 15 +- 33 files changed, 784 insertions(+), 859 deletions(-) rename core/src/main/scala/org/apache/spark/network/{ReceiverTest.scala => BlockDataManager.scala} (56%) create mode 100644 core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala create mode 100644 core/src/main/scala/org/apache/spark/network/BlockTransferService.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala create mode 100644 core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/SenderTest.scala rename core/src/main/scala/org/apache/spark/{storage => network/cm}/BlockMessage.scala (97%) rename core/src/main/scala/org/apache/spark/{storage => network/cm}/BlockMessageArray.scala (98%) rename core/src/main/scala/org/apache/spark/network/{ => cm}/BufferMessage.scala (99%) create mode 100644 core/src/main/scala/org/apache/spark/network/cm/CMBlockTransferService.scala rename core/src/main/scala/org/apache/spark/network/{ => cm}/Connection.scala (99%) rename core/src/main/scala/org/apache/spark/network/{ => cm}/ConnectionId.scala (97%) rename core/src/main/scala/org/apache/spark/network/{ => cm}/ConnectionManager.scala (99%) rename core/src/main/scala/org/apache/spark/network/{ => cm}/ConnectionManagerId.scala (97%) rename core/src/main/scala/org/apache/spark/network/{ => cm}/Message.scala (98%) rename core/src/main/scala/org/apache/spark/network/{ => cm}/MessageChunk.scala (96%) rename core/src/main/scala/org/apache/spark/network/{ => cm}/MessageChunkHeader.scala (96%) rename core/src/main/scala/org/apache/spark/network/{ => cm}/SecurityMessage.scala (97%) delete mode 100644 core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala delete mode 100644 core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala create mode 100644 core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala rename core/src/test/scala/org/apache/spark/network/{ => cm}/ConnectionManagerSuite.scala (97%) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 72716567ca99..294a58fafc36 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -31,7 +31,8 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.network.ConnectionManager +import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.cm.CMBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} @@ -59,8 +60,8 @@ class SparkEnv ( val mapOutputTracker: MapOutputTracker, val shuffleManager: ShuffleManager, val broadcastManager: BroadcastManager, + val blockTransferService: BlockTransferService, val blockManager: BlockManager, - val connectionManager: ConnectionManager, val securityManager: SecurityManager, val httpFileServer: HttpFileServer, val sparkFilesDir: String, @@ -79,6 +80,7 @@ class SparkEnv ( Option(httpFileServer).foreach(_.stop()) mapOutputTracker.stop() shuffleManager.stop() + blockTransferService.stop() broadcastManager.stop() blockManager.stop() blockManager.master.stop() @@ -223,14 +225,14 @@ object SparkEnv extends Logging { val shuffleMemoryManager = new ShuffleMemoryManager(conf) + val blockTransferService = new CMBlockTransferService(conf, securityManager) + val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf) val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, securityManager, mapOutputTracker, shuffleManager) - - val connectionManager = blockManager.connectionManager + serializer, conf, mapOutputTracker, shuffleManager, blockTransferService) val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) @@ -278,8 +280,8 @@ object SparkEnv extends Logging { mapOutputTracker, shuffleManager, broadcastManager, + blockTransferService, blockManager, - connectionManager, securityManager, httpFileServer, sparkFilesDir, diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala similarity index 56% rename from core/src/main/scala/org/apache/spark/network/ReceiverTest.scala rename to core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 53a6038a9b59..e0e91724271c 100644 --- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -17,21 +17,20 @@ package org.apache.spark.network -import java.nio.ByteBuffer -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.storage.StorageLevel -private[spark] object ReceiverTest { - def main(args: Array[String]) { - val conf = new SparkConf - val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) - println("Started connection manager with id = " + manager.id) - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - /* println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis) */ - val buffer = ByteBuffer.wrap("response".getBytes("utf-8")) - Some(Message.createBufferMessage(buffer, msg.id)) - }) - Thread.currentThread.join() - } -} +trait BlockDataManager { + + /** + * Interface to get local block data. + * + * @return Some(buffer) if the block exists locally, and None if it doesn't. + */ + def getBlockData(blockId: String): Option[ManagedBuffer] + /** + * Put the block locally, using the given storage level. + */ + def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit +} diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala new file mode 100644 index 000000000000..c1dfcf1c12d3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import java.util.EventListener + + +/** + * Listener callback interface for [[BlockTransferService.fetchBlocks]]. + */ +trait BlockFetchingListener extends EventListener { + + /** + * Called once per successfully fetched block. + */ + def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit + + /** + * Called upon failures. + */ + def onBlockFetchFailure(exception: Exception): Unit +} diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala new file mode 100644 index 000000000000..0aa4a85531fa --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import org.apache.spark.storage.StorageLevel + + +abstract class BlockTransferService { + + /** + * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch + * local blocks or put local blocks. + */ + def init(blockDataManager: BlockDataManager) + + /** + * Tear down the transfer service. + */ + def stop(): Unit + + /** + * Port number the service is listening on, available only after [[init]] is invoked. + */ + def port: Int + + /** + * Host name the service is listening on, available only after [[init]] is invoked. + */ + def hostName: String + + /** + * Fetch a sequence of blocks from a remote node, available only after [[init]] is invoked. + * + * This takes a sequence so the implementation can batch requests. + */ + def fetchBlocks( + hostName: String, + port: Int, + blockIds: Seq[String], + listener: BlockFetchingListener): Unit + + /** + * Fetch a single block from a remote node, available only after [[init]] is invoked. + * + * This is functionally equivalent to + * {{{ + * fetchBlocks(hostName, port, Seq(blockId)).iterator().next()._2 + * }}} + */ + def fetchBlock(hostName: String, port: Int, blockId: String): ManagedBuffer = { + //fetchBlocks(hostName, port, Seq(blockId)).iterator().next()._2 + null + } + + /** + * Upload a single block to a remote node, available only after [[init]] is invoked. + * + * This call blocks until the upload completes, or throws an exception upon failures. + */ + def uploadBlock( + hostname: String, + port: Int, + blockId: String, + blockData: ManagedBuffer, + level: StorageLevel): Unit +} diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala deleted file mode 100644 index 4894ecd41f6e..000000000000 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.nio.ByteBuffer - -import scala.concurrent.Await -import scala.concurrent.duration._ -import scala.io.Source - -import org.apache.spark._ - -private[spark] object ConnectionManagerTest extends Logging{ - def main(args: Array[String]) { - // - the master URL - a list slaves to run connectionTest on - // [num of tasks] - the number of parallel tasks to be initiated default is number of slave - // hosts [size of msg in MB (integer)] - the size of messages to be sent in each task, - // default is 10 [count] - how many times to run, default is 3 [await time in seconds] : - // await time (in seconds), default is 600 - if (args.length < 2) { - println("Usage: ConnectionManagerTest [num of tasks] " + - "[size of msg in MB (integer)] [count] [await time in seconds)] ") - System.exit(1) - } - - if (args(0).startsWith("local")) { - println("This runs only on a mesos cluster") - } - - val sc = new SparkContext(args(0), "ConnectionManagerTest") - val slavesFile = Source.fromFile(args(1)) - val slaves = slavesFile.mkString.split("\n") - slavesFile.close() - - /* println("Slaves") */ - /* slaves.foreach(println) */ - val tasknum = if (args.length > 2) args(2).toInt else slaves.length - val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024 - val count = if (args.length > 4) args(4).toInt else 3 - val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second - println("Running " + count + " rounds of test: " + "parallel tasks = " + tasknum + ", " + - "msg size = " + size/1024/1024 + " MB, awaitTime = " + awaitTime) - val slaveConnManagerIds = sc.parallelize(0 until tasknum, tasknum).map( - i => SparkEnv.get.connectionManager.id).collect() - println("\nSlave ConnectionManagerIds") - slaveConnManagerIds.foreach(println) - println - - (0 until count).foreach(i => { - val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => { - val connManager = SparkEnv.get.connectionManager - val thisConnManagerId = connManager.id - connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - logInfo("Received [" + msg + "] from [" + id + "]") - None - }) - - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val startTime = System.currentTimeMillis - val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map{ slaveConnManagerId => - { - val bufferMessage = Message.createBufferMessage(buffer.duplicate) - logInfo("Sending [" + bufferMessage + "] to [" + slaveConnManagerId + "]") - connManager.sendMessageReliably(slaveConnManagerId, bufferMessage) - } - } - val results = futures.map(f => Await.result(f, awaitTime)) - val finishTime = System.currentTimeMillis - Thread.sleep(5000) - - val mb = size * results.size / 1024.0 / 1024.0 - val ms = finishTime - startTime - val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * - 1000.0) + " MB/s" - logInfo(resultStr) - resultStr - }).collect() - - println("---------------------") - println("Run " + i) - resultStrs.foreach(println) - println("---------------------") - }) - } -} - diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala new file mode 100644 index 000000000000..f51724593a9b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import java.io.{File, FileInputStream, InputStream} +import java.nio.ByteBuffer + +import io.netty.buffer.{ByteBufInputStream, ByteBuf, Unpooled} +import io.netty.channel.DefaultFileRegion + +import org.apache.spark.storage.FileSegment +import org.apache.spark.util.ByteBufferInputStream + + +/** + * Provides a buffer abstraction that allows pooling and reuse. + */ +abstract class ManagedBuffer { + // Note that all the methods are defined with parenthesis because their implementations can + // have side effects (io operations). + + def byteBuffer(): ByteBuffer = throw new UnsupportedOperationException + + def fileSegment(): Option[FileSegment] = None + + def inputStream(): InputStream = throw new UnsupportedOperationException + + def release(): Unit = throw new UnsupportedOperationException + + def size: Long + + private[network] def toNetty(): AnyRef +} + + +/** + * A ManagedBuffer backed by a segment in a file. + */ +final class FileSegmentManagedBuffer(file: File, offset: Long, length: Long) + extends ManagedBuffer { + + override def size: Long = length + + override private[network] def toNetty(): AnyRef = { + val fileChannel = new FileInputStream(file).getChannel + new DefaultFileRegion(fileChannel, offset, length) + } +} + + +/** + * A ManagedBuffer backed by [[java.nio.ByteBuffer]]. + */ +final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { + + override def byteBuffer() = buf + + override def inputStream() = new ByteBufferInputStream(buf) + + override def size: Long = buf.remaining() + + override private[network] def toNetty(): AnyRef = Unpooled.wrappedBuffer(buf) +} + + +/** + * A ManagedBuffer backed by a Netty [[ByteBuf]]. + */ +final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer { + + override def byteBuffer() = buf.nioBuffer() + + override def inputStream() = new ByteBufInputStream(buf) + + override def release(): Unit = buf.release() + + override def size: Long = buf.readableBytes() + + override private[network] def toNetty(): AnyRef = buf +} diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala deleted file mode 100644 index ea2ad104ecae..000000000000 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.nio.ByteBuffer -import org.apache.spark.{SecurityManager, SparkConf} - -import scala.concurrent.Await -import scala.concurrent.duration.Duration -import scala.util.Try - -private[spark] object SenderTest { - def main(args: Array[String]) { - - if (args.length < 2) { - println("Usage: SenderTest ") - System.exit(1) - } - - val targetHost = args(0) - val targetPort = args(1).toInt - val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort) - val conf = new SparkConf - val manager = new ConnectionManager(0, conf, new SecurityManager(conf)) - println("Started connection manager with id = " + manager.id) - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { - println("Received [" + msg + "] from [" + id + "]") - None - }) - - val size = 100 * 1024 * 1024 - val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) - buffer.flip - - val targetServer = args(0) - - val count = 100 - (0 until count).foreach(i => { - val dataMessage = Message.createBufferMessage(buffer.duplicate) - val startTime = System.currentTimeMillis - /* println("Started timer at " + startTime) */ - val promise = manager.sendMessageReliably(targetConnectionManagerId, dataMessage) - val responseStr: String = Try(Await.result(promise, Duration.Inf)) - .map { response => - val buffer = response.asInstanceOf[BufferMessage].buffers(0) - new String(buffer.array, "utf-8") - }.getOrElse("none") - - val finishTime = System.currentTimeMillis - val mb = size / 1024.0 / 1024.0 - val ms = finishTime - startTime - // val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms at " + (mb / ms - // * 1000.0) + " MB/s" - val resultStr = "Sent " + mb + " MB " + targetServer + " in " + ms + " ms (" + - (mb / ms * 1000.0).toInt + "MB/s) | Response = " + responseStr - println(resultStr) - }) - } -} - diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/cm/BlockMessage.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/storage/BlockMessage.scala rename to core/src/main/scala/org/apache/spark/network/cm/BlockMessage.scala index a2bfce7b4a0f..107d0131efd7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/BlockMessage.scala @@ -15,14 +15,13 @@ * limitations under the License. */ -package org.apache.spark.storage +package org.apache.spark.network.cm import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.StringBuilder +import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} -import org.apache.spark.network._ +import scala.collection.mutable.{ArrayBuffer, StringBuilder} private[spark] case class GetBlock(id: BlockId) private[spark] case class GotBlock(id: BlockId, data: ByteBuffer) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/cm/BlockMessageArray.scala similarity index 98% rename from core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala rename to core/src/main/scala/org/apache/spark/network/cm/BlockMessageArray.scala index 973d85c0a9b3..b0f770261c19 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/BlockMessageArray.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.storage +package org.apache.spark.network.cm import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer - import org.apache.spark._ -import org.apache.spark.network._ +import org.apache.spark.storage.{StorageLevel, TestBlockId} + +import scala.collection.mutable.ArrayBuffer private[spark] class BlockMessageArray(var blockMessages: Seq[BlockMessage]) diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/cm/BufferMessage.scala similarity index 99% rename from core/src/main/scala/org/apache/spark/network/BufferMessage.scala rename to core/src/main/scala/org/apache/spark/network/cm/BufferMessage.scala index af35f1fc3e45..5f7761838ab3 100644 --- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/BufferMessage.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.storage.BlockManager +import scala.collection.mutable.ArrayBuffer + private[spark] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) extends Message(Message.BUFFER_MESSAGE, id_) { diff --git a/core/src/main/scala/org/apache/spark/network/cm/CMBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/cm/CMBlockTransferService.scala new file mode 100644 index 000000000000..3b61c0ee852c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/cm/CMBlockTransferService.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.cm + +import java.nio.ByteBuffer + +import scala.concurrent.Await +import scala.concurrent.duration.Duration + +import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf} +import org.apache.spark.network._ +import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.util.Utils + + +/** + * A [[BlockTransferService]] implementation based on our [[ConnectionManager]]. + */ +final class CMBlockTransferService(conf: SparkConf, securityManager: SecurityManager) + extends BlockTransferService with Logging { + + private var cm: ConnectionManager = _ + + private var blockDataManager: BlockDataManager = _ + + /** + * Port number the service is listening on, available only after [[init]] is invoked. + */ + override def port: Int = cm.id.port + + /** + * Host name the service is listening on, available only after [[init]] is invoked. + */ + override def hostName: String = cm.id.host + + /** + * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch + * local blocks or put local blocks. + */ + override def init(blockDataManager: BlockDataManager): Unit = { + this.blockDataManager = blockDataManager + cm = new ConnectionManager( + conf.getInt("spark.blockManager.port", 0), + conf, + securityManager, + "Connection manager for block manager") + cm.onReceiveMessage(onBlockMessageReceive) + } + + /** + * Tear down the transfer service. + */ + override def stop(): Unit = { + if (cm != null) { + cm.stop() + } + } + + override def fetchBlocks( + hostName: String, + port: Int, + blockIds: Seq[String], + listener: BlockFetchingListener): Unit = { + + val cmId = new ConnectionManagerId(hostName, port) + val blockMessageArray = new BlockMessageArray(blockIds.map { blockId => + BlockMessage.fromGetBlock(GetBlock(BlockId(blockId))) + }) + + val future = cm.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) + + // If succeeds in getting blocks from a remote connection manager, put the block in results. + future.onSuccess { case message => + val bufferMessage = message.asInstanceOf[BufferMessage] + val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) + + for (blockMessage <- blockMessageArray) { + if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { + listener.onBlockFetchFailure( + new SparkException(s"Unexpected message ${blockMessage.getType} received from $cmId")) + } else { + val blockId = blockMessage.getId + val networkSize = blockMessage.getData.limit() + listener.onBlockFetchSuccess( + blockId.toString, new NioByteBufferManagedBuffer(blockMessage.getData)) + } + } + }(cm.futureExecContext) + } + + /** + * Upload a single block to a remote node, available only after [[init]] is invoked. + * + * This call blocks until the upload completes, or throws an exception upon failures. + */ + override def uploadBlock( + hostname: String, + port: Int, + blockId: String, + blockData: ManagedBuffer, + level: StorageLevel) { + val msg = PutBlock(BlockId(blockId), blockData.byteBuffer(), level) + val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg)) + val remoteCmId = new ConnectionManagerId(hostName, port) + + // TODO: Not wait infinitely. + Await.result(cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage), + Duration.Inf) + } + + private def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { + logDebug("Handling message " + msg) + msg match { + case bufferMessage: BufferMessage => + try { + logDebug("Handling as a buffer message " + bufferMessage) + val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) + logDebug("Parsed as a block message array") + val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) + Some(new BlockMessageArray(responseMessages).toBufferMessage) + } catch { + case e: Exception => { + logError("Exception handling buffer message", e) + val errorMessage = Message.createBufferMessage(msg.id) + errorMessage.hasError = true + Some(errorMessage) + } + } + + case otherMessage: Any => + logError("Unknown type message received: " + otherMessage) + val errorMessage = Message.createBufferMessage(msg.id) + errorMessage.hasError = true + Some(errorMessage) + } + } + + private def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = { + blockMessage.getType match { + case BlockMessage.TYPE_PUT_BLOCK => + val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) + logDebug("Received [" + msg + "]") + putBlock(msg.id.toString, msg.data, msg.level) + None + + case BlockMessage.TYPE_GET_BLOCK => + val msg = new GetBlock(blockMessage.getId) + logDebug("Received [" + msg + "]") + val buffer = getBlock(msg.id.toString) + if (buffer == null) { + return None + } + Some(BlockMessage.fromGotBlock(GotBlock(msg.id, buffer))) + + case _ => None + } + } + + private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + val startTimeMs = System.currentTimeMillis() + logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes) + blockDataManager.putBlockData(blockId, new NioByteBufferManagedBuffer(bytes), level) + logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + + " with data size: " + bytes.limit) + } + + private def getBlock(blockId: String): ByteBuffer = { + val startTimeMs = System.currentTimeMillis() + logDebug("GetBlock " + blockId + " started from " + startTimeMs) + val buffer = blockDataManager.getBlockData(blockId).orNull + logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + + " and got buffer " + buffer) + buffer.byteBuffer() + } +} diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/cm/Connection.scala similarity index 99% rename from core/src/main/scala/org/apache/spark/network/Connection.scala rename to core/src/main/scala/org/apache/spark/network/cm/Connection.scala index 5285ec82c1b6..080c3e7dd42a 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/Connection.scala @@ -15,16 +15,16 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm import java.net._ import java.nio._ import java.nio.channels._ -import scala.collection.mutable.{ArrayBuffer, HashMap, Queue} - import org.apache.spark._ +import scala.collection.mutable.{ArrayBuffer, HashMap, Queue} + private[spark] abstract class Connection(val channel: SocketChannel, val selector: Selector, val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/cm/ConnectionId.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/network/ConnectionId.scala rename to core/src/main/scala/org/apache/spark/network/cm/ConnectionId.scala index d579c165a191..7b358a4d2598 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/ConnectionId.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm private[spark] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/cm/ConnectionManager.scala similarity index 99% rename from core/src/main/scala/org/apache/spark/network/ConnectionManager.scala rename to core/src/main/scala/org/apache/spark/network/cm/ConnectionManager.scala index 578d80626300..f9e35fb793fa 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/ConnectionManager.scala @@ -15,31 +15,25 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm import java.io.IOException +import java.net._ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ -import java.net._ -import java.util.{Timer, TimerTask} import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} +import java.util.{Timer, TimerTask} -import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.collection.mutable.HashSet -import scala.collection.mutable.SynchronizedMap -import scala.collection.mutable.SynchronizedQueue +import org.apache.spark._ +import org.apache.spark.util.{SystemClock, Utils} -import scala.concurrent.{Await, ExecutionContext, Future, Promise} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue} import scala.concurrent.duration._ +import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.language.postfixOps -import org.apache.spark._ -import org.apache.spark.util.{SystemClock, Utils} - private[spark] class ConnectionManager( port: Int, conf: SparkConf, @@ -904,7 +898,7 @@ private[spark] class ConnectionManager( private[spark] object ConnectionManager { - import ExecutionContext.Implicits.global + import scala.concurrent.ExecutionContext.Implicits.global def main(args: Array[String]) { val conf = new SparkConf diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/cm/ConnectionManagerId.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala rename to core/src/main/scala/org/apache/spark/network/cm/ConnectionManagerId.scala index 57f7586883af..b6b2cb0db429 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerId.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/ConnectionManagerId.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm import java.net.InetSocketAddress diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/cm/Message.scala similarity index 98% rename from core/src/main/scala/org/apache/spark/network/Message.scala rename to core/src/main/scala/org/apache/spark/network/cm/Message.scala index 04ea50f62918..5b5bcc2d966e 100644 --- a/core/src/main/scala/org/apache/spark/network/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/Message.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm import java.net.InetSocketAddress import java.nio.ByteBuffer diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunk.scala b/core/src/main/scala/org/apache/spark/network/cm/MessageChunk.scala similarity index 96% rename from core/src/main/scala/org/apache/spark/network/MessageChunk.scala rename to core/src/main/scala/org/apache/spark/network/cm/MessageChunk.scala index d0f986a12bfe..95b46cd11f6b 100644 --- a/core/src/main/scala/org/apache/spark/network/MessageChunk.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/MessageChunk.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -private[network] +private[cm] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { val size = if (buffer == null) 0 else buffer.remaining diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/cm/MessageChunkHeader.scala similarity index 96% rename from core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala rename to core/src/main/scala/org/apache/spark/network/cm/MessageChunkHeader.scala index f3ecca5f992e..7087c7ad6c50 100644 --- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/MessageChunkHeader.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm -import java.net.InetAddress -import java.net.InetSocketAddress +import java.net.{InetAddress, InetSocketAddress} import java.nio.ByteBuffer private[spark] class MessageChunkHeader( diff --git a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/cm/SecurityMessage.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/network/SecurityMessage.scala rename to core/src/main/scala/org/apache/spark/network/cm/SecurityMessage.scala index 9af9e2e8e9e5..f59df06fb3d9 100644 --- a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/SecurityMessage.scala @@ -15,15 +15,13 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm import java.nio.ByteBuffer -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.StringBuilder - import org.apache.spark._ -import org.apache.spark.network._ + +import scala.collection.mutable.{ArrayBuffer, StringBuilder} /** * SecurityMessage is class that contains the connectionId and sasl token diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 87ef9bb0b43c..dd0421a5c15a 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -27,9 +27,9 @@ import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.spark._ import org.apache.spark.broadcast.HttpBroadcast +import org.apache.spark.network.cm.{PutBlock, GotBlock, GetBlock} import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage._ -import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock} import org.apache.spark.util.BoundedPriorityQueue import org.apache.spark.util.collection.CompactBuffer diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 12b475658e29..6cf9305977a3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -21,10 +21,9 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import org.apache.spark._ -import org.apache.spark.executor.ShuffleReadMetrics import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} import org.apache.spark.util.CompletionIterator private[hash] object BlockStoreShuffleFetcher extends Logging { @@ -32,8 +31,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { shuffleId: Int, reduceId: Int, context: TaskContext, - serializer: Serializer, - shuffleMetrics: ShuffleReadMetrics) + serializer: Serializer) : Iterator[T] = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) @@ -74,7 +72,13 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { } } - val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer, shuffleMetrics) + val blockFetcherItr = new ShuffleBlockFetcherIterator( + context, + SparkEnv.get.blockTransferService, + blockManager, + blocksByAddress, + serializer, + SparkEnv.get.conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024) val itr = blockFetcherItr.flatMap(unpackBlock) val completionIter = CompletionIterator[T, Iterator[T]](itr, { diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 7bed97a63f0f..88a5f1e5ddf5 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -36,10 +36,8 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() val ser = Serializer.getSerializer(dep.serializer) - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser, - readMetrics) + val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala deleted file mode 100644 index ca60ec78b62e..000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ /dev/null @@ -1,328 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.util.concurrent.LinkedBlockingQueue -import org.apache.spark.network.netty.client.{BlockClientListener, LazyInitIterator, ReferenceCountedBuffer} - -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashSet -import scala.collection.mutable.Queue -import scala.util.{Failure, Success} - -import org.apache.spark.{Logging, SparkException} -import org.apache.spark.executor.ShuffleReadMetrics -import org.apache.spark.network.BufferMessage -import org.apache.spark.network.ConnectionManagerId -import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils - -/** - * A block fetcher iterator interface. There are two implementations: - * - * BasicBlockFetcherIterator: uses a custom-built NIO communication layer. - * NettyBlockFetcherIterator: uses Netty (OIO) as the communication layer. - * - * Eventually we would like the two to converge and use a single NIO-based communication layer, - * but extensive tests show that under some circumstances (e.g. large shuffles with lots of cores), - * NIO would perform poorly and thus the need for the Netty OIO one. - */ - -private[storage] -trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging { - def initialize() -} - - -private[storage] -object BlockFetcherIterator { - - /** - * A request to fetch blocks from a remote BlockManager. - * @param address remote BlockManager to fetch from. - * @param blocks Sequence of tuple, where the first element is the block id, - * and the second element is the estimated size, used to calculate bytesInFlight. - */ - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { - val size = blocks.map(_._2).sum - } - - /** - * Result of a fetch from a remote block. A failure is represented as size == -1. - * @param blockId block id - * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. - * @param deserialize closure to return the result in the form of an Iterator. - */ - class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { - def failed: Boolean = size == -1 - } - - // TODO: Refactor this whole thing to make code more reusable. - class BasicBlockFetcherIterator( - private val blockManager: BlockManager, - val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, - readMetrics: ShuffleReadMetrics) - extends BlockFetcherIterator { - - import blockManager._ - - if (blocksByAddress == null) { - throw new IllegalArgumentException("BlocksByAddress is null") - } - - // Total number blocks fetched (local + remote). Also number of FetchResults expected - protected var _numBlocksToFetch = 0 - - protected var startTime = System.currentTimeMillis - - // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks - protected val localBlocksToFetch = new ArrayBuffer[BlockId]() - - // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks - protected val remoteBlocksToFetch = new HashSet[BlockId]() - - // A queue to hold our results. - protected val results = new LinkedBlockingQueue[FetchResult] - - // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that - // the number of bytes in flight is limited to maxBytesInFlight - protected val fetchRequests = new Queue[FetchRequest] - - // Current bytes in flight from our requests - protected var bytesInFlight = 0L - - protected def sendRequest(req: FetchRequest) { - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) - val cmId = new ConnectionManagerId(req.address.host, req.address.port) - val blockMessageArray = new BlockMessageArray(req.blocks.map { - case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId)) - }) - bytesInFlight += req.size - val sizeMap = req.blocks.toMap // so we can look up the size of each blockID - val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - future.onComplete { - case Success(message) => { - val bufferMessage = message.asInstanceOf[BufferMessage] - val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - for (blockMessage <- blockMessageArray) { - if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - throw new SparkException( - "Unexpected message " + blockMessage.getType + " received from " + cmId) - } - val blockId = blockMessage.getId - val networkSize = blockMessage.getData.limit() - results.put(new FetchResult(blockId, sizeMap(blockId), - () => dataDeserialize(blockId, blockMessage.getData, serializer))) - // TODO: NettyBlockFetcherIterator has some race conditions where multiple threads can - // be incrementing bytes read at the same time (SPARK-2625). - readMetrics.remoteBytesRead += networkSize - readMetrics.remoteBlocksFetched += 1 - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) - } - } - case Failure(exception) => { - logError("Could not get block(s) from " + cmId, exception) - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) - } - } - } - } - - protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { - // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them - // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 - // nodes, rather than blocking on reading output from one node. - val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) - - // Split local and remote blocks. Remote blocks are further split into FetchRequests of size - // at most maxBytesInFlight in order to limit the amount of data in flight. - val remoteRequests = new ArrayBuffer[FetchRequest] - var totalBlocks = 0 - for ((address, blockInfos) <- blocksByAddress) { - totalBlocks += blockInfos.size - if (address == blockManagerId) { - // Filter out zero-sized blocks - localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1) - _numBlocksToFetch += localBlocksToFetch.size - } else { - val iterator = blockInfos.iterator - var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(BlockId, Long)] - while (iterator.hasNext) { - val (blockId, size) = iterator.next() - // Skip empty blocks - if (size > 0) { - curBlocks += ((blockId, size)) - remoteBlocksToFetch += blockId - _numBlocksToFetch += 1 - curRequestSize += size - } else if (size < 0) { - throw new BlockException(blockId, "Negative block size " + size) - } - if (curRequestSize >= targetRequestSize) { - // Add this FetchRequest - remoteRequests += new FetchRequest(address, curBlocks) - curBlocks = new ArrayBuffer[(BlockId, Long)] - logDebug(s"Creating fetch request of $curRequestSize at $address") - curRequestSize = 0 - } - } - // Add in the final request - if (!curBlocks.isEmpty) { - remoteRequests += new FetchRequest(address, curBlocks) - } - } - } - logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " + - totalBlocks + " blocks") - remoteRequests - } - - protected def getLocalBlocks() { - // Get the local blocks while remote blocks are being fetched. Note that it's okay to do - // these all at once because they will just memory-map some files, so they won't consume - // any memory that might exceed our maxBytesInFlight - for (id <- localBlocksToFetch) { - try { - // getLocalFromDisk never return None but throws BlockException - val iter = getLocalFromDisk(id, serializer).get - // Pass 0 as size since it's not in flight - readMetrics.localBlocksFetched += 1 - results.put(new FetchResult(id, 0, () => iter)) - logDebug("Got local block " + id) - } catch { - case e: Exception => { - logError(s"Error occurred while fetching local blocks", e) - results.put(new FetchResult(id, -1, null)) - return - } - } - } - } - - override def initialize() { - // Split local and remote blocks. - val remoteRequests = splitLocalRemoteBlocks() - // Add the remote requests into our queue in a random order - fetchRequests ++= Utils.randomize(remoteRequests) - - // Send out initial requests for blocks, up to our maxBytesInFlight - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } - - val numFetches = remoteRequests.size - fetchRequests.size - logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime)) - - // Get Local Blocks - startTime = System.currentTimeMillis - getLocalBlocks() - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") - } - - // Implementing the Iterator methods with an iterator that reads fetched blocks off the queue - // as they arrive. - @volatile protected var resultsGotten = 0 - - override def hasNext: Boolean = resultsGotten < _numBlocksToFetch - - override def next(): (BlockId, Option[Iterator[Any]]) = { - resultsGotten += 1 - val startFetchWait = System.currentTimeMillis() - val result = results.take() - val stopFetchWait = System.currentTimeMillis() - readMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) - if (! result.failed) bytesInFlight -= result.size - while (!fetchRequests.isEmpty && - (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { - sendRequest(fetchRequests.dequeue()) - } - (result.blockId, if (result.failed) None else Some(result.deserialize())) - } - } - // End of BasicBlockFetcherIterator - - class NettyBlockFetcherIterator( - blockManager: BlockManager, - blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, - readMetrics: ShuffleReadMetrics) - extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer, readMetrics) { - - override protected def sendRequest(req: FetchRequest) { - logDebug("Sending request for %d blocks (%s) from %s".format( - req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) - val cmId = new ConnectionManagerId(req.address.host, req.address.port) - - bytesInFlight += req.size - val sizeMap = req.blocks.toMap // so we can look up the size of each blockID - - // This could throw a TimeoutException. In that case we will just retry the task. - val client = blockManager.nettyBlockClientFactory.createClient( - cmId.host, req.address.nettyPort) - val blocks = req.blocks.map(_._1.toString) - - client.fetchBlocks( - blocks, - new BlockClientListener { - override def onFetchFailure(blockId: String, errorMsg: String): Unit = { - logError(s"Could not get block(s) from $cmId with error: $errorMsg") - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) - } - } - - override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = { - // Increment the reference count so the buffer won't be recycled. - // TODO: This could result in memory leaks when the task is stopped due to exception - // before the iterator is exhausted. - data.retain() - val buf = data.byteBuffer() - val blockSize = buf.remaining() - val bid = BlockId(blockId) - - // TODO: remove code duplication between here and BlockManager.dataDeserialization. - results.put(new FetchResult(bid, sizeMap(bid), () => { - def createIterator: Iterator[Any] = { - val stream = blockManager.wrapForCompression(bid, data.inputStream()) - serializer.newInstance().deserializeStream(stream).asIterator - } - new LazyInitIterator(createIterator) { - // Release the buffer when we are done traversing it. - override def close(): Unit = data.release() - } - })) - - readMetrics.synchronized { - readMetrics.remoteBytesRead += blockSize - readMetrics.remoteBlocksFetched += 1 - } - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) - } - } - ) - } - } - // End of NettyBlockFetcherIterator -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 12a92d44f4c3..cc5d505303fc 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -20,6 +20,8 @@ package org.apache.spark.storage import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} +import scala.concurrent.ExecutionContext.Implicits.global + import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{Await, Future} import scala.concurrent.duration._ @@ -32,8 +34,6 @@ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ -import org.apache.spark.network.netty.client.BlockFetchingClientFactory -import org.apache.spark.network.netty.server.BlockServer import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.util._ @@ -60,18 +60,16 @@ private[spark] class BlockManager( defaultSerializer: Serializer, maxMemory: Long, val conf: SparkConf, - securityManager: SecurityManager, mapOutputTracker: MapOutputTracker, - shuffleManager: ShuffleManager) + shuffleManager: ShuffleManager, + blockTransferService: BlockTransferService) extends BlockDataProvider with Logging { + //blockTransferService.init(this) + private val port = conf.getInt("spark.blockManager.port", 0) val shuffleBlockManager = new ShuffleBlockManager(this, shuffleManager) val diskBlockManager = new DiskBlockManager(shuffleBlockManager, conf) - val connectionManager = - new ConnectionManager(port, conf, securityManager, "Connection manager for block manager") - - implicit val futureExecContext = connectionManager.futureExecContext private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] @@ -90,31 +88,8 @@ private[spark] class BlockManager( new TachyonStore(this, tachyonBlockManager) } - private val useNetty = conf.getBoolean("spark.shuffle.use.netty", false) - - // If we use Netty for shuffle, start a new Netty-based shuffle sender service. - private[storage] val nettyBlockClientFactory: BlockFetchingClientFactory = { - if (useNetty) new BlockFetchingClientFactory(conf) else null - } - - private val nettyBlockServer: BlockServer = { - if (useNetty) { - val server = new BlockServer(conf, this) - logInfo(s"Created NettyBlockServer binding to port: ${server.port}") - server - } else { - null - } - } - - private val nettyPort: Int = if (useNetty) nettyBlockServer.port else 0 - val blockManagerId = BlockManagerId( - executorId, connectionManager.id.host, connectionManager.id.port, nettyPort) - - // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory - // for receiving shuffle outputs) - val maxBytesInFlight = conf.getLong("spark.reducer.maxMbInFlight", 48) * 1024 * 1024 + executorId, blockTransferService.hostName, blockTransferService.port) // Whether to compress broadcast variables that are stored private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) @@ -157,11 +132,11 @@ private[spark] class BlockManager( master: BlockManagerMaster, serializer: Serializer, conf: SparkConf, - securityManager: SecurityManager, mapOutputTracker: MapOutputTracker, - shuffleManager: ShuffleManager) = { + shuffleManager: ShuffleManager, + blockTransferService: BlockTransferService) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), - conf, securityManager, mapOutputTracker, shuffleManager) + conf, mapOutputTracker, shuffleManager, blockTransferService) } /** @@ -170,7 +145,6 @@ private[spark] class BlockManager( */ private def initialize(): Unit = { master.registerBlockManager(blockManagerId, maxMemory, slaveActor) - BlockManagerWorker.startBlockManagerWorker(this) } /** @@ -527,8 +501,8 @@ private[spark] class BlockManager( val locations = Random.shuffle(master.getLocations(blockId)) for (loc <- locations) { logDebug(s"Getting remote block $blockId from $loc") - val data = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(loc.host, loc.port)) + val data = blockTransferService.fetchBlock(loc.host, loc.port, blockId.toString).byteBuffer() + if (data != null) { if (asBlockResult) { return Some(new BlockResult( @@ -562,28 +536,6 @@ private[spark] class BlockManager( None } - /** - * Get multiple blocks from local and remote block manager using their BlockManagerIds. Returns - * an Iterator of (block ID, value) pairs so that clients may handle blocks in a pipelined - * fashion as they're received. Expects a size in bytes to be provided for each block fetched, - * so that we can control the maxMegabytesInFlight for the fetch. - */ - def getMultiple( - blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, - readMetrics: ShuffleReadMetrics): BlockFetcherIterator = { - val iter = - if (conf.getBoolean("spark.shuffle.use.netty", false)) { - new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer, - readMetrics) - } else { - new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer, - readMetrics) - } - iter.initialize() - iter - } - def putIterator( blockId: BlockId, values: Iterator[Any], @@ -836,12 +788,15 @@ private[spark] class BlockManager( data.rewind() logDebug(s"Try to replicate $blockId once; The size of the data is ${data.limit()} Bytes. " + s"To node: $peer") - val putBlock = PutBlock(blockId, data, tLevel) - val cmId = new ConnectionManagerId(peer.host, peer.port) - val syncPutBlockSuccess = BlockManagerWorker.syncPutBlock(putBlock, cmId) - if (!syncPutBlockSuccess) { - logError(s"Failed to call syncPutBlock to $peer") + + try { + blockTransferService.uploadBlock( + peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel) + } catch { + case e: Exception => + logError(s"Failed to replicate block to $peer", e) } + logDebug("Replicating BlockId %s once used %fs; The size of the data is %d bytes." .format(blockId, (System.nanoTime - start) / 1e6, data.limit())) } @@ -1066,40 +1021,13 @@ private[spark] class BlockManager( bytes: ByteBuffer, serializer: Serializer = defaultSerializer): Iterator[Any] = { bytes.rewind() - - def getIterator: Iterator[Any] = { - val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) - serializer.newInstance().deserializeStream(stream).asIterator - } - - if (blockId.isShuffle) { - /* Reducer may need to read many local shuffle blocks and will wrap them into Iterators - * at the beginning. The wrapping will cost some memory (compression instance - * initialization, etc.). Reducer reads shuffle blocks one by one so we could do the - * wrapping lazily to save memory. */ - class LazyProxyIterator(f: => Iterator[Any]) extends Iterator[Any] { - lazy val proxy = f - override def hasNext: Boolean = proxy.hasNext - override def next(): Any = proxy.next() - } - new LazyProxyIterator(getIterator) - } else { - getIterator - } + val stream = wrapForCompression(blockId, new ByteBufferInputStream(bytes, true)) + serializer.newInstance().deserializeStream(stream).asIterator } def stop(): Unit = { - connectionManager.stop() shuffleBlockManager.stop() diskBlockManager.stop() - - if (nettyBlockClientFactory != null) { - nettyBlockClientFactory.stop() - } - if (nettyBlockServer != null) { - nettyBlockServer.stop() - } - actorSystem.stop(slaveActor) blockInfo.clear() memoryStore.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index b1585bd8199d..f39510160e63 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -36,11 +36,10 @@ import org.apache.spark.util.Utils class BlockManagerId private ( private var executorId_ : String, private var host_ : String, - private var port_ : Int, - private var nettyPort_ : Int - ) extends Externalizable { + private var port_ : Int) + extends Externalizable { - private def this() = this(null, null, 0, 0) // For deserialization only + private def this() = this(null, null, 0) // For deserialization only def executorId: String = executorId_ @@ -60,32 +59,29 @@ class BlockManagerId private ( def port: Int = port_ - def nettyPort: Int = nettyPort_ override def writeExternal(out: ObjectOutput) { out.writeUTF(executorId_) out.writeUTF(host_) out.writeInt(port_) - out.writeInt(nettyPort_) } override def readExternal(in: ObjectInput) { executorId_ = in.readUTF() host_ = in.readUTF() port_ = in.readInt() - nettyPort_ = in.readInt() } @throws(classOf[IOException]) private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this) - override def toString = "BlockManagerId(%s, %s, %d, %d)".format(executorId, host, port, nettyPort) + override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, host, port) - override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port + nettyPort + override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port override def equals(that: Any) = that match { case id: BlockManagerId => - executorId == id.executorId && port == id.port && host == id.host && nettyPort == id.nettyPort + executorId == id.executorId && port == id.port && host == id.host case _ => false } @@ -100,11 +96,10 @@ private[spark] object BlockManagerId { * @param execId ID of the executor. * @param host Host name of the block manager. * @param port Port of the block manager. - * @param nettyPort Optional port for the Netty-based shuffle sender. * @return A new [[org.apache.spark.storage.BlockManagerId]]. */ - def apply(execId: String, host: String, port: Int, nettyPort: Int) = - getCachedBlockManagerId(new BlockManagerId(execId, host, port, nettyPort)) + def apply(execId: String, host: String, port: Int) = + getCachedBlockManagerId(new BlockManagerId(execId, host, port)) def apply(in: ObjectInput) = { val obj = new BlockManagerId() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala deleted file mode 100644 index bf002a42d5dc..000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala +++ /dev/null @@ -1,147 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.nio.ByteBuffer - -import org.apache.spark.Logging -import org.apache.spark.network._ -import org.apache.spark.util.Utils - -import scala.concurrent.Await -import scala.concurrent.duration.Duration -import scala.util.{Try, Failure, Success} - -/** - * A network interface for BlockManager. Each slave should have one - * BlockManagerWorker. - * - * TODO: Use event model. - */ -private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends Logging { - - blockManager.connectionManager.onReceiveMessage(onBlockMessageReceive) - - def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { - logDebug("Handling message " + msg) - msg match { - case bufferMessage: BufferMessage => { - try { - logDebug("Handling as a buffer message " + bufferMessage) - val blockMessages = BlockMessageArray.fromBufferMessage(bufferMessage) - logDebug("Parsed as a block message array") - val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) - Some(new BlockMessageArray(responseMessages).toBufferMessage) - } catch { - case e: Exception => { - logError("Exception handling buffer message", e) - val errorMessage = Message.createBufferMessage(msg.id) - errorMessage.hasError = true - Some(errorMessage) - } - } - } - case otherMessage: Any => { - logError("Unknown type message received: " + otherMessage) - val errorMessage = Message.createBufferMessage(msg.id) - errorMessage.hasError = true - Some(errorMessage) - } - } - } - - def processBlockMessage(blockMessage: BlockMessage): Option[BlockMessage] = { - blockMessage.getType match { - case BlockMessage.TYPE_PUT_BLOCK => { - val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) - logDebug("Received [" + pB + "]") - putBlock(pB.id, pB.data, pB.level) - None - } - case BlockMessage.TYPE_GET_BLOCK => { - val gB = new GetBlock(blockMessage.getId) - logDebug("Received [" + gB + "]") - val buffer = getBlock(gB.id) - if (buffer == null) { - return None - } - Some(BlockMessage.fromGotBlock(GotBlock(gB.id, buffer))) - } - case _ => None - } - } - - private def putBlock(id: BlockId, bytes: ByteBuffer, level: StorageLevel) { - val startTimeMs = System.currentTimeMillis() - logDebug("PutBlock " + id + " started from " + startTimeMs + " with data: " + bytes) - blockManager.putBytes(id, bytes, level) - logDebug("PutBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) - + " with data size: " + bytes.limit) - } - - private def getBlock(id: BlockId): ByteBuffer = { - val startTimeMs = System.currentTimeMillis() - logDebug("GetBlock " + id + " started from " + startTimeMs) - val buffer = blockManager.getLocalBytes(id) match { - case Some(bytes) => bytes - case None => null - } - logDebug("GetBlock " + id + " used " + Utils.getUsedTimeMs(startTimeMs) - + " and got buffer " + buffer) - buffer - } -} - -private[spark] object BlockManagerWorker extends Logging { - private var blockManagerWorker: BlockManagerWorker = null - - def startBlockManagerWorker(manager: BlockManager) { - blockManagerWorker = new BlockManagerWorker(manager) - } - - def syncPutBlock(msg: PutBlock, toConnManagerId: ConnectionManagerId): Boolean = { - val blockManager = blockManagerWorker.blockManager - val connectionManager = blockManager.connectionManager - val blockMessage = BlockMessage.fromPutBlock(msg) - val blockMessageArray = new BlockMessageArray(blockMessage) - val resultMessage = Try(Await.result(connectionManager.sendMessageReliably( - toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf)) - resultMessage.isSuccess - } - - def syncGetBlock(msg: GetBlock, toConnManagerId: ConnectionManagerId): ByteBuffer = { - val blockManager = blockManagerWorker.blockManager - val connectionManager = blockManager.connectionManager - val blockMessage = BlockMessage.fromGetBlock(msg) - val blockMessageArray = new BlockMessageArray(blockMessage) - val responseMessage = Try(Await.result(connectionManager.sendMessageReliably( - toConnManagerId, blockMessageArray.toBufferMessage), Duration.Inf)) - responseMessage match { - case Success(message) => { - val bufferMessage = message.asInstanceOf[BufferMessage] - logDebug("Response message received " + bufferMessage) - BlockMessageArray.fromBufferMessage(bufferMessage).foreach(blockMessage => { - logDebug("Found " + blockMessage) - return blockMessage.getData - }) - } - case Failure(exception) => logDebug("No response message received") - } - null - } -} diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala new file mode 100644 index 000000000000..d4ed33eb506c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.util.concurrent.LinkedBlockingQueue + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashSet +import scala.collection.mutable.Queue + +import org.apache.spark.{TaskContext, Logging, SparkException} +import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService} +import org.apache.spark.serializer.Serializer +import org.apache.spark.util.Utils + + +/** + * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block + * manager. For remote blocks, it fetches them using the provided BlockTransferService. + * + * This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a + * pipelined fashion as they are received. + * + * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid + * using too much memory. + * + * @param context + * @param blockManager + * @param blocksByAddress + * @param serializer + * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. + */ +private[spark] +final class ShuffleBlockFetcherIterator( + context: TaskContext, + blockTransferService: BlockTransferService, + blockManager: BlockManager, + blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], + serializer: Serializer, + maxBytesInFlight: Long) + extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging { + + import ShuffleBlockFetcherIterator._ + + /** + * Total number of blocks to fetch. This can be smaller than the total number of blocks + * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]]. + * + * This should equal localBlocks.size + remoteBlocks.size. + */ + private[this] var numBlocksToFetch = 0 + + /** + * The number of blocks proccessed by the caller. The iterator is exhausted when + * [[numBlocksProcessed]] == [[numBlocksToFetch]]. + */ + private[this] var numBlocksProcessed = 0 + + private[this] val startTime = System.currentTimeMillis + + /** Local blocks to fetch, excluding zero-sized blocks. */ + private[this] val localBlocks = new ArrayBuffer[BlockId]() + + /** Remote blocks to fetch, excluding zero-sized blocks. */ + private[this] val remoteBlocks = new HashSet[BlockId]() + + /** + * A queue to hold our results. This turns the asynchronous model provided by + * [[BlockTransferService]] into a synchronous model (iterator). + */ + private[this] val results = new LinkedBlockingQueue[FetchResult] + + // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that + // the number of bytes in flight is limited to maxBytesInFlight + private[this] val fetchRequests = new Queue[FetchRequest] + + // Current bytes in flight from our requests + private[this] var bytesInFlight = 0L + + private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + + initialize() + + private[this] def sendRequest(req: FetchRequest) { + logDebug("Sending request for %d blocks (%s) from %s".format( + req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) + bytesInFlight += req.size + + // so we can look up the size of each blockID + val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap + val blockIds = req.blocks.map(_._1.toString) + + blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds, + new BlockFetchingListener { + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), + () => blockManager.dataDeserialize(BlockId(blockId), data.byteBuffer(), serializer) + )) + shuffleMetrics.remoteBytesRead += data.size + shuffleMetrics.remoteBlocksFetched += 1 + logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + } + + override def onBlockFetchFailure(exception: Exception): Unit = { + + } + } + ) + // case Failure(exception) => { + // logError("Could not get block(s) from " + cmId, exception) + // for ((blockId, size) <- req.blocks) { + // results.put(new FetchResult(blockId, -1, null)) + // } + // } + // } + } + + private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { + // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) + logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) + + // Split local and remote blocks. Remote blocks are further split into FetchRequests of size + // at most maxBytesInFlight in order to limit the amount of data in flight. + val remoteRequests = new ArrayBuffer[FetchRequest] + + // Tracks total number of blocks (including zero sized blocks) + var totalBlocks = 0 + for ((address, blockInfos) <- blocksByAddress) { + totalBlocks += blockInfos.size + if (address == blockManager.blockManagerId) { + // Filter out zero-sized blocks + localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) + numBlocksToFetch += localBlocks.size + } else { + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[(BlockId, Long)] + while (iterator.hasNext) { + val (blockId, size) = iterator.next() + // Skip empty blocks + if (size > 0) { + curBlocks += ((blockId, size)) + remoteBlocks += blockId + numBlocksToFetch += 1 + curRequestSize += size + } else if (size < 0) { + throw new BlockException(blockId, "Negative block size " + size) + } + if (curRequestSize >= targetRequestSize) { + // Add this FetchRequest + remoteRequests += new FetchRequest(address, curBlocks) + curBlocks = new ArrayBuffer[(BlockId, Long)] + logDebug(s"Creating fetch request of $curRequestSize at $address") + curRequestSize = 0 + } + } + // Add in the final request + if (curBlocks.nonEmpty) { + remoteRequests += new FetchRequest(address, curBlocks) + } + } + } + logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks") + remoteRequests + } + + private[this] def fetchLocalBlocks() { + // Get the local blocks while remote blocks are being fetched. Note that it's okay to do + // these all at once because they will just memory-map some files, so they won't consume + // any memory that might exceed our maxBytesInFlight + for (id <- localBlocks) { + try { + shuffleMetrics.localBlocksFetched += 1 + results.put(new FetchResult(id, 0, () => blockManager.getLocalFromDisk(id, serializer).get)) + logDebug("Got local block " + id) + } catch { + case e: Exception => + logError(s"Error occurred while fetching local blocks", e) + results.put(new FetchResult(id, -1, null)) + return + } + } + } + + private[this] def initialize(): Unit = { + // Split local and remote blocks. + val remoteRequests = splitLocalRemoteBlocks() + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(remoteRequests) + + // Send out initial requests for blocks, up to our maxBytesInFlight + while (fetchRequests.nonEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + + val numFetches = remoteRequests.size - fetchRequests.size + logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime)) + + // Get Local Blocks + fetchLocalBlocks() + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + } + + override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch + + override def next(): (BlockId, Option[Iterator[Any]]) = { + numBlocksProcessed += 1 + val startFetchWait = System.currentTimeMillis() + val result = results.take() + val stopFetchWait = System.currentTimeMillis() + shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) + if (!result.failed) { + bytesInFlight -= result.size + } + while (!fetchRequests.isEmpty && + (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { + sendRequest(fetchRequests.dequeue()) + } + (result.blockId, if (result.failed) None else Some(result.deserialize())) + } +} + + +private[storage] +object ShuffleBlockFetcherIterator { + + /** + * A request to fetch blocks from a remote BlockManager. + * @param address remote BlockManager to fetch from. + * @param blocks Sequence of tuple, where the first element is the block id, + * and the second element is the estimated size, used to calculate bytesInFlight. + */ + class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { + val size = blocks.map(_._2).sum + } + + /** + * Result of a fetch from a remote block. A failure is represented as size == -1. + * @param blockId block id + * @param size estimated size of the block, used to calculate bytesInFlight. + * Note that this is NOT the exact bytes. + * @param deserialize closure to return the result in the form of an Iterator. + */ + class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { + def failed: Boolean = size == -1 + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index aa83ea90ee9e..8a836bbba274 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -102,7 +102,7 @@ private[spark] object ThreadingTest { conf) val blockManager = new BlockManager( "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf, - new SecurityManager(conf), new MapOutputTrackerMaster(conf), new HashShuffleManager(conf)) + new MapOutputTrackerMaster(conf), new HashShuffleManager(conf), null) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index db7384705fc1..a7543454eca1 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -295,8 +295,7 @@ private[spark] object JsonProtocol { def blockManagerIdToJson(blockManagerId: BlockManagerId): JValue = { ("Executor ID" -> blockManagerId.executorId) ~ ("Host" -> blockManagerId.host) ~ - ("Port" -> blockManagerId.port) ~ - ("Netty Port" -> blockManagerId.nettyPort) + ("Port" -> blockManagerId.port) } def jobResultToJson(jobResult: JobResult): JValue = { @@ -644,8 +643,7 @@ private[spark] object JsonProtocol { val executorId = (json \ "Executor ID").extract[String] val host = (json \ "Host").extract[String] val port = (json \ "Port").extract[Int] - val nettyPort = (json \ "Netty Port").extract[Int] - BlockManagerId(executorId, host, port, nettyPort) + BlockManagerId(executorId, host, port) } def jobResultFromJson(json: JValue): JobResult = { diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 41c294f727b3..5406fcc2ac83 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark +import org.apache.spark.network.cm.{GetBlock, BlockManagerWorker, ConnectionManagerId} import org.scalatest.BeforeAndAfter import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts._ @@ -24,8 +25,7 @@ import org.scalatest.Matchers import org.scalatest.time.{Millis, Span} import org.apache.spark.SparkContext._ -import org.apache.spark.network.ConnectionManagerId -import org.apache.spark.storage.{BlockManagerWorker, GetBlock, RDDBlockId, StorageLevel} +import org.apache.spark.storage.{RDDBlockId, StorageLevel} class NotSerializableClass class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {} diff --git a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/cm/ConnectionManagerSuite.scala similarity index 97% rename from core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala rename to core/src/test/scala/org/apache/spark/network/cm/ConnectionManagerSuite.scala index e2f4d4c57cdb..258492051173 100644 --- a/core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/cm/ConnectionManagerSuite.scala @@ -15,23 +15,17 @@ * limitations under the License. */ -package org.apache.spark.network +package org.apache.spark.network.cm import java.io.IOException import java.nio._ -import java.util.concurrent.TimeoutException import org.apache.spark.{SecurityManager, SparkConf} import org.scalatest.FunSuite -import org.mockito.Mockito._ -import org.mockito.Matchers._ - -import scala.concurrent.TimeoutException -import scala.concurrent.{Await, TimeoutException} import scala.concurrent.duration._ +import scala.concurrent.{Await, TimeoutException} import scala.language.postfixOps -import scala.util.{Failure, Success, Try} /** * Test the ConnectionManager with various security settings. diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala index bcbfe8baf36a..56d5907d4f2c 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala @@ -24,16 +24,17 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.future import scala.concurrent.ExecutionContext.Implicits.global -import org.scalatest.{FunSuite, Matchers} - import org.mockito.Mockito._ import org.mockito.Matchers.{any, eq => meq} import org.mockito.stubbing.Answer import org.mockito.invocation.InvocationOnMock -import org.apache.spark.storage.BlockFetcherIterator._ -import org.apache.spark.network.{ConnectionManager, Message} +import org.scalatest.{FunSuite, Matchers} + +import org.apache.spark.network.cm._ import org.apache.spark.executor.ShuffleReadMetrics +import org.apache.spark.storage.BlockFetcherIterator._ + class BlockFetcherIteratorSuite extends FunSuite with Matchers { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index f32ce6f9fcc7..8c458b99e6bc 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -21,10 +21,15 @@ import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays import java.util.concurrent.TimeUnit +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.language.implicitConversions +import scala.language.postfixOps + import akka.actor._ import akka.pattern.ask import akka.util.Timeout -import org.apache.spark.shuffle.hash.HashShuffleManager import org.mockito.invocation.InvocationOnMock import org.mockito.Matchers.any @@ -38,17 +43,13 @@ import org.scalatest.Matchers import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.executor.DataReadMethod -import org.apache.spark.network.{Message, ConnectionManagerId} +import org.apache.spark.network.cm._ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await -import scala.concurrent.duration._ -import scala.language.implicitConversions -import scala.language.postfixOps class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter with PrivateMethodTester { From ae05fcd47b52f0da3db669d74888f3cc0780f33b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 2 Sep 2014 01:06:26 -0700 Subject: [PATCH 02/10] Updated tests, although DistributedSuite is hanging. --- .../spark/network/BlockTransferService.scala | 35 ++- .../apache/spark/network/ManagedBuffer.scala | 10 +- .../apache/spark/storage/BlockManager.scala | 26 +- .../storage/ShuffleBlockFetcherIterator.scala | 3 +- .../org/apache/spark/DistributedSuite.scala | 16 +- .../storage/BlockFetcherIteratorSuite.scala | 232 ------------------ .../spark/storage/BlockManagerSuite.scala | 119 +-------- .../spark/storage/DiskBlockManagerSuite.scala | 7 +- 8 files changed, 85 insertions(+), 363 deletions(-) delete mode 100644 core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 0aa4a85531fa..645adbce2311 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -63,8 +63,39 @@ abstract class BlockTransferService { * }}} */ def fetchBlock(hostName: String, port: Int, blockId: String): ManagedBuffer = { - //fetchBlocks(hostName, port, Seq(blockId)).iterator().next()._2 - null + // TODO(rxin): Add timeout? + val lock = new Object + @volatile var result: Either[ManagedBuffer, Exception] = null + fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener { + override def onBlockFetchFailure(exception: Exception): Unit = { + lock.synchronized { + result = Right(exception) + lock.notify() + } + } + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + lock.synchronized { + result = Left(data) + lock.notify() + } + } + }) + + // Sleep until result is no longer null + lock.synchronized { + while (result == null) { + try { + lock.wait() + } catch { + case e: InterruptedException => + } + } + } + + result match { + case Left(data: ManagedBuffer) => data + case Right(e: Exception) => throw e + } } /** diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index f51724593a9b..8f48152b3cd2 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -17,8 +17,9 @@ package org.apache.spark.network -import java.io.{File, FileInputStream, InputStream} +import java.io.{RandomAccessFile, File, FileInputStream, InputStream} import java.nio.ByteBuffer +import java.nio.channels.FileChannel.MapMode import io.netty.buffer.{ByteBufInputStream, ByteBuf, Unpooled} import io.netty.channel.DefaultFileRegion @@ -34,7 +35,7 @@ abstract class ManagedBuffer { // Note that all the methods are defined with parenthesis because their implementations can // have side effects (io operations). - def byteBuffer(): ByteBuffer = throw new UnsupportedOperationException + def byteBuffer(): ByteBuffer def fileSegment(): Option[FileSegment] = None @@ -56,6 +57,11 @@ final class FileSegmentManagedBuffer(file: File, offset: Long, length: Long) override def size: Long = length + override def byteBuffer(): ByteBuffer = { + val channel = new RandomAccessFile(file, "r").getChannel + channel.map(MapMode.READ_ONLY, offset, length) + } + override private[network] def toNetty(): AnyRef = { val fileChannel = new FileInputStream(file).getChannel new DefaultFileRegion(fileChannel, offset, length) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index cc5d505303fc..cfb18dbe4e29 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -63,9 +63,9 @@ private[spark] class BlockManager( mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, blockTransferService: BlockTransferService) - extends BlockDataProvider with Logging { + extends BlockDataManager with Logging { - //blockTransferService.init(this) + blockTransferService.init(this) private val port = conf.getInt("spark.blockManager.port", 0) val shuffleBlockManager = new ShuffleBlockManager(this, shuffleManager) @@ -207,20 +207,34 @@ private[spark] class BlockManager( } } - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { + /** + * Interface to get local block data. + * + * @return Some(buffer) if the block exists locally, and None if it doesn't. + */ + override def getBlockData(blockId: String): Option[ManagedBuffer] = { val bid = BlockId(blockId) if (bid.isShuffle) { - Left(diskBlockManager.getBlockLocation(bid)) + val fileSegment = diskBlockManager.getBlockLocation(bid) + Some(new FileSegmentManagedBuffer(fileSegment.file, fileSegment.offset, fileSegment.length)) } else { val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] if (blockBytesOpt.isDefined) { - Right(blockBytesOpt.get) + val buffer = blockBytesOpt.get + Some(new NioByteBufferManagedBuffer(buffer)) } else { - throw new BlockNotFoundException(blockId) + None } } } + /** + * Put the block locally, using the given storage level. + */ + override def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = { + putBytes(BlockId(blockId), data.byteBuffer(), level) + } + /** * Get the BlockStatus for the block identified by the given ID, if it exists. * NOTE: This is mainly for testing, and it doesn't fetch information from Tachyon. diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d4ed33eb506c..5c647f9754f4 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -231,7 +231,8 @@ final class ShuffleBlockFetcherIterator( if (!result.failed) { bytesInFlight -= result.size } - while (!fetchRequests.isEmpty && + // Send fetch requests up to maxBytesInFlight + while (fetchRequests.nonEmpty && (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { sendRequest(fetchRequests.dequeue()) } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 5406fcc2ac83..b54163a3bcfa 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark -import org.apache.spark.network.cm.{GetBlock, BlockManagerWorker, ConnectionManagerId} import org.scalatest.BeforeAndAfter import org.scalatest.FunSuite import org.scalatest.concurrent.Timeouts._ @@ -202,12 +201,13 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter val blockIds = data.partitions.indices.map(index => RDDBlockId(data.id, index)).toArray val blockId = blockIds(0) val blockManager = SparkEnv.get.blockManager - blockManager.master.getLocations(blockId).foreach(id => { - val bytes = BlockManagerWorker.syncGetBlock( - GetBlock(blockId), ConnectionManagerId(id.host, id.port)) - val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList + val blockTransfer = SparkEnv.get.blockTransferService + blockManager.master.getLocations(blockId).foreach { cmId => + val bytes = blockTransfer.fetchBlock(cmId.host, cmId.port, blockId.toString) + val deserialized = blockManager.dataDeserialize(blockId, bytes.byteBuffer()) + .asInstanceOf[Iterator[Int]].toList assert(deserialized === (1 to 100).toList) - }) + } } test("compute without caching when no partitions fit in memory") { @@ -268,6 +268,8 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter DistributedSuite.amMaster = true sc = new SparkContext(clusterUrl, "test") for (i <- 1 to 3) { + println("i = " + i) + Console.out.flush() val data = sc.parallelize(Seq(true, true), 2) assert(data.count === 2) assert(data.map(markNodeIfIdentity).collect.size === 2) @@ -339,6 +341,7 @@ object DistributedSuite { // Act like an identity function, but if the argument is true, set mark to true. def markNodeIfIdentity(item: Boolean): Boolean = { if (item) { + println("marking node!!!!!!!!!!!!!!!") assert(!amMaster) mark = true } @@ -349,6 +352,7 @@ object DistributedSuite { // crashing the entire JVM. def failOnMarkedIdentity(item: Boolean): Boolean = { if (mark) { + println("failing node !!!!!!!!!!!!!!!") System.exit(42) } item diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala deleted file mode 100644 index d689d6c15d5a..000000000000 --- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala +++ /dev/null @@ -1,232 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.io.IOException -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.future -import scala.concurrent.ExecutionContext.Implicits.global - -import org.mockito.Mockito._ -import org.mockito.Matchers.{any, eq => meq} -import org.mockito.stubbing.Answer -import org.mockito.invocation.InvocationOnMock - -import org.scalatest.{FunSuite, Matchers} - -import org.apache.spark.network.cm._ -import org.apache.spark.executor.ShuffleReadMetrics -import org.apache.spark.storage.BlockFetcherIterator._ - - -class BlockFetcherIteratorSuite extends FunSuite with Matchers { - - test("block fetch from local fails using BasicBlockFetcherIterator") { - val blockManager = mock(classOf[BlockManager]) - val connManager = mock(classOf[ConnectionManager]) - doReturn(connManager).when(blockManager).connectionManager - doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId - - doReturn((48 * 1024 * 1024).asInstanceOf[Long]).when(blockManager).maxBytesInFlight - - val blIds = Array[BlockId]( - ShuffleBlockId(0,0,0), - ShuffleBlockId(0,1,0), - ShuffleBlockId(0,2,0), - ShuffleBlockId(0,3,0), - ShuffleBlockId(0,4,0)) - - val optItr = mock(classOf[Option[Iterator[Any]]]) - val answer = new Answer[Option[Iterator[Any]]] { - override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] { - throw new Exception - } - } - - // 3rd block is going to fail - doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(0)), any()) - doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(1)), any()) - doAnswer(answer).when(blockManager).getLocalFromDisk(meq(blIds(2)), any()) - doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(3)), any()) - doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(4)), any()) - - val bmId = BlockManagerId("test-client", "test-client", 1) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) - ) - - val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null, - new ShuffleReadMetrics()) - - iterator.initialize() - - // 3rd getLocalFromDisk invocation should be failed - verify(blockManager, times(3)).getLocalFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") - // the 2nd element of the tuple returned by iterator.next should be defined when fetching successfully - assert(iterator.next._2.isDefined, "1st element should be defined but is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") - assert(iterator.next._2.isDefined, "2nd element should be defined but is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") - // 3rd fetch should be failed - assert(!iterator.next._2.isDefined, "3rd element should not be defined but is actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements") - // Don't call next() after fetching non-defined element even if thare are rest of elements in the iterator. - // Otherwise, BasicBlockFetcherIterator hangs up. - } - - - test("block fetch from local succeed using BasicBlockFetcherIterator") { - val blockManager = mock(classOf[BlockManager]) - val connManager = mock(classOf[ConnectionManager]) - doReturn(connManager).when(blockManager).connectionManager - doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId - - doReturn((48 * 1024 * 1024).asInstanceOf[Long]).when(blockManager).maxBytesInFlight - - val blIds = Array[BlockId]( - ShuffleBlockId(0,0,0), - ShuffleBlockId(0,1,0), - ShuffleBlockId(0,2,0), - ShuffleBlockId(0,3,0), - ShuffleBlockId(0,4,0)) - - val optItr = mock(classOf[Option[Iterator[Any]]]) - - // All blocks should be fetched successfully - doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(0)), any()) - doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(1)), any()) - doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(2)), any()) - doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(3)), any()) - doReturn(optItr).when(blockManager).getLocalFromDisk(meq(blIds(4)), any()) - - val bmId = BlockManagerId("test-client", "test-client", 1) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) - ) - - val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null, - new ShuffleReadMetrics()) - - iterator.initialize() - - // getLocalFromDis should be invoked for all of 5 blocks - verify(blockManager, times(5)).getLocalFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") - assert(iterator.next._2.isDefined, "All elements should be defined but 1st element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") - assert(iterator.next._2.isDefined, "All elements should be defined but 2nd element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") - assert(iterator.next._2.isDefined, "All elements should be defined but 3rd element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements") - assert(iterator.next._2.isDefined, "All elements should be defined but 4th element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements") - assert(iterator.next._2.isDefined, "All elements should be defined but 5th element is not actually defined") - } - - test("block fetch from remote fails using BasicBlockFetcherIterator") { - val blockManager = mock(classOf[BlockManager]) - val connManager = mock(classOf[ConnectionManager]) - when(blockManager.connectionManager).thenReturn(connManager) - - val f = future { - throw new IOException("Send failed or we received an error ACK") - } - when(connManager.sendMessageReliably(any(), - any())).thenReturn(f) - when(blockManager.futureExecContext).thenReturn(global) - - when(blockManager.blockManagerId).thenReturn( - BlockManagerId("test-client", "test-client", 1)) - when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024) - - val blId1 = ShuffleBlockId(0,0,0) - val blId2 = ShuffleBlockId(0,1,0) - val bmId = BlockManagerId("test-server", "test-server", 1) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, Seq((blId1, 1L), (blId2, 1L))) - ) - - val iterator = new BasicBlockFetcherIterator(blockManager, - blocksByAddress, null, new ShuffleReadMetrics()) - - iterator.initialize() - iterator.foreach{ - case (_, r) => { - (!r.isDefined) should be(true) - } - } - } - - test("block fetch from remote succeed using BasicBlockFetcherIterator") { - val blockManager = mock(classOf[BlockManager]) - val connManager = mock(classOf[ConnectionManager]) - when(blockManager.connectionManager).thenReturn(connManager) - - val blId1 = ShuffleBlockId(0,0,0) - val blId2 = ShuffleBlockId(0,1,0) - val buf1 = ByteBuffer.allocate(4) - val buf2 = ByteBuffer.allocate(4) - buf1.putInt(1) - buf1.flip() - buf2.putInt(1) - buf2.flip() - val blockMessage1 = BlockMessage.fromGotBlock(GotBlock(blId1, buf1)) - val blockMessage2 = BlockMessage.fromGotBlock(GotBlock(blId2, buf2)) - val blockMessageArray = new BlockMessageArray( - Seq(blockMessage1, blockMessage2)) - - val bufferMessage = blockMessageArray.toBufferMessage - val buffer = ByteBuffer.allocate(bufferMessage.size) - val arrayBuffer = new ArrayBuffer[ByteBuffer] - bufferMessage.buffers.foreach{ b => - buffer.put(b) - } - buffer.flip() - arrayBuffer += buffer - - val f = future { - Message.createBufferMessage(arrayBuffer) - } - when(connManager.sendMessageReliably(any(), - any())).thenReturn(f) - when(blockManager.futureExecContext).thenReturn(global) - - when(blockManager.blockManagerId).thenReturn( - BlockManagerId("test-client", "test-client", 1)) - when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024) - - val bmId = BlockManagerId("test-server", "test-server", 1) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, Seq((blId1, 1L), (blId2, 1L))) - ) - - val iterator = new BasicBlockFetcherIterator(blockManager, - blocksByAddress, null, new ShuffleReadMetrics()) - iterator.initialize() - iterator.foreach{ - case (_, r) => { - (r.isDefined) should be(true) - } - } - } -} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index b20ed4cbc56c..bd980763c780 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -21,6 +21,8 @@ import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays import java.util.concurrent.TimeUnit +import org.apache.spark.network.cm.CMBlockTransferService + import scala.collection.mutable.ArrayBuffer import scala.concurrent.Await import scala.concurrent.duration._ @@ -31,10 +33,7 @@ import akka.actor._ import akka.pattern.ask import akka.util.Timeout -import org.mockito.invocation.InvocationOnMock -import org.mockito.Matchers.any -import org.mockito.Mockito.{doAnswer, mock, spy, when} -import org.mockito.stubbing.Answer +import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ @@ -43,7 +42,6 @@ import org.scalatest.Matchers import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.executor.DataReadMethod -import org.apache.spark.network.cm._ import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.hash.HashShuffleManager @@ -74,8 +72,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) private def makeBlockManager(maxMem: Long, name: String = ""): BlockManager = { - new BlockManager(name, actorSystem, master, serializer, maxMem, conf, securityMgr, - mapOutputTracker, shuffleManager) + val transfer = new CMBlockTransferService(conf, securityMgr) + new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + mapOutputTracker, shuffleManager, transfer) } before { @@ -793,8 +792,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("block store put failure") { // Use Java serializer so we can create an unserializable error. + val transfer = new CMBlockTransferService(conf, securityMgr) store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf, - securityMgr, mapOutputTracker, shuffleManager) + mapOutputTracker, shuffleManager, transfer) // The put should fail since a1 is not serializable. class UnserializableClass @@ -1007,109 +1007,6 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter assert(!store.memoryStore.contains(rdd(1, 0)), "rdd_1_0 was in store") } - test("return error message when error occurred in BlockManagerWorker#onBlockMessageReceive") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker, shuffleManager) - - val worker = spy(new BlockManagerWorker(store)) - val connManagerId = mock(classOf[ConnectionManagerId]) - - // setup request block messages - val reqBlId1 = ShuffleBlockId(0,0,0) - val reqBlId2 = ShuffleBlockId(0,1,0) - val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1)) - val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2)) - val reqBlockMessages = new BlockMessageArray( - Seq(reqBlockMessage1, reqBlockMessage2)) - val reqBufferMessage = reqBlockMessages.toBufferMessage - - val answer = new Answer[Option[BlockMessage]] { - override def answer(invocation: InvocationOnMock) - :Option[BlockMessage]= { - throw new Exception - } - } - - doAnswer(answer).when(worker).processBlockMessage(any()) - - // Test when exception was thrown during processing block messages - var ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId) - - assert(ackMessage.isDefined, "When Exception was thrown in " + - "BlockManagerWorker#processBlockMessage, " + - "ackMessage should be defined") - assert(ackMessage.get.hasError, "When Exception was thown in " + - "BlockManagerWorker#processBlockMessage, " + - "ackMessage should have error") - - val notBufferMessage = mock(classOf[Message]) - - // Test when not BufferMessage was received - ackMessage = worker.onBlockMessageReceive(notBufferMessage, connManagerId) - assert(ackMessage.isDefined, "When not BufferMessage was passed to " + - "BlockManagerWorker#onBlockMessageReceive, " + - "ackMessage should be defined") - assert(ackMessage.get.hasError, "When not BufferMessage was passed to " + - "BlockManagerWorker#onBlockMessageReceive, " + - "ackMessage should have error") - } - - test("return ack message when no error occurred in BlocManagerWorker#onBlockMessageReceive") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, - securityMgr, mapOutputTracker, shuffleManager) - - val worker = spy(new BlockManagerWorker(store)) - val connManagerId = mock(classOf[ConnectionManagerId]) - - // setup request block messages - val reqBlId1 = ShuffleBlockId(0,0,0) - val reqBlId2 = ShuffleBlockId(0,1,0) - val reqBlockMessage1 = BlockMessage.fromGetBlock(GetBlock(reqBlId1)) - val reqBlockMessage2 = BlockMessage.fromGetBlock(GetBlock(reqBlId2)) - val reqBlockMessages = new BlockMessageArray( - Seq(reqBlockMessage1, reqBlockMessage2)) - - val tmpBufferMessage = reqBlockMessages.toBufferMessage - val buffer = ByteBuffer.allocate(tmpBufferMessage.size) - val arrayBuffer = new ArrayBuffer[ByteBuffer] - tmpBufferMessage.buffers.foreach{ b => - buffer.put(b) - } - buffer.flip() - arrayBuffer += buffer - val reqBufferMessage = Message.createBufferMessage(arrayBuffer) - - // setup ack block messages - val buf1 = ByteBuffer.allocate(4) - val buf2 = ByteBuffer.allocate(4) - buf1.putInt(1) - buf1.flip() - buf2.putInt(1) - buf2.flip() - val ackBlockMessage1 = BlockMessage.fromGotBlock(GotBlock(reqBlId1, buf1)) - val ackBlockMessage2 = BlockMessage.fromGotBlock(GotBlock(reqBlId2, buf2)) - - val answer = new Answer[Option[BlockMessage]] { - override def answer(invocation: InvocationOnMock) - :Option[BlockMessage]= { - if (invocation.getArguments()(0).asInstanceOf[BlockMessage].eq( - reqBlockMessage1)) { - return Some(ackBlockMessage1) - } else { - return Some(ackBlockMessage2) - } - } - } - - doAnswer(answer).when(worker).processBlockMessage(any()) - - val ackMessage = worker.onBlockMessageReceive(reqBufferMessage, connManagerId) - assert(ackMessage.isDefined, "When BlockManagerWorker#onBlockMessageReceive " + - "was executed successfully, ackMessage should be defined") - assert(!ackMessage.get.hasError, "When BlockManagerWorker#onBlockMessageReceive " + - "was executed successfully, ackMessage should not have error") - } - test("reserve/release unroll memory") { store = makeBlockManager(12000) val memoryStore = store.memoryStore diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index aabaeadd7a07..98fa4544385c 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.io.{File, FileWriter} +import org.apache.spark.network.cm.CMBlockTransferService import org.apache.spark.shuffle.hash.HashShuffleManager import scala.collection.mutable @@ -61,7 +62,6 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before rootDir1 = Files.createTempDir() rootDir1.deleteOnExit() rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath - println("Created root dirs: " + rootDirs) } override def afterAll() { @@ -153,8 +153,9 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with Before val master = new BlockManagerMaster( actorSystem.actorOf(Props(new BlockManagerMasterActor(true, confCopy, new LiveListenerBus))), confCopy) - val store = new BlockManager("", actorSystem, master , serializer, confCopy, - securityManager, null, shuffleManager) + val transfer = new CMBlockTransferService(confCopy, securityManager) + val store = new BlockManager("", actorSystem, master, serializer, confCopy, + mapOutputTracker = null, shuffleManager, transfer) try { From 98c668ae98e6e7d3d22504b3607527ad162356fc Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 2 Sep 2014 15:03:58 -0700 Subject: [PATCH 03/10] Added failure handling and fixed unit tests. --- .../spark/network/BlockFetchingListener.scala | 2 +- .../spark/network/BlockTransferService.scala | 24 ++++++++++--------- .../network/cm/CMBlockTransferService.scala | 6 ++++- .../storage/ShuffleBlockFetcherIterator.scala | 17 +++++++------ .../org/apache/spark/DistributedSuite.scala | 2 -- 5 files changed, 27 insertions(+), 24 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala index c1dfcf1c12d3..6bc123701bdd 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala @@ -33,5 +33,5 @@ trait BlockFetchingListener extends EventListener { /** * Called upon failures. */ - def onBlockFetchFailure(exception: Exception): Unit + def onBlockFetchFailure(exception: Throwable): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 645adbce2311..3000fc74fdb0 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -44,7 +44,11 @@ abstract class BlockTransferService { def hostName: String /** - * Fetch a sequence of blocks from a remote node, available only after [[init]] is invoked. + * Fetch a sequence of blocks from a remote node asynchronously, + * available only after [[init]] is invoked. + * + * Note that [[BlockFetchingListener.onBlockFetchSuccess]] is called once per block, + * while [[BlockFetchingListener.onBlockFetchSuccess]] is called once per failure. * * This takes a sequence so the implementation can batch requests. */ @@ -55,19 +59,17 @@ abstract class BlockTransferService { listener: BlockFetchingListener): Unit /** - * Fetch a single block from a remote node, available only after [[init]] is invoked. - * - * This is functionally equivalent to - * {{{ - * fetchBlocks(hostName, port, Seq(blockId)).iterator().next()._2 - * }}} + * Fetch a single block from a remote node, synchronously, + * available only after [[init]] is invoked. */ def fetchBlock(hostName: String, port: Int, blockId: String): ManagedBuffer = { // TODO(rxin): Add timeout? + + // A monitor for the thread to wait on. val lock = new Object - @volatile var result: Either[ManagedBuffer, Exception] = null + @volatile var result: Either[ManagedBuffer, Throwable] = null fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener { - override def onBlockFetchFailure(exception: Exception): Unit = { + override def onBlockFetchFailure(exception: Throwable): Unit = { lock.synchronized { result = Right(exception) lock.notify() @@ -93,8 +95,8 @@ abstract class BlockTransferService { } result match { - case Left(data: ManagedBuffer) => data - case Right(e: Exception) => throw e + case Left(data) => data + case Right(e) => throw e } } diff --git a/core/src/main/scala/org/apache/spark/network/cm/CMBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/cm/CMBlockTransferService.scala index 3b61c0ee852c..86d6396dfdc2 100644 --- a/core/src/main/scala/org/apache/spark/network/cm/CMBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/CMBlockTransferService.scala @@ -84,7 +84,7 @@ final class CMBlockTransferService(conf: SparkConf, securityManager: SecurityMan val future = cm.sendMessageReliably(cmId, blockMessageArray.toBufferMessage) - // If succeeds in getting blocks from a remote connection manager, put the block in results. + // Register the listener on success/failure future callback. future.onSuccess { case message => val bufferMessage = message.asInstanceOf[BufferMessage] val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) @@ -101,6 +101,10 @@ final class CMBlockTransferService(conf: SparkConf, securityManager: SecurityMan } } }(cm.futureExecContext) + + future.onFailure { case exception => + listener.onBlockFetchFailure(exception) + }(cm.futureExecContext) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 5c647f9754f4..bdba7b133f92 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -116,18 +116,17 @@ final class ShuffleBlockFetcherIterator( logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } - override def onBlockFetchFailure(exception: Exception): Unit = { - + override def onBlockFetchFailure(e: Throwable): Unit = { + logError("Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + // Note that there is a chance that some blocks have been fetched successfully, but we + // still add them to the failed queue. This is fine because when the caller see a + // FetchFailedException, it is going to fail the entire task anyway. + for ((blockId, size) <- req.blocks) { + results.put(new FetchResult(blockId, -1, null)) + } } } ) - // case Failure(exception) => { - // logError("Could not get block(s) from " + cmId, exception) - // for ((blockId, size) <- req.blocks) { - // results.put(new FetchResult(blockId, -1, null)) - // } - // } - // } } private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index b54163a3bcfa..2cd1fc187028 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -341,7 +341,6 @@ object DistributedSuite { // Act like an identity function, but if the argument is true, set mark to true. def markNodeIfIdentity(item: Boolean): Boolean = { if (item) { - println("marking node!!!!!!!!!!!!!!!") assert(!amMaster) mark = true } @@ -352,7 +351,6 @@ object DistributedSuite { // crashing the entire JVM. def failOnMarkedIdentity(item: Boolean): Boolean = { if (mark) { - println("failing node !!!!!!!!!!!!!!!") System.exit(42) } item From 07ccf0db4c250c0160e42cf6efb030e7502d30e7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 2 Sep 2014 15:12:54 -0700 Subject: [PATCH 04/10] Added init check to CMBlockTransferService. --- .../network/cm/CMBlockTransferService.scala | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/cm/CMBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/cm/CMBlockTransferService.scala index 86d6396dfdc2..227d73c77075 100644 --- a/core/src/main/scala/org/apache/spark/network/cm/CMBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/cm/CMBlockTransferService.scala @@ -41,12 +41,18 @@ final class CMBlockTransferService(conf: SparkConf, securityManager: SecurityMan /** * Port number the service is listening on, available only after [[init]] is invoked. */ - override def port: Int = cm.id.port + override def port: Int = { + checkInit() + cm.id.port + } /** * Host name the service is listening on, available only after [[init]] is invoked. */ - override def hostName: String = cm.id.host + override def hostName: String = { + checkInit() + cm.id.host + } /** * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch @@ -76,6 +82,7 @@ final class CMBlockTransferService(conf: SparkConf, securityManager: SecurityMan port: Int, blockIds: Seq[String], listener: BlockFetchingListener): Unit = { + checkInit() val cmId = new ConnectionManagerId(hostName, port) val blockMessageArray = new BlockMessageArray(blockIds.map { blockId => @@ -118,6 +125,7 @@ final class CMBlockTransferService(conf: SparkConf, securityManager: SecurityMan blockId: String, blockData: ManagedBuffer, level: StorageLevel) { + checkInit() val msg = PutBlock(BlockId(blockId), blockData.byteBuffer(), level) val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg)) val remoteCmId = new ConnectionManagerId(hostName, port) @@ -127,6 +135,10 @@ final class CMBlockTransferService(conf: SparkConf, securityManager: SecurityMan Duration.Inf) } + private def checkInit(): Unit = if (cm == null) { + throw new IllegalStateException(getClass.getName + " has not been initialized") + } + private def onBlockMessageReceive(msg: Message, id: ConnectionManagerId): Option[Message] = { logDebug("Handling message " + msg) msg match { From 2c6b1e1b0cf77e926b0f140ee3360cdf3c12fc4e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 2 Sep 2014 17:37:05 -0700 Subject: [PATCH 05/10] Removed println in test cases. --- core/src/test/scala/org/apache/spark/DistributedSuite.scala | 3 --- 1 file changed, 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 2cd1fc187028..03c2de99cd97 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -135,7 +135,6 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter sc.parallelize(1 to 10, 2).foreach { x => if (x == 1) System.exit(42) } } assert(thrown.getClass === classOf[SparkException]) - System.out.println(thrown.getMessage) assert(thrown.getMessage.contains("failed 4 times")) } } @@ -268,8 +267,6 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter DistributedSuite.amMaster = true sc = new SparkContext(clusterUrl, "test") for (i <- 1 to 3) { - println("i = " + i) - Console.out.flush() val data = sc.parallelize(Seq(true, true), 2) assert(data.count === 2) assert(data.map(markNodeIfIdentity).collect.size === 2) From 8a1046ee0b52fb0b695cc92e62c402311ea4f49f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 2 Sep 2014 22:56:25 -0700 Subject: [PATCH 06/10] Code review feedback: 1. Rename package name from cm to nio. 2. Refined BlockTransferService and ManagedBuffer interfaces. --- .../scala/org/apache/spark/SparkEnv.scala | 4 +- .../spark/network/BlockFetchingListener.scala | 2 +- .../spark/network/BlockTransferService.scala | 35 ++++++--- .../apache/spark/network/ManagedBuffer.scala | 73 ++++++++++--------- .../network/{cm => nio}/BlockMessage.scala | 19 +---- .../{cm => nio}/BlockMessageArray.scala | 6 +- .../network/{cm => nio}/BufferMessage.scala | 7 +- .../network/{cm => nio}/Connection.scala | 6 +- .../network/{cm => nio}/ConnectionId.scala | 6 +- .../{cm => nio}/ConnectionManager.scala | 11 +-- .../{cm => nio}/ConnectionManagerId.scala | 6 +- .../spark/network/{cm => nio}/Message.scala | 7 +- .../network/{cm => nio}/MessageChunk.scala | 4 +- .../{cm => nio}/MessageChunkHeader.scala | 6 +- .../NioBlockTransferService.scala} | 23 +++--- .../network/{cm => nio}/SecurityMessage.scala | 10 +-- .../spark/serializer/KryoSerializer.scala | 2 +- .../shuffle/FileShuffleBlockManager.scala | 2 +- .../shuffle/IndexShuffleBlockManager.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 9 +-- .../storage/ShuffleBlockFetcherIterator.scala | 2 +- .../org/apache/spark/DistributedSuite.scala | 4 +- .../{cm => nio}/ConnectionManagerSuite.scala | 9 ++- .../hash/HashShuffleManagerSuite.scala | 5 +- .../spark/storage/BlockManagerSuite.scala | 6 +- .../spark/storage/DiskBlockManagerSuite.scala | 2 +- 26 files changed, 141 insertions(+), 127 deletions(-) rename core/src/main/scala/org/apache/spark/network/{cm => nio}/BlockMessage.scala (91%) rename core/src/main/scala/org/apache/spark/network/{cm => nio}/BlockMessageArray.scala (98%) rename core/src/main/scala/org/apache/spark/network/{cm => nio}/BufferMessage.scala (98%) rename core/src/main/scala/org/apache/spark/network/{cm => nio}/Connection.scala (99%) rename core/src/main/scala/org/apache/spark/network/{cm => nio}/ConnectionId.scala (88%) rename core/src/main/scala/org/apache/spark/network/{cm => nio}/ConnectionManager.scala (99%) rename core/src/main/scala/org/apache/spark/network/{cm => nio}/ConnectionManagerId.scala (88%) rename core/src/main/scala/org/apache/spark/network/{cm => nio}/Message.scala (95%) rename core/src/main/scala/org/apache/spark/network/{cm => nio}/MessageChunk.scala (96%) rename core/src/main/scala/org/apache/spark/network/{cm => nio}/MessageChunkHeader.scala (95%) rename core/src/main/scala/org/apache/spark/network/{cm/CMBlockTransferService.scala => nio/NioBlockTransferService.scala} (92%) rename core/src/main/scala/org/apache/spark/network/{cm => nio}/SecurityMessage.scala (97%) rename core/src/test/scala/org/apache/spark/network/{cm => nio}/ConnectionManagerSuite.scala (99%) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index ad461ae10662..1642a2f8140c 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -32,7 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService -import org.apache.spark.network.cm.CMBlockTransferService +import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} @@ -226,7 +226,7 @@ object SparkEnv extends Logging { val shuffleMemoryManager = new ShuffleMemoryManager(conf) - val blockTransferService = new CMBlockTransferService(conf, securityManager) + val blockTransferService = new NioBlockTransferService(conf, securityManager) val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala index 6bc123701bdd..34acaa563ca5 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala @@ -31,7 +31,7 @@ trait BlockFetchingListener extends EventListener { def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit /** - * Called upon failures. + * Called upon failures. For each failure, this is called only once (i.e. not once per block). */ def onBlockFetchFailure(exception: Throwable): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 3000fc74fdb0..39e0a60398e8 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -17,6 +17,9 @@ package org.apache.spark.network +import scala.concurrent.{Await, Future} +import scala.concurrent.duration.Duration + import org.apache.spark.storage.StorageLevel @@ -48,9 +51,11 @@ abstract class BlockTransferService { * available only after [[init]] is invoked. * * Note that [[BlockFetchingListener.onBlockFetchSuccess]] is called once per block, - * while [[BlockFetchingListener.onBlockFetchSuccess]] is called once per failure. + * while [[BlockFetchingListener.onBlockFetchFailure]] is called once per failure (not per block). * - * This takes a sequence so the implementation can batch requests. + * Note that this API takes a sequence so the implementation can batch requests, and does not + * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as + * the data of a block is fetched, rather than waiting for all blocks to be fetched. */ def fetchBlocks( hostName: String, @@ -59,12 +64,21 @@ abstract class BlockTransferService { listener: BlockFetchingListener): Unit /** - * Fetch a single block from a remote node, synchronously, - * available only after [[init]] is invoked. + * Upload a single block to a remote node, available only after [[init]] is invoked. */ - def fetchBlock(hostName: String, port: Int, blockId: String): ManagedBuffer = { - // TODO(rxin): Add timeout? + def uploadBlock( + hostname: String, + port: Int, + blockId: String, + blockData: ManagedBuffer, + level: StorageLevel): Future[Unit] + /** + * A special case of [[fetchBlocks]], since it only fetches on block and is blocking. + * + * It is also only available after [[init]] is invoked. + */ + def fetchBlockSync(hostName: String, port: Int, blockId: String): ManagedBuffer = { // A monitor for the thread to wait on. val lock = new Object @volatile var result: Either[ManagedBuffer, Throwable] = null @@ -103,12 +117,15 @@ abstract class BlockTransferService { /** * Upload a single block to a remote node, available only after [[init]] is invoked. * - * This call blocks until the upload completes, or throws an exception upon failures. + * This method is similar to [[uploadBlock]], except this one blocks the thread + * until the upload finishes. */ - def uploadBlock( + def uploadBlockSync( hostname: String, port: Int, blockId: String, blockData: ManagedBuffer, - level: StorageLevel): Unit + level: StorageLevel): Unit = { + Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf) + } } diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index 8f48152b3cd2..5d30e9e7183f 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -17,85 +17,90 @@ package org.apache.spark.network -import java.io.{RandomAccessFile, File, FileInputStream, InputStream} +import java.io.{FileInputStream, RandomAccessFile, File, InputStream} import java.nio.ByteBuffer import java.nio.channels.FileChannel.MapMode -import io.netty.buffer.{ByteBufInputStream, ByteBuf, Unpooled} -import io.netty.channel.DefaultFileRegion +import io.netty.buffer.{ByteBufInputStream, ByteBuf} -import org.apache.spark.storage.FileSegment import org.apache.spark.util.ByteBufferInputStream /** - * Provides a buffer abstraction that allows pooling and reuse. + * This interface provides an immutable view for data in the form of bytes. The implementation + * should specify how the data is provided: + * + * - FileSegmentManagedBuffer: data backed by part of a file + * - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer + * - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf */ -abstract class ManagedBuffer { +sealed abstract class ManagedBuffer { // Note that all the methods are defined with parenthesis because their implementations can // have side effects (io operations). - def byteBuffer(): ByteBuffer - - def fileSegment(): Option[FileSegment] = None - - def inputStream(): InputStream = throw new UnsupportedOperationException - - def release(): Unit = throw new UnsupportedOperationException - + /** Number of bytes of the data. */ def size: Long - private[network] def toNetty(): AnyRef + /** + * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the + * returned ByteBuffer should not affect the content of this buffer. + */ + def nioByteBuffer(): ByteBuffer + + /** + * Exposes this buffer's data as an InputStream. The underlying implementation does not + * necessarily check for the length of bytes read, so the caller is responsible for making sure + * it does not go over the limit. + */ + def inputStream(): InputStream } /** - * A ManagedBuffer backed by a segment in a file. + * A [[ManagedBuffer]] backed by a segment in a file */ -final class FileSegmentManagedBuffer(file: File, offset: Long, length: Long) +final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long) extends ManagedBuffer { override def size: Long = length - override def byteBuffer(): ByteBuffer = { + override def nioByteBuffer(): ByteBuffer = { val channel = new RandomAccessFile(file, "r").getChannel channel.map(MapMode.READ_ONLY, offset, length) } - override private[network] def toNetty(): AnyRef = { - val fileChannel = new FileInputStream(file).getChannel - new DefaultFileRegion(fileChannel, offset, length) + override def inputStream(): InputStream = { + val is = new FileInputStream(file) + is.skip(offset) + is } } /** - * A ManagedBuffer backed by [[java.nio.ByteBuffer]]. + * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]]. */ final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { - override def byteBuffer() = buf - - override def inputStream() = new ByteBufferInputStream(buf) - override def size: Long = buf.remaining() - override private[network] def toNetty(): AnyRef = Unpooled.wrappedBuffer(buf) + override def nioByteBuffer() = buf + + override def inputStream() = new ByteBufferInputStream(buf) } /** - * A ManagedBuffer backed by a Netty [[ByteBuf]]. + * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]]. */ final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer { - override def byteBuffer() = buf.nioBuffer() - - override def inputStream() = new ByteBufInputStream(buf) + override def size: Long = buf.readableBytes() - override def release(): Unit = buf.release() + override def nioByteBuffer() = buf.nioBuffer() - override def size: Long = buf.readableBytes() + override def inputStream() = new ByteBufInputStream(buf) - override private[network] def toNetty(): AnyRef = buf + // TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it. + def release(): Unit = buf.release() } diff --git a/core/src/main/scala/org/apache/spark/network/cm/BlockMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala similarity index 91% rename from core/src/main/scala/org/apache/spark/network/cm/BlockMessage.scala rename to core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala index 107d0131efd7..b573f1a8a5fc 100644 --- a/core/src/main/scala/org/apache/spark/network/cm/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessage.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network.cm +package org.apache.spark.network.nio import java.nio.ByteBuffer @@ -23,11 +23,12 @@ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import scala.collection.mutable.{ArrayBuffer, StringBuilder} +// private[spark] because we need to register them in Kryo private[spark] case class GetBlock(id: BlockId) private[spark] case class GotBlock(id: BlockId, data: ByteBuffer) private[spark] case class PutBlock(id: BlockId, data: ByteBuffer, level: StorageLevel) -private[spark] class BlockMessage() { +private[nio] class BlockMessage() { // Un-initialized: typ = 0 // GetBlock: typ = 1 // GotBlock: typ = 2 @@ -158,7 +159,7 @@ private[spark] class BlockMessage() { } } -private[spark] object BlockMessage { +private[nio] object BlockMessage { val TYPE_NON_INITIALIZED: Int = 0 val TYPE_GET_BLOCK: Int = 1 val TYPE_GOT_BLOCK: Int = 2 @@ -193,16 +194,4 @@ private[spark] object BlockMessage { newBlockMessage.set(putBlock) newBlockMessage } - - def main(args: Array[String]) { - val B = new BlockMessage() - val blockId = TestBlockId("ABC") - B.set(new PutBlock(blockId, ByteBuffer.allocate(10), StorageLevel.MEMORY_AND_DISK_SER_2)) - val bMsg = B.toBufferMessage - val C = new BlockMessage() - C.set(bMsg) - - println(B.getId + " " + B.getLevel) - println(C.getId + " " + C.getLevel) - } } diff --git a/core/src/main/scala/org/apache/spark/network/cm/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala similarity index 98% rename from core/src/main/scala/org/apache/spark/network/cm/BlockMessageArray.scala rename to core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala index b0f770261c19..a1a2c00ed154 100644 --- a/core/src/main/scala/org/apache/spark/network/cm/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BlockMessageArray.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network.cm +package org.apache.spark.network.nio import java.nio.ByteBuffer @@ -24,7 +24,7 @@ import org.apache.spark.storage.{StorageLevel, TestBlockId} import scala.collection.mutable.ArrayBuffer -private[spark] +private[nio] class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockMessage] with Logging { @@ -102,7 +102,7 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) } } -private[spark] object BlockMessageArray { +private[nio] object BlockMessageArray { def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { val newBlockMessageArray = new BlockMessageArray() diff --git a/core/src/main/scala/org/apache/spark/network/cm/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala similarity index 98% rename from core/src/main/scala/org/apache/spark/network/cm/BufferMessage.scala rename to core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala index 5f7761838ab3..3b245c5c7a4f 100644 --- a/core/src/main/scala/org/apache/spark/network/cm/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/BufferMessage.scala @@ -15,15 +15,16 @@ * limitations under the License. */ -package org.apache.spark.network.cm +package org.apache.spark.network.nio import java.nio.ByteBuffer +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.storage.BlockManager -import scala.collection.mutable.ArrayBuffer -private[spark] +private[nio] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int) extends Message(Message.BUFFER_MESSAGE, id_) { diff --git a/core/src/main/scala/org/apache/spark/network/cm/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala similarity index 99% rename from core/src/main/scala/org/apache/spark/network/cm/Connection.scala rename to core/src/main/scala/org/apache/spark/network/nio/Connection.scala index 080c3e7dd42a..74074a8dcbff 100644 --- a/core/src/main/scala/org/apache/spark/network/cm/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network.cm +package org.apache.spark.network.nio import java.net._ import java.nio._ @@ -25,7 +25,7 @@ import org.apache.spark._ import scala.collection.mutable.{ArrayBuffer, HashMap, Queue} -private[spark] +private[nio] abstract class Connection(val channel: SocketChannel, val selector: Selector, val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId) extends Logging { @@ -190,7 +190,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, } -private[spark] +private[nio] class SendingConnection(val address: InetSocketAddress, selector_ : Selector, remoteId_ : ConnectionManagerId, id_ : ConnectionId) extends Connection(SocketChannel.open, selector_, remoteId_, id_) { diff --git a/core/src/main/scala/org/apache/spark/network/cm/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala similarity index 88% rename from core/src/main/scala/org/apache/spark/network/cm/ConnectionId.scala rename to core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala index 7b358a4d2598..764dc5e5503e 100644 --- a/core/src/main/scala/org/apache/spark/network/cm/ConnectionId.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionId.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.network.cm +package org.apache.spark.network.nio -private[spark] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { +private[nio] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId } -private[spark] object ConnectionId { +private[nio] object ConnectionId { def createConnectionIdFromString(connectionIdString: String): ConnectionId = { val res = connectionIdString.split("_").map(_.trim()) diff --git a/core/src/main/scala/org/apache/spark/network/cm/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala similarity index 99% rename from core/src/main/scala/org/apache/spark/network/cm/ConnectionManager.scala rename to core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index f9e35fb793fa..09d3ea306515 100644 --- a/core/src/main/scala/org/apache/spark/network/cm/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.network.cm +package org.apache.spark.network.nio import java.io.IOException import java.net._ @@ -26,15 +26,16 @@ import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} import java.util.{Timer, TimerTask} -import org.apache.spark._ -import org.apache.spark.util.{SystemClock, Utils} - import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue} import scala.concurrent.duration._ import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.language.postfixOps -private[spark] class ConnectionManager( +import org.apache.spark._ +import org.apache.spark.util.{SystemClock, Utils} + + +private[nio] class ConnectionManager( port: Int, conf: SparkConf, securityManager: SecurityManager, diff --git a/core/src/main/scala/org/apache/spark/network/cm/ConnectionManagerId.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala similarity index 88% rename from core/src/main/scala/org/apache/spark/network/cm/ConnectionManagerId.scala rename to core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala index b6b2cb0db429..cbb37ec5ced1 100644 --- a/core/src/main/scala/org/apache/spark/network/cm/ConnectionManagerId.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManagerId.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.network.cm +package org.apache.spark.network.nio import java.net.InetSocketAddress import org.apache.spark.util.Utils -private[spark] case class ConnectionManagerId(host: String, port: Int) { +private[nio] case class ConnectionManagerId(host: String, port: Int) { // DEBUG code Utils.checkHost(host) assert (port > 0) @@ -30,7 +30,7 @@ private[spark] case class ConnectionManagerId(host: String, port: Int) { } -private[spark] object ConnectionManagerId { +private[nio] object ConnectionManagerId { def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = { new ConnectionManagerId(socketAddress.getHostName, socketAddress.getPort) } diff --git a/core/src/main/scala/org/apache/spark/network/cm/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala similarity index 95% rename from core/src/main/scala/org/apache/spark/network/cm/Message.scala rename to core/src/main/scala/org/apache/spark/network/nio/Message.scala index 5b5bcc2d966e..0b874c289125 100644 --- a/core/src/main/scala/org/apache/spark/network/cm/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala @@ -15,14 +15,15 @@ * limitations under the License. */ -package org.apache.spark.network.cm +package org.apache.spark.network.nio import java.net.InetSocketAddress import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -private[spark] abstract class Message(val typ: Long, val id: Int) { + +private[nio] abstract class Message(val typ: Long, val id: Int) { var senderAddress: InetSocketAddress = null var started = false var startTime = -1L @@ -42,7 +43,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) { } -private[spark] object Message { +private[nio] object Message { val BUFFER_MESSAGE = 1111111111L var lastId = 1 diff --git a/core/src/main/scala/org/apache/spark/network/cm/MessageChunk.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala similarity index 96% rename from core/src/main/scala/org/apache/spark/network/cm/MessageChunk.scala rename to core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala index 95b46cd11f6b..278c5ac356ef 100644 --- a/core/src/main/scala/org/apache/spark/network/cm/MessageChunk.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunk.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.network.cm +package org.apache.spark.network.nio import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer -private[cm] +private[nio] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) { val size = if (buffer == null) 0 else buffer.remaining diff --git a/core/src/main/scala/org/apache/spark/network/cm/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala similarity index 95% rename from core/src/main/scala/org/apache/spark/network/cm/MessageChunkHeader.scala rename to core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala index 7087c7ad6c50..6e20f291c5ce 100644 --- a/core/src/main/scala/org/apache/spark/network/cm/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/MessageChunkHeader.scala @@ -15,12 +15,12 @@ * limitations under the License. */ -package org.apache.spark.network.cm +package org.apache.spark.network.nio import java.net.{InetAddress, InetSocketAddress} import java.nio.ByteBuffer -private[spark] class MessageChunkHeader( +private[nio] class MessageChunkHeader( val typ: Long, val id: Int, val totalSize: Int, @@ -56,7 +56,7 @@ private[spark] class MessageChunkHeader( } -private[spark] object MessageChunkHeader { +private[nio] object MessageChunkHeader { val HEADER_SIZE = 45 def create(buffer: ByteBuffer): MessageChunkHeader = { diff --git a/core/src/main/scala/org/apache/spark/network/cm/CMBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala similarity index 92% rename from core/src/main/scala/org/apache/spark/network/cm/CMBlockTransferService.scala rename to core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index 227d73c77075..59958ee89423 100644 --- a/core/src/main/scala/org/apache/spark/network/cm/CMBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -15,12 +15,11 @@ * limitations under the License. */ -package org.apache.spark.network.cm +package org.apache.spark.network.nio import java.nio.ByteBuffer -import scala.concurrent.Await -import scala.concurrent.duration.Duration +import scala.concurrent.Future import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf} import org.apache.spark.network._ @@ -29,9 +28,10 @@ import org.apache.spark.util.Utils /** - * A [[BlockTransferService]] implementation based on our [[ConnectionManager]]. + * A [[BlockTransferService]] implementation based on [[ConnectionManager]], a custom + * implementation using Java NIO. */ -final class CMBlockTransferService(conf: SparkConf, securityManager: SecurityManager) +final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityManager) extends BlockTransferService with Logging { private var cm: ConnectionManager = _ @@ -124,15 +124,14 @@ final class CMBlockTransferService(conf: SparkConf, securityManager: SecurityMan port: Int, blockId: String, blockData: ManagedBuffer, - level: StorageLevel) { + level: StorageLevel) + : Future[Unit] = { checkInit() - val msg = PutBlock(BlockId(blockId), blockData.byteBuffer(), level) + val msg = PutBlock(BlockId(blockId), blockData.nioByteBuffer(), level) val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg)) val remoteCmId = new ConnectionManagerId(hostName, port) - - // TODO: Not wait infinitely. - Await.result(cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage), - Duration.Inf) + val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage) + reply.map(x => ())(cm.futureExecContext) } private def checkInit(): Unit = if (cm == null) { @@ -201,6 +200,6 @@ final class CMBlockTransferService(conf: SparkConf, securityManager: SecurityMan val buffer = blockDataManager.getBlockData(blockId).orNull logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + " and got buffer " + buffer) - buffer.byteBuffer() + buffer.nioByteBuffer() } } diff --git a/core/src/main/scala/org/apache/spark/network/cm/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/network/cm/SecurityMessage.scala rename to core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala index f59df06fb3d9..747a2088a725 100644 --- a/core/src/main/scala/org/apache/spark/network/cm/SecurityMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/SecurityMessage.scala @@ -15,14 +15,14 @@ * limitations under the License. */ -package org.apache.spark.network.cm +package org.apache.spark.network.nio import java.nio.ByteBuffer -import org.apache.spark._ - import scala.collection.mutable.{ArrayBuffer, StringBuilder} +import org.apache.spark._ + /** * SecurityMessage is class that contains the connectionId and sasl token * used in SASL negotiation. SecurityMessage has routines for converting @@ -52,7 +52,7 @@ import scala.collection.mutable.{ArrayBuffer, StringBuilder} * - Length of the token * - Token */ -private[spark] class SecurityMessage() extends Logging { +private[nio] class SecurityMessage extends Logging { private var connectionId: String = null private var token: Array[Byte] = null @@ -132,7 +132,7 @@ private[spark] class SecurityMessage() extends Logging { } } -private[spark] object SecurityMessage { +private[nio] object SecurityMessage { /** * Convert the given BufferMessage to a SecurityMessage by parsing the contents diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index dd0421a5c15a..d6386f8c06ff 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -27,7 +27,7 @@ import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.spark._ import org.apache.spark.broadcast.HttpBroadcast -import org.apache.spark.network.cm.{PutBlock, GotBlock, GetBlock} +import org.apache.spark.network.nio.{PutBlock, GotBlock, GetBlock} import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage._ import org.apache.spark.util.BoundedPriorityQueue diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 6c63147e50d6..292ac0d66366 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -169,7 +169,7 @@ class FileShuffleBlockManager(conf: SparkConf) override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { val segment = getBlockData(blockId) - Some(segment.byteBuffer()) + Some(segment.nioByteBuffer()) } override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index 9558b2c45fd8..4ab34336d3f0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -91,7 +91,7 @@ class IndexShuffleBlockManager extends ShuffleBlockManager { } override def getBytes(blockId: ShuffleBlockId): Option[ByteBuffer] = { - Some(getBlockData(blockId).byteBuffer()) + Some(getBlockData(blockId).nioByteBuffer()) } override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 97900c0e2d39..8a84ba443b57 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -35,7 +35,7 @@ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{ShuffleBlockManager, ShuffleManager} +import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.util._ @@ -67,7 +67,6 @@ private[spark] class BlockManager( blockTransferService.init(this) - private val port = conf.getInt("spark.blockManager.port", 0) val diskBlockManager = new DiskBlockManager(this, conf) private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] @@ -230,7 +229,7 @@ private[spark] class BlockManager( * Put the block locally, using the given storage level. */ override def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = { - putBytes(BlockId(blockId), data.byteBuffer(), level) + putBytes(BlockId(blockId), data.nioByteBuffer(), level) } /** @@ -520,7 +519,7 @@ private[spark] class BlockManager( val locations = Random.shuffle(master.getLocations(blockId)) for (loc <- locations) { logDebug(s"Getting remote block $blockId from $loc") - val data = blockTransferService.fetchBlock(loc.host, loc.port, blockId.toString).byteBuffer() + val data = blockTransferService.fetchBlockSync(loc.host, loc.port, blockId.toString).nioByteBuffer() if (data != null) { if (asBlockResult) { @@ -809,7 +808,7 @@ private[spark] class BlockManager( s"To node: $peer") try { - blockTransferService.uploadBlock( + blockTransferService.uploadBlockSync( peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel) } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 2548303cb62b..f0abcfdd6090 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -109,7 +109,7 @@ final class ShuffleBlockFetcherIterator( new BlockFetchingListener { override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), - () => blockManager.dataDeserialize(BlockId(blockId), data.byteBuffer(), serializer) + () => blockManager.dataDeserialize(BlockId(blockId), data.nioByteBuffer(), serializer) )) shuffleMetrics.remoteBytesRead += data.size shuffleMetrics.remoteBlocksFetched += 1 diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 03c2de99cd97..81b64c36ddca 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -202,8 +202,8 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter val blockManager = SparkEnv.get.blockManager val blockTransfer = SparkEnv.get.blockTransferService blockManager.master.getLocations(blockId).foreach { cmId => - val bytes = blockTransfer.fetchBlock(cmId.host, cmId.port, blockId.toString) - val deserialized = blockManager.dataDeserialize(blockId, bytes.byteBuffer()) + val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, blockId.toString) + val deserialized = blockManager.dataDeserialize(blockId, bytes.nioByteBuffer()) .asInstanceOf[Iterator[Int]].toList assert(deserialized === (1 to 100).toList) } diff --git a/core/src/test/scala/org/apache/spark/network/cm/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala similarity index 99% rename from core/src/test/scala/org/apache/spark/network/cm/ConnectionManagerSuite.scala rename to core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala index 258492051173..9f49587cdc67 100644 --- a/core/src/test/scala/org/apache/spark/network/cm/ConnectionManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/nio/ConnectionManagerSuite.scala @@ -15,18 +15,19 @@ * limitations under the License. */ -package org.apache.spark.network.cm +package org.apache.spark.network.nio import java.io.IOException import java.nio._ -import org.apache.spark.{SecurityManager, SparkConf} -import org.scalatest.FunSuite - import scala.concurrent.duration._ import scala.concurrent.{Await, TimeoutException} import scala.language.postfixOps +import org.scalatest.FunSuite + +import org.apache.spark.{SecurityManager, SparkConf} + /** * Test the ConnectionManager with various security settings. */ diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala index 152cf0f26a9d..ba47fe5e25b9 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.FunSuite import org.apache.spark.{SparkEnv, SparkContext, LocalSparkContext, SparkConf} import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.ManagedBuffer +import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.FileShuffleBlockManager import org.apache.spark.storage.{ShuffleBlockId, FileSegment} @@ -34,7 +34,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext { private val testConf = new SparkConf(false) private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) { - val segment = buffer.fileSegment().get + assert(buffer.isInstanceOf[FileSegmentManagedBuffer]) + val segment = buffer.asInstanceOf[FileSegmentManagedBuffer] assert(expected.file.getCanonicalPath === segment.file.getCanonicalPath) assert(expected.offset === segment.offset) assert(expected.length === segment.length) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index b2fca9c4edd9..5a015e252191 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -21,7 +21,7 @@ import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays import java.util.concurrent.TimeUnit -import org.apache.spark.network.cm.CMBlockTransferService +import org.apache.spark.network.nio.NioBlockTransferService import scala.collection.mutable.ArrayBuffer import scala.concurrent.Await @@ -72,7 +72,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) private def makeBlockManager(maxMem: Long, name: String = ""): BlockManager = { - val transfer = new CMBlockTransferService(conf, securityMgr) + val transfer = new NioBlockTransferService(conf, securityMgr) new BlockManager(name, actorSystem, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer) } @@ -792,7 +792,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter test("block store put failure") { // Use Java serializer so we can create an unserializable error. - val transfer = new CMBlockTransferService(conf, securityMgr) + val transfer = new NioBlockTransferService(conf, securityMgr) store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index d2997a7cc010..e4522e00a622 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import java.io.{File, FileWriter} -import org.apache.spark.network.cm.CMBlockTransferService +import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.shuffle.hash.HashShuffleManager import scala.collection.mutable From e29c721132fcee79d65d4b6e30dd4ee46a814ef7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 2 Sep 2014 22:59:46 -0700 Subject: [PATCH 07/10] Updated comment for ShuffleBlockFetcherIterator. --- .../spark/storage/ShuffleBlockFetcherIterator.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index f0abcfdd6090..439bbe06f131 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -39,10 +39,13 @@ import org.apache.spark.util.Utils * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid * using too much memory. * - * @param context - * @param blockManager - * @param blocksByAddress - * @param serializer + * @param context [[TaskContext]], used for metrics update + * @param blockTransferService [[BlockTransferService]] for fetching remote blocks + * @param blockManager [[BlockManager]] for reading local blocks + * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. + * For each block we also require the size (in bytes as a long field) in + * order to throttle the memory usage. + * @param serializer serializer used to deserialize the data. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. */ private[spark] From 2960c93a85db42bb6a018e28c44cde2aed73f3d6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 2 Sep 2014 23:55:56 -0700 Subject: [PATCH 08/10] Added ShuffleBlockFetcherIteratorSuite. --- .../spark/network/BlockTransferService.scala | 2 +- .../apache/spark/storage/ThreadingTest.scala | 120 ------------ .../ShuffleBlockFetcherIteratorSuite.scala | 183 ++++++++++++++++++ 3 files changed, 184 insertions(+), 121 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala create mode 100644 core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 39e0a60398e8..84d991fa6808 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -74,7 +74,7 @@ abstract class BlockTransferService { level: StorageLevel): Future[Unit] /** - * A special case of [[fetchBlocks]], since it only fetches on block and is blocking. + * A special case of [[fetchBlocks]], as it fetches only one block and is blocking. * * It is also only available after [[init]] is invoked. */ diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala deleted file mode 100644 index 8a836bbba274..000000000000 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.util.concurrent.ArrayBlockingQueue - -import akka.actor._ -import org.apache.spark.shuffle.hash.HashShuffleManager -import util.Random - -import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} -import org.apache.spark.scheduler.LiveListenerBus -import org.apache.spark.serializer.KryoSerializer - -/** - * This class tests the BlockManager and MemoryStore for thread safety and - * deadlocks. It spawns a number of producer and consumer threads. Producer - * threads continuously pushes blocks into the BlockManager and consumer - * threads continuously retrieves the blocks form the BlockManager and tests - * whether the block is correct or not. - */ -private[spark] object ThreadingTest { - - val numProducers = 5 - val numBlocksPerProducer = 20000 - - private[spark] class ProducerThread(manager: BlockManager, id: Int) extends Thread { - val queue = new ArrayBlockingQueue[(BlockId, Seq[Int])](100) - - override def run() { - for (i <- 1 to numBlocksPerProducer) { - val blockId = TestBlockId("b-" + id + "-" + i) - val blockSize = Random.nextInt(1000) - val block = (1 to blockSize).map(_ => Random.nextInt()) - val level = randomLevel() - val startTime = System.currentTimeMillis() - manager.putIterator(blockId, block.iterator, level, tellMaster = true) - println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms") - queue.add((blockId, block)) - } - println("Producer thread " + id + " terminated") - } - - def randomLevel(): StorageLevel = { - math.abs(Random.nextInt()) % 4 match { - case 0 => StorageLevel.MEMORY_ONLY - case 1 => StorageLevel.MEMORY_ONLY_SER - case 2 => StorageLevel.MEMORY_AND_DISK - case 3 => StorageLevel.MEMORY_AND_DISK_SER - } - } - } - - private[spark] class ConsumerThread( - manager: BlockManager, - queue: ArrayBlockingQueue[(BlockId, Seq[Int])] - ) extends Thread { - var numBlockConsumed = 0 - - override def run() { - println("Consumer thread started") - while(numBlockConsumed < numBlocksPerProducer) { - val (blockId, block) = queue.take() - val startTime = System.currentTimeMillis() - manager.get(blockId) match { - case Some(retrievedBlock) => - assert(retrievedBlock.data.toList.asInstanceOf[List[Int]] == block.toList, - "Block " + blockId + " did not match") - println("Got block " + blockId + " in " + - (System.currentTimeMillis - startTime) + " ms") - case None => - assert(false, "Block " + blockId + " could not be retrieved") - } - numBlockConsumed += 1 - } - println("Consumer thread terminated") - } - } - - def main(args: Array[String]) { - System.setProperty("spark.kryoserializer.buffer.mb", "1") - val actorSystem = ActorSystem("test") - val conf = new SparkConf() - val serializer = new KryoSerializer(conf) - val blockManagerMaster = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf) - val blockManager = new BlockManager( - "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf, - new MapOutputTrackerMaster(conf), new HashShuffleManager(conf), null) - val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) - val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) - producers.foreach(_.start) - consumers.foreach(_.start) - producers.foreach(_.join) - consumers.foreach(_.join) - blockManager.stop() - blockManagerMaster.stop() - actorSystem.shutdown() - actorSystem.awaitTermination() - println("Everything stopped.") - println( - "It will take sometime for the JVM to clean all temporary files and shutdown. Sit tight.") - } -} diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala new file mode 100644 index 000000000000..809bd7092965 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import org.apache.spark.TaskContext +import org.apache.spark.network.{BlockFetchingListener, BlockTransferService} + +import org.mockito.Mockito._ +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.scalatest.FunSuite + + +class ShuffleBlockFetcherIteratorSuite extends FunSuite { + + test("handle local read failures in BlockManager") { + val transfer = mock(classOf[BlockTransferService]) + val blockManager = mock(classOf[BlockManager]) + doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId + + val blIds = Array[BlockId]( + ShuffleBlockId(0,0,0), + ShuffleBlockId(0,1,0), + ShuffleBlockId(0,2,0), + ShuffleBlockId(0,3,0), + ShuffleBlockId(0,4,0)) + + val optItr = mock(classOf[Option[Iterator[Any]]]) + val answer = new Answer[Option[Iterator[Any]]] { + override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] { + throw new Exception + } + } + + // 3rd block is going to fail + doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any()) + doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any()) + doAnswer(answer).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any()) + doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any()) + doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any()) + + val bmId = BlockManagerId("test-client", "test-client", 1) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) + ) + + val iterator = new ShuffleBlockFetcherIterator( + new TaskContext(0, 0, 0), + transfer, + blockManager, + blocksByAddress, + null, + 48 * 1024 * 1024) + + // Without exhausting the iterator, the iterator should be lazy and not call + // getLocalShuffleFromDisk. + verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any()) + + assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") + // the 2nd element of the tuple returned by iterator.next should be defined when + // fetching successfully + assert(iterator.next()._2.isDefined, + "1st element should be defined but is not actually defined") + verify(blockManager, times(1)).getLocalShuffleFromDisk(any(), any()) + + assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") + assert(iterator.next()._2.isDefined, + "2nd element should be defined but is not actually defined") + verify(blockManager, times(2)).getLocalShuffleFromDisk(any(), any()) + + assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") + // 3rd fetch should be failed + intercept[Exception] { + iterator.next() + } + verify(blockManager, times(3)).getLocalShuffleFromDisk(any(), any()) + } + + test("handle local read successes") { + val transfer = mock(classOf[BlockTransferService]) + val blockManager = mock(classOf[BlockManager]) + doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId + + val blIds = Array[BlockId]( + ShuffleBlockId(0,0,0), + ShuffleBlockId(0,1,0), + ShuffleBlockId(0,2,0), + ShuffleBlockId(0,3,0), + ShuffleBlockId(0,4,0)) + + val optItr = mock(classOf[Option[Iterator[Any]]]) + + // All blocks should be fetched successfully + doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any()) + doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any()) + doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any()) + doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any()) + doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any()) + + val bmId = BlockManagerId("test-client", "test-client", 1) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) + ) + + val iterator = new ShuffleBlockFetcherIterator( + new TaskContext(0, 0, 0), + transfer, + blockManager, + blocksByAddress, + null, + 48 * 1024 * 1024) + + // Without exhausting the iterator, the iterator should be lazy and not call getLocalShuffleFromDisk. + verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any()) + + assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") + assert(iterator.next()._2.isDefined, + "All elements should be defined but 1st element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") + assert(iterator.next()._2.isDefined, + "All elements should be defined but 2nd element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") + assert(iterator.next()._2.isDefined, + "All elements should be defined but 3rd element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements") + assert(iterator.next()._2.isDefined, + "All elements should be defined but 4th element is not actually defined") + assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements") + assert(iterator.next()._2.isDefined, + "All elements should be defined but 5th element is not actually defined") + + verify(blockManager, times(5)).getLocalShuffleFromDisk(any(), any()) + } + + test("handle remote fetch failures in BlockTransferService") { + val transfer = mock(classOf[BlockTransferService]) + when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener] + listener.onBlockFetchFailure(new Exception("blah")) + } + }) + + val blockManager = mock(classOf[BlockManager]) + + when(blockManager.blockManagerId).thenReturn(BlockManagerId("test-client", "test-client", 1)) + + val blId1 = ShuffleBlockId(0, 0, 0) + val blId2 = ShuffleBlockId(0, 1, 0) + val bmId = BlockManagerId("test-server", "test-server", 1) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (bmId, Seq((blId1, 1L), (blId2, 1L)))) + + val iterator = new ShuffleBlockFetcherIterator( + new TaskContext(0, 0, 0), + transfer, + blockManager, + blocksByAddress, + null, + 48 * 1024 * 1024) + + iterator.foreach { case (_, iterOption) => + assert(!iterOption.isDefined) + } + } +} From 13321569b169b173444e5c8a4aab2975ebc3244d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 3 Sep 2014 00:05:24 -0700 Subject: [PATCH 09/10] Fixed style violation from refactoring. --- .../src/main/scala/org/apache/spark/storage/BlockManager.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 8a84ba443b57..05261e0e55af 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -519,7 +519,8 @@ private[spark] class BlockManager( val locations = Random.shuffle(master.getLocations(blockId)) for (loc <- locations) { logDebug(s"Getting remote block $blockId from $loc") - val data = blockTransferService.fetchBlockSync(loc.host, loc.port, blockId.toString).nioByteBuffer() + val data = blockTransferService.fetchBlockSync( + loc.host, loc.port, blockId.toString).nioByteBuffer() if (data != null) { if (asBlockResult) { From 1dfd3d7b1da8af79f3100b5661e9d457bab4f06f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 8 Sep 2014 12:51:01 -0700 Subject: [PATCH 10/10] Limit the length of the FileInputStream. --- .../org/apache/spark/network/ManagedBuffer.scala | 5 +++-- .../org/apache/spark/storage/BlockManager.scala | 14 ++++---------- .../storage/ShuffleBlockFetcherIterator.scala | 3 ++- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index 5d30e9e7183f..dcecb6beeea9 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -21,6 +21,7 @@ import java.io.{FileInputStream, RandomAccessFile, File, InputStream} import java.nio.ByteBuffer import java.nio.channels.FileChannel.MapMode +import com.google.common.io.ByteStreams import io.netty.buffer.{ByteBufInputStream, ByteBuf} import org.apache.spark.util.ByteBufferInputStream @@ -72,7 +73,7 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt override def inputStream(): InputStream = { val is = new FileInputStream(file) is.skip(offset) - is + ByteStreams.limit(is, length) } } @@ -84,7 +85,7 @@ final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { override def size: Long = buf.remaining() - override def nioByteBuffer() = buf + override def nioByteBuffer() = buf.duplicate() override def inputStream() = new ByteBufferInputStream(buf) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 05261e0e55af..d1bee3d2c033 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -339,16 +339,10 @@ private[spark] class BlockManager( * shuffle blocks. It is safe to do so without a lock on block info since disk store * never deletes (recent) items. */ - def getLocalShuffleFromDisk( - blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { - - val shuffleBlockManager = shuffleManager.shuffleBlockManager - val values = shuffleBlockManager.getBytes(blockId.asInstanceOf[ShuffleBlockId]).map( - bytes => this.dataDeserialize(blockId, bytes, serializer)) - - values.orElse { - throw new BlockException(blockId, s"Block $blockId not found on disk, though it should be") - } + def getLocalShuffleFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { + val buf = shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + val is = wrapForCompression(blockId, buf.inputStream()) + Some(serializer.newInstance().deserializeStream(is).asIterator) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 439bbe06f131..c8e708aa6b1b 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -112,7 +112,8 @@ final class ShuffleBlockFetcherIterator( new BlockFetchingListener { override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), - () => blockManager.dataDeserialize(BlockId(blockId), data.nioByteBuffer(), serializer) + () => serializer.newInstance().deserializeStream( + blockManager.wrapForCompression(BlockId(blockId), data.inputStream())).asIterator )) shuffleMetrics.remoteBytesRead += data.size shuffleMetrics.remoteBlocksFetched += 1