diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 009ed6477584..6123b182c1a5 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -32,6 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer @@ -39,6 +40,7 @@ import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ import org.apache.spark.util.{AkkaUtils, Utils} + /** * :: DeveloperApi :: * Holds all the runtime environment objects for a running Spark instance (either master or worker), @@ -234,7 +236,12 @@ object SparkEnv extends Logging { val shuffleMemoryManager = new ShuffleMemoryManager(conf) - val blockTransferService = new NioBlockTransferService(conf, securityManager) + // TODO(rxin): Config option based on class name, similar to shuffle mgr and compression codec. + val blockTransferService = if (conf.getBoolean("spark.shuffle.use.netty", false)) { + new NettyBlockTransferService(conf) + } else { + new NioBlockTransferService(conf, securityManager) + } val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index e0e91724271c..0eeffe0e7c5e 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -20,14 +20,14 @@ package org.apache.spark.network import org.apache.spark.storage.StorageLevel +private[spark] trait BlockDataManager { /** - * Interface to get local block data. - * - * @return Some(buffer) if the block exists locally, and None if it doesn't. + * Interface to get local block data. Throws an exception if the block cannot be found or + * cannot be read successfully. */ - def getBlockData(blockId: String): Option[ManagedBuffer] + def getBlockData(blockId: String): ManagedBuffer /** * Put the block locally, using the given storage level. diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala index 34acaa563ca5..dd70e2664793 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala @@ -23,6 +23,7 @@ import java.util.EventListener /** * Listener callback interface for [[BlockTransferService.fetchBlocks]]. */ +private[spark] trait BlockFetchingListener extends EventListener { /** @@ -31,7 +32,7 @@ trait BlockFetchingListener extends EventListener { def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit /** - * Called upon failures. For each failure, this is called only once (i.e. not once per block). + * Called at least once per block upon failures. */ - def onBlockFetchFailure(exception: Throwable): Unit + def onBlockFetchFailure(blockId: String, exception: Throwable): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 84d991fa6808..d3ed683c7e88 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -17,13 +17,17 @@ package org.apache.spark.network +import java.io.Closeable +import java.nio.ByteBuffer + import scala.concurrent.{Await, Future} import scala.concurrent.duration.Duration import org.apache.spark.storage.StorageLevel -abstract class BlockTransferService { +private[spark] +abstract class BlockTransferService extends Closeable { /** * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch @@ -34,7 +38,7 @@ abstract class BlockTransferService { /** * Tear down the transfer service. */ - def stop(): Unit + def close(): Unit /** * Port number the service is listening on, available only after [[init]] is invoked. @@ -50,9 +54,6 @@ abstract class BlockTransferService { * Fetch a sequence of blocks from a remote node asynchronously, * available only after [[init]] is invoked. * - * Note that [[BlockFetchingListener.onBlockFetchSuccess]] is called once per block, - * while [[BlockFetchingListener.onBlockFetchFailure]] is called once per failure (not per block). - * * Note that this API takes a sequence so the implementation can batch requests, and does not * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. @@ -83,7 +84,7 @@ abstract class BlockTransferService { val lock = new Object @volatile var result: Either[ManagedBuffer, Throwable] = null fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener { - override def onBlockFetchFailure(exception: Throwable): Unit = { + override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { lock.synchronized { result = Right(exception) lock.notify() @@ -91,7 +92,10 @@ abstract class BlockTransferService { } override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { lock.synchronized { - result = Left(data) + val ret = ByteBuffer.allocate(data.size.toInt) + ret.put(data.nioByteBuffer()) + ret.flip() + result = Left(new NioManagedBuffer(ret)) lock.notify() } } diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index a4409181ec90..8f8467b046b5 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -25,7 +25,8 @@ import java.nio.channels.FileChannel.MapMode import scala.util.Try import com.google.common.io.ByteStreams -import io.netty.buffer.{ByteBufInputStream, ByteBuf} +import io.netty.buffer.{Unpooled, ByteBufInputStream, ByteBuf} +import io.netty.channel.DefaultFileRegion import org.apache.spark.util.{ByteBufferInputStream, Utils} @@ -34,11 +35,17 @@ import org.apache.spark.util.{ByteBufferInputStream, Utils} * This interface provides an immutable view for data in the form of bytes. The implementation * should specify how the data is provided: * - * - FileSegmentManagedBuffer: data backed by part of a file - * - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer - * - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf + * - [[FileSegmentManagedBuffer]]: data backed by part of a file + * - [[NioManagedBuffer]]: data backed by a NIO ByteBuffer + * - [[NettyManagedBuffer]]: data backed by a Netty ByteBuf + * + * The concrete buffer implementation might be managed outside the JVM garbage collector. + * For example, in the case of [[NettyManagedBuffer]], the buffers are reference counted. + * In that case, if the buffer is going to be passed around to a different thread, retain/release + * should be called. */ -sealed abstract class ManagedBuffer { +private[spark] +abstract class ManagedBuffer { // Note that all the methods are defined with parenthesis because their implementations can // have side effects (io operations). @@ -57,12 +64,29 @@ sealed abstract class ManagedBuffer { * it does not go over the limit. */ def inputStream(): InputStream + + /** + * Increment the reference count by one if applicable. + */ + def retain(): this.type + + /** + * If applicable, decrement the reference count by one and deallocates the buffer if the + * reference count reaches zero. + */ + def release(): this.type + + /** + * Convert the buffer into an Netty object, used to write the data out. + */ + private[network] def convertToNetty(): AnyRef } /** * A [[ManagedBuffer]] backed by a segment in a file */ +private[spark] final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long) extends ManagedBuffer { @@ -113,6 +137,15 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt } } + private[network] override def convertToNetty(): AnyRef = { + val fileChannel = new FileInputStream(file).getChannel + new DefaultFileRegion(fileChannel, offset, length) + } + + // Content of file segments are not in-memory, so no need to reference count. + override def retain(): this.type = this + override def release(): this.type = this + override def toString: String = s"${getClass.getName}($file, $offset, $length)" } @@ -120,20 +153,30 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt /** * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]]. */ -final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { +private[spark] +final class NioManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { override def size: Long = buf.remaining() override def nioByteBuffer() = buf.duplicate() override def inputStream() = new ByteBufferInputStream(buf) + + private[network] override def convertToNetty(): AnyRef = Unpooled.wrappedBuffer(buf) + + // [[ByteBuffer]] is managed by the JVM garbage collector itself. + override def retain(): this.type = this + override def release(): this.type = this + + override def toString: String = s"${getClass.getName}($buf)" } /** * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]]. */ -final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer { +private[spark] +final class NettyManagedBuffer(buf: ByteBuf) extends ManagedBuffer { override def size: Long = buf.readableBytes() @@ -141,6 +184,17 @@ final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer { override def inputStream() = new ByteBufInputStream(buf) - // TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it. - def release(): Unit = buf.release() + private[network] override def convertToNetty(): AnyRef = buf + + override def retain(): this.type = { + buf.retain() + this + } + + override def release(): this.type = { + buf.release() + this + } + + override def toString: String = s"${getClass.getName}($buf)" } diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala b/core/src/main/scala/org/apache/spark/network/exceptions.scala similarity index 65% rename from core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala rename to core/src/main/scala/org/apache/spark/network/exceptions.scala index e28219dd7745..d918d358c4ad 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala +++ b/core/src/main/scala/org/apache/spark/network/exceptions.scala @@ -15,15 +15,17 @@ * limitations under the License. */ -package org.apache.spark.network.netty.client +package org.apache.spark.network -import java.util.EventListener +class BlockFetchFailureException(blockId: String, errorMsg: String, cause: Throwable) + extends Exception(errorMsg, cause) { + def this(blockId: String, errorMsg: String) = this(blockId, errorMsg, null) +} -trait BlockClientListener extends EventListener { - - def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit - def onFetchFailure(blockId: String, errorMsg: String): Unit +class BlockUploadFailureException(blockId: String, cause: Throwable) + extends Exception(s"Failed to fetch block $blockId", cause) { + def this(blockId: String) = this(blockId, null) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala new file mode 100644 index 000000000000..6bdbf88d337c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.io.Closeable +import java.util.concurrent.TimeoutException + +import scala.concurrent.{Future, promise} + +import io.netty.channel.{ChannelFuture, ChannelFutureListener} + +import org.apache.spark.Logging +import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener} +import org.apache.spark.storage.StorageLevel + + +/** + * Client for [[NettyBlockTransferService]]. The connection to server must have been established + * using [[BlockClientFactory]] before instantiating this. + * + * This class is used to make requests to the server , while [[BlockClientHandler]] is responsible + * for handling responses from the server. + * + * Concurrency: thread safe and can be called from multiple threads. + * + * @param cf the ChannelFuture for the connection. + * @param handler [[BlockClientHandler]] for handling outstanding requests. + */ +@throws[TimeoutException] +private[netty] +class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Closeable with Logging { + + private[this] val serverAddr = cf.channel().remoteAddress().toString + + def isActive: Boolean = cf.channel().isActive + + /** + * Ask the remote server for a sequence of blocks, and execute the callback. + * + * Note that this is asynchronous and returns immediately. Upstream caller should throttle the + * rate of fetching; otherwise we could run out of memory due to large outstanding fetches. + * + * @param blockIds sequence of block ids to fetch. + * @param listener callback to fire on fetch success / failure. + */ + def fetchBlocks(blockIds: Seq[String], listener: BlockFetchingListener): Unit = { + var startTime: Long = 0 + logTrace { + startTime = System.currentTimeMillis() + s"Sending request $blockIds to $serverAddr" + } + + blockIds.foreach { blockId => + handler.addFetchRequest(blockId, listener) + } + + cf.channel().writeAndFlush(BlockFetchRequest(blockIds)).addListener(new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (future.isSuccess) { + logTrace { + val timeTaken = System.currentTimeMillis() - startTime + s"Sending request $blockIds to $serverAddr took $timeTaken ms" + } + } else { + // Fail all blocks. + val errorMsg = + s"Failed to send request $blockIds to $serverAddr: ${future.cause.getMessage}" + logError(errorMsg, future.cause) + blockIds.foreach { blockId => + handler.removeFetchRequest(blockId) + listener.onBlockFetchFailure(blockId, new RuntimeException(errorMsg)) + } + } + } + }) + } + + def uploadBlock(blockId: String, data: ManagedBuffer, storageLevel: StorageLevel): Future[Unit] = + { + var startTime: Long = 0 + logTrace { + startTime = System.currentTimeMillis() + s"Uploading block ($blockId) to $serverAddr" + } + val f = cf.channel().writeAndFlush(new BlockUploadRequest(blockId, data, storageLevel)) + + val p = promise[Unit]() + handler.addUploadRequest(blockId, p) + f.addListener(new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (future.isSuccess) { + logTrace { + val timeTaken = System.currentTimeMillis() - startTime + s"Uploading block ($blockId) to $serverAddr took $timeTaken ms" + } + } else { + // Fail all blocks. + val errorMsg = + s"Failed to upload block $blockId to $serverAddr: ${future.cause.getMessage}" + logError(errorMsg, future.cause) + } + } + }) + + p.future + } + + /** Close the connection. This does NOT block till the connection is closed. */ + def close(): Unit = cf.channel().close() +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala new file mode 100644 index 000000000000..8021cfdf42d1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.io.Closeable +import java.util.concurrent.{ConcurrentHashMap, TimeoutException} + +import io.netty.bootstrap.Bootstrap +import io.netty.buffer.PooledByteBufAllocator +import io.netty.channel._ +import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollSocketChannel} +import io.netty.channel.nio.NioEventLoopGroup +import io.netty.channel.socket.SocketChannel +import io.netty.channel.socket.nio.NioSocketChannel +import io.netty.util.internal.PlatformDependent + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.Utils + + +/** + * Factory for creating [[BlockClient]] by using createClient. + * + * The factory maintains a connection pool to other hosts and should return the same [[BlockClient]] + * for the same remote host. It also shares a single worker thread pool for all [[BlockClient]]s. + */ +private[netty] +class BlockClientFactory(val conf: NettyConfig) extends Logging with Closeable { + + def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) + + /** A thread factory so the threads are named (for debugging). */ + private[this] val threadFactory = Utils.namedThreadFactory("spark-netty-client") + + /** Socket channel type, initialized by [[init]] depending ioMode. */ + private[this] var socketChannelClass: Class[_ <: Channel] = _ + + /** Thread pool shared by all clients. */ + private[this] var workerGroup: EventLoopGroup = _ + + private[this] val connectionPool = new ConcurrentHashMap[(String, Int), BlockClient] + + // The encoders are stateless and can be shared among multiple clients. + private[this] val encoder = new ClientRequestEncoder + private[this] val decoder = new ServerResponseDecoder + + init() + + /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */ + private def init(): Unit = { + def initNio(): Unit = { + socketChannelClass = classOf[NioSocketChannel] + workerGroup = new NioEventLoopGroup(conf.clientThreads, threadFactory) + } + def initEpoll(): Unit = { + socketChannelClass = classOf[EpollSocketChannel] + workerGroup = new EpollEventLoopGroup(conf.clientThreads, threadFactory) + } + + // For auto mode, first try epoll (only available on Linux), then nio. + conf.ioMode match { + case "nio" => initNio() + case "epoll" => initEpoll() + case "auto" => if (Epoll.isAvailable) initEpoll() else initNio() + } + } + + /** + * Create a new BlockFetchingClient connecting to the given remote host / port. + * + * This blocks until a connection is successfully established. + * + * Concurrency: This method is safe to call from multiple threads. + */ + def createClient(remoteHost: String, remotePort: Int): BlockClient = { + // Get connection from the connection pool first. + // If it is not found or not active, create a new one. + val cachedClient = connectionPool.get((remoteHost, remotePort)) + if (cachedClient != null && cachedClient.isActive) { + return cachedClient + } + + logDebug(s"Creating new connection to $remoteHost:$remotePort") + + // There is a chance two threads are creating two different clients connecting to the same host. + // But that's probably ok ... + + val handler = new BlockClientHandler + + val bootstrap = new Bootstrap + bootstrap.group(workerGroup) + .channel(socketChannelClass) + // Disable Nagle's Algorithm since we don't want packets to wait + .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) + .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) + .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectTimeoutMs) + + // Use pooled buffers to reduce temporary buffer allocation + bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()) + + bootstrap.handler(new ChannelInitializer[SocketChannel] { + override def initChannel(ch: SocketChannel): Unit = { + ch.pipeline + .addLast("clientRequestEncoder", encoder) + .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) + .addLast("serverResponseDecoder", decoder) + .addLast("handler", handler) + } + }) + + // Connect to the remote server + val cf: ChannelFuture = bootstrap.connect(remoteHost, remotePort) + if (!cf.awaitUninterruptibly(conf.connectTimeoutMs)) { + throw new TimeoutException( + s"Connecting to $remoteHost:$remotePort timed out (${conf.connectTimeoutMs} ms)") + } + + val client = new BlockClient(cf, handler) + connectionPool.put((remoteHost, remotePort), client) + client + } + + /** Close all connections in the connection pool, and shutdown the worker thread pool. */ + override def close(): Unit = { + val iter = connectionPool.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + entry.getValue.close() + connectionPool.remove(entry.getKey) + } + + if (workerGroup != null) { + workerGroup.shutdownGracefully() + } + } + + /** + * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches + * are disabled because the ByteBufs are allocated by the event loop thread, but released by the + * executor thread rather than the event loop thread. Those thread-local caches actually delay + * the recycling of buffers, leading to larger memory usage. + */ + private def createPooledByteBufAllocator(): PooledByteBufAllocator = { + def getPrivateStaticField(name: String): Int = { + val f = PooledByteBufAllocator.DEFAULT.getClass.getDeclaredField(name) + f.setAccessible(true) + f.getInt(null) + } + new PooledByteBufAllocator( + PlatformDependent.directBufferPreferred(), + getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), + getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), + getPrivateStaticField("DEFAULT_PAGE_SIZE"), + getPrivateStaticField("DEFAULT_MAX_ORDER"), + 0, // tinyCacheSize + 0, // smallCacheSize + 0 // normalCacheSize + ) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala new file mode 100644 index 000000000000..5e28a07a461f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.util.concurrent.ConcurrentHashMap + +import scala.concurrent.Promise + +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.Logging +import org.apache.spark.network.{BlockFetchFailureException, BlockUploadFailureException, BlockFetchingListener} + + +/** + * Handler that processes server responses, in response to requests issued from [[BlockClient]]. + * It works by tracking the list of outstanding requests (and their callbacks). + * + * Concurrency: thread safe and can be called from multiple threads. + */ +private[netty] +class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] with Logging { + + /** Tracks the list of outstanding requests and their listeners on success/failure. */ + private[this] val outstandingFetches: java.util.Map[String, BlockFetchingListener] = + new ConcurrentHashMap[String, BlockFetchingListener] + + private[this] val outstandingUploads: java.util.Map[String, Promise[Unit]] = + new ConcurrentHashMap[String, Promise[Unit]] + + def addFetchRequest(blockId: String, listener: BlockFetchingListener): Unit = { + outstandingFetches.put(blockId, listener) + } + + def removeFetchRequest(blockId: String): Unit = { + outstandingFetches.remove(blockId) + } + + def addUploadRequest(blockId: String, promise: Promise[Unit]): Unit = { + outstandingUploads.put(blockId, promise) + } + + /** + * Fire the failure callback for all outstanding requests. This is called when we have an + * uncaught exception or pre-mature connection termination. + */ + private def failOutstandingRequests(cause: Throwable): Unit = { + val iter1 = outstandingFetches.entrySet().iterator() + while (iter1.hasNext) { + val entry = iter1.next() + entry.getValue.onBlockFetchFailure(entry.getKey, cause) + } + // TODO(rxin): Maybe we need to synchronize the access? Otherwise we could clear new requests + // as well. But I guess that is ok given the caller will fail as soon as any requests fail. + outstandingFetches.clear() + + val iter2 = outstandingUploads.entrySet().iterator() + while (iter2.hasNext) { + val entry = iter2.next() + entry.getValue.failure(new RuntimeException(s"Failed to upload block ${entry.getKey}")) + } + outstandingUploads.clear() + } + + override def channelUnregistered(ctx: ChannelHandlerContext): Unit = { + if (outstandingFetches.size() > 0) { + logError("Still have " + outstandingFetches.size() + " requests outstanding " + + s"when connection from ${ctx.channel.remoteAddress} is closed") + failOutstandingRequests(new RuntimeException( + s"Connection from ${ctx.channel.remoteAddress} closed")) + } + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + if (outstandingFetches.size() > 0) { + logError( + s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}", cause) + failOutstandingRequests(cause) + } + ctx.close() + } + + override def channelRead0(ctx: ChannelHandlerContext, response: ServerResponse) { + val server = ctx.channel.remoteAddress.toString + response match { + case BlockFetchSuccess(blockId, buf) => + val listener = outstandingFetches.get(blockId) + if (listener == null) { + logWarning(s"Got a response for block $blockId from $server but it is not outstanding") + buf.release() + } else { + outstandingFetches.remove(blockId) + listener.onBlockFetchSuccess(blockId, buf) + buf.release() + } + case BlockFetchFailure(blockId, errorMsg) => + val listener = outstandingFetches.get(blockId) + if (listener == null) { + logWarning( + s"Got a response for block $blockId from $server ($errorMsg) but it is not outstanding") + } else { + outstandingFetches.remove(blockId) + listener.onBlockFetchFailure(blockId, new BlockFetchFailureException(blockId, errorMsg)) + } + case BlockUploadSuccess(blockId) => + val p = outstandingUploads.get(blockId) + if (p == null) { + logWarning(s"Got a response for upload $blockId from $server but it is not outstanding") + } else { + outstandingUploads.remove(blockId) + p.success(Unit) + } + case BlockUploadFailure(blockId, error) => + val p = outstandingUploads.get(blockId) + if (p == null) { + logWarning(s"Got a response for upload $blockId from $server but it is not outstanding") + } else { + outstandingUploads.remove(blockId) + p.failure(new BlockUploadFailureException(blockId)) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala similarity index 50% rename from core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala rename to core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index 7b2f9a8d4dfd..e2eb7c379f14 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -15,50 +15,30 @@ * limitations under the License. */ -package org.apache.spark.network.netty.server +package org.apache.spark.network.netty +import java.io.Closeable import java.net.InetSocketAddress import io.netty.bootstrap.ServerBootstrap import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption} -import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel} +import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollServerSocketChannel} import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.oio.OioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel -import io.netty.channel.socket.oio.OioServerSocketChannel -import io.netty.handler.codec.LineBasedFrameDecoder -import io.netty.handler.codec.string.StringDecoder -import io.netty.util.CharsetUtil - -import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.network.netty.NettyConfig -import org.apache.spark.storage.BlockDataProvider +import io.netty.channel.{ChannelInitializer, ChannelFuture, ChannelOption} + +import org.apache.spark.Logging +import org.apache.spark.network.BlockDataManager import org.apache.spark.util.Utils /** - * Server for serving Spark data blocks. - * This should be used together with [[org.apache.spark.network.netty.client.BlockFetchingClient]]. - * - * Protocol for requesting blocks (client to server): - * One block id per line, e.g. to request 3 blocks: "block1\nblock2\nblock3\n" - * - * Protocol for sending blocks (server to client): - * frame-length (4 bytes), block-id-length (4 bytes), block-id, block-data. - * - * frame-length should not include the length of itself. - * If block-id-length is negative, then this is an error message rather than block-data. The real - * length is the absolute value of the frame-length. - * + * Server for the [[NettyBlockTransferService]]. */ -private[spark] -class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Logging { - - def this(sparkConf: SparkConf, dataProvider: BlockDataProvider) = { - this(new NettyConfig(sparkConf), dataProvider) - } +private[netty] +class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) + extends Closeable with Logging { def port: Int = _port @@ -74,42 +54,24 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Lo /** Initialize the server. */ private def init(): Unit = { bootstrap = new ServerBootstrap - val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss") - val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker") + val threadFactory = Utils.namedThreadFactory("spark-netty-server") // Use only one thread to accept connections, and 2 * num_cores for worker. def initNio(): Unit = { - val bossGroup = new NioEventLoopGroup(1, bossThreadFactory) - val workerGroup = new NioEventLoopGroup(0, workerThreadFactory) - workerGroup.setIoRatio(conf.ioRatio) + val bossGroup = new NioEventLoopGroup(conf.serverThreads, threadFactory) + val workerGroup = bossGroup bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel]) } - def initOio(): Unit = { - val bossGroup = new OioEventLoopGroup(1, bossThreadFactory) - val workerGroup = new OioEventLoopGroup(0, workerThreadFactory) - bootstrap.group(bossGroup, workerGroup).channel(classOf[OioServerSocketChannel]) - } def initEpoll(): Unit = { - val bossGroup = new EpollEventLoopGroup(1, bossThreadFactory) - val workerGroup = new EpollEventLoopGroup(0, workerThreadFactory) - workerGroup.setIoRatio(conf.ioRatio) + val bossGroup = new EpollEventLoopGroup(conf.serverThreads, threadFactory) + val workerGroup = bossGroup bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel]) } conf.ioMode match { case "nio" => initNio() - case "oio" => initOio() case "epoll" => initEpoll() - case "auto" => - // For auto mode, first try epoll (only available on Linux), then nio. - try { - initEpoll() - } catch { - // TODO: Should we log the throwable? But that always happen on non-Linux systems. - // Perhaps the right thing to do is to check whether the system is Linux, and then only - // call initEpoll on Linux. - case e: Throwable => initNio() - } + case "auto" => if (Epoll.isAvailable) initEpoll() else initNio() } // Use pooled buffers to reduce temporary buffer allocation @@ -121,18 +83,18 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Lo bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog) } conf.receiveBuf.foreach { receiveBuf => - bootstrap.option[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf) + bootstrap.childOption[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf) } conf.sendBuf.foreach { sendBuf => - bootstrap.option[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf) + bootstrap.childOption[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf) } bootstrap.childHandler(new ChannelInitializer[SocketChannel] { override def initChannel(ch: SocketChannel): Unit = { ch.pipeline - .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024 - .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8)) - .addLast("blockHeaderEncoder", new BlockHeaderEncoder) + .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) + .addLast("clientRequestDecoder", new ClientRequestDecoder) + .addLast("serverResponseEncoder", new ServerResponseEncoder) .addLast("handler", new BlockServerHandler(dataProvider)) } }) @@ -142,11 +104,14 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Lo val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress] _port = addr.getPort - _hostName = addr.getHostName + // _hostName = addr.getHostName + _hostName = Utils.localHostName() + + logInfo(s"Server started ${_hostName}:${_port}") } /** Shutdown the server. */ - def stop(): Unit = { + def close(): Unit = { if (channelFuture != null) { channelFuture.channel().close().awaitUninterruptibly() channelFuture = null diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala new file mode 100644 index 000000000000..44687f0b770e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import io.netty.channel._ + +import org.apache.spark.Logging +import org.apache.spark.network.{ManagedBuffer, BlockDataManager} +import org.apache.spark.storage.StorageLevel + + +/** + * A handler that processes requests from clients and writes block data back. + * + * The messages should have been processed by the pipeline setup by BlockServerChannelInitializer. + */ +private[netty] class BlockServerHandler(dataProvider: BlockDataManager) + extends SimpleChannelInboundHandler[ClientRequest] with Logging { + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause) + ctx.close() + } + + override def channelRead0(ctx: ChannelHandlerContext, request: ClientRequest): Unit = { + request match { + case BlockFetchRequest(blockIds) => + blockIds.foreach(processFetchRequest(ctx, _)) + case BlockUploadRequest(blockId, data, level) => + processUploadRequest(ctx, blockId, data, level) + } + } // end of channelRead0 + + private def processFetchRequest(ctx: ChannelHandlerContext, blockId: String): Unit = { + // A helper function to send error message back to the client. + def client = ctx.channel.remoteAddress.toString + + def respondWithError(error: String): Unit = { + ctx.writeAndFlush(new BlockFetchFailure(blockId, error)).addListener( + new ChannelFutureListener { + override def operationComplete(future: ChannelFuture) { + if (!future.isSuccess) { + // TODO: Maybe log the success case as well. + logError(s"Error sending error back to $client", future.cause) + ctx.close() + } + } + } + ) + } + + logTrace(s"Received request from $client to fetch block $blockId") + + // First make sure we can find the block. If not, send error back to the user. + var buf: ManagedBuffer = null + try { + buf = dataProvider.getBlockData(blockId) + } catch { + case e: Exception => + logError(s"Error opening block $blockId for request from $client", e) + respondWithError(e.getMessage) + return + } + + ctx.writeAndFlush(new BlockFetchSuccess(blockId, buf)).addListener( + new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (future.isSuccess) { + logTrace(s"Sent block $blockId (${buf.size} B) back to $client") + } else { + logError( + s"Error sending block $blockId to $client; closing connection", future.cause) + ctx.close() + } + } + } + ) + } // end of processBlockRequest + + private def processUploadRequest( + ctx: ChannelHandlerContext, + blockId: String, + data: ManagedBuffer, + level: StorageLevel): Unit = { + // A helper function to send error message back to the client. + def client = ctx.channel.remoteAddress.toString + + try { + dataProvider.putBlockData(blockId, data, level) + ctx.writeAndFlush(BlockUploadSuccess(blockId)).addListener(new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (!future.isSuccess) { + logError(s"Error sending an ACK back to client $client") + } + } + }) + } catch { + case e: Throwable => + logError(s"Error processing uploaded block $blockId", e) + ctx.writeAndFlush(BlockUploadFailure(blockId, e.getMessage)).addListener( + new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (!future.isSuccess) { + logError(s"Error sending an ACK back to client $client") + } + } + }) + } + } // end of processUploadRequest +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala new file mode 100644 index 000000000000..b7f979dccd0f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import scala.concurrent.Future + +import org.apache.spark.SparkConf +import org.apache.spark.network._ +import org.apache.spark.storage.StorageLevel + + +/** + * A [[BlockTransferService]] implementation based on Netty. + * + * See protocol.scala for the communication protocol between server and client + */ +private[spark] +final class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { + + private[this] val nettyConf: NettyConfig = new NettyConfig(conf) + + private[this] var server: BlockServer = _ + private[this] var clientFactory: BlockClientFactory = _ + + override def init(blockDataManager: BlockDataManager): Unit = { + server = new BlockServer(nettyConf, blockDataManager) + clientFactory = new BlockClientFactory(nettyConf) + } + + override def close(): Unit = { + if (server != null) { + server.close() + } + if (clientFactory != null) { + clientFactory.close() + } + } + + override def fetchBlocks( + hostName: String, + port: Int, + blockIds: Seq[String], + listener: BlockFetchingListener): Unit = { + clientFactory.createClient(hostName, port).fetchBlocks(blockIds, listener) + } + + override def uploadBlock( + hostname: String, + port: Int, + blockId: String, + blockData: ManagedBuffer, + level: StorageLevel): Future[Unit] = { + clientFactory.createClient(hostName, port).uploadBlock(blockId, blockData, level) + } + + override def hostName: String = { + if (server == null) { + throw new IllegalStateException("Server has not been started") + } + server.hostName + } + + override def port: Int = { + if (server == null) { + throw new IllegalStateException("Server has not been started") + } + server.port + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala index b5870152c5a6..7c3074e93979 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala @@ -31,18 +31,20 @@ class NettyConfig(conf: SparkConf) { /** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */ private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "nio").toLowerCase - /** Connect timeout in secs. Default 60 secs. */ - private[netty] val connectTimeoutMs = conf.getInt("spark.shuffle.io.connectionTimeout", 60) * 1000 - - /** - * Percentage of the desired amount of time spent for I/O in the child event loops. - * Only applicable in nio and epoll. - */ - private[netty] val ioRatio = conf.getInt("spark.shuffle.io.netty.ioRatio", 80) + /** Connect timeout in secs. Default 120 secs. */ + private[netty] val connectTimeoutMs = { + conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000 + } /** Requested maximum length of the queue of incoming connections. */ private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt) + /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */ + private[netty] val serverThreads: Int = conf.getInt("spark.shuffle.io.serverThreads", 0) + + /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ + private[netty] val clientThreads: Int = conf.getInt("spark.shuffle.io.clientThreads", 0) + /** * Receive buffer size (SO_RCVBUF). * Note: the optimal size for receive buffer and send buffer should be @@ -51,7 +53,7 @@ class NettyConfig(conf: SparkConf) { * buffer size should be ~ 1.25MB */ private[netty] val receiveBuf: Option[Int] = - conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt) + conf.getOption("spark.shuffle.io.receiveBuffer").map(_.toInt) /** Send buffer size (SO_SNDBUF). */ private[netty] val sendBuf: Option[Int] = diff --git a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala b/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala deleted file mode 100644 index 0d7695072a7b..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import org.apache.spark.storage.{BlockId, FileSegment} - -trait PathResolver { - /** Get the file segment in which the given block resides. */ - def getBlockLocation(blockId: BlockId): FileSegment -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala deleted file mode 100644 index 5aea7ba2f367..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import java.util.concurrent.TimeoutException - -import io.netty.bootstrap.Bootstrap -import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.socket.SocketChannel -import io.netty.channel.{ChannelFutureListener, ChannelFuture, ChannelInitializer, ChannelOption} -import io.netty.handler.codec.LengthFieldBasedFrameDecoder -import io.netty.handler.codec.string.StringEncoder -import io.netty.util.CharsetUtil - -import org.apache.spark.Logging - -/** - * Client for fetching data blocks from [[org.apache.spark.network.netty.server.BlockServer]]. - * Use [[BlockFetchingClientFactory]] to instantiate this client. - * - * The constructor blocks until a connection is successfully established. - * - * See [[org.apache.spark.network.netty.server.BlockServer]] for client/server protocol. - * - * Concurrency: thread safe and can be called from multiple threads. - */ -@throws[TimeoutException] -private[spark] -class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, port: Int) - extends Logging { - - private val handler = new BlockFetchingClientHandler - - /** Netty Bootstrap for creating the TCP connection. */ - private val bootstrap: Bootstrap = { - val b = new Bootstrap - b.group(factory.workerGroup) - .channel(factory.socketChannelClass) - // Use pooled buffers to reduce temporary buffer allocation - .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - // Disable Nagle's Algorithm since we don't want packets to wait - .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) - .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) - .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, factory.conf.connectTimeoutMs) - - b.handler(new ChannelInitializer[SocketChannel] { - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("encoder", new StringEncoder(CharsetUtil.UTF_8)) - // maxFrameLength = 2G, lengthFieldOffset = 0, lengthFieldLength = 4 - .addLast("framedLengthDecoder", new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 4)) - .addLast("handler", handler) - } - }) - b - } - - /** Netty ChannelFuture for the connection. */ - private val cf: ChannelFuture = bootstrap.connect(hostname, port) - if (!cf.awaitUninterruptibly(factory.conf.connectTimeoutMs)) { - throw new TimeoutException( - s"Connecting to $hostname:$port timed out (${factory.conf.connectTimeoutMs} ms)") - } - - /** - * Ask the remote server for a sequence of blocks, and execute the callback. - * - * Note that this is asynchronous and returns immediately. Upstream caller should throttle the - * rate of fetching; otherwise we could run out of memory. - * - * @param blockIds sequence of block ids to fetch. - * @param listener callback to fire on fetch success / failure. - */ - def fetchBlocks(blockIds: Seq[String], listener: BlockClientListener): Unit = { - // It's best to limit the number of "write" calls since it needs to traverse the whole pipeline. - // It's also best to limit the number of "flush" calls since it requires system calls. - // Let's concatenate the string and then call writeAndFlush once. - // This is also why this implementation might be more efficient than multiple, separate - // fetch block calls. - var startTime: Long = 0 - logTrace { - startTime = System.nanoTime - s"Sending request $blockIds to $hostname:$port" - } - - blockIds.foreach { blockId => - handler.addRequest(blockId, listener) - } - - val writeFuture = cf.channel().writeAndFlush(blockIds.mkString("\n") + "\n") - writeFuture.addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (future.isSuccess) { - logTrace { - val timeTaken = (System.nanoTime - startTime).toDouble / 1000000 - s"Sending request $blockIds to $hostname:$port took $timeTaken ms" - } - } else { - // Fail all blocks. - val errorMsg = - s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}" - logError(errorMsg, future.cause) - blockIds.foreach { blockId => - listener.onFetchFailure(blockId, errorMsg) - handler.removeRequest(blockId) - } - } - } - }) - } - - def waitForClose(): Unit = { - cf.channel().closeFuture().sync() - } - - def close(): Unit = cf.channel().close() -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala deleted file mode 100644 index 2b28402c52b4..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import io.netty.channel.epoll.{EpollEventLoopGroup, EpollSocketChannel} -import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.oio.OioEventLoopGroup -import io.netty.channel.socket.nio.NioSocketChannel -import io.netty.channel.socket.oio.OioSocketChannel -import io.netty.channel.{EventLoopGroup, Channel} - -import org.apache.spark.SparkConf -import org.apache.spark.network.netty.NettyConfig -import org.apache.spark.util.Utils - -/** - * Factory for creating [[BlockFetchingClient]] by using createClient. This factory reuses - * the worker thread pool for Netty. - * - * Concurrency: createClient is safe to be called from multiple threads concurrently. - */ -private[spark] -class BlockFetchingClientFactory(val conf: NettyConfig) { - - def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) - - /** A thread factory so the threads are named (for debugging). */ - val threadFactory = Utils.namedThreadFactory("spark-shuffle-client") - - /** The following two are instantiated by the [[init]] method, depending ioMode. */ - var socketChannelClass: Class[_ <: Channel] = _ - var workerGroup: EventLoopGroup = _ - - init() - - /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */ - private def init(): Unit = { - def initOio(): Unit = { - socketChannelClass = classOf[OioSocketChannel] - workerGroup = new OioEventLoopGroup(0, threadFactory) - } - def initNio(): Unit = { - socketChannelClass = classOf[NioSocketChannel] - workerGroup = new NioEventLoopGroup(0, threadFactory) - } - def initEpoll(): Unit = { - socketChannelClass = classOf[EpollSocketChannel] - workerGroup = new EpollEventLoopGroup(0, threadFactory) - } - - conf.ioMode match { - case "nio" => initNio() - case "oio" => initOio() - case "epoll" => initEpoll() - case "auto" => - // For auto mode, first try epoll (only available on Linux), then nio. - try { - initEpoll() - } catch { - // TODO: Should we log the throwable? But that always happen on non-Linux systems. - // Perhaps the right thing to do is to check whether the system is Linux, and then only - // call initEpoll on Linux. - case e: Throwable => initNio() - } - } - } - - /** - * Create a new BlockFetchingClient connecting to the given remote host / port. - * - * This blocks until a connection is successfully established. - * - * Concurrency: This method is safe to call from multiple threads. - */ - def createClient(remoteHost: String, remotePort: Int): BlockFetchingClient = { - new BlockFetchingClient(this, remoteHost, remotePort) - } - - def stop(): Unit = { - if (workerGroup != null) { - workerGroup.shutdownGracefully() - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala deleted file mode 100644 index 83265b164299..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import io.netty.buffer.ByteBuf -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} - -import org.apache.spark.Logging - - -/** - * Handler that processes server responses. It uses the protocol documented in - * [[org.apache.spark.network.netty.server.BlockServer]]. - * - * Concurrency: thread safe and can be called from multiple threads. - */ -private[client] -class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] with Logging { - - /** Tracks the list of outstanding requests and their listeners on success/failure. */ - private val outstandingRequests = java.util.Collections.synchronizedMap { - new java.util.HashMap[String, BlockClientListener] - } - - def addRequest(blockId: String, listener: BlockClientListener): Unit = { - outstandingRequests.put(blockId, listener) - } - - def removeRequest(blockId: String): Unit = { - outstandingRequests.remove(blockId) - } - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}" - logError(errorMsg, cause) - - // Fire the failure callback for all outstanding blocks - outstandingRequests.synchronized { - val iter = outstandingRequests.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - entry.getValue.onFetchFailure(entry.getKey, errorMsg) - } - outstandingRequests.clear() - } - - ctx.close() - } - - override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) { - val totalLen = in.readInt() - val blockIdLen = in.readInt() - val blockIdBytes = new Array[Byte](math.abs(blockIdLen)) - in.readBytes(blockIdBytes) - val blockId = new String(blockIdBytes) - val blockSize = totalLen - math.abs(blockIdLen) - 4 - - def server = ctx.channel.remoteAddress.toString - - // blockIdLen is negative when it is an error message. - if (blockIdLen < 0) { - val errorMessageBytes = new Array[Byte](blockSize) - in.readBytes(errorMessageBytes) - val errorMsg = new String(errorMessageBytes) - logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server") - - val listener = outstandingRequests.get(blockId) - if (listener == null) { - // Ignore callback - logWarning(s"Got a response for block $blockId but it is not in our outstanding requests") - } else { - outstandingRequests.remove(blockId) - listener.onFetchFailure(blockId, errorMsg) - } - } else { - logTrace(s"Received block $blockId ($blockSize B) from $server") - - val listener = outstandingRequests.get(blockId) - if (listener == null) { - // Ignore callback - logWarning(s"Got a response for block $blockId but it is not in our outstanding requests") - } else { - outstandingRequests.remove(blockId) - listener.onFetchSuccess(blockId, new ReferenceCountedBuffer(in)) - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala b/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala deleted file mode 100644 index 9740ee64d1f2..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -/** - * A simple iterator that lazily initializes the underlying iterator. - * - * The use case is that sometimes we might have many iterators open at the same time, and each of - * the iterator might initialize its own buffer (e.g. decompression buffer, deserialization buffer). - * This could lead to too many buffers open. If this iterator is used, we lazily initialize those - * buffers. - */ -private[spark] -class LazyInitIterator(createIterator: => Iterator[Any]) extends Iterator[Any] { - - lazy val proxy = createIterator - - override def hasNext: Boolean = { - val gotNext = proxy.hasNext - if (!gotNext) { - close() - } - gotNext - } - - override def next(): Any = proxy.next() - - def close(): Unit = Unit -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala deleted file mode 100644 index ea1abf5eccc2..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import java.io.InputStream -import java.nio.ByteBuffer - -import io.netty.buffer.{ByteBuf, ByteBufInputStream} - - -/** - * A buffer abstraction based on Netty's ByteBuf so we don't expose Netty. - * This is a Scala value class. - * - * The buffer's life cycle is NOT managed by the JVM, and thus requiring explicit declaration of - * reference by the retain method and release method. - */ -private[spark] -class ReferenceCountedBuffer(val underlying: ByteBuf) extends AnyVal { - - /** Return the nio ByteBuffer view of the underlying buffer. */ - def byteBuffer(): ByteBuffer = underlying.nioBuffer - - /** Creates a new input stream that starts from the current position of the buffer. */ - def inputStream(): InputStream = new ByteBufInputStream(underlying) - - /** Increment the reference counter by one. */ - def retain(): Unit = underlying.retain() - - /** Decrement the reference counter by one and release the buffer if the ref count is 0. */ - def release(): Unit = underlying.release() -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala new file mode 100644 index 000000000000..13942f3d0adc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio.ByteBuffer +import java.util.{List => JList} + +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.ChannelHandler.Sharable +import io.netty.handler.codec._ + +import org.apache.spark.Logging +import org.apache.spark.network.{NioManagedBuffer, NettyManagedBuffer, ManagedBuffer} +import org.apache.spark.storage.StorageLevel + + +/** Messages from the client to the server. */ +private[netty] +sealed trait ClientRequest { + def id: Byte +} + +/** + * Request to fetch a sequence of blocks from the server. A single [[BlockFetchRequest]] can + * correspond to multiple [[ServerResponse]]s. + */ +private[netty] +final case class BlockFetchRequest(blocks: Seq[String]) extends ClientRequest { + override def id = 0 +} + +/** + * Request to upload a block to the server. Currently the server does not ack the upload request. + */ +private[netty] +final case class BlockUploadRequest( + blockId: String, + data: ManagedBuffer, + level: StorageLevel) + extends ClientRequest { + require(blockId.length <= Byte.MaxValue) + override def id = 1 +} + + +/** Messages from server to client (usually in response to some [[ClientRequest]]. */ +private[netty] +sealed trait ServerResponse { + def id: Byte +} + +/** Response to [[BlockFetchRequest]] when a block exists and has been successfully fetched. */ +private[netty] +final case class BlockFetchSuccess(blockId: String, data: ManagedBuffer) extends ServerResponse { + require(blockId.length <= Byte.MaxValue) + override def id = 0 +} + +/** Response to [[BlockFetchRequest]] when there is an error fetching the block. */ +private[netty] +final case class BlockFetchFailure(blockId: String, error: String) extends ServerResponse { + require(blockId.length <= Byte.MaxValue) + override def id = 1 +} + +/** Response to [[BlockUploadRequest]] when a block is successfully uploaded. */ +private[netty] +final case class BlockUploadSuccess(blockId: String) extends ServerResponse { + require(blockId.length <= Byte.MaxValue) + override def id = 2 +} + +/** Response to [[BlockUploadRequest]] when there is an error uploading the block. */ +private[netty] +final case class BlockUploadFailure(blockId: String, error: String) extends ServerResponse { + require(blockId.length <= Byte.MaxValue) + override def id = 3 +} + + +/** + * Encoder for [[ClientRequest]] used in client side. + * + * This encoder is stateless so it is safe to be shared by multiple threads. + */ +@Sharable +private[netty] +final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] { + override def encode(ctx: ChannelHandlerContext, in: ClientRequest, out: JList[Object]): Unit = { + in match { + case BlockFetchRequest(blocks) => + // 8 bytes: frame size + // 1 byte: BlockFetchRequest vs BlockUploadRequest + // 4 byte: num blocks + // then for each block id write 1 byte for blockId.length and then blockId itself + val frameLength = 8 + 1 + 4 + blocks.size + blocks.map(_.size).fold(0)(_ + _) + val buf = ctx.alloc().buffer(frameLength) + + buf.writeLong(frameLength) + buf.writeByte(in.id) + buf.writeInt(blocks.size) + blocks.foreach { blockId => + ProtocolUtils.writeBlockId(buf, blockId) + } + + assert(buf.writableBytes() == 0) + out.add(buf) + + case BlockUploadRequest(blockId, data, level) => + // 8 bytes: frame size + // 1 byte: msg id (BlockFetchRequest vs BlockUploadRequest) + // 1 byte: blockId.length + // data itself (length can be derived from: frame size - 1 - blockId.length) + val headerLength = 8 + 1 + 1 + blockId.length + 5 + val frameLength = headerLength + data.size + val header = ctx.alloc().buffer(headerLength) + + // Call this before we add header to out so in case of exceptions + // we don't send anything at all. + val body = data.convertToNetty() + + header.writeLong(frameLength) + header.writeByte(in.id) + ProtocolUtils.writeBlockId(header, blockId) + header.writeInt(level.toInt) + header.writeByte(level.replication) + + assert(header.writableBytes() == 0) + out.add(header) + out.add(body) + } + } +} + + +/** + * Decoder in the server side to decode client requests. + * This decoder is stateless so it is safe to be shared by multiple threads. + * + * This assumes the inbound messages have been processed by a frame decoder created by + * [[ProtocolUtils.createFrameDecoder()]]. + */ +@Sharable +private[netty] +final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] { + override protected def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: JList[AnyRef]): Unit = + { + val msgTypeId = in.readByte() + val decoded = msgTypeId match { + case 0 => // BlockFetchRequest + val numBlocks = in.readInt() + val blockIds = Seq.fill(numBlocks) { ProtocolUtils.readBlockId(in) } + BlockFetchRequest(blockIds) + + case 1 => // BlockUploadRequest + val blockId = ProtocolUtils.readBlockId(in) + val level = new StorageLevel(in.readInt(), in.readByte()) + + val ret = ByteBuffer.allocate(in.readableBytes()) + ret.put(in.nioBuffer()) + ret.flip() + BlockUploadRequest(blockId, new NioManagedBuffer(ret), level) + } + + assert(decoded.id == msgTypeId) + out.add(decoded) + } +} + + +/** + * Encoder used by the server side to encode server-to-client responses. + * This encoder is stateless so it is safe to be shared by multiple threads. + */ +@Sharable +private[netty] +final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse] with Logging { + override def encode(ctx: ChannelHandlerContext, in: ServerResponse, out: JList[Object]): Unit = { + in match { + case BlockFetchSuccess(blockId, data) => + // Handle the body first so if we encounter an error getting the body, we can respond + // with an error instead. + var body: AnyRef = null + try { + body = data.convertToNetty() + } catch { + case e: Exception => + // Re-encode this message as BlockFetchFailure. + logError(s"Error opening block $blockId for client ${ctx.channel.remoteAddress}", e) + encode(ctx, new BlockFetchFailure(blockId, e.getMessage), out) + return + } + + // If we got here, body cannot be null + // 8 bytes = long for frame length + // 1 byte = message id (type) + // 1 byte = block id length + // followed by block id itself + val headerLength = 8 + 1 + 1 + blockId.length + val frameLength = headerLength + data.size + val header = ctx.alloc().buffer(headerLength) + header.writeLong(frameLength) + header.writeByte(in.id) + ProtocolUtils.writeBlockId(header, blockId) + + assert(header.writableBytes() == 0) + out.add(header) + out.add(body) + + case BlockFetchFailure(blockId, error) => + val frameLength = 8 + 1 + 1 + blockId.length + error.length + val buf = ctx.alloc().buffer(frameLength) + buf.writeLong(frameLength) + buf.writeByte(in.id) + ProtocolUtils.writeBlockId(buf, blockId) + buf.writeBytes(error.getBytes) + + assert(buf.writableBytes() == 0) + out.add(buf) + + case BlockUploadSuccess(blockId) => + val frameLength = 8 + 1 + 1 + blockId.length + val buf = ctx.alloc().buffer(frameLength) + buf.writeLong(frameLength) + buf.writeByte(in.id) + ProtocolUtils.writeBlockId(buf, blockId) + + assert(buf.writableBytes() == 0) + out.add(buf) + + case BlockUploadFailure(blockId, error) => + val frameLength = 8 + 1 + 1 + blockId.length + + error.length + val buf = ctx.alloc().buffer(frameLength) + buf.writeLong(frameLength) + buf.writeByte(in.id) + ProtocolUtils.writeBlockId(buf, blockId) + buf.writeBytes(error.getBytes) + + assert(buf.writableBytes() == 0) + out.add(buf) + } + } +} + + +/** + * Decoder in the client side to decode server responses. + * This decoder is stateless so it is safe to be shared by multiple threads. + * + * This assumes the inbound messages have been processed by a frame decoder created by + * [[ProtocolUtils.createFrameDecoder()]]. + */ +@Sharable +private[netty] +final class ServerResponseDecoder extends MessageToMessageDecoder[ByteBuf] { + override def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: JList[AnyRef]): Unit = { + val msgId = in.readByte() + val decoded = msgId match { + case 0 => // BlockFetchSuccess + val blockId = ProtocolUtils.readBlockId(in) + in.retain() + BlockFetchSuccess(blockId, new NettyManagedBuffer(in)) + + case 1 => // BlockFetchFailure + val blockId = ProtocolUtils.readBlockId(in) + val errorBytes = new Array[Byte](in.readableBytes()) + in.readBytes(errorBytes) + BlockFetchFailure(blockId, new String(errorBytes)) + + case 2 => // BlockUploadSuccess + BlockUploadSuccess(ProtocolUtils.readBlockId(in)) + + case 3 => // BlockUploadFailure + val blockId = ProtocolUtils.readBlockId(in) + val errorBytes = new Array[Byte](in.readableBytes()) + in.readBytes(errorBytes) + BlockUploadFailure(blockId, new String(errorBytes)) + } + + assert(decoded.id == msgId) + out.add(decoded) + } +} + + +private[netty] object ProtocolUtils { + + /** LengthFieldBasedFrameDecoder used before all decoders. */ + def createFrameDecoder(): ByteToMessageDecoder = { + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 8 + // lengthAdjustment = -8, i.e. exclude the 8 byte length itself + // initialBytesToStrip = 8, i.e. strip out the length field itself + new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 8, -8, 8) + } + + // TODO(rxin): Make sure these work for all charsets. + def readBlockId(in: ByteBuf): String = { + val numBytesToRead = in.readByte().toInt + val bytes = new Array[Byte](numBytesToRead) + in.readBytes(bytes) + new String(bytes) + } + + def writeBlockId(out: ByteBuf, blockId: String): Unit = { + out.writeByte(blockId.length) + out.writeBytes(blockId.getBytes) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala deleted file mode 100644 index 162e9cc6828d..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -/** - * Header describing a block. This is used only in the server pipeline. - * - * [[BlockServerHandler]] creates this, and [[BlockHeaderEncoder]] encodes it. - * - * @param blockSize length of the block content, excluding the length itself. - * If positive, this is the header for a block (not part of the header). - * If negative, this is the header and content for an error message. - * @param blockId block id - * @param error some error message from reading the block - */ -private[server] -class BlockHeader(val blockSize: Int, val blockId: String, val error: Option[String] = None) diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala deleted file mode 100644 index 8e4dda4ef859..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import io.netty.buffer.ByteBuf -import io.netty.channel.ChannelHandlerContext -import io.netty.handler.codec.MessageToByteEncoder - -/** - * A simple encoder for BlockHeader. See [[BlockServer]] for the server to client protocol. - */ -private[server] -class BlockHeaderEncoder extends MessageToByteEncoder[BlockHeader] { - override def encode(ctx: ChannelHandlerContext, msg: BlockHeader, out: ByteBuf): Unit = { - // message = message length (4 bytes) + block id length (4 bytes) + block id + block data - // message length = block id length (4 bytes) + size of block id + size of block data - val blockIdBytes = msg.blockId.getBytes - msg.error match { - case Some(errorMsg) => - val errorBytes = errorMsg.getBytes - out.writeInt(4 + blockIdBytes.length + errorBytes.size) - out.writeInt(-blockIdBytes.length) // use negative block id length to represent errors - out.writeBytes(blockIdBytes) // next is blockId itself - out.writeBytes(errorBytes) // error message - case None => - out.writeInt(4 + blockIdBytes.length + msg.blockSize) - out.writeInt(blockIdBytes.length) // First 4 bytes is blockId length - out.writeBytes(blockIdBytes) // next is blockId itself - // msg of size blockSize will be written by ServerHandler - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala deleted file mode 100644 index cc70bd0c5c47..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import io.netty.channel.ChannelInitializer -import io.netty.channel.socket.SocketChannel -import io.netty.handler.codec.LineBasedFrameDecoder -import io.netty.handler.codec.string.StringDecoder -import io.netty.util.CharsetUtil -import org.apache.spark.storage.BlockDataProvider - - -/** Channel initializer that sets up the pipeline for the BlockServer. */ -private[netty] -class BlockServerChannelInitializer(dataProvider: BlockDataProvider) - extends ChannelInitializer[SocketChannel] { - - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024 - .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8)) - .addLast("blockHeaderEncoder", new BlockHeaderEncoder) - .addLast("handler", new BlockServerHandler(dataProvider)) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala deleted file mode 100644 index 40dd5e5d1a2a..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import java.io.FileInputStream -import java.nio.ByteBuffer -import java.nio.channels.FileChannel - -import io.netty.buffer.Unpooled -import io.netty.channel._ - -import org.apache.spark.Logging -import org.apache.spark.storage.{FileSegment, BlockDataProvider} - - -/** - * A handler that processes requests from clients and writes block data back. - * - * The messages should have been processed by a LineBasedFrameDecoder and a StringDecoder first - * so channelRead0 is called once per line (i.e. per block id). - */ -private[server] -class BlockServerHandler(dataProvider: BlockDataProvider) - extends SimpleChannelInboundHandler[String] with Logging { - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause) - ctx.close() - } - - override def channelRead0(ctx: ChannelHandlerContext, blockId: String): Unit = { - def client = ctx.channel.remoteAddress.toString - - // A helper function to send error message back to the client. - def respondWithError(error: String): Unit = { - ctx.writeAndFlush(new BlockHeader(-1, blockId, Some(error))).addListener( - new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (!future.isSuccess) { - // TODO: Maybe log the success case as well. - logError(s"Error sending error back to $client", future.cause) - ctx.close() - } - } - } - ) - } - - def writeFileSegment(segment: FileSegment): Unit = { - // Send error message back if the block is too large. Even though we are capable of sending - // large (2G+) blocks, the receiving end cannot handle it so let's fail fast. - // Once we fixed the receiving end to be able to process large blocks, this should be removed. - // Also make sure we update BlockHeaderEncoder to support length > 2G. - - // See [[BlockHeaderEncoder]] for the way length is encoded. - if (segment.length + blockId.length + 4 > Int.MaxValue) { - respondWithError(s"Block $blockId size ($segment.length) greater than 2G") - return - } - - var fileChannel: FileChannel = null - try { - fileChannel = new FileInputStream(segment.file).getChannel - } catch { - case e: Exception => - logError( - s"Error opening channel for $blockId in ${segment.file} for request from $client", e) - respondWithError(e.getMessage) - } - - // Found the block. Send it back. - if (fileChannel != null) { - // Write the header and block data. In the case of failures, the listener on the block data - // write should close the connection. - ctx.write(new BlockHeader(segment.length.toInt, blockId)) - - val region = new DefaultFileRegion(fileChannel, segment.offset, segment.length) - ctx.writeAndFlush(region).addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (future.isSuccess) { - logTrace(s"Sent block $blockId (${segment.length} B) back to $client") - } else { - logError(s"Error sending block $blockId to $client; closing connection", future.cause) - ctx.close() - } - } - }) - } - } - - def writeByteBuffer(buf: ByteBuffer): Unit = { - ctx.write(new BlockHeader(buf.remaining, blockId)) - ctx.writeAndFlush(Unpooled.wrappedBuffer(buf)).addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (future.isSuccess) { - logTrace(s"Sent block $blockId (${buf.remaining} B) back to $client") - } else { - logError(s"Error sending block $blockId to $client; closing connection", future.cause) - ctx.close() - } - } - }) - } - - logTrace(s"Received request from $client to fetch block $blockId") - - var blockData: Either[FileSegment, ByteBuffer] = null - - // First make sure we can find the block. If not, send error back to the user. - try { - blockData = dataProvider.getBlockData(blockId) - } catch { - case e: Exception => - logError(s"Error opening block $blockId for request from $client", e) - respondWithError(e.getMessage) - return - } - - blockData match { - case Left(segment) => writeFileSegment(segment) - case Right(buf) => writeByteBuffer(buf) - } - - } // end of channelRead0 -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index b389b9a2022c..e942b43d9cc4 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -71,7 +71,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa /** * Tear down the transfer service. */ - override def stop(): Unit = { + override def close(): Unit = { if (cm != null) { cm.stop() } @@ -96,21 +96,25 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa val bufferMessage = message.asInstanceOf[BufferMessage] val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - for (blockMessage <- blockMessageArray) { + for (blockMessage: BlockMessage <- blockMessageArray) { if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - listener.onBlockFetchFailure( - new SparkException(s"Unexpected message ${blockMessage.getType} received from $cmId")) + if (blockMessage.getId != null) { + listener.onBlockFetchFailure(blockMessage.getId.toString, + new SparkException(s"Unexpected message ${blockMessage.getType} received from $cmId")) + } } else { val blockId = blockMessage.getId val networkSize = blockMessage.getData.limit() listener.onBlockFetchSuccess( - blockId.toString, new NioByteBufferManagedBuffer(blockMessage.getData)) + blockId.toString, new NioManagedBuffer(blockMessage.getData)) } } }(cm.futureExecContext) future.onFailure { case exception => - listener.onBlockFetchFailure(exception) + blockIds.foreach { blockId => + listener.onBlockFetchFailure(blockId, exception) + } }(cm.futureExecContext) } @@ -189,7 +193,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) { val startTimeMs = System.currentTimeMillis() logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes) - blockDataManager.putBlockData(blockId, new NioByteBufferManagedBuffer(bytes), level) + blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level) logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + " with data size: " + bytes.limit) } @@ -197,9 +201,9 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa private def getBlock(blockId: String): ByteBuffer = { val startTimeMs = System.currentTimeMillis() logDebug("GetBlock " + blockId + " started from " + startTimeMs) - val buffer = blockDataManager.getBlockData(blockId).orNull + val buffer = blockDataManager.getBlockData(blockId) logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + " and got buffer " + buffer) - if (buffer == null) null else buffer.nioByteBuffer() + buffer.nioByteBuffer() } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala b/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala deleted file mode 100644 index 5b6d08663083..000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.nio.ByteBuffer - - -/** - * An interface for providing data for blocks. - * - * getBlockData returns either a FileSegment (for zero-copy send), or a ByteBuffer. - * - * Aside from unit tests, [[BlockManager]] is the main class that implements this. - */ -private[spark] trait BlockDataProvider { - def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d1bee3d2c033..abef0e171a5c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -206,21 +206,20 @@ private[spark] class BlockManager( } /** - * Interface to get local block data. - * - * @return Some(buffer) if the block exists locally, and None if it doesn't. + * Interface to get local block data. Throws an exception if the block cannot be found or + * cannot be read successfully. */ - override def getBlockData(blockId: String): Option[ManagedBuffer] = { + override def getBlockData(blockId: String): ManagedBuffer = { val bid = BlockId(blockId) if (bid.isShuffle) { - Some(shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId])) + shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId]) } else { val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] if (blockBytesOpt.isDefined) { val buffer = blockBytesOpt.get - Some(new NioByteBufferManagedBuffer(buffer)) + new NioManagedBuffer(buffer) } else { - None + throw new BlockNotFoundException(blockId) } } } @@ -334,17 +333,6 @@ private[spark] class BlockManager( locations } - /** - * A short-circuited method to get blocks directly from disk. This is used for getting - * shuffle blocks. It is safe to do so without a lock on block info since disk store - * never deletes (recent) items. - */ - def getLocalShuffleFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { - val buf = shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) - val is = wrapForCompression(blockId, buf.inputStream()) - Some(serializer.newInstance().deserializeStream(is).asIterator) - } - /** * Get block from local block manager. */ @@ -804,7 +792,7 @@ private[spark] class BlockManager( try { blockTransferService.uploadBlockSync( - peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel) + peer.host, peer.port, blockId.toString, new NioManagedBuffer(data), tLevel) } catch { case e: Exception => logError(s"Failed to replicate block to $peer", e) @@ -1039,7 +1027,7 @@ private[spark] class BlockManager( } def stop(): Unit = { - blockTransferService.stop() + blockTransferService.close() diskBlockManager.stop() actorSystem.stop(slaveActor) blockInfo.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 71b276b5f18e..4e69a5d9e3ec 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -23,10 +23,10 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashSet import scala.collection.mutable.Queue -import org.apache.spark.{TaskContext, Logging} +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService} import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils +import org.apache.spark.util.{CompletionIterator, Utils} /** @@ -88,17 +88,51 @@ final class ShuffleBlockFetcherIterator( */ private[this] val results = new LinkedBlockingQueue[FetchResult] - // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that - // the number of bytes in flight is limited to maxBytesInFlight + /** + * Current [[FetchResult]] being processed. We track this so we can release the current buffer + * in case of a runtime exception when processing the current buffer. + */ + private[this] var currentResult: FetchResult = null + + /** + * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that + * the number of bytes in flight is limited to maxBytesInFlight. + */ private[this] val fetchRequests = new Queue[FetchRequest] - // Current bytes in flight from our requests + /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + /** + * Whether the iterator is still active. If isZombie is true, the callback interface will no + * longer place fetched blocks into [[results]]. + */ + @volatile private[this] var isZombie = false + initialize() + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[this] def cleanup() { + isZombie = true + // Release the current buffer if necessary + if (currentResult != null && !currentResult.failed) { + currentResult.buf.release() + } + + // Release buffers in the results queue + val iter = results.iterator() + while (iter.hasNext) { + val result = iter.next() + if (!result.failed) { + result.buf.release() + } + } + } + private[this] def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) @@ -110,27 +144,25 @@ final class ShuffleBlockFetcherIterator( blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds, new BlockFetchingListener { - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), - () => serializer.newInstance().deserializeStream( - blockManager.wrapForCompression(BlockId(blockId), data.inputStream())).asIterator - )) - shuffleMetrics.remoteBytesRead += data.size - shuffleMetrics.remoteBlocksFetched += 1 - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), buf)) + shuffleMetrics.remoteBytesRead += buf.size + shuffleMetrics.remoteBlocksFetched += 1 + } + logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } - override def onBlockFetchFailure(e: Throwable): Unit = { + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - // Note that there is a chance that some blocks have been fetched successfully, but we - // still add them to the failed queue. This is fine because when the caller see a - // FetchFailedException, it is going to fail the entire task anyway. - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) - } + results.put(new FetchResult(BlockId(blockId), -1, null)) } - } - ) + }) } private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { @@ -138,7 +170,7 @@ final class ShuffleBlockFetcherIterator( // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 // nodes, rather than blocking on reading output from one node. val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) + logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. @@ -185,26 +217,34 @@ final class ShuffleBlockFetcherIterator( remoteRequests } + /** + * Fetch the local blocks while we are fetching remote blocks. This is ok because + * [[ManagedBuffer]]'s memory is allocated lazily when we create the input stream, so all we + * track in-memory are the ManagedBuffer references themselves. + */ private[this] def fetchLocalBlocks() { - // Get the local blocks while remote blocks are being fetched. Note that it's okay to do - // these all at once because they will just memory-map some files, so they won't consume - // any memory that might exceed our maxBytesInFlight - for (id <- localBlocks) { + val iter = localBlocks.iterator + while (iter.hasNext) { + val blockId = iter.next() try { + val buf = blockManager.getBlockData(blockId.toString) shuffleMetrics.localBlocksFetched += 1 - results.put(new FetchResult( - id, 0, () => blockManager.getLocalShuffleFromDisk(id, serializer).get)) - logDebug("Got local block " + id) + buf.retain() + results.put(new FetchResult(blockId, 0, buf)) } catch { case e: Exception => + // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) - results.put(new FetchResult(id, -1, null)) + results.put(new FetchResult(blockId, -1, null)) return } } } private[this] def initialize(): Unit = { + // Add a task completion callback (called in both success case and failure case) to cleanup. + context.addTaskCompletionListener(_ => cleanup()) + // Split local and remote blocks. val remoteRequests = splitLocalRemoteBlocks() // Add the remote requests into our queue in a random order @@ -229,7 +269,8 @@ final class ShuffleBlockFetcherIterator( override def next(): (BlockId, Option[Iterator[Any]]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() - val result = results.take() + currentResult = results.take() + val result = currentResult val stopFetchWait = System.currentTimeMillis() shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) if (!result.failed) { @@ -240,7 +281,21 @@ final class ShuffleBlockFetcherIterator( (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { sendRequest(fetchRequests.dequeue()) } - (result.blockId, if (result.failed) None else Some(result.deserialize())) + + val iteratorOpt: Option[Iterator[Any]] = if (result.failed) { + None + } else { + val is = blockManager.wrapForCompression(result.blockId, result.buf.inputStream()) + val iter = serializer.newInstance().deserializeStream(is).asIterator + Some(CompletionIterator[Any, Iterator[Any]](iter, { + // Once the iterator is exhausted, release the buffer and set currentResult to null + // so we don't release it again in cleanup. + currentResult = null + result.buf.release() + })) + } + + (result.blockId, iteratorOpt) } } @@ -254,7 +309,7 @@ object ShuffleBlockFetcherIterator { * @param blocks Sequence of tuple, where the first element is the block id, * and the second element is the estimated size, used to calculate bytesInFlight. */ - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { + case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) { val size = blocks.map(_._2).sum } @@ -262,10 +317,11 @@ object ShuffleBlockFetcherIterator { * Result of a fetch from a remote block. A failure is represented as size == -1. * @param blockId block id * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. - * @param deserialize closure to return the result in the form of an Iterator. + * Note that this is NOT the exact bytes. -1 if failure is present. + * @param buf [[ManagedBuffer]] for the content. null is error. */ - class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { + case class FetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) { def failed: Boolean = size == -1 + if (failed) assert(buf == null) else assert(buf != null) } } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 1e35abaab535..2fc7c7d9b831 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -42,7 +42,7 @@ class StorageLevel private( extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. - private def this(flags: Int, replication: Int) { + private[spark] def this(flags: Int, replication: Int) { this((flags & 8) != 0, (flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } @@ -98,6 +98,7 @@ class StorageLevel private( } override def writeExternal(out: ObjectOutput) { + /* If the wire protocol changes, please also update [[ClientRequestEncoder]] */ out.writeByte(toInt) out.writeByte(_replication) } diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala new file mode 100644 index 000000000000..2d4baafcf03d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import scala.concurrent.{Await, future} +import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext.Implicits.global + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.SparkConf + + +class BlockClientFactorySuite extends FunSuite with BeforeAndAfterAll { + + private val conf = new SparkConf + private var server1: BlockServer = _ + private var server2: BlockServer = _ + + override def beforeAll() { + server1 = new BlockServer(new NettyConfig(conf), null) + server2 = new BlockServer(new NettyConfig(conf), null) + } + + override def afterAll() { + if (server1 != null) { + server1.close() + } + if (server2 != null) { + server2.close() + } + } + + test("BlockClients created are active and reused") { + val factory = new BlockClientFactory(conf) + val c1 = factory.createClient(server1.hostName, server1.port) + val c2 = factory.createClient(server1.hostName, server1.port) + val c3 = factory.createClient(server2.hostName, server2.port) + assert(c1.isActive) + assert(c3.isActive) + assert(c1 === c2) + assert(c1 !== c3) + factory.close() + } + + test("never return inactive clients") { + val factory = new BlockClientFactory(conf) + val c1 = factory.createClient(server1.hostName, server1.port) + c1.close() + + // Block until c1 is no longer active + val f = future { + while (c1.isActive) { + Thread.sleep(10) + } + } + Await.result(f, 3.seconds) + assert(!c1.isActive) + + // Create c2, which should be different from c1 + val c2 = factory.createClient(server1.hostName, server1.port) + assert(c1 !== c2) + factory.close() + } + + test("BlockClients are close when BlockClientFactory is stopped") { + val factory = new BlockClientFactory(conf) + val c1 = factory.createClient(server1.hostName, server1.port) + val c2 = factory.createClient(server2.hostName, server2.port) + assert(c1.isActive) + assert(c2.isActive) + factory.close() + assert(!c1.isActive) + assert(!c2.isActive) + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala new file mode 100644 index 000000000000..4c3a64908157 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio.ByteBuffer + +import io.netty.buffer.Unpooled +import io.netty.channel.embedded.EmbeddedChannel + +import org.mockito.Mockito._ +import org.mockito.Matchers.{any, eq => meq} + +import org.scalatest.{FunSuite, PrivateMethodTester} + +import org.apache.spark.network._ + + +class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { + + /** Helper method to get num. outstanding requests from a private field using reflection. */ + private def sizeOfOutstandingRequests(handler: BlockClientHandler): Int = { + val f = handler.getClass.getDeclaredField( + "org$apache$spark$network$netty$BlockClientHandler$$outstandingFetches") + f.setAccessible(true) + f.get(handler).asInstanceOf[java.util.Map[_, _]].size + } + + test("handling block data (successful fetch)") { + val blockId = "test_block" + val blockData = "blahblahblahblahblah" + val handler = new BlockClientHandler + val listener = mock(classOf[BlockFetchingListener]) + handler.addFetchRequest(blockId, listener) + assert(sizeOfOutstandingRequests(handler) === 1) + + val channel = new EmbeddedChannel(handler) + val buf = ByteBuffer.allocate(blockData.size) // 4 bytes for the length field itself + buf.put(blockData.getBytes) + buf.flip() + + channel.writeInbound(BlockFetchSuccess(blockId, new NioManagedBuffer(buf))) + verify(listener, times(1)).onBlockFetchSuccess(meq(blockId), any()) + assert(sizeOfOutstandingRequests(handler) === 0) + assert(channel.finish() === false) + } + + test("handling error message (failed fetch)") { + val blockId = "test_block" + val handler = new BlockClientHandler + val listener = mock(classOf[BlockFetchingListener]) + handler.addFetchRequest(blockId, listener) + assert(sizeOfOutstandingRequests(handler) === 1) + + val channel = new EmbeddedChannel(handler) + channel.writeInbound(BlockFetchFailure(blockId, "some error msg")) + verify(listener, times(0)).onBlockFetchSuccess(any(), any()) + verify(listener, times(1)).onBlockFetchFailure(meq(blockId), any()) + assert(sizeOfOutstandingRequests(handler) === 0) + assert(channel.finish() === false) + } + + test("clear all outstanding request upon uncaught exception") { + val handler = new BlockClientHandler + val listener = mock(classOf[BlockFetchingListener]) + handler.addFetchRequest("b1", listener) + handler.addFetchRequest("b2", listener) + handler.addFetchRequest("b3", listener) + assert(sizeOfOutstandingRequests(handler) === 3) + + val channel = new EmbeddedChannel(handler) + channel.writeInbound(BlockFetchSuccess("b1", new NettyManagedBuffer(Unpooled.buffer()))) + channel.pipeline().fireExceptionCaught(new Exception("duh duh duh")) + + // should fail both b2 and b3 + verify(listener, times(1)).onBlockFetchSuccess(any(), any()) + verify(listener, times(2)).onBlockFetchFailure(any(), any()) + assert(sizeOfOutstandingRequests(handler) === 0) + assert(channel.finish() === false) + } + + test("clear all outstanding request upon connection close") { + val handler = new BlockClientHandler + val listener = mock(classOf[BlockFetchingListener]) + handler.addFetchRequest("c1", listener) + handler.addFetchRequest("c2", listener) + handler.addFetchRequest("c3", listener) + assert(sizeOfOutstandingRequests(handler) === 3) + + val channel = new EmbeddedChannel(handler) + channel.writeInbound(BlockFetchSuccess("c1", new NettyManagedBuffer(Unpooled.buffer()))) + channel.finish() + + // should fail both b2 and b3 + verify(listener, times(1)).onBlockFetchSuccess(any(), any()) + verify(listener, times(2)).onBlockFetchFailure(any(), any()) + assert(sizeOfOutstandingRequests(handler) === 0) + assert(channel.finish() === false) + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala new file mode 100644 index 000000000000..8d1b7276f408 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import io.netty.channel.embedded.EmbeddedChannel + +import org.scalatest.FunSuite + +import org.apache.spark.api.java.StorageLevels + + +/** + * Test client/server encoder/decoder protocol. + */ +class ProtocolSuite extends FunSuite { + + /** + * Helper to test server to client message protocol by encoding a message and decoding it. + */ + private def testServerToClient(msg: ServerResponse) { + val serverChannel = new EmbeddedChannel(new ServerResponseEncoder) + serverChannel.writeOutbound(msg) + + val clientChannel = new EmbeddedChannel( + ProtocolUtils.createFrameDecoder(), + new ServerResponseDecoder) + + // Drain all server outbound messages and write them to the client's server decoder. + while (!serverChannel.outboundMessages().isEmpty) { + clientChannel.writeInbound(serverChannel.readOutbound()) + } + + assert(clientChannel.inboundMessages().size === 1) + // Must put "msg === ..." instead of "... === msg" since only TestManagedBuffer equals is + // overridden. + assert(msg === clientChannel.readInbound()) + } + + /** + * Helper to test client to server message protocol by encoding a message and decoding it. + */ + private def testClientToServer(msg: ClientRequest) { + val clientChannel = new EmbeddedChannel(new ClientRequestEncoder) + clientChannel.writeOutbound(msg) + + val serverChannel = new EmbeddedChannel( + ProtocolUtils.createFrameDecoder(), + new ClientRequestDecoder) + + // Drain all client outbound messages and write them to the server's decoder. + while (!clientChannel.outboundMessages().isEmpty) { + serverChannel.writeInbound(clientChannel.readOutbound()) + } + + assert(serverChannel.inboundMessages().size === 1) + // Must put "msg === ..." instead of "... === msg" since only TestManagedBuffer equals is + // overridden. + assert(msg === serverChannel.readInbound()) + } + + test("server to client protocol - BlockFetchSuccess(\"a1234\", new TestManagedBuffer(10))") { + testServerToClient(BlockFetchSuccess("a1234", new TestManagedBuffer(10))) + } + + test("server to client protocol - BlockFetchSuccess(\"\", new TestManagedBuffer(0))") { + testServerToClient(BlockFetchSuccess("", new TestManagedBuffer(0))) + } + + test("server to client protocol - BlockFetchFailure(\"abcd\", \"this is an error\")") { + testServerToClient(BlockFetchFailure("abcd", "this is an error")) + } + + test("server to client protocol - BlockFetchFailure(\"\", \"\")") { + testServerToClient(BlockFetchFailure("", "")) + } + + test("client to server protocol - BlockFetchRequest(Seq.empty[String])") { + testClientToServer(BlockFetchRequest(Seq.empty[String])) + } + + test("client to server protocol - BlockFetchRequest(Seq(\"b1\"))") { + testClientToServer(BlockFetchRequest(Seq("b1"))) + } + + test("client to server protocol - BlockFetchRequest(Seq(\"b1\", \"b2\", \"b3\"))") { + testClientToServer(BlockFetchRequest(Seq("b1", "b2", "b3"))) + } + + test("client to server protocol - BlockUploadRequest(\"\", new TestManagedBuffer(0))") { + testClientToServer( + BlockUploadRequest("", new TestManagedBuffer(0), StorageLevels.MEMORY_AND_DISK)) + } + + test("client to server protocol - BlockUploadRequest(\"b_upload\", new TestManagedBuffer(10))") { + testClientToServer( + BlockUploadRequest("b_upload", new TestManagedBuffer(10), StorageLevels.MEMORY_AND_DISK_2)) + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index 02d0ffc86f58..35ff90a2dabc 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -24,26 +24,28 @@ import java.util.concurrent.{TimeUnit, Semaphore} import scala.collection.JavaConversions._ -import io.netty.buffer.{ByteBufUtil, Unpooled} +import io.netty.buffer.Unpooled import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.Span +import org.scalatest.time.Seconds import org.apache.spark.SparkConf -import org.apache.spark.network.netty.client.{BlockClientListener, ReferenceCountedBuffer, BlockFetchingClientFactory} -import org.apache.spark.network.netty.server.BlockServer -import org.apache.spark.storage.{FileSegment, BlockDataProvider} +import org.apache.spark.network._ +import org.apache.spark.storage.{BlockNotFoundException, StorageLevel} /** - * Test suite that makes sure the server and the client implementations share the same protocol. - */ +* Test cases that create real clients and servers and connect. +*/ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { val bufSize = 100000 var buf: ByteBuffer = _ var testFile: File = _ var server: BlockServer = _ - var clientFactory: BlockFetchingClientFactory = _ + var clientFactory: BlockClientFactory = _ val bufferBlockId = "buffer_block" val fileBlockId = "file_block" @@ -63,24 +65,29 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { fp.write(fileContent) fp.close() - server = new BlockServer(new SparkConf, new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { + server = new BlockServer(new NettyConfig(new SparkConf), new BlockDataManager { + override def getBlockData(blockId: String): ManagedBuffer = { if (blockId == bufferBlockId) { - Right(buf) + new NioManagedBuffer(buf) } else if (blockId == fileBlockId) { - Left(new FileSegment(testFile, 10, testFile.length - 25)) + new FileSegmentManagedBuffer(testFile, 10, testFile.length - 25) } else { - throw new Exception("Unknown block id " + blockId) + throw new BlockNotFoundException(blockId) } } + + /** + * Put the block locally, using the given storage level. + */ + def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = ??? }) - clientFactory = new BlockFetchingClientFactory(new SparkConf) + clientFactory = new BlockClientFactory(new SparkConf) } override def afterAll() = { - server.stop() - clientFactory.stop() + server.close() + clientFactory.close() } /** A ByteBuf for buffer_block */ @@ -89,31 +96,30 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { /** A ByteBuf for file_block */ lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25) - def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ReferenceCountedBuffer], Set[String]) = - { + def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ManagedBuffer], Set[String]) = { val client = clientFactory.createClient(server.hostName, server.port) val sem = new Semaphore(0) val receivedBlockIds = Collections.synchronizedSet(new HashSet[String]) val errorBlockIds = Collections.synchronizedSet(new HashSet[String]) - val receivedBuffers = Collections.synchronizedSet(new HashSet[ReferenceCountedBuffer]) + val receivedBuffers = Collections.synchronizedSet(new HashSet[ManagedBuffer]) client.fetchBlocks( blockIds, - new BlockClientListener { - override def onFetchFailure(blockId: String, errorMsg: String): Unit = { + new BlockFetchingListener { + override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { errorBlockIds.add(blockId) sem.release() } - override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = { - receivedBlockIds.add(blockId) + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { data.retain() + receivedBlockIds.add(blockId) receivedBuffers.add(data) sem.release() } } ) - if (!sem.tryAcquire(blockIds.size, 30, TimeUnit.SECONDS)) { + if (!sem.tryAcquire(blockIds.size, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server") } client.close() @@ -123,7 +129,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { test("fetch a ByteBuffer block") { val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId)) assert(blockIds === Set(bufferBlockId)) - assert(buffers.map(_.underlying) === Set(byteBufferBlockReference)) + assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) assert(failBlockIds.isEmpty) buffers.foreach(_.release()) } @@ -131,7 +137,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { test("fetch a FileSegment block via zero-copy send") { val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId)) assert(blockIds === Set(fileBlockId)) - assert(buffers.map(_.underlying) === Set(fileBlockReference)) + assert(buffers.map(_.convertToNetty()) === Set(fileBlockReference)) assert(failBlockIds.isEmpty) buffers.foreach(_.release()) } @@ -141,12 +147,13 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { assert(blockIds.isEmpty) assert(buffers.isEmpty) assert(failBlockIds === Set("random-block")) + buffers.foreach(_.release()) } test("fetch both ByteBuffer block and FileSegment block") { val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId)) assert(blockIds === Set(bufferBlockId, fileBlockId)) - assert(buffers.map(_.underlying) === Set(byteBufferBlockReference, fileBlockReference)) + assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference, fileBlockReference)) assert(failBlockIds.isEmpty) buffers.foreach(_.release()) } @@ -154,8 +161,14 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { test("fetch both ByteBuffer block and a non-existent block") { val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block")) assert(blockIds === Set(bufferBlockId)) - assert(buffers.map(_.underlying) === Set(byteBufferBlockReference)) + assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) assert(failBlockIds === Set("random-block")) buffers.foreach(_.release()) } + + test("shutting down server should also close client") { + val client = clientFactory.createClient(server.hostName, server.port) + server.close() + eventually(timeout(Span(5, Seconds))) { assert(!client.isActive) } + } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala new file mode 100644 index 000000000000..e47e4d03fa89 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.io.InputStream +import java.nio.ByteBuffer + +import io.netty.buffer.Unpooled + +import org.apache.spark.network.{NettyManagedBuffer, ManagedBuffer} + + +/** + * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1). + * + * Used for testing. + */ +class TestManagedBuffer(len: Int) extends ManagedBuffer { + + require(len <= Byte.MaxValue) + + private val byteArray: Array[Byte] = Array.tabulate[Byte](len)(_.toByte) + + private val underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray)) + + override def size: Long = underlying.size + + override private[network] def convertToNetty(): AnyRef = underlying.convertToNetty() + + override def nioByteBuffer(): ByteBuffer = underlying.nioByteBuffer() + + override def inputStream(): InputStream = underlying.inputStream() + + override def toString: String = s"${getClass.getName}($len)" + + override def equals(other: Any): Boolean = other match { + case otherBuf: ManagedBuffer => + val nioBuf = otherBuf.nioByteBuffer() + if (nioBuf.remaining() != len) { + return false + } else { + var i = 0 + while (i < len) { + if (nioBuf.get() != i) { + return false + } + i += 1 + } + return true + } + case _ => false + } + + override def retain(): this.type = this + + override def release(): this.type = this +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala deleted file mode 100644 index 903ab09ae432..000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import java.nio.ByteBuffer - -import io.netty.buffer.Unpooled -import io.netty.channel.embedded.EmbeddedChannel - -import org.scalatest.{PrivateMethodTester, FunSuite} - - -class BlockFetchingClientHandlerSuite extends FunSuite with PrivateMethodTester { - - test("handling block data (successful fetch)") { - val blockId = "test_block" - val blockData = "blahblahblahblahblah" - val totalLength = 4 + blockId.length + blockData.length - - var parsedBlockId: String = "" - var parsedBlockData: String = "" - val handler = new BlockFetchingClientHandler - handler.addRequest(blockId, - new BlockClientListener { - override def onFetchFailure(blockId: String, errorMsg: String): Unit = ??? - override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer): Unit = { - parsedBlockId = bid - val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining) - refCntBuf.byteBuffer().get(bytes) - parsedBlockData = new String(bytes) - } - } - ) - - val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) - assert(handler.invokePrivate(outstandingRequests()).size === 1) - - val channel = new EmbeddedChannel(handler) - val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself - buf.putInt(totalLength) - buf.putInt(blockId.length) - buf.put(blockId.getBytes) - buf.put(blockData.getBytes) - buf.flip() - - channel.writeInbound(Unpooled.wrappedBuffer(buf)) - assert(parsedBlockId === blockId) - assert(parsedBlockData === blockData) - - assert(handler.invokePrivate(outstandingRequests()).size === 0) - - channel.close() - } - - test("handling error message (failed fetch)") { - val blockId = "test_block" - val errorMsg = "error erro5r error err4or error3 error6 error erro1r" - val totalLength = 4 + blockId.length + errorMsg.length - - var parsedBlockId: String = "" - var parsedErrorMsg: String = "" - val handler = new BlockFetchingClientHandler - handler.addRequest(blockId, new BlockClientListener { - override def onFetchFailure(bid: String, msg: String) ={ - parsedBlockId = bid - parsedErrorMsg = msg - } - override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer) = ??? - }) - - val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) - assert(handler.invokePrivate(outstandingRequests()).size === 1) - - val channel = new EmbeddedChannel(handler) - val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself - buf.putInt(totalLength) - buf.putInt(-blockId.length) - buf.put(blockId.getBytes) - buf.put(errorMsg.getBytes) - buf.flip() - - channel.writeInbound(Unpooled.wrappedBuffer(buf)) - assert(parsedBlockId === blockId) - assert(parsedErrorMsg === errorMsg) - - assert(handler.invokePrivate(outstandingRequests()).size === 0) - - channel.close() - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala deleted file mode 100644 index 3ee281cb1350..000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import io.netty.buffer.ByteBuf -import io.netty.channel.embedded.EmbeddedChannel - -import org.scalatest.FunSuite - - -class BlockHeaderEncoderSuite extends FunSuite { - - test("encode normal block data") { - val blockId = "test_block" - val channel = new EmbeddedChannel(new BlockHeaderEncoder) - channel.writeOutbound(new BlockHeader(17, blockId, None)) - val out = channel.readOutbound().asInstanceOf[ByteBuf] - assert(out.readInt() === 4 + blockId.length + 17) - assert(out.readInt() === blockId.length) - - val blockIdBytes = new Array[Byte](blockId.length) - out.readBytes(blockIdBytes) - assert(new String(blockIdBytes) === blockId) - assert(out.readableBytes() === 0) - - channel.close() - } - - test("encode error message") { - val blockId = "error_block" - val errorMsg = "error encountered" - val channel = new EmbeddedChannel(new BlockHeaderEncoder) - channel.writeOutbound(new BlockHeader(17, blockId, Some(errorMsg))) - val out = channel.readOutbound().asInstanceOf[ByteBuf] - assert(out.readInt() === 4 + blockId.length + errorMsg.length) - assert(out.readInt() === -blockId.length) - - val blockIdBytes = new Array[Byte](blockId.length) - out.readBytes(blockIdBytes) - assert(new String(blockIdBytes) === blockId) - - val errorMsgBytes = new Array[Byte](errorMsg.length) - out.readBytes(errorMsgBytes) - assert(new String(errorMsgBytes) === errorMsg) - assert(out.readableBytes() === 0) - - channel.close() - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala deleted file mode 100644 index 3239c710f163..000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import java.io.{RandomAccessFile, File} -import java.nio.ByteBuffer - -import io.netty.buffer.{Unpooled, ByteBuf} -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler, DefaultFileRegion} -import io.netty.channel.embedded.EmbeddedChannel - -import org.scalatest.FunSuite - -import org.apache.spark.storage.{BlockDataProvider, FileSegment} - - -class BlockServerHandlerSuite extends FunSuite { - - test("ByteBuffer block") { - val expectedBlockId = "test_bytebuffer_block" - val buf = ByteBuffer.allocate(10000) - for (i <- 1 to 10000) { - buf.put(i.toByte) - } - buf.flip() - - val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = Right(buf) - })) - - channel.writeInbound(expectedBlockId) - assert(channel.outboundMessages().size === 2) - - val out1 = channel.readOutbound().asInstanceOf[BlockHeader] - val out2 = channel.readOutbound().asInstanceOf[ByteBuf] - - assert(out1.blockId === expectedBlockId) - assert(out1.blockSize === buf.remaining) - assert(out1.error === None) - - assert(out2.equals(Unpooled.wrappedBuffer(buf))) - - channel.close() - } - - test("FileSegment block via zero-copy") { - val expectedBlockId = "test_file_block" - - // Create random file data - val fileContent = new Array[Byte](1024) - scala.util.Random.nextBytes(fileContent) - val testFile = File.createTempFile("netty-test-file", "txt") - val fp = new RandomAccessFile(testFile, "rw") - fp.write(fileContent) - fp.close() - - val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { - Left(new FileSegment(testFile, 15, testFile.length - 25)) - } - })) - - channel.writeInbound(expectedBlockId) - assert(channel.outboundMessages().size === 2) - - val out1 = channel.readOutbound().asInstanceOf[BlockHeader] - val out2 = channel.readOutbound().asInstanceOf[DefaultFileRegion] - - assert(out1.blockId === expectedBlockId) - assert(out1.blockSize === testFile.length - 25) - assert(out1.error === None) - - assert(out2.count === testFile.length - 25) - assert(out2.position === 15) - } - - test("pipeline exception propagation") { - val blockServerHandler = new BlockServerHandler(new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = ??? - }) - val exceptionHandler = new SimpleChannelInboundHandler[String]() { - override def channelRead0(ctx: ChannelHandlerContext, msg: String): Unit = { - throw new Exception("this is an error") - } - } - - val channel = new EmbeddedChannel(exceptionHandler, blockServerHandler) - assert(channel.isOpen) - channel.writeInbound("a message to trigger the error") - assert(!channel.isOpen) - } -} diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala new file mode 100644 index 000000000000..0ade1bab18d7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{EOFException, OutputStream, InputStream} +import java.nio.ByteBuffer + +import scala.reflect.ClassTag + + +/** + * A serializer implementation that always return a single element in a deserialization stream. + */ +class TestSerializer extends Serializer { + override def newInstance() = new TestSerializerInstance +} + + +class TestSerializerInstance extends SerializerInstance { + override def serialize[T: ClassTag](t: T): ByteBuffer = ??? + + override def serializeStream(s: OutputStream): SerializationStream = ??? + + override def deserializeStream(s: InputStream) = new TestDeserializationStream + + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = ??? + + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = ??? +} + + +class TestDeserializationStream extends DeserializationStream { + + private var count = 0 + + override def readObject[T: ClassTag](): T = { + count += 1 + if (count == 2) { + throw new EOFException + } + new Object().asInstanceOf[T] + } + + override def close(): Unit = {} +} diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 809bd7092965..7d4086313fcc 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,8 +17,10 @@ package org.apache.spark.storage -import org.apache.spark.TaskContext -import org.apache.spark.network.{BlockFetchingListener, BlockTransferService} +import java.util.concurrent.Semaphore + +import scala.concurrent.future +import scala.concurrent.ExecutionContext.Implicits.global import org.mockito.Mockito._ import org.mockito.Matchers.{any, eq => meq} @@ -27,38 +29,63 @@ import org.mockito.stubbing.Answer import org.scalatest.FunSuite +import org.apache.spark.{SparkConf, TaskContext} +import org.apache.spark.network._ +import org.apache.spark.serializer.TestSerializer + class ShuffleBlockFetcherIteratorSuite extends FunSuite { + // Some of the tests are quite tricky because we are testing the cleanup behavior + // in the presence of faults. - test("handle local read failures in BlockManager") { + /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ + private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - val blockManager = mock(classOf[BlockManager]) - doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId - - val blIds = Array[BlockId]( - ShuffleBlockId(0,0,0), - ShuffleBlockId(0,1,0), - ShuffleBlockId(0,2,0), - ShuffleBlockId(0,3,0), - ShuffleBlockId(0,4,0)) - - val optItr = mock(classOf[Option[Iterator[Any]]]) - val answer = new Answer[Option[Iterator[Any]]] { - override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] { - throw new Exception + when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val blocks = invocation.getArguments()(2).asInstanceOf[Seq[String]] + val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener] + + for (blockId <- blocks) { + if (data.contains(BlockId(blockId))) { + listener.onBlockFetchSuccess(blockId, data(BlockId(blockId))) + } else { + listener.onBlockFetchFailure(blockId, new BlockNotFoundException(blockId)) + } + } } + }) + transfer + } + + private val conf = new SparkConf + + test("successful 3 local reads + 2 remote reads") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure blockManager.getBlockData would return the blocks + val localBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])) + localBlocks.foreach { case (blockId, buf) => + doReturn(buf).when(blockManager).getBlockData(meq(blockId.toString)) } - // 3rd block is going to fail - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any()) - doAnswer(answer).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any()) + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer]) + ) + + val transfer = createMockTransfer(remoteBlocks) - val bmId = BlockManagerId("test-client", "test-client", 1) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) + (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq), + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) ) val iterator = new ShuffleBlockFetcherIterator( @@ -66,118 +93,145 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { transfer, blockManager, blocksByAddress, - null, + new TestSerializer, 48 * 1024 * 1024) - // Without exhausting the iterator, the iterator should be lazy and not call - // getLocalShuffleFromDisk. - verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") - // the 2nd element of the tuple returned by iterator.next should be defined when - // fetching successfully - assert(iterator.next()._2.isDefined, - "1st element should be defined but is not actually defined") - verify(blockManager, times(1)).getLocalShuffleFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") - assert(iterator.next()._2.isDefined, - "2nd element should be defined but is not actually defined") - verify(blockManager, times(2)).getLocalShuffleFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") - // 3rd fetch should be failed - intercept[Exception] { - iterator.next() + // 3 local blocks fetched in initialization + verify(blockManager, times(3)).getBlockData(any()) + + for (i <- 0 until 5) { + assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") + val (blockId, subIterator) = iterator.next() + assert(subIterator.isDefined, + s"iterator should have 5 elements defined but actually has $i elements") + + // Make sure we release the buffer once the iterator is exhausted. + val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) + verify(mockBuf, times(0)).release() + subIterator.get.foreach(_ => Unit) // exhaust the iterator + verify(mockBuf, times(1)).release() } - verify(blockManager, times(3)).getLocalShuffleFromDisk(any(), any()) + + // 3 local blocks, and 2 remote blocks + // (but from the same block manager so one call to fetchBlocks) + verify(blockManager, times(3)).getBlockData(any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any()) } - test("handle local read successes") { - val transfer = mock(classOf[BlockTransferService]) + test("release current unexhausted buffer in case the task completes early") { val blockManager = mock(classOf[BlockManager]) - doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId - - val blIds = Array[BlockId]( - ShuffleBlockId(0,0,0), - ShuffleBlockId(0,1,0), - ShuffleBlockId(0,2,0), - ShuffleBlockId(0,3,0), - ShuffleBlockId(0,4,0)) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]) + ) - val optItr = mock(classOf[Option[Iterator[Any]]]) + // Semaphore to coordinate event sequence in two different threads. + val sem = new Semaphore(0) - // All blocks should be fetched successfully - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any()) + val transfer = mock(classOf[BlockTransferService]) + when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener] + future { + // Return the first two blocks, and wait till task completion before returning the 3rd one + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0))) + sem.acquire() + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0))) + } + } + }) - val bmId = BlockManagerId("test-client", "test-client", 1) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) - ) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + val taskContext = new TaskContext(0, 0, 0) val iterator = new ShuffleBlockFetcherIterator( - new TaskContext(0, 0, 0), + taskContext, transfer, blockManager, blocksByAddress, - null, + new TestSerializer, 48 * 1024 * 1024) - // Without exhausting the iterator, the iterator should be lazy and not call getLocalShuffleFromDisk. - verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 1st element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 2nd element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 3rd element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 4th element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 5th element is not actually defined") - - verify(blockManager, times(5)).getLocalShuffleFromDisk(any(), any()) + // Exhaust the first block, and then it should be released. + iterator.next()._2.get.foreach(_ => Unit) + verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() + + // Get the 2nd block but do not exhaust the iterator + val subIter = iterator.next()._2.get + + // Complete the task; then the 2nd block buffer should be exhausted + verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() + taskContext.markTaskCompleted() + verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release() + + // The 3rd block should not be retained because the iterator is already in zombie state + sem.release() + verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).retain() + verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release() } - test("handle remote fetch failures in BlockTransferService") { + test("fail all blocks if any of the remote request fails") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]) + ) + + // Semaphore to coordinate event sequence in two different threads. + val sem = new Semaphore(0) + val transfer = mock(classOf[BlockTransferService]) when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener] - listener.onBlockFetchFailure(new Exception("blah")) + future { + // Return the first block, and then fail. + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchFailure( + ShuffleBlockId(0, 1, 0).toString, new BlockNotFoundException("blah")) + listener.onBlockFetchFailure( + ShuffleBlockId(0, 2, 0).toString, new BlockNotFoundException("blah")) + sem.release() + } } }) - val blockManager = mock(classOf[BlockManager]) - - when(blockManager.blockManagerId).thenReturn(BlockManagerId("test-client", "test-client", 1)) - - val blId1 = ShuffleBlockId(0, 0, 0) - val blId2 = ShuffleBlockId(0, 1, 0) - val bmId = BlockManagerId("test-server", "test-server", 1) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, Seq((blId1, 1L), (blId2, 1L)))) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + val taskContext = new TaskContext(0, 0, 0) val iterator = new ShuffleBlockFetcherIterator( - new TaskContext(0, 0, 0), + taskContext, transfer, blockManager, blocksByAddress, - null, + new TestSerializer, 48 * 1024 * 1024) - iterator.foreach { case (_, iterOption) => - assert(!iterOption.isDefined) - } + // Continue only after the mock calls onBlockFetchFailure + sem.acquire() + + // The first block should be defined, and the last two are not defined (due to failure) + assert(iterator.next()._2.isDefined === true) + assert(iterator.next()._2.isDefined === false) + assert(iterator.next()._2.isDefined === false) } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 4076ebc6fc8d..0f8eefd2d429 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -38,6 +38,17 @@ object MimaExcludes { MimaBuild.excludeSparkPackage("deploy"), MimaBuild.excludeSparkPackage("graphx") ) ++ + // This is @DeveloperAPI, but Mima still gives false-positives: + MimaBuild.excludeSparkClass("scheduler.SparkListenerApplicationStart") ++ + Seq( + // This is @Experimental, but Mima still gives false-positives: + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.network.netty.PathResolver"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.network.netty.client.BlockClientListener") + ) ++ MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++ MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++ Seq(