From 5bb88f5fb7b02557d2c5438275b49993d0956e80 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 9 Sep 2014 00:29:33 -0700 Subject: [PATCH 01/28] [SPARK-3453] Refactor Netty module to use BlockTransferService. Also includes some partial support for uploading blocks. --- .../apache/spark/network/ManagedBuffer.scala | 19 +- ...FetchingClient.scala => BlockClient.scala} | 41 ++- ...Factory.scala => BlockClientFactory.scala} | 38 +-- .../network/netty/BlockClientHandler.scala | 86 +++++++ .../netty/{server => }/BlockServer.scala | 39 +-- .../network/netty/BlockServerHandler.scala | 98 +++++++ .../netty/NettyBlockTransferService.scala | 83 ++++++ .../spark/network/netty/PathResolver.scala | 25 -- .../netty/client/BlockClientListener.scala | 29 --- .../client/BlockFetchingClientHandler.scala | 103 -------- .../netty/client/LazyInitIterator.scala | 44 ---- .../netty/client/ReferenceCountedBuffer.scala | 47 ---- .../apache/spark/network/netty/protocol.scala | 243 ++++++++++++++++++ .../network/netty/server/BlockHeader.scala | 32 --- .../netty/server/BlockHeaderEncoder.scala | 47 ---- .../BlockServerChannelInitializer.scala | 40 --- .../netty/server/BlockServerHandler.scala | 140 ---------- .../spark/storage/BlockDataProvider.scala | 32 --- .../netty/BlockClientHandlerSuite.scala | 129 ++++++++++ .../spark/network/netty/ProtocolSuite.scala | 88 +++++++ .../netty/ServerClientIntegrationSuite.scala | 90 ++++--- .../network/netty/TestManagedBuffer.scala | 68 +++++ .../BlockFetchingClientHandlerSuite.scala | 105 -------- .../server/BlockHeaderEncoderSuite.scala | 64 ----- .../server/BlockServerHandlerSuite.scala | 107 -------- 25 files changed, 899 insertions(+), 938 deletions(-) rename core/src/main/scala/org/apache/spark/network/netty/{client/BlockFetchingClient.scala => BlockClient.scala} (70%) rename core/src/main/scala/org/apache/spark/network/netty/{client/BlockFetchingClientFactory.scala => BlockClientFactory.scala} (66%) create mode 100644 core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala rename core/src/main/scala/org/apache/spark/network/netty/{server => }/BlockServer.scala (77%) create mode 100644 core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala create mode 100644 core/src/main/scala/org/apache/spark/network/netty/protocol.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala delete mode 100644 core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala delete mode 100644 core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala create mode 100644 core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index dcecb6beeea9..d6e1216c0e4d 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -22,7 +22,8 @@ import java.nio.ByteBuffer import java.nio.channels.FileChannel.MapMode import com.google.common.io.ByteStreams -import io.netty.buffer.{ByteBufInputStream, ByteBuf} +import io.netty.buffer.{Unpooled, ByteBufInputStream, ByteBuf} +import io.netty.channel.DefaultFileRegion import org.apache.spark.util.ByteBufferInputStream @@ -35,7 +36,7 @@ import org.apache.spark.util.ByteBufferInputStream * - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer * - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf */ -sealed abstract class ManagedBuffer { +abstract class ManagedBuffer { // Note that all the methods are defined with parenthesis because their implementations can // have side effects (io operations). @@ -54,6 +55,11 @@ sealed abstract class ManagedBuffer { * it does not go over the limit. */ def inputStream(): InputStream + + /** + * Convert the buffer into an Netty object, used to write the data out. + */ + private[network] def convertToNetty(): AnyRef } @@ -75,6 +81,11 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt is.skip(offset) ByteStreams.limit(is, length) } + + private[network] override def convertToNetty(): AnyRef = { + val fileChannel = new FileInputStream(file).getChannel + new DefaultFileRegion(fileChannel, offset, length) + } } @@ -88,6 +99,8 @@ final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { override def nioByteBuffer() = buf.duplicate() override def inputStream() = new ByteBufferInputStream(buf) + + private[network] override def convertToNetty(): AnyRef = Unpooled.wrappedBuffer(buf) } @@ -102,6 +115,8 @@ final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer { override def inputStream() = new ByteBufInputStream(buf) + private[network] override def convertToNetty(): AnyRef = buf + // TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it. def release(): Unit = buf.release() } diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala similarity index 70% rename from core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala rename to core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala index 5aea7ba2f367..95af2565bcc3 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala @@ -15,36 +15,35 @@ * limitations under the License. */ -package org.apache.spark.network.netty.client +package org.apache.spark.network.netty import java.util.concurrent.TimeoutException import io.netty.bootstrap.Bootstrap import io.netty.buffer.PooledByteBufAllocator import io.netty.channel.socket.SocketChannel -import io.netty.channel.{ChannelFutureListener, ChannelFuture, ChannelInitializer, ChannelOption} -import io.netty.handler.codec.LengthFieldBasedFrameDecoder -import io.netty.handler.codec.string.StringEncoder -import io.netty.util.CharsetUtil +import io.netty.channel.{ChannelFuture, ChannelFutureListener, ChannelInitializer, ChannelOption} import org.apache.spark.Logging +import org.apache.spark.network.BlockFetchingListener + /** - * Client for fetching data blocks from [[org.apache.spark.network.netty.server.BlockServer]]. - * Use [[BlockFetchingClientFactory]] to instantiate this client. + * Client for [[NettyBlockTransferService]]. Use [[BlockClientFactory]] to + * instantiate this client. * * The constructor blocks until a connection is successfully established. * - * See [[org.apache.spark.network.netty.server.BlockServer]] for client/server protocol. - * * Concurrency: thread safe and can be called from multiple threads. */ @throws[TimeoutException] -private[spark] -class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, port: Int) +private[netty] +class BlockClient(factory: BlockClientFactory, hostname: String, port: Int) extends Logging { - private val handler = new BlockFetchingClientHandler + private val handler = new BlockClientHandler + private val encoder = new ClientRequestEncoder + private val decoder = new ServerResponseDecoder /** Netty Bootstrap for creating the TCP connection. */ private val bootstrap: Bootstrap = { @@ -61,9 +60,9 @@ class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, b.handler(new ChannelInitializer[SocketChannel] { override def initChannel(ch: SocketChannel): Unit = { ch.pipeline - .addLast("encoder", new StringEncoder(CharsetUtil.UTF_8)) - // maxFrameLength = 2G, lengthFieldOffset = 0, lengthFieldLength = 4 - .addLast("framedLengthDecoder", new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 4)) + .addLast("clientRequestEncoder", encoder) + .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) + .addLast("serverResponseDecoder", decoder) .addLast("handler", handler) } }) @@ -86,12 +85,7 @@ class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, * @param blockIds sequence of block ids to fetch. * @param listener callback to fire on fetch success / failure. */ - def fetchBlocks(blockIds: Seq[String], listener: BlockClientListener): Unit = { - // It's best to limit the number of "write" calls since it needs to traverse the whole pipeline. - // It's also best to limit the number of "flush" calls since it requires system calls. - // Let's concatenate the string and then call writeAndFlush once. - // This is also why this implementation might be more efficient than multiple, separate - // fetch block calls. + def fetchBlocks(blockIds: Seq[String], listener: BlockFetchingListener): Unit = { var startTime: Long = 0 logTrace { startTime = System.nanoTime @@ -102,8 +96,7 @@ class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, handler.addRequest(blockId, listener) } - val writeFuture = cf.channel().writeAndFlush(blockIds.mkString("\n") + "\n") - writeFuture.addListener(new ChannelFutureListener { + cf.channel().writeAndFlush(BlockFetchRequest(blockIds)).addListener(new ChannelFutureListener { override def operationComplete(future: ChannelFuture): Unit = { if (future.isSuccess) { logTrace { @@ -116,9 +109,9 @@ class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}" logError(errorMsg, future.cause) blockIds.foreach { blockId => - listener.onFetchFailure(blockId, errorMsg) handler.removeRequest(blockId) } + listener.onBlockFetchFailure(new RuntimeException(errorMsg)) } } }) diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala similarity index 66% rename from core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala rename to core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala index 2b28402c52b4..0777275cd4fe 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala @@ -15,36 +15,34 @@ * limitations under the License. */ -package org.apache.spark.network.netty.client +package org.apache.spark.network.netty -import io.netty.channel.epoll.{EpollEventLoopGroup, EpollSocketChannel} +import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollSocketChannel} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.oio.OioEventLoopGroup import io.netty.channel.socket.nio.NioSocketChannel import io.netty.channel.socket.oio.OioSocketChannel -import io.netty.channel.{EventLoopGroup, Channel} +import io.netty.channel.{Channel, EventLoopGroup} import org.apache.spark.SparkConf -import org.apache.spark.network.netty.NettyConfig import org.apache.spark.util.Utils + /** - * Factory for creating [[BlockFetchingClient]] by using createClient. This factory reuses + * Factory for creating [[BlockClient]] by using createClient. This factory reuses * the worker thread pool for Netty. - * - * Concurrency: createClient is safe to be called from multiple threads concurrently. */ -private[spark] -class BlockFetchingClientFactory(val conf: NettyConfig) { +private[netty] +class BlockClientFactory(val conf: NettyConfig) { def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) /** A thread factory so the threads are named (for debugging). */ - val threadFactory = Utils.namedThreadFactory("spark-shuffle-client") + private[netty] val threadFactory = Utils.namedThreadFactory("spark-shuffle-client") /** The following two are instantiated by the [[init]] method, depending ioMode. */ - var socketChannelClass: Class[_ <: Channel] = _ - var workerGroup: EventLoopGroup = _ + private[netty] var socketChannelClass: Class[_ <: Channel] = _ + private[netty] var workerGroup: EventLoopGroup = _ init() @@ -63,20 +61,12 @@ class BlockFetchingClientFactory(val conf: NettyConfig) { workerGroup = new EpollEventLoopGroup(0, threadFactory) } + // For auto mode, first try epoll (only available on Linux), then nio. conf.ioMode match { case "nio" => initNio() case "oio" => initOio() case "epoll" => initEpoll() - case "auto" => - // For auto mode, first try epoll (only available on Linux), then nio. - try { - initEpoll() - } catch { - // TODO: Should we log the throwable? But that always happen on non-Linux systems. - // Perhaps the right thing to do is to check whether the system is Linux, and then only - // call initEpoll on Linux. - case e: Throwable => initNio() - } + case "auto" => if (Epoll.isAvailable) initEpoll() else initNio() } } @@ -87,8 +77,8 @@ class BlockFetchingClientFactory(val conf: NettyConfig) { * * Concurrency: This method is safe to call from multiple threads. */ - def createClient(remoteHost: String, remotePort: Int): BlockFetchingClient = { - new BlockFetchingClient(this, remoteHost, remotePort) + def createClient(remoteHost: String, remotePort: Int): BlockClient = { + new BlockClient(this, remoteHost, remotePort) } def stop(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala new file mode 100644 index 000000000000..b41c831f3d7e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.Logging +import org.apache.spark.network.BlockFetchingListener + + +/** + * Handler that processes server responses. + * + * Concurrency: thread safe and can be called from multiple threads. + */ +private[netty] +class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] with Logging { + + /** Tracks the list of outstanding requests and their listeners on success/failure. */ + private val outstandingRequests = java.util.Collections.synchronizedMap { + new java.util.HashMap[String, BlockFetchingListener] + } + + def addRequest(blockId: String, listener: BlockFetchingListener): Unit = { + outstandingRequests.put(blockId, listener) + } + + def removeRequest(blockId: String): Unit = { + outstandingRequests.remove(blockId) + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}" + logError(errorMsg, cause) + + // Fire the failure callback for all outstanding blocks + outstandingRequests.synchronized { + val iter = outstandingRequests.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + entry.getValue.onBlockFetchFailure(cause) + } + outstandingRequests.clear() + } + + ctx.close() + } + + override def channelRead0(ctx: ChannelHandlerContext, response: ServerResponse) { + val server = ctx.channel.remoteAddress.toString + response match { + case BlockFetchSuccess(blockId, buf) => + val listener = outstandingRequests.get(blockId) + if (listener == null) { + logWarning(s"Got a response for block $blockId from $server but it is not outstanding") + } else { + outstandingRequests.remove(blockId) + listener.onBlockFetchSuccess(blockId, buf) + } + case BlockFetchFailure(blockId, errorMsg) => + val listener = outstandingRequests.get(blockId) + if (listener == null) { + logWarning( + s"Got a response for block $blockId from $server ($errorMsg) but it is not outstanding") + } else { + outstandingRequests.remove(blockId) + listener.onBlockFetchFailure(new RuntimeException(errorMsg)) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala similarity index 77% rename from core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala rename to core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index 7b2f9a8d4dfd..76f28aa00112 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -15,48 +15,33 @@ * limitations under the License. */ -package org.apache.spark.network.netty.server +package org.apache.spark.network.netty import java.net.InetSocketAddress import io.netty.bootstrap.ServerBootstrap import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption} import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.oio.OioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.channel.socket.oio.OioServerSocketChannel -import io.netty.handler.codec.LineBasedFrameDecoder -import io.netty.handler.codec.string.StringDecoder -import io.netty.util.CharsetUtil +import io.netty.channel.{ChannelInitializer, ChannelFuture, ChannelOption} +import io.netty.handler.codec.LengthFieldBasedFrameDecoder import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.network.netty.NettyConfig -import org.apache.spark.storage.BlockDataProvider +import org.apache.spark.network.BlockDataManager import org.apache.spark.util.Utils /** - * Server for serving Spark data blocks. - * This should be used together with [[org.apache.spark.network.netty.client.BlockFetchingClient]]. - * - * Protocol for requesting blocks (client to server): - * One block id per line, e.g. to request 3 blocks: "block1\nblock2\nblock3\n" - * - * Protocol for sending blocks (server to client): - * frame-length (4 bytes), block-id-length (4 bytes), block-id, block-data. - * - * frame-length should not include the length of itself. - * If block-id-length is negative, then this is an error message rather than block-data. The real - * length is the absolute value of the frame-length. - * + * Server for the [[NettyBlockTransferService]]. */ -private[spark] -class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Logging { +private[netty] +class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Logging { - def this(sparkConf: SparkConf, dataProvider: BlockDataProvider) = { + def this(sparkConf: SparkConf, dataProvider: BlockDataManager) = { this(new NettyConfig(sparkConf), dataProvider) } @@ -129,10 +114,10 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Lo bootstrap.childHandler(new ChannelInitializer[SocketChannel] { override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024 - .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8)) - .addLast("blockHeaderEncoder", new BlockHeaderEncoder) + val p = ch.pipeline + .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) + .addLast("clientRequestDecoder", new ClientRequestDecoder) + .addLast("serverResponseEncoder", new ServerResponseEncoder) .addLast("handler", new BlockServerHandler(dataProvider)) } }) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala new file mode 100644 index 000000000000..739526a4fc6b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import io.netty.channel._ + +import org.apache.spark.Logging +import org.apache.spark.network.{ManagedBuffer, BlockDataManager} + + +/** + * A handler that processes requests from clients and writes block data back. + * + * The messages should have been processed by the pipeline setup by BlockServerChannelInitializer. + */ +private[netty] class BlockServerHandler(dataProvider: BlockDataManager) + extends SimpleChannelInboundHandler[ClientRequest] with Logging { + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause) + ctx.close() + } + + override def channelRead0(ctx: ChannelHandlerContext, request: ClientRequest): Unit = { + request match { + case BlockFetchRequest(blockIds) => + blockIds.foreach(processBlockRequest(ctx, _)) + case BlockUploadRequest(blockId, data) => + // TODO(rxin): handle upload. + } + } // end of channelRead0 + + private def processBlockRequest(ctx: ChannelHandlerContext, blockId: String): Unit = { + // A helper function to send error message back to the client. + def client = ctx.channel.remoteAddress.toString + + def respondWithError(error: String): Unit = { + ctx.writeAndFlush(new BlockFetchFailure(blockId, error)).addListener( + new ChannelFutureListener { + override def operationComplete(future: ChannelFuture) { + if (!future.isSuccess) { + // TODO: Maybe log the success case as well. + logError(s"Error sending error back to $client", future.cause) + ctx.close() + } + } + } + ) + } + + logTrace(s"Received request from $client to fetch block $blockId") + + // First make sure we can find the block. If not, send error back to the user. + var blockData: Option[ManagedBuffer] = null + try { + blockData = dataProvider.getBlockData(blockId) + } catch { + case e: Exception => + logError(s"Error opening block $blockId for request from $client", e) + respondWithError(e.getMessage) + return + } + + blockData match { + case Some(buf) => + ctx.writeAndFlush(new BlockFetchSuccess(blockId, buf)).addListener( + new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (future.isSuccess) { + logTrace(s"Sent block $blockId (${buf.size} B) back to $client") + } else { + logError( + s"Error sending block $blockId to $client; closing connection", future.cause) + ctx.close() + } + } + } + ) + case None => + respondWithError("Block not found") + } + } // end of processBlockRequest +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala new file mode 100644 index 000000000000..fa8bdfc96e8b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import scala.concurrent.Future + +import org.apache.spark.SparkConf +import org.apache.spark.network._ +import org.apache.spark.storage.StorageLevel + + +/** + * A [[BlockTransferService]] implementation based on Netty. + * + * See protocol.scala for the communication protocol between server and client + */ +final class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { + + private[this] val nettyConf: NettyConfig = new NettyConfig(conf) + + private[this] var server: BlockServer = _ + private[this] var clientFactory: BlockClientFactory = _ + + override def init(blockDataManager: BlockDataManager): Unit = { + server = new BlockServer(nettyConf, blockDataManager) + clientFactory = new BlockClientFactory(nettyConf) + } + + override def stop(): Unit = { + if (server != null) { + server.stop() + } + if (clientFactory != null) { + clientFactory.stop() + } + } + + override def fetchBlocks( + hostName: String, + port: Int, + blockIds: Seq[String], + listener: BlockFetchingListener): Unit = { + clientFactory.createClient(hostName, port).fetchBlocks(blockIds, listener) + } + + override def uploadBlock( + hostname: String, + port: Int, + blockId: String, + blockData: ManagedBuffer, level: StorageLevel): Future[Unit] = { + // TODO(rxin): Implement uploadBlock. + ??? + } + + override def hostName: String = { + if (server == null) { + throw new IllegalStateException("Server has not been started") + } + server.hostName + } + + override def port: Int = { + if (server == null) { + throw new IllegalStateException("Server has not been started") + } + server.port + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala b/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala deleted file mode 100644 index 0d7695072a7b..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import org.apache.spark.storage.{BlockId, FileSegment} - -trait PathResolver { - /** Get the file segment in which the given block resides. */ - def getBlockLocation(blockId: BlockId): FileSegment -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala deleted file mode 100644 index e28219dd7745..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import java.util.EventListener - - -trait BlockClientListener extends EventListener { - - def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit - - def onFetchFailure(blockId: String, errorMsg: String): Unit - -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala deleted file mode 100644 index 83265b164299..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import io.netty.buffer.ByteBuf -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} - -import org.apache.spark.Logging - - -/** - * Handler that processes server responses. It uses the protocol documented in - * [[org.apache.spark.network.netty.server.BlockServer]]. - * - * Concurrency: thread safe and can be called from multiple threads. - */ -private[client] -class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] with Logging { - - /** Tracks the list of outstanding requests and their listeners on success/failure. */ - private val outstandingRequests = java.util.Collections.synchronizedMap { - new java.util.HashMap[String, BlockClientListener] - } - - def addRequest(blockId: String, listener: BlockClientListener): Unit = { - outstandingRequests.put(blockId, listener) - } - - def removeRequest(blockId: String): Unit = { - outstandingRequests.remove(blockId) - } - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}" - logError(errorMsg, cause) - - // Fire the failure callback for all outstanding blocks - outstandingRequests.synchronized { - val iter = outstandingRequests.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - entry.getValue.onFetchFailure(entry.getKey, errorMsg) - } - outstandingRequests.clear() - } - - ctx.close() - } - - override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) { - val totalLen = in.readInt() - val blockIdLen = in.readInt() - val blockIdBytes = new Array[Byte](math.abs(blockIdLen)) - in.readBytes(blockIdBytes) - val blockId = new String(blockIdBytes) - val blockSize = totalLen - math.abs(blockIdLen) - 4 - - def server = ctx.channel.remoteAddress.toString - - // blockIdLen is negative when it is an error message. - if (blockIdLen < 0) { - val errorMessageBytes = new Array[Byte](blockSize) - in.readBytes(errorMessageBytes) - val errorMsg = new String(errorMessageBytes) - logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server") - - val listener = outstandingRequests.get(blockId) - if (listener == null) { - // Ignore callback - logWarning(s"Got a response for block $blockId but it is not in our outstanding requests") - } else { - outstandingRequests.remove(blockId) - listener.onFetchFailure(blockId, errorMsg) - } - } else { - logTrace(s"Received block $blockId ($blockSize B) from $server") - - val listener = outstandingRequests.get(blockId) - if (listener == null) { - // Ignore callback - logWarning(s"Got a response for block $blockId but it is not in our outstanding requests") - } else { - outstandingRequests.remove(blockId) - listener.onFetchSuccess(blockId, new ReferenceCountedBuffer(in)) - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala b/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala deleted file mode 100644 index 9740ee64d1f2..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -/** - * A simple iterator that lazily initializes the underlying iterator. - * - * The use case is that sometimes we might have many iterators open at the same time, and each of - * the iterator might initialize its own buffer (e.g. decompression buffer, deserialization buffer). - * This could lead to too many buffers open. If this iterator is used, we lazily initialize those - * buffers. - */ -private[spark] -class LazyInitIterator(createIterator: => Iterator[Any]) extends Iterator[Any] { - - lazy val proxy = createIterator - - override def hasNext: Boolean = { - val gotNext = proxy.hasNext - if (!gotNext) { - close() - } - gotNext - } - - override def next(): Any = proxy.next() - - def close(): Unit = Unit -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala deleted file mode 100644 index ea1abf5eccc2..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import java.io.InputStream -import java.nio.ByteBuffer - -import io.netty.buffer.{ByteBuf, ByteBufInputStream} - - -/** - * A buffer abstraction based on Netty's ByteBuf so we don't expose Netty. - * This is a Scala value class. - * - * The buffer's life cycle is NOT managed by the JVM, and thus requiring explicit declaration of - * reference by the retain method and release method. - */ -private[spark] -class ReferenceCountedBuffer(val underlying: ByteBuf) extends AnyVal { - - /** Return the nio ByteBuffer view of the underlying buffer. */ - def byteBuffer(): ByteBuffer = underlying.nioBuffer - - /** Creates a new input stream that starts from the current position of the buffer. */ - def inputStream(): InputStream = new ByteBufInputStream(underlying) - - /** Increment the reference counter by one. */ - def retain(): Unit = underlying.retain() - - /** Decrement the reference counter by one and release the buffer if the ref count is 0. */ - def release(): Unit = underlying.release() -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala new file mode 100644 index 000000000000..0159eca1d3b4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.util.{List => JList} + +import io.netty.buffer.ByteBuf +import io.netty.channel.ChannelHandlerContext +import io.netty.channel.ChannelHandler.Sharable +import io.netty.handler.codec._ + +import org.apache.spark.Logging +import org.apache.spark.network.{NettyByteBufManagedBuffer, ManagedBuffer} + + +sealed trait ClientRequest { + def id: Byte +} + +final case class BlockFetchRequest(blocks: Seq[String]) extends ClientRequest { + override def id = 0 +} + +final case class BlockUploadRequest(blockId: String, data: ManagedBuffer) extends ClientRequest { + require(blockId.length <= Byte.MaxValue) + override def id = 1 +} + + +sealed trait ServerResponse { + def id: Byte +} + +final case class BlockFetchSuccess(blockId: String, data: ManagedBuffer) extends ServerResponse { + require(blockId.length <= Byte.MaxValue) + override def id = 0 +} + +final case class BlockFetchFailure(blockId: String, error: String) extends ServerResponse { + require(blockId.length <= Byte.MaxValue) + override def id = 1 +} + + +/** + * Encoder used by the client side to encode client-to-server responses. + */ +@Sharable +final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] { + override def encode(ctx: ChannelHandlerContext, in: ClientRequest, out: JList[Object]): Unit = { + in match { + case BlockFetchRequest(blocks) => + // 8 bytes: frame size + // 1 byte: BlockFetchRequest vs BlockUploadRequest + // 4 byte: num blocks + // then for each block id write 1 byte for blockId.length and then blockId itself + val frameLength = 8 + 1 + 4 + blocks.size + blocks.map(_.size).fold(0)(_ + _) + val buf = ctx.alloc().buffer(frameLength) + + buf.writeLong(frameLength) + buf.writeByte(in.id) + buf.writeInt(blocks.size) + blocks.foreach { blockId => + ProtocolUtils.writeBlockId(buf, blockId) + } + + assert(buf.writableBytes() == 0) + out.add(buf) + + case BlockUploadRequest(blockId, data) => + // 8 bytes: frame size + // 1 byte: msg id (BlockFetchRequest vs BlockUploadRequest) + // 1 byte: blockId.length + // data itself (length can be derived from: frame size - 1 - blockId.length) + val headerLength = 8 + 1 + 1 + blockId.length + val frameLength = headerLength + data.size + val header = ctx.alloc().buffer(headerLength) + + // Call this before we add header to out so in case of exceptions + // we don't send anything at all. + val body = data.convertToNetty() + + header.writeLong(frameLength) + header.writeByte(in.id) + ProtocolUtils.writeBlockId(header, blockId) + + assert(header.writableBytes() == 0) + out.add(header) + out.add(body) + } + } +} + + +/** + * Decoder in the server side to decode client requests. + * + * This assumes the inbound messages have been processed by a frame decoder created by + * [[ProtocolUtils.createFrameDecoder()]]. + */ +@Sharable +final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] { + override protected def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: JList[AnyRef]): Unit = + { + val msgTypeId = in.readByte() + val decoded = msgTypeId match { + case 0 => // BlockFetchRequest + val numBlocks = in.readInt() + val blockIds = Seq.fill(numBlocks) { ProtocolUtils.readBlockId(in) } + BlockFetchRequest(blockIds) + + case 1 => // BlockUploadRequest + val blockId = ProtocolUtils.readBlockId(in) + in.retain() // retain the bytebuf so we don't recycle it immediately. + BlockUploadRequest(blockId, new NettyByteBufManagedBuffer(in)) + } + + assert(decoded.id == msgTypeId) + out.add(decoded) + } +} + + +/** + * Encoder used by the server side to encode server-to-client responses. + */ +@Sharable +final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse] with Logging { + override def encode(ctx: ChannelHandlerContext, in: ServerResponse, out: JList[Object]): Unit = { + in match { + case BlockFetchSuccess(blockId, data) => + // Handle the body first so if we encounter an error getting the body, we can respond + // with an error instead. + var body: AnyRef = null + try { + body = data.convertToNetty() + } catch { + case e: Exception => + // Re-encode this message as BlockFetchFailure. + logError(s"Error opening block $blockId for client ${ctx.channel.remoteAddress}", e) + encode(ctx, new BlockFetchFailure(blockId, e.getMessage), out) + return + } + + // If we got here, body cannot be null + // 8 bytes = long for frame length + // 1 byte = message id (type) + // 1 byte = block id length + // followed by block id itself + val headerLength = 8 + 1 + 1 + blockId.length + val frameLength = headerLength + data.size + val header = ctx.alloc().buffer(headerLength) + header.writeLong(frameLength) + header.writeByte(in.id) + ProtocolUtils.writeBlockId(header, blockId) + + assert(header.writableBytes() == 0) + out.add(header) + out.add(body) + + case BlockFetchFailure(blockId, error) => + val frameLength = 8 + 1 + 1 + blockId.length + error.length + val buf = ctx.alloc().buffer(frameLength) + buf.writeLong(frameLength) + buf.writeByte(in.id) + ProtocolUtils.writeBlockId(buf, blockId) + buf.writeBytes(error.getBytes) + + assert(buf.writableBytes() == 0) + out.add(buf) + } + } +} + + +/** + * Decoder in the client side to decode server responses. + * + * This assumes the inbound messages have been processed by a frame decoder created by + * [[ProtocolUtils.createFrameDecoder()]]. + */ +@Sharable +final class ServerResponseDecoder extends MessageToMessageDecoder[ByteBuf] { + override def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: JList[AnyRef]): Unit = { + val msgId = in.readByte() + val decoded = msgId match { + case 0 => // BlockFetchSuccess + val blockId = ProtocolUtils.readBlockId(in) + in.retain() + new BlockFetchSuccess(blockId, new NettyByteBufManagedBuffer(in)) + + case 1 => // BlockFetchFailure + val blockId = ProtocolUtils.readBlockId(in) + val errorBytes = new Array[Byte](in.readableBytes()) + in.readBytes(errorBytes) + new BlockFetchFailure(blockId, new String(errorBytes)) + } + + assert(decoded.id == msgId) + out.add(decoded) + } +} + + +private[netty] object ProtocolUtils { + + /** LengthFieldBasedFrameDecoder used before all decoders. */ + def createFrameDecoder(): ByteToMessageDecoder = { + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 8 + // lengthAdjustment = -8, i.e. exclude the 8 byte length itself + // initialBytesToStrip = 8, i.e. strip out the length field itself + new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 8, -8, 8) + } + + def readBlockId(in: ByteBuf): String = { + val numBytesToRead = in.readByte().toInt + val bytes = new Array[Byte](numBytesToRead) + in.readBytes(bytes) + new String(bytes) + } + + def writeBlockId(out: ByteBuf, blockId: String): Unit = { + out.writeByte(blockId.length) + out.writeBytes(blockId.getBytes) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala deleted file mode 100644 index 162e9cc6828d..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -/** - * Header describing a block. This is used only in the server pipeline. - * - * [[BlockServerHandler]] creates this, and [[BlockHeaderEncoder]] encodes it. - * - * @param blockSize length of the block content, excluding the length itself. - * If positive, this is the header for a block (not part of the header). - * If negative, this is the header and content for an error message. - * @param blockId block id - * @param error some error message from reading the block - */ -private[server] -class BlockHeader(val blockSize: Int, val blockId: String, val error: Option[String] = None) diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala deleted file mode 100644 index 8e4dda4ef859..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import io.netty.buffer.ByteBuf -import io.netty.channel.ChannelHandlerContext -import io.netty.handler.codec.MessageToByteEncoder - -/** - * A simple encoder for BlockHeader. See [[BlockServer]] for the server to client protocol. - */ -private[server] -class BlockHeaderEncoder extends MessageToByteEncoder[BlockHeader] { - override def encode(ctx: ChannelHandlerContext, msg: BlockHeader, out: ByteBuf): Unit = { - // message = message length (4 bytes) + block id length (4 bytes) + block id + block data - // message length = block id length (4 bytes) + size of block id + size of block data - val blockIdBytes = msg.blockId.getBytes - msg.error match { - case Some(errorMsg) => - val errorBytes = errorMsg.getBytes - out.writeInt(4 + blockIdBytes.length + errorBytes.size) - out.writeInt(-blockIdBytes.length) // use negative block id length to represent errors - out.writeBytes(blockIdBytes) // next is blockId itself - out.writeBytes(errorBytes) // error message - case None => - out.writeInt(4 + blockIdBytes.length + msg.blockSize) - out.writeInt(blockIdBytes.length) // First 4 bytes is blockId length - out.writeBytes(blockIdBytes) // next is blockId itself - // msg of size blockSize will be written by ServerHandler - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala deleted file mode 100644 index cc70bd0c5c47..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import io.netty.channel.ChannelInitializer -import io.netty.channel.socket.SocketChannel -import io.netty.handler.codec.LineBasedFrameDecoder -import io.netty.handler.codec.string.StringDecoder -import io.netty.util.CharsetUtil -import org.apache.spark.storage.BlockDataProvider - - -/** Channel initializer that sets up the pipeline for the BlockServer. */ -private[netty] -class BlockServerChannelInitializer(dataProvider: BlockDataProvider) - extends ChannelInitializer[SocketChannel] { - - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024 - .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8)) - .addLast("blockHeaderEncoder", new BlockHeaderEncoder) - .addLast("handler", new BlockServerHandler(dataProvider)) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala deleted file mode 100644 index 40dd5e5d1a2a..000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import java.io.FileInputStream -import java.nio.ByteBuffer -import java.nio.channels.FileChannel - -import io.netty.buffer.Unpooled -import io.netty.channel._ - -import org.apache.spark.Logging -import org.apache.spark.storage.{FileSegment, BlockDataProvider} - - -/** - * A handler that processes requests from clients and writes block data back. - * - * The messages should have been processed by a LineBasedFrameDecoder and a StringDecoder first - * so channelRead0 is called once per line (i.e. per block id). - */ -private[server] -class BlockServerHandler(dataProvider: BlockDataProvider) - extends SimpleChannelInboundHandler[String] with Logging { - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause) - ctx.close() - } - - override def channelRead0(ctx: ChannelHandlerContext, blockId: String): Unit = { - def client = ctx.channel.remoteAddress.toString - - // A helper function to send error message back to the client. - def respondWithError(error: String): Unit = { - ctx.writeAndFlush(new BlockHeader(-1, blockId, Some(error))).addListener( - new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (!future.isSuccess) { - // TODO: Maybe log the success case as well. - logError(s"Error sending error back to $client", future.cause) - ctx.close() - } - } - } - ) - } - - def writeFileSegment(segment: FileSegment): Unit = { - // Send error message back if the block is too large. Even though we are capable of sending - // large (2G+) blocks, the receiving end cannot handle it so let's fail fast. - // Once we fixed the receiving end to be able to process large blocks, this should be removed. - // Also make sure we update BlockHeaderEncoder to support length > 2G. - - // See [[BlockHeaderEncoder]] for the way length is encoded. - if (segment.length + blockId.length + 4 > Int.MaxValue) { - respondWithError(s"Block $blockId size ($segment.length) greater than 2G") - return - } - - var fileChannel: FileChannel = null - try { - fileChannel = new FileInputStream(segment.file).getChannel - } catch { - case e: Exception => - logError( - s"Error opening channel for $blockId in ${segment.file} for request from $client", e) - respondWithError(e.getMessage) - } - - // Found the block. Send it back. - if (fileChannel != null) { - // Write the header and block data. In the case of failures, the listener on the block data - // write should close the connection. - ctx.write(new BlockHeader(segment.length.toInt, blockId)) - - val region = new DefaultFileRegion(fileChannel, segment.offset, segment.length) - ctx.writeAndFlush(region).addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (future.isSuccess) { - logTrace(s"Sent block $blockId (${segment.length} B) back to $client") - } else { - logError(s"Error sending block $blockId to $client; closing connection", future.cause) - ctx.close() - } - } - }) - } - } - - def writeByteBuffer(buf: ByteBuffer): Unit = { - ctx.write(new BlockHeader(buf.remaining, blockId)) - ctx.writeAndFlush(Unpooled.wrappedBuffer(buf)).addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (future.isSuccess) { - logTrace(s"Sent block $blockId (${buf.remaining} B) back to $client") - } else { - logError(s"Error sending block $blockId to $client; closing connection", future.cause) - ctx.close() - } - } - }) - } - - logTrace(s"Received request from $client to fetch block $blockId") - - var blockData: Either[FileSegment, ByteBuffer] = null - - // First make sure we can find the block. If not, send error back to the user. - try { - blockData = dataProvider.getBlockData(blockId) - } catch { - case e: Exception => - logError(s"Error opening block $blockId for request from $client", e) - respondWithError(e.getMessage) - return - } - - blockData match { - case Left(segment) => writeFileSegment(segment) - case Right(buf) => writeByteBuffer(buf) - } - - } // end of channelRead0 -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala b/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala deleted file mode 100644 index 5b6d08663083..000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.nio.ByteBuffer - - -/** - * An interface for providing data for blocks. - * - * getBlockData returns either a FileSegment (for zero-copy send), or a ByteBuffer. - * - * Aside from unit tests, [[BlockManager]] is the main class that implements this. - */ -private[spark] trait BlockDataProvider { - def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala new file mode 100644 index 000000000000..1358b2f9c807 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicInteger + +import io.netty.buffer.Unpooled +import io.netty.channel.embedded.EmbeddedChannel + +import org.scalatest.{FunSuite, PrivateMethodTester} + +import org.apache.spark.network._ + + +class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { + + private def sizeOfOutstandingRequests(handler: BlockClientHandler): Int = { + val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) + handler.invokePrivate(outstandingRequests()).size + } + + test("handling block data (successful fetch)") { + val blockId = "test_block" + val blockData = "blahblahblahblahblah" + + var parsedBlockId: String = "" + var parsedBlockData: String = "" + val handler = new BlockClientHandler + handler.addRequest(blockId, + new BlockFetchingListener { + override def onBlockFetchFailure(exception: Throwable): Unit = { + throw new UnsupportedOperationException + } + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + parsedBlockId = blockId + val bytes = new Array[Byte](data.size.toInt) + data.nioByteBuffer().get(bytes) + parsedBlockData = new String(bytes) + } + } + ) + + val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) + assert(handler.invokePrivate(outstandingRequests()).size === 1) + + val channel = new EmbeddedChannel(handler) + val buf = ByteBuffer.allocate(blockData.size) // 4 bytes for the length field itself + buf.put(blockData.getBytes) + buf.flip() + + channel.writeInbound(BlockFetchSuccess(blockId, new NioByteBufferManagedBuffer(buf))) + + assert(parsedBlockId === blockId) + assert(parsedBlockData === blockData) + assert(handler.invokePrivate(outstandingRequests()).size === 0) + assert(channel.finish() === false) + } + + test("handling error message (failed fetch)") { + val blockId = "test_block" + val errorMsg = "error erro5r error err4or error3 error6 error erro1r" + + var parsedErrorMsg: String = "" + val handler = new BlockClientHandler + handler.addRequest(blockId, + new BlockFetchingListener { + override def onBlockFetchFailure(exception: Throwable): Unit = { + parsedErrorMsg = exception.getMessage + } + + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + throw new UnsupportedOperationException + } + } + ) + + assert(sizeOfOutstandingRequests(handler) === 1) + + val channel = new EmbeddedChannel(handler) + channel.writeInbound(BlockFetchFailure(blockId, errorMsg)) + assert(parsedErrorMsg === errorMsg) + assert(sizeOfOutstandingRequests(handler) === 0) + assert(channel.finish() === false) + } + + ignore("clear all outstanding request upon connection close") { + val errorCount = new AtomicInteger(0) + val successCount = new AtomicInteger(0) + val handler = new BlockClientHandler + + val listener = new BlockFetchingListener { + override def onBlockFetchFailure(exception: Throwable): Unit = { + errorCount.incrementAndGet() + } + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + successCount.incrementAndGet() + } + } + + handler.addRequest("b1", listener) + handler.addRequest("b2", listener) + handler.addRequest("b3", listener) + assert(sizeOfOutstandingRequests(handler) === 3) + + val channel = new EmbeddedChannel(handler) + channel.writeInbound(BlockFetchSuccess("b1", new NettyByteBufManagedBuffer(Unpooled.buffer()))) + // Need to figure out a way to generate an exception + assert(successCount.get() === 1) + assert(errorCount.get() === 2) + assert(sizeOfOutstandingRequests(handler) === 0) + assert(channel.finish() === false) + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala new file mode 100644 index 000000000000..72034634a5bd --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import io.netty.channel.embedded.EmbeddedChannel + +import org.scalatest.FunSuite + + +/** + * Test client/server encoder/decoder protocol. + */ +class ProtocolSuite extends FunSuite { + + /** + * Helper to test server to client message protocol by encoding a message and decoding it. + */ + private def testServerToClient(msg: ServerResponse) { + val serverChannel = new EmbeddedChannel(new ServerResponseEncoder) + serverChannel.writeOutbound(msg) + + val clientChannel = new EmbeddedChannel( + ProtocolUtils.createFrameDecoder(), + new ServerResponseDecoder) + + // Drain all server outbound messages and write them to the client's server decoder. + while (!serverChannel.outboundMessages().isEmpty) { + clientChannel.writeInbound(serverChannel.readOutbound()) + } + + assert(clientChannel.inboundMessages().size === 1) + // Must put "msg === ..." instead of "... === msg" since only TestManagedBuffer equals is + // overridden. + assert(msg === clientChannel.readInbound()) + } + + /** + * Helper to test client to server message protocol by encoding a message and decoding it. + */ + private def testClientToServer(msg: ClientRequest) { + val clientChannel = new EmbeddedChannel(new ClientRequestEncoder) + clientChannel.writeOutbound(msg) + + val serverChannel = new EmbeddedChannel( + ProtocolUtils.createFrameDecoder(), + new ClientRequestDecoder) + + // Drain all client outbound messages and write them to the server's decoder. + while (!clientChannel.outboundMessages().isEmpty) { + serverChannel.writeInbound(clientChannel.readOutbound()) + } + + assert(serverChannel.inboundMessages().size === 1) + // Must put "msg === ..." instead of "... === msg" since only TestManagedBuffer equals is + // overridden. + assert(msg === serverChannel.readInbound()) + } + + test("server to client protocol") { + testServerToClient(BlockFetchSuccess("a1234", new TestManagedBuffer(10))) + testServerToClient(BlockFetchSuccess("", new TestManagedBuffer(0))) + testServerToClient(BlockFetchFailure("abcd", "this is an error")) + testServerToClient(BlockFetchFailure("", "")) + } + + test("client to server protocol") { + testClientToServer(BlockFetchRequest(Seq.empty[String])) + testClientToServer(BlockFetchRequest(Seq("b1"))) + testClientToServer(BlockFetchRequest(Seq("b1", "b2", "b3"))) + testClientToServer(BlockUploadRequest("", new TestManagedBuffer(0))) + testClientToServer(BlockUploadRequest("b_upload", new TestManagedBuffer(10))) + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index 02d0ffc86f58..a468764fb184 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -1,19 +1,19 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ package org.apache.spark.network.netty @@ -24,26 +24,25 @@ import java.util.concurrent.{TimeUnit, Semaphore} import scala.collection.JavaConversions._ -import io.netty.buffer.{ByteBufUtil, Unpooled} +import io.netty.buffer.Unpooled import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.spark.SparkConf -import org.apache.spark.network.netty.client.{BlockClientListener, ReferenceCountedBuffer, BlockFetchingClientFactory} -import org.apache.spark.network.netty.server.BlockServer -import org.apache.spark.storage.{FileSegment, BlockDataProvider} +import org.apache.spark.network._ +import org.apache.spark.storage.StorageLevel /** - * Test suite that makes sure the server and the client implementations share the same protocol. - */ +* Test suite that makes sure the server and the client implementations share the same protocol. +*/ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { val bufSize = 100000 var buf: ByteBuffer = _ var testFile: File = _ var server: BlockServer = _ - var clientFactory: BlockFetchingClientFactory = _ + var clientFactory: BlockClientFactory = _ val bufferBlockId = "buffer_block" val fileBlockId = "file_block" @@ -63,19 +62,24 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { fp.write(fileContent) fp.close() - server = new BlockServer(new SparkConf, new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { + server = new BlockServer(new SparkConf, new BlockDataManager { + override def getBlockData(blockId: String): Option[ManagedBuffer] = { if (blockId == bufferBlockId) { - Right(buf) + Some(new NioByteBufferManagedBuffer(buf)) } else if (blockId == fileBlockId) { - Left(new FileSegment(testFile, 10, testFile.length - 25)) + Some(new FileSegmentManagedBuffer(testFile, 10, testFile.length - 25)) } else { - throw new Exception("Unknown block id " + blockId) + None } } + + /** + * Put the block locally, using the given storage level. + */ + def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = ??? }) - clientFactory = new BlockFetchingClientFactory(new SparkConf) + clientFactory = new BlockClientFactory(new SparkConf) } override def afterAll() = { @@ -89,31 +93,29 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { /** A ByteBuf for file_block */ lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25) - def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ReferenceCountedBuffer], Set[String]) = + def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ManagedBuffer], Set[String]) = { val client = clientFactory.createClient(server.hostName, server.port) val sem = new Semaphore(0) val receivedBlockIds = Collections.synchronizedSet(new HashSet[String]) val errorBlockIds = Collections.synchronizedSet(new HashSet[String]) - val receivedBuffers = Collections.synchronizedSet(new HashSet[ReferenceCountedBuffer]) + val receivedBuffers = Collections.synchronizedSet(new HashSet[ManagedBuffer]) client.fetchBlocks( blockIds, - new BlockClientListener { - override def onFetchFailure(blockId: String, errorMsg: String): Unit = { - errorBlockIds.add(blockId) + new BlockFetchingListener { + override def onBlockFetchFailure(exception: Throwable): Unit = { sem.release() } - override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = { + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { receivedBlockIds.add(blockId) - data.retain() receivedBuffers.add(data) sem.release() } } ) - if (!sem.tryAcquire(blockIds.size, 30, TimeUnit.SECONDS)) { + if (!sem.tryAcquire(blockIds.size, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server") } client.close() @@ -123,20 +125,18 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { test("fetch a ByteBuffer block") { val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId)) assert(blockIds === Set(bufferBlockId)) - assert(buffers.map(_.underlying) === Set(byteBufferBlockReference)) + assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) } test("fetch a FileSegment block via zero-copy send") { val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId)) assert(blockIds === Set(fileBlockId)) - assert(buffers.map(_.underlying) === Set(fileBlockReference)) + assert(buffers.map(_.convertToNetty()) === Set(fileBlockReference)) assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) } - test("fetch a non-existent block") { + ignore("fetch a non-existent block") { val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block")) assert(blockIds.isEmpty) assert(buffers.isEmpty) @@ -146,16 +146,14 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { test("fetch both ByteBuffer block and FileSegment block") { val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId)) assert(blockIds === Set(bufferBlockId, fileBlockId)) - assert(buffers.map(_.underlying) === Set(byteBufferBlockReference, fileBlockReference)) + assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference, fileBlockReference)) assert(failBlockIds.isEmpty) - buffers.foreach(_.release()) } - test("fetch both ByteBuffer block and a non-existent block") { + ignore("fetch both ByteBuffer block and a non-existent block") { val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block")) assert(blockIds === Set(bufferBlockId)) - assert(buffers.map(_.underlying) === Set(byteBufferBlockReference)) + assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) assert(failBlockIds === Set("random-block")) - buffers.foreach(_.release()) } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala new file mode 100644 index 000000000000..6ae2d3b3faf9 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.io.InputStream +import java.nio.ByteBuffer + +import io.netty.buffer.Unpooled + +import org.apache.spark.network.{NettyByteBufManagedBuffer, ManagedBuffer} + + +/** + * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1). + * + * Used for testing. + */ +class TestManagedBuffer(len: Int) extends ManagedBuffer { + + require(len <= Byte.MaxValue) + + private val byteArray: Array[Byte] = Array.tabulate[Byte](len)(_.toByte) + + private val underlying = new NettyByteBufManagedBuffer(Unpooled.wrappedBuffer(byteArray)) + + override def size: Long = underlying.size + + override private[network] def convertToNetty(): AnyRef = underlying.convertToNetty() + + override def nioByteBuffer(): ByteBuffer = underlying.nioByteBuffer() + + override def inputStream(): InputStream = underlying.inputStream() + + override def toString: String = s"${getClass.getName}($len)" + + override def equals(other: Any): Boolean = other match { + case otherBuf: ManagedBuffer => + val nioBuf = otherBuf.nioByteBuffer() + if (nioBuf.remaining() != len) { + return false + } else { + var i = 0 + while (i < len) { + if (nioBuf.get() != i) { + return false + } + i += 1 + } + return true + } + case _ => false + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala deleted file mode 100644 index 903ab09ae432..000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import java.nio.ByteBuffer - -import io.netty.buffer.Unpooled -import io.netty.channel.embedded.EmbeddedChannel - -import org.scalatest.{PrivateMethodTester, FunSuite} - - -class BlockFetchingClientHandlerSuite extends FunSuite with PrivateMethodTester { - - test("handling block data (successful fetch)") { - val blockId = "test_block" - val blockData = "blahblahblahblahblah" - val totalLength = 4 + blockId.length + blockData.length - - var parsedBlockId: String = "" - var parsedBlockData: String = "" - val handler = new BlockFetchingClientHandler - handler.addRequest(blockId, - new BlockClientListener { - override def onFetchFailure(blockId: String, errorMsg: String): Unit = ??? - override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer): Unit = { - parsedBlockId = bid - val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining) - refCntBuf.byteBuffer().get(bytes) - parsedBlockData = new String(bytes) - } - } - ) - - val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) - assert(handler.invokePrivate(outstandingRequests()).size === 1) - - val channel = new EmbeddedChannel(handler) - val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself - buf.putInt(totalLength) - buf.putInt(blockId.length) - buf.put(blockId.getBytes) - buf.put(blockData.getBytes) - buf.flip() - - channel.writeInbound(Unpooled.wrappedBuffer(buf)) - assert(parsedBlockId === blockId) - assert(parsedBlockData === blockData) - - assert(handler.invokePrivate(outstandingRequests()).size === 0) - - channel.close() - } - - test("handling error message (failed fetch)") { - val blockId = "test_block" - val errorMsg = "error erro5r error err4or error3 error6 error erro1r" - val totalLength = 4 + blockId.length + errorMsg.length - - var parsedBlockId: String = "" - var parsedErrorMsg: String = "" - val handler = new BlockFetchingClientHandler - handler.addRequest(blockId, new BlockClientListener { - override def onFetchFailure(bid: String, msg: String) ={ - parsedBlockId = bid - parsedErrorMsg = msg - } - override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer) = ??? - }) - - val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) - assert(handler.invokePrivate(outstandingRequests()).size === 1) - - val channel = new EmbeddedChannel(handler) - val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself - buf.putInt(totalLength) - buf.putInt(-blockId.length) - buf.put(blockId.getBytes) - buf.put(errorMsg.getBytes) - buf.flip() - - channel.writeInbound(Unpooled.wrappedBuffer(buf)) - assert(parsedBlockId === blockId) - assert(parsedErrorMsg === errorMsg) - - assert(handler.invokePrivate(outstandingRequests()).size === 0) - - channel.close() - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala deleted file mode 100644 index 3ee281cb1350..000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import io.netty.buffer.ByteBuf -import io.netty.channel.embedded.EmbeddedChannel - -import org.scalatest.FunSuite - - -class BlockHeaderEncoderSuite extends FunSuite { - - test("encode normal block data") { - val blockId = "test_block" - val channel = new EmbeddedChannel(new BlockHeaderEncoder) - channel.writeOutbound(new BlockHeader(17, blockId, None)) - val out = channel.readOutbound().asInstanceOf[ByteBuf] - assert(out.readInt() === 4 + blockId.length + 17) - assert(out.readInt() === blockId.length) - - val blockIdBytes = new Array[Byte](blockId.length) - out.readBytes(blockIdBytes) - assert(new String(blockIdBytes) === blockId) - assert(out.readableBytes() === 0) - - channel.close() - } - - test("encode error message") { - val blockId = "error_block" - val errorMsg = "error encountered" - val channel = new EmbeddedChannel(new BlockHeaderEncoder) - channel.writeOutbound(new BlockHeader(17, blockId, Some(errorMsg))) - val out = channel.readOutbound().asInstanceOf[ByteBuf] - assert(out.readInt() === 4 + blockId.length + errorMsg.length) - assert(out.readInt() === -blockId.length) - - val blockIdBytes = new Array[Byte](blockId.length) - out.readBytes(blockIdBytes) - assert(new String(blockIdBytes) === blockId) - - val errorMsgBytes = new Array[Byte](errorMsg.length) - out.readBytes(errorMsgBytes) - assert(new String(errorMsgBytes) === errorMsg) - assert(out.readableBytes() === 0) - - channel.close() - } -} diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala deleted file mode 100644 index 3239c710f163..000000000000 --- a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import java.io.{RandomAccessFile, File} -import java.nio.ByteBuffer - -import io.netty.buffer.{Unpooled, ByteBuf} -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler, DefaultFileRegion} -import io.netty.channel.embedded.EmbeddedChannel - -import org.scalatest.FunSuite - -import org.apache.spark.storage.{BlockDataProvider, FileSegment} - - -class BlockServerHandlerSuite extends FunSuite { - - test("ByteBuffer block") { - val expectedBlockId = "test_bytebuffer_block" - val buf = ByteBuffer.allocate(10000) - for (i <- 1 to 10000) { - buf.put(i.toByte) - } - buf.flip() - - val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = Right(buf) - })) - - channel.writeInbound(expectedBlockId) - assert(channel.outboundMessages().size === 2) - - val out1 = channel.readOutbound().asInstanceOf[BlockHeader] - val out2 = channel.readOutbound().asInstanceOf[ByteBuf] - - assert(out1.blockId === expectedBlockId) - assert(out1.blockSize === buf.remaining) - assert(out1.error === None) - - assert(out2.equals(Unpooled.wrappedBuffer(buf))) - - channel.close() - } - - test("FileSegment block via zero-copy") { - val expectedBlockId = "test_file_block" - - // Create random file data - val fileContent = new Array[Byte](1024) - scala.util.Random.nextBytes(fileContent) - val testFile = File.createTempFile("netty-test-file", "txt") - val fp = new RandomAccessFile(testFile, "rw") - fp.write(fileContent) - fp.close() - - val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = { - Left(new FileSegment(testFile, 15, testFile.length - 25)) - } - })) - - channel.writeInbound(expectedBlockId) - assert(channel.outboundMessages().size === 2) - - val out1 = channel.readOutbound().asInstanceOf[BlockHeader] - val out2 = channel.readOutbound().asInstanceOf[DefaultFileRegion] - - assert(out1.blockId === expectedBlockId) - assert(out1.blockSize === testFile.length - 25) - assert(out1.error === None) - - assert(out2.count === testFile.length - 25) - assert(out2.position === 15) - } - - test("pipeline exception propagation") { - val blockServerHandler = new BlockServerHandler(new BlockDataProvider { - override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = ??? - }) - val exceptionHandler = new SimpleChannelInboundHandler[String]() { - override def channelRead0(ctx: ChannelHandlerContext, msg: String): Unit = { - throw new Exception("this is an error") - } - } - - val channel = new EmbeddedChannel(exceptionHandler, blockServerHandler) - assert(channel.isOpen) - channel.writeInbound("a message to trigger the error") - assert(!channel.isOpen) - } -} From 9b3b3973af78d5bcc680f46f8f162fb3d4bd69f8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 9 Sep 2014 00:42:37 -0700 Subject: [PATCH 02/28] Use Epoll.isAvailable in BlockServer as well. --- .../apache/spark/network/netty/BlockServer.scala | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index 76f28aa00112..3433c5763ab3 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -21,14 +21,13 @@ import java.net.InetSocketAddress import io.netty.bootstrap.ServerBootstrap import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel} +import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollServerSocketChannel} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.oio.OioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.channel.socket.oio.OioServerSocketChannel import io.netty.channel.{ChannelInitializer, ChannelFuture, ChannelOption} -import io.netty.handler.codec.LengthFieldBasedFrameDecoder import org.apache.spark.{Logging, SparkConf} import org.apache.spark.network.BlockDataManager @@ -85,16 +84,7 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log case "nio" => initNio() case "oio" => initOio() case "epoll" => initEpoll() - case "auto" => - // For auto mode, first try epoll (only available on Linux), then nio. - try { - initEpoll() - } catch { - // TODO: Should we log the throwable? But that always happen on non-Linux systems. - // Perhaps the right thing to do is to check whether the system is Linux, and then only - // call initEpoll on Linux. - case e: Throwable => initNio() - } + case "auto" => if (Epoll.isAvailable) initEpoll() else initNio() } // Use pooled buffers to reduce temporary buffer allocation @@ -114,7 +104,7 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log bootstrap.childHandler(new ChannelInitializer[SocketChannel] { override def initChannel(ch: SocketChannel): Unit = { - val p = ch.pipeline + ch.pipeline .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) .addLast("clientRequestDecoder", new ClientRequestDecoder) .addLast("serverResponseEncoder", new ServerResponseEncoder) From dd783ffb35d227aab301387edce2af38ca4f947b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 9 Sep 2014 14:36:31 -0700 Subject: [PATCH 03/28] Added more documentation. --- .../spark/network/netty/BlockClient.scala | 61 +++++-------------- .../network/netty/BlockClientFactory.scala | 44 ++++++++++++- .../network/netty/BlockClientHandler.scala | 5 +- .../spark/network/netty/BlockServer.scala | 4 +- .../apache/spark/network/netty/protocol.scala | 19 +++++- .../netty/ServerClientIntegrationSuite.scala | 5 +- 6 files changed, 80 insertions(+), 58 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala index 95af2565bcc3..9333fefa9295 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala @@ -19,68 +19,35 @@ package org.apache.spark.network.netty import java.util.concurrent.TimeoutException -import io.netty.bootstrap.Bootstrap -import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.socket.SocketChannel -import io.netty.channel.{ChannelFuture, ChannelFutureListener, ChannelInitializer, ChannelOption} +import io.netty.channel.{ChannelFuture, ChannelFutureListener} import org.apache.spark.Logging import org.apache.spark.network.BlockFetchingListener /** - * Client for [[NettyBlockTransferService]]. Use [[BlockClientFactory]] to - * instantiate this client. + * Client for [[NettyBlockTransferService]]. The connection to server must have been established + * using [[BlockClientFactory]] before instantiating this. * - * The constructor blocks until a connection is successfully established. + * This class is used to make requests to the server , while [[BlockClientHandler]] is responsible + * for handling responses from the server. * * Concurrency: thread safe and can be called from multiple threads. + * + * @param cf the ChannelFuture for the connection. + * @param handler [[BlockClientHandler]] for handling outstanding requests. */ @throws[TimeoutException] private[netty] -class BlockClient(factory: BlockClientFactory, hostname: String, port: Int) - extends Logging { - - private val handler = new BlockClientHandler - private val encoder = new ClientRequestEncoder - private val decoder = new ServerResponseDecoder - - /** Netty Bootstrap for creating the TCP connection. */ - private val bootstrap: Bootstrap = { - val b = new Bootstrap - b.group(factory.workerGroup) - .channel(factory.socketChannelClass) - // Use pooled buffers to reduce temporary buffer allocation - .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - // Disable Nagle's Algorithm since we don't want packets to wait - .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) - .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) - .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, factory.conf.connectTimeoutMs) - - b.handler(new ChannelInitializer[SocketChannel] { - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("clientRequestEncoder", encoder) - .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) - .addLast("serverResponseDecoder", decoder) - .addLast("handler", handler) - } - }) - b - } +class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Logging { - /** Netty ChannelFuture for the connection. */ - private val cf: ChannelFuture = bootstrap.connect(hostname, port) - if (!cf.awaitUninterruptibly(factory.conf.connectTimeoutMs)) { - throw new TimeoutException( - s"Connecting to $hostname:$port timed out (${factory.conf.connectTimeoutMs} ms)") - } + private[this] val serverAddr = cf.channel().remoteAddress().toString /** * Ask the remote server for a sequence of blocks, and execute the callback. * * Note that this is asynchronous and returns immediately. Upstream caller should throttle the - * rate of fetching; otherwise we could run out of memory. + * rate of fetching; otherwise we could run out of memory due to large outstanding fetches. * * @param blockIds sequence of block ids to fetch. * @param listener callback to fire on fetch success / failure. @@ -89,7 +56,7 @@ class BlockClient(factory: BlockClientFactory, hostname: String, port: Int) var startTime: Long = 0 logTrace { startTime = System.nanoTime - s"Sending request $blockIds to $hostname:$port" + s"Sending request $blockIds to $serverAddr" } blockIds.foreach { blockId => @@ -101,12 +68,12 @@ class BlockClient(factory: BlockClientFactory, hostname: String, port: Int) if (future.isSuccess) { logTrace { val timeTaken = (System.nanoTime - startTime).toDouble / 1000000 - s"Sending request $blockIds to $hostname:$port took $timeTaken ms" + s"Sending request $blockIds to $serverAddr took $timeTaken ms" } } else { // Fail all blocks. val errorMsg = - s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}" + s"Failed to send request $blockIds to $serverAddr: ${future.cause.getMessage}" logError(errorMsg, future.cause) blockIds.foreach { blockId => handler.removeRequest(blockId) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala index 0777275cd4fe..f05f1419ded1 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala @@ -17,12 +17,17 @@ package org.apache.spark.network.netty +import java.util.concurrent.TimeoutException + +import io.netty.bootstrap.Bootstrap +import io.netty.buffer.PooledByteBufAllocator +import io.netty.channel._ import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollSocketChannel} import io.netty.channel.nio.NioEventLoopGroup import io.netty.channel.oio.OioEventLoopGroup +import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioSocketChannel import io.netty.channel.socket.oio.OioSocketChannel -import io.netty.channel.{Channel, EventLoopGroup} import org.apache.spark.SparkConf import org.apache.spark.util.Utils @@ -38,12 +43,16 @@ class BlockClientFactory(val conf: NettyConfig) { def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) /** A thread factory so the threads are named (for debugging). */ - private[netty] val threadFactory = Utils.namedThreadFactory("spark-shuffle-client") + private[netty] val threadFactory = Utils.namedThreadFactory("spark-netty-client") /** The following two are instantiated by the [[init]] method, depending ioMode. */ private[netty] var socketChannelClass: Class[_ <: Channel] = _ private[netty] var workerGroup: EventLoopGroup = _ + // The encoders are stateless and can be shared among multiple clients. + private[this] val encoder = new ClientRequestEncoder + private[this] val decoder = new ServerResponseDecoder + init() /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */ @@ -78,7 +87,36 @@ class BlockClientFactory(val conf: NettyConfig) { * Concurrency: This method is safe to call from multiple threads. */ def createClient(remoteHost: String, remotePort: Int): BlockClient = { - new BlockClient(this, remoteHost, remotePort) + val handler = new BlockClientHandler + + val bootstrap = new Bootstrap + bootstrap.group(workerGroup) + .channel(socketChannelClass) + // Use pooled buffers to reduce temporary buffer allocation + .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) + // Disable Nagle's Algorithm since we don't want packets to wait + .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) + .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) + .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectTimeoutMs) + + bootstrap.handler(new ChannelInitializer[SocketChannel] { + override def initChannel(ch: SocketChannel): Unit = { + ch.pipeline + .addLast("clientRequestEncoder", encoder) + .addLast("frameDecoder", ProtocolUtils.createFrameDecoder()) + .addLast("serverResponseDecoder", decoder) + .addLast("handler", handler) + } + }) + + // Connect to the remote server + val cf: ChannelFuture = bootstrap.connect(remoteHost, remotePort) + if (!cf.awaitUninterruptibly(conf.connectTimeoutMs)) { + throw new TimeoutException( + s"Connecting to $remoteHost:$remotePort timed out (${conf.connectTimeoutMs} ms)") + } + + new BlockClient(cf, handler) } def stop(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala index b41c831f3d7e..2a474cd71eab 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala @@ -24,7 +24,8 @@ import org.apache.spark.network.BlockFetchingListener /** - * Handler that processes server responses. + * Handler that processes server responses, in response to requests issued from [[BlockClient]]. + * It works by tracking the list of outstanding requests (and their callbacks). * * Concurrency: thread safe and can be called from multiple threads. */ @@ -32,7 +33,7 @@ private[netty] class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] with Logging { /** Tracks the list of outstanding requests and their listeners on success/failure. */ - private val outstandingRequests = java.util.Collections.synchronizedMap { + private[this] val outstandingRequests = java.util.Collections.synchronizedMap { new java.util.HashMap[String, BlockFetchingListener] } diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index 3433c5763ab3..05443a74094d 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -58,8 +58,8 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log /** Initialize the server. */ private def init(): Unit = { bootstrap = new ServerBootstrap - val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss") - val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker") + val bossThreadFactory = Utils.namedThreadFactory("spark-netty-server-boss") + val workerThreadFactory = Utils.namedThreadFactory("spark-netty-server-worker") // Use only one thread to accept connections, and 2 * num_cores for worker. def initNio(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala index 0159eca1d3b4..ac6a4d00f654 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala @@ -28,29 +28,40 @@ import org.apache.spark.Logging import org.apache.spark.network.{NettyByteBufManagedBuffer, ManagedBuffer} +/** Messages from the client to the server. */ sealed trait ClientRequest { def id: Byte } +/** + * Request to fetch a sequence of blocks from the server. A single [[BlockFetchRequest]] can + * correspond to multiple [[ServerResponse]]s. + */ final case class BlockFetchRequest(blocks: Seq[String]) extends ClientRequest { override def id = 0 } +/** + * Request to upload a block to the server. Currently the server does not ack the upload request. + */ final case class BlockUploadRequest(blockId: String, data: ManagedBuffer) extends ClientRequest { require(blockId.length <= Byte.MaxValue) override def id = 1 } +/** Messages from server to client (usually in response to some [[ClientRequest]]. */ sealed trait ServerResponse { def id: Byte } +/** Response to [[BlockFetchRequest]] when a block exists and has been successfully fetched. */ final case class BlockFetchSuccess(blockId: String, data: ManagedBuffer) extends ServerResponse { require(blockId.length <= Byte.MaxValue) override def id = 0 } +/** Response to [[BlockFetchRequest]] when there is an error fetching the block. */ final case class BlockFetchFailure(blockId: String, error: String) extends ServerResponse { require(blockId.length <= Byte.MaxValue) override def id = 1 @@ -58,7 +69,9 @@ final case class BlockFetchFailure(blockId: String, error: String) extends Serve /** - * Encoder used by the client side to encode client-to-server responses. + * Encoder for [[ClientRequest]] used in client side. + * + * This encoder is stateless so it is safe to be shared by multiple threads. */ @Sharable final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] { @@ -109,6 +122,7 @@ final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] /** * Decoder in the server side to decode client requests. + * This decoder is stateless so it is safe to be shared by multiple threads. * * This assumes the inbound messages have been processed by a frame decoder created by * [[ProtocolUtils.createFrameDecoder()]]. @@ -138,6 +152,7 @@ final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] { /** * Encoder used by the server side to encode server-to-client responses. + * This encoder is stateless so it is safe to be shared by multiple threads. */ @Sharable final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse] with Logging { @@ -190,6 +205,7 @@ final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse /** * Decoder in the client side to decode server responses. + * This decoder is stateless so it is safe to be shared by multiple threads. * * This assumes the inbound messages have been processed by a frame decoder created by * [[ProtocolUtils.createFrameDecoder()]]. @@ -229,6 +245,7 @@ private[netty] object ProtocolUtils { new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 8, -8, 8) } + // TODO(rxin): Make sure these work for all charsets. def readBlockId(in: ByteBuf): String = { val numBytesToRead = in.readByte().toInt val bytes = new Array[Byte](numBytesToRead) diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index a468764fb184..178c60a048b9 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.storage.StorageLevel /** -* Test suite that makes sure the server and the client implementations share the same protocol. +* Test cases that create real clients and servers and connect. */ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { @@ -93,8 +93,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { /** A ByteBuf for file_block */ lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25) - def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ManagedBuffer], Set[String]) = - { + def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ManagedBuffer], Set[String]) = { val client = clientFactory.createClient(server.hostName, server.port) val sem = new Semaphore(0) val receivedBlockIds = Collections.synchronizedSet(new HashSet[String]) From b5b380ed9c50819a02e778c981670603102623e2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 9 Sep 2014 23:38:38 -0700 Subject: [PATCH 04/28] Reference count buffers and clean them up properly. --- .../scala/org/apache/spark/SparkEnv.scala | 9 +- .../spark/network/BlockDataManager.scala | 7 +- .../apache/spark/network/ManagedBuffer.scala | 41 ++++++- .../spark/network/netty/BlockServer.scala | 7 +- .../network/netty/BlockServerHandler.scala | 33 +++--- .../network/nio/NioBlockTransferService.scala | 3 +- .../apache/spark/storage/BlockManager.scala | 8 +- .../storage/ShuffleBlockFetcherIterator.scala | 111 ++++++++++++++---- .../netty/ServerClientIntegrationSuite.scala | 12 +- .../network/netty/TestManagedBuffer.scala | 4 + 10 files changed, 164 insertions(+), 71 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index dd95e406f2a8..5c8477a6e334 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -32,6 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer @@ -39,6 +40,7 @@ import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ import org.apache.spark.util.{AkkaUtils, Utils} + /** * :: DeveloperApi :: * Holds all the runtime environment objects for a running Spark instance (either master or worker), @@ -226,7 +228,12 @@ object SparkEnv extends Logging { val shuffleMemoryManager = new ShuffleMemoryManager(conf) - val blockTransferService = new NioBlockTransferService(conf, securityManager) + // TODO(rxin): Config option based on class name, similar to shuffle mgr and compression codec. + val blockTransferService = if (conf.getBoolean("spark.shuffle.use.netty", false)) { + new NettyBlockTransferService(conf) + } else { + new NioBlockTransferService(conf, securityManager) + } val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index e0e91724271c..638e05f481f5 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -23,11 +23,10 @@ import org.apache.spark.storage.StorageLevel trait BlockDataManager { /** - * Interface to get local block data. - * - * @return Some(buffer) if the block exists locally, and None if it doesn't. + * Interface to get local block data. Throws an exception if the block cannot be found or + * cannot be read successfully. */ - def getBlockData(blockId: String): Option[ManagedBuffer] + def getBlockData(blockId: String): ManagedBuffer /** * Put the block locally, using the given storage level. diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index d6e1216c0e4d..454dd477b455 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -32,9 +32,14 @@ import org.apache.spark.util.ByteBufferInputStream * This interface provides an immutable view for data in the form of bytes. The implementation * should specify how the data is provided: * - * - FileSegmentManagedBuffer: data backed by part of a file - * - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer - * - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf + * - [[FileSegmentManagedBuffer]]: data backed by part of a file + * - [[NioByteBufferManagedBuffer]]: data backed by a NIO ByteBuffer + * - [[NettyByteBufManagedBuffer]]: data backed by a Netty ByteBuf + * + * The concrete buffer implementation might be managed outside the JVM garbage collector. + * For example, in the case of [[NettyByteBufManagedBuffer]], the buffers are reference counted. + * In that case, if the buffer is going to be passed around to a different thread, retain/release + * should be called. */ abstract class ManagedBuffer { // Note that all the methods are defined with parenthesis because their implementations can @@ -56,6 +61,17 @@ abstract class ManagedBuffer { */ def inputStream(): InputStream + /** + * Increment the reference count by one if applicable. + */ + def retain(): this.type + + /** + * If applicable, decrement the reference count by one and deallocates the buffer if the + * reference count reaches zero. + */ + def release(): this.type + /** * Convert the buffer into an Netty object, used to write the data out. */ @@ -86,6 +102,10 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt val fileChannel = new FileInputStream(file).getChannel new DefaultFileRegion(fileChannel, offset, length) } + + // Content of file segments are not in-memory, so no need to reference count. + override def retain(): this.type = this + override def release(): this.type = this } @@ -101,6 +121,10 @@ final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { override def inputStream() = new ByteBufferInputStream(buf) private[network] override def convertToNetty(): AnyRef = Unpooled.wrappedBuffer(buf) + + // [[ByteBuffer]] is managed by the JVM garbage collector itself. + override def retain(): this.type = this + override def release(): this.type = this } @@ -117,6 +141,13 @@ final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer { private[network] override def convertToNetty(): AnyRef = buf - // TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it. - def release(): Unit = buf.release() + override def retain(): this.type = { + buf.retain() + this + } + + override def release(): this.type = { + buf.release() + this + } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index 05443a74094d..ceae31efac93 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -40,10 +40,6 @@ import org.apache.spark.util.Utils private[netty] class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Logging { - def this(sparkConf: SparkConf, dataProvider: BlockDataManager) = { - this(new NettyConfig(sparkConf), dataProvider) - } - def port: Int = _port def hostName: String = _hostName @@ -117,7 +113,8 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress] _port = addr.getPort - _hostName = addr.getHostName + //_hostName = addr.getHostName + _hostName = Utils.localHostName() } /** Shutdown the server. */ diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala index 739526a4fc6b..c3b4d41829f4 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala @@ -66,9 +66,9 @@ private[netty] class BlockServerHandler(dataProvider: BlockDataManager) logTrace(s"Received request from $client to fetch block $blockId") // First make sure we can find the block. If not, send error back to the user. - var blockData: Option[ManagedBuffer] = null + var buf: ManagedBuffer = null try { - blockData = dataProvider.getBlockData(blockId) + buf = dataProvider.getBlockData(blockId) } catch { case e: Exception => logError(s"Error opening block $blockId for request from $client", e) @@ -76,23 +76,18 @@ private[netty] class BlockServerHandler(dataProvider: BlockDataManager) return } - blockData match { - case Some(buf) => - ctx.writeAndFlush(new BlockFetchSuccess(blockId, buf)).addListener( - new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (future.isSuccess) { - logTrace(s"Sent block $blockId (${buf.size} B) back to $client") - } else { - logError( - s"Error sending block $blockId to $client; closing connection", future.cause) - ctx.close() - } - } + ctx.writeAndFlush(new BlockFetchSuccess(blockId, buf)).addListener( + new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (future.isSuccess) { + logTrace(s"Sent block $blockId (${buf.size} B) back to $client") + } else { + logError( + s"Error sending block $blockId to $client; closing connection", future.cause) + ctx.close() } - ) - case None => - respondWithError("Block not found") - } + } + } + ) } // end of processBlockRequest } diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index 59958ee89423..a8c396833af8 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -197,7 +197,8 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa private def getBlock(blockId: String): ByteBuffer = { val startTimeMs = System.currentTimeMillis() logDebug("GetBlock " + blockId + " started from " + startTimeMs) - val buffer = blockDataManager.getBlockData(blockId).orNull + // TODO(rxin): propagate error back to the client? + val buffer = blockDataManager.getBlockData(blockId) logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + " and got buffer " + buffer) buffer.nioByteBuffer() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d1bee3d2c033..c1b1c02e0059 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -210,17 +210,17 @@ private[spark] class BlockManager( * * @return Some(buffer) if the block exists locally, and None if it doesn't. */ - override def getBlockData(blockId: String): Option[ManagedBuffer] = { + override def getBlockData(blockId: String): ManagedBuffer = { val bid = BlockId(blockId) if (bid.isShuffle) { - Some(shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId])) + shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId]) } else { val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] if (blockBytesOpt.isDefined) { val buffer = blockBytesOpt.get - Some(new NioByteBufferManagedBuffer(buffer)) + new NioByteBufferManagedBuffer(buffer) } else { - None + throw new BlockNotFoundException(blockId) } } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index c8e708aa6b1b..92eb72089323 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -23,10 +23,10 @@ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashSet import scala.collection.mutable.Queue -import org.apache.spark.{TaskContext, Logging, SparkException} +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService} import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils +import org.apache.spark.util.{CompletionIterator, Utils} /** @@ -88,17 +88,49 @@ final class ShuffleBlockFetcherIterator( */ private[this] val results = new LinkedBlockingQueue[FetchResult] - // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that - // the number of bytes in flight is limited to maxBytesInFlight + /** + * Current [[FetchResult]] being processed. We track this so we can release the current buffer + * in case of a runtime exception when processing the current buffer. + */ + private[this] var currentResult: FetchResult = null + + /** + * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that + * the number of bytes in flight is limited to maxBytesInFlight. + */ private[this] val fetchRequests = new Queue[FetchRequest] - // Current bytes in flight from our requests + /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + /** + * Whether the iterator is still active. If isZombie is true, the callback interface will no + * longer place fetched blocks into [[results]]. + */ + @volatile private[this] var isZombie = false + initialize() + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[this] def cleanup() { + isZombie = true + // Release the current buffer if necessary + if (currentResult != null && currentResult.buf != null) { + currentResult.buf.release() + } + + // Release buffers in the results queue + val iter = results.iterator() + while (iter.hasNext) { + val result = iter.next() + result.buf.release() + } + } + private[this] def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) @@ -110,13 +142,17 @@ final class ShuffleBlockFetcherIterator( blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds, new BlockFetchingListener { - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), - () => serializer.newInstance().deserializeStream( - blockManager.wrapForCompression(BlockId(blockId), data.inputStream())).asIterator - )) - shuffleMetrics.remoteBytesRead += data.size - shuffleMetrics.remoteBlocksFetched += 1 + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), buf)) + shuffleMetrics.remoteBytesRead += buf.size + shuffleMetrics.remoteBlocksFetched += 1 + } logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } @@ -138,7 +174,7 @@ final class ShuffleBlockFetcherIterator( // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 // nodes, rather than blocking on reading output from one node. val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) + logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. @@ -185,26 +221,34 @@ final class ShuffleBlockFetcherIterator( remoteRequests } + /** + * Fetch the local blocks while we are fetching remote blocks. This is ok because + * [[ManagedBuffer]]'s memory is allocated lazily when we create the input stream, so all we + * track in-memory are the ManagedBuffer references themselves. + */ private[this] def fetchLocalBlocks() { - // Get the local blocks while remote blocks are being fetched. Note that it's okay to do - // these all at once because they will just memory-map some files, so they won't consume - // any memory that might exceed our maxBytesInFlight - for (id <- localBlocks) { + val iter = localBlocks.iterator + while (iter.hasNext) { + val blockId = iter.next() try { + val buf = blockManager.getBlockData(blockId.toString) shuffleMetrics.localBlocksFetched += 1 - results.put(new FetchResult( - id, 0, () => blockManager.getLocalShuffleFromDisk(id, serializer).get)) - logDebug("Got local block " + id) + buf.retain() + results.put(new FetchResult(blockId, 0, buf)) } catch { case e: Exception => + // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) - results.put(new FetchResult(id, -1, null)) + results.put(new FetchResult(blockId, -1, null)) return } } } private[this] def initialize(): Unit = { + // Add a task completion callback (called in both success case and failure case) to cleanup. + context.addTaskCompletionListener(_ => cleanup()) + // Split local and remote blocks. val remoteRequests = splitLocalRemoteBlocks() // Add the remote requests into our queue in a random order @@ -229,7 +273,8 @@ final class ShuffleBlockFetcherIterator( override def next(): (BlockId, Option[Iterator[Any]]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() - val result = results.take() + currentResult = results.take() + val result = currentResult val stopFetchWait = System.currentTimeMillis() shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) if (!result.failed) { @@ -240,7 +285,21 @@ final class ShuffleBlockFetcherIterator( (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { sendRequest(fetchRequests.dequeue()) } - (result.blockId, if (result.failed) None else Some(result.deserialize())) + + val iteratorOpt: Option[Iterator[Any]] = if (result.failed) { + None + } else { + val is = blockManager.wrapForCompression(result.blockId, result.buf.inputStream()) + val iter = serializer.newInstance().deserializeStream(is).asIterator + Some(CompletionIterator[Any, Iterator[Any]](iter, { + // Once the iterator is exhausted, release the buffer and set currentResult to null + // so we don't release it again in cleanup. + currentResult = null + result.buf.release() + })) + } + + (result.blockId, iteratorOpt) } } @@ -262,10 +321,10 @@ object ShuffleBlockFetcherIterator { * Result of a fetch from a remote block. A failure is represented as size == -1. * @param blockId block id * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. - * @param deserialize closure to return the result in the form of an Iterator. + * Note that this is NOT the exact bytes. -1 if failure is present. + * @param buf [[ManagedBuffer]] for the content. null is error. */ - class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { + class FetchResult(val blockId: BlockId, val size: Long, val buf: ManagedBuffer) { def failed: Boolean = size == -1 } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index 178c60a048b9..72d7c4b53109 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -30,7 +30,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.apache.spark.SparkConf import org.apache.spark.network._ -import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.{BlockNotFoundException, StorageLevel} /** @@ -62,14 +62,14 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { fp.write(fileContent) fp.close() - server = new BlockServer(new SparkConf, new BlockDataManager { - override def getBlockData(blockId: String): Option[ManagedBuffer] = { + server = new BlockServer(new NettyConfig(new SparkConf), new BlockDataManager { + override def getBlockData(blockId: String): ManagedBuffer = { if (blockId == bufferBlockId) { - Some(new NioByteBufferManagedBuffer(buf)) + new NioByteBufferManagedBuffer(buf) } else if (blockId == fileBlockId) { - Some(new FileSegmentManagedBuffer(testFile, 10, testFile.length - 25)) + new FileSegmentManagedBuffer(testFile, 10, testFile.length - 25) } else { - None + throw new BlockNotFoundException(blockId) } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala index 6ae2d3b3faf9..1d13fd92e1f2 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala @@ -65,4 +65,8 @@ class TestManagedBuffer(len: Int) extends ManagedBuffer { } case _ => false } + + override def retain(): this.type = this + + override def release(): this.type = this } From 1474824478df006b5ed3288fd9a6eb5cd0504086 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 10 Sep 2014 01:09:44 -0700 Subject: [PATCH 05/28] Fixed ShuffleBlockFetcherIteratorSuite. --- .../apache/spark/storage/BlockManager.scala | 5 +- .../netty/ServerClientIntegrationSuite.scala | 30 ++-- .../ShuffleBlockFetcherIteratorSuite.scala | 134 ++++-------------- 3 files changed, 47 insertions(+), 122 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index c1b1c02e0059..06d6ee68d185 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -206,9 +206,8 @@ private[spark] class BlockManager( } /** - * Interface to get local block data. - * - * @return Some(buffer) if the block exists locally, and None if it doesn't. + * Interface to get local block data. Throws an exception if the block cannot be found or + * cannot be read successfully. */ override def getBlockData(blockId: String): ManagedBuffer = { val bid = BlockId(blockId) diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index 72d7c4b53109..3dacc0fb69be 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -1,19 +1,19 @@ /* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.network.netty diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 809bd7092965..d4c5e7bc39b8 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,9 +17,6 @@ package org.apache.spark.storage -import org.apache.spark.TaskContext -import org.apache.spark.network.{BlockFetchingListener, BlockTransferService} - import org.mockito.Mockito._ import org.mockito.Matchers.{any, eq => meq} import org.mockito.invocation.InvocationOnMock @@ -27,126 +24,55 @@ import org.mockito.stubbing.Answer import org.scalatest.FunSuite +import org.apache.spark.{SparkConf, TaskContext} +import org.apache.spark.network._ +import org.apache.spark.serializer.TestSerializer -class ShuffleBlockFetcherIteratorSuite extends FunSuite { - - test("handle local read failures in BlockManager") { - val transfer = mock(classOf[BlockTransferService]) - val blockManager = mock(classOf[BlockManager]) - doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId - val blIds = Array[BlockId]( - ShuffleBlockId(0,0,0), - ShuffleBlockId(0,1,0), - ShuffleBlockId(0,2,0), - ShuffleBlockId(0,3,0), - ShuffleBlockId(0,4,0)) - - val optItr = mock(classOf[Option[Iterator[Any]]]) - val answer = new Answer[Option[Iterator[Any]]] { - override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] { - throw new Exception - } - } - - // 3rd block is going to fail - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any()) - doAnswer(answer).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any()) - - val bmId = BlockManagerId("test-client", "test-client", 1) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) - ) - - val iterator = new ShuffleBlockFetcherIterator( - new TaskContext(0, 0, 0), - transfer, - blockManager, - blocksByAddress, - null, - 48 * 1024 * 1024) +class ShuffleBlockFetcherIteratorSuite extends FunSuite { - // Without exhausting the iterator, the iterator should be lazy and not call - // getLocalShuffleFromDisk. - verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") - // the 2nd element of the tuple returned by iterator.next should be defined when - // fetching successfully - assert(iterator.next()._2.isDefined, - "1st element should be defined but is not actually defined") - verify(blockManager, times(1)).getLocalShuffleFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") - assert(iterator.next()._2.isDefined, - "2nd element should be defined but is not actually defined") - verify(blockManager, times(2)).getLocalShuffleFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") - // 3rd fetch should be failed - intercept[Exception] { - iterator.next() - } - verify(blockManager, times(3)).getLocalShuffleFromDisk(any(), any()) - } + val conf = new SparkConf - test("handle local read successes") { - val transfer = mock(classOf[BlockTransferService]) + test("handle successful local reads") { + val buf = mock(classOf[ManagedBuffer]) val blockManager = mock(classOf[BlockManager]) doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId - val blIds = Array[BlockId]( - ShuffleBlockId(0,0,0), - ShuffleBlockId(0,1,0), - ShuffleBlockId(0,2,0), - ShuffleBlockId(0,3,0), - ShuffleBlockId(0,4,0)) - - val optItr = mock(classOf[Option[Iterator[Any]]]) + val blockIds = Array[BlockId]( + ShuffleBlockId(0, 0, 0), + ShuffleBlockId(0, 1, 0), + ShuffleBlockId(0, 2, 0), + ShuffleBlockId(0, 3, 0), + ShuffleBlockId(0, 4, 0)) // All blocks should be fetched successfully - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any()) - doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any()) + blockIds.foreach { blockId => + doReturn(buf).when(blockManager).getBlockData(meq(blockId.toString)) + } val bmId = BlockManagerId("test-client", "test-client", 1) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq) + (bmId, blockIds.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) ) val iterator = new ShuffleBlockFetcherIterator( new TaskContext(0, 0, 0), - transfer, + mock(classOf[BlockTransferService]), blockManager, blocksByAddress, - null, + new TestSerializer, 48 * 1024 * 1024) - // Without exhausting the iterator, the iterator should be lazy and not call getLocalShuffleFromDisk. - verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any()) - - assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 1st element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 2nd element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 3rd element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 4th element is not actually defined") - assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements") - assert(iterator.next()._2.isDefined, - "All elements should be defined but 5th element is not actually defined") - - verify(blockManager, times(5)).getLocalShuffleFromDisk(any(), any()) + // Local blocks are fetched immediately. + verify(blockManager, times(5)).getBlockData(any()) + + for (i <- 0 until 5) { + assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") + assert(iterator.next()._2.isDefined, + s"iterator should have 5 elements defined but actually has $i elements") + } + // No more fetching of local blocks. + verify(blockManager, times(5)).getBlockData(any()) } test("handle remote fetch failures in BlockTransferService") { @@ -173,7 +99,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { transfer, blockManager, blocksByAddress, - null, + new TestSerializer, 48 * 1024 * 1024) iterator.foreach { case (_, iterOption) => From b404da3b4b03fe409fa0e3e3af83eb114d8d4a46 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 10 Sep 2014 01:10:04 -0700 Subject: [PATCH 06/28] Forgot to add TestSerializer to the commit list. --- .../spark/serializer/TestSerializer.scala | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala new file mode 100644 index 000000000000..0ade1bab18d7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{EOFException, OutputStream, InputStream} +import java.nio.ByteBuffer + +import scala.reflect.ClassTag + + +/** + * A serializer implementation that always return a single element in a deserialization stream. + */ +class TestSerializer extends Serializer { + override def newInstance() = new TestSerializerInstance +} + + +class TestSerializerInstance extends SerializerInstance { + override def serialize[T: ClassTag](t: T): ByteBuffer = ??? + + override def serializeStream(s: OutputStream): SerializationStream = ??? + + override def deserializeStream(s: InputStream) = new TestDeserializationStream + + override def deserialize[T: ClassTag](bytes: ByteBuffer): T = ??? + + override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = ??? +} + + +class TestDeserializationStream extends DeserializationStream { + + private var count = 0 + + override def readObject[T: ClassTag](): T = { + count += 1 + if (count == 2) { + throw new EOFException + } + new Object().asInstanceOf[T] + } + + override def close(): Unit = {} +} From fbf882daf121b41ca9ba2790062c03031325026f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 10 Sep 2014 01:11:40 -0700 Subject: [PATCH 07/28] Shorten NioManagedBuffer and NettyManagedBuffer class names. --- .../scala/org/apache/spark/network/ManagedBuffer.scala | 10 +++++----- .../org/apache/spark/network/netty/protocol.scala | 6 +++--- .../spark/network/nio/NioBlockTransferService.scala | 4 ++-- .../scala/org/apache/spark/storage/BlockManager.scala | 4 ++-- .../spark/network/netty/BlockClientHandlerSuite.scala | 4 ++-- .../network/netty/ServerClientIntegrationSuite.scala | 2 +- .../apache/spark/network/netty/TestManagedBuffer.scala | 4 ++-- 7 files changed, 17 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index 454dd477b455..9e1c83197eab 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -33,11 +33,11 @@ import org.apache.spark.util.ByteBufferInputStream * should specify how the data is provided: * * - [[FileSegmentManagedBuffer]]: data backed by part of a file - * - [[NioByteBufferManagedBuffer]]: data backed by a NIO ByteBuffer - * - [[NettyByteBufManagedBuffer]]: data backed by a Netty ByteBuf + * - [[NioManagedBuffer]]: data backed by a NIO ByteBuffer + * - [[NettyManagedBuffer]]: data backed by a Netty ByteBuf * * The concrete buffer implementation might be managed outside the JVM garbage collector. - * For example, in the case of [[NettyByteBufManagedBuffer]], the buffers are reference counted. + * For example, in the case of [[NettyManagedBuffer]], the buffers are reference counted. * In that case, if the buffer is going to be passed around to a different thread, retain/release * should be called. */ @@ -112,7 +112,7 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt /** * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]]. */ -final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { +final class NioManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { override def size: Long = buf.remaining() @@ -131,7 +131,7 @@ final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { /** * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]]. */ -final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer { +final class NettyManagedBuffer(buf: ByteBuf) extends ManagedBuffer { override def size: Long = buf.readableBytes() diff --git a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala index ac6a4d00f654..ac9d2097c93e 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala @@ -25,7 +25,7 @@ import io.netty.channel.ChannelHandler.Sharable import io.netty.handler.codec._ import org.apache.spark.Logging -import org.apache.spark.network.{NettyByteBufManagedBuffer, ManagedBuffer} +import org.apache.spark.network.{NettyManagedBuffer, ManagedBuffer} /** Messages from the client to the server. */ @@ -141,7 +141,7 @@ final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] { case 1 => // BlockUploadRequest val blockId = ProtocolUtils.readBlockId(in) in.retain() // retain the bytebuf so we don't recycle it immediately. - BlockUploadRequest(blockId, new NettyByteBufManagedBuffer(in)) + BlockUploadRequest(blockId, new NettyManagedBuffer(in)) } assert(decoded.id == msgTypeId) @@ -218,7 +218,7 @@ final class ServerResponseDecoder extends MessageToMessageDecoder[ByteBuf] { case 0 => // BlockFetchSuccess val blockId = ProtocolUtils.readBlockId(in) in.retain() - new BlockFetchSuccess(blockId, new NettyByteBufManagedBuffer(in)) + new BlockFetchSuccess(blockId, new NettyManagedBuffer(in)) case 1 => // BlockFetchFailure val blockId = ProtocolUtils.readBlockId(in) diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index a8c396833af8..6f73d3fa930c 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -104,7 +104,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa val blockId = blockMessage.getId val networkSize = blockMessage.getData.limit() listener.onBlockFetchSuccess( - blockId.toString, new NioByteBufferManagedBuffer(blockMessage.getData)) + blockId.toString, new NioManagedBuffer(blockMessage.getData)) } } }(cm.futureExecContext) @@ -189,7 +189,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) { val startTimeMs = System.currentTimeMillis() logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes) - blockDataManager.putBlockData(blockId, new NioByteBufferManagedBuffer(bytes), level) + blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level) logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + " with data size: " + bytes.limit) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 06d6ee68d185..caaf5cfccd06 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -217,7 +217,7 @@ private[spark] class BlockManager( val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] if (blockBytesOpt.isDefined) { val buffer = blockBytesOpt.get - new NioByteBufferManagedBuffer(buffer) + new NioManagedBuffer(buffer) } else { throw new BlockNotFoundException(blockId) } @@ -803,7 +803,7 @@ private[spark] class BlockManager( try { blockTransferService.uploadBlockSync( - peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel) + peer.host, peer.port, blockId.toString, new NioManagedBuffer(data), tLevel) } catch { case e: Exception => logError(s"Failed to replicate block to $peer", e) diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala index 1358b2f9c807..f2ed404ed8d4 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala @@ -64,7 +64,7 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { buf.put(blockData.getBytes) buf.flip() - channel.writeInbound(BlockFetchSuccess(blockId, new NioByteBufferManagedBuffer(buf))) + channel.writeInbound(BlockFetchSuccess(blockId, new NioManagedBuffer(buf))) assert(parsedBlockId === blockId) assert(parsedBlockData === blockData) @@ -119,7 +119,7 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { assert(sizeOfOutstandingRequests(handler) === 3) val channel = new EmbeddedChannel(handler) - channel.writeInbound(BlockFetchSuccess("b1", new NettyByteBufManagedBuffer(Unpooled.buffer()))) + channel.writeInbound(BlockFetchSuccess("b1", new NettyManagedBuffer(Unpooled.buffer()))) // Need to figure out a way to generate an exception assert(successCount.get() === 1) assert(errorCount.get() === 2) diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index 3dacc0fb69be..fa3512768d9a 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -65,7 +65,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { server = new BlockServer(new NettyConfig(new SparkConf), new BlockDataManager { override def getBlockData(blockId: String): ManagedBuffer = { if (blockId == bufferBlockId) { - new NioByteBufferManagedBuffer(buf) + new NioManagedBuffer(buf) } else if (blockId == fileBlockId) { new FileSegmentManagedBuffer(testFile, 10, testFile.length - 25) } else { diff --git a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala index 1d13fd92e1f2..e47e4d03fa89 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/TestManagedBuffer.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import io.netty.buffer.Unpooled -import org.apache.spark.network.{NettyByteBufManagedBuffer, ManagedBuffer} +import org.apache.spark.network.{NettyManagedBuffer, ManagedBuffer} /** @@ -36,7 +36,7 @@ class TestManagedBuffer(len: Int) extends ManagedBuffer { private val byteArray: Array[Byte] = Array.tabulate[Byte](len)(_.toByte) - private val underlying = new NettyByteBufManagedBuffer(Unpooled.wrappedBuffer(byteArray)) + private val underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray)) override def size: Long = underlying.size From b32c3fe568163614a6eb424e523f7dd545d8ce9e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 10 Sep 2014 19:01:23 -0700 Subject: [PATCH 08/28] Added more test cases covering cleanup when fault happens in ShuffleBlockFetcherIteratorSuite --- .../storage/ShuffleBlockFetcherIterator.scala | 11 +- .../ShuffleBlockFetcherIteratorSuite.scala | 189 +++++++++++++++--- 2 files changed, 164 insertions(+), 36 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 92eb72089323..1d3e650cb059 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -119,7 +119,7 @@ final class ShuffleBlockFetcherIterator( private[this] def cleanup() { isZombie = true // Release the current buffer if necessary - if (currentResult != null && currentResult.buf != null) { + if (currentResult != null && !currentResult.failed) { currentResult.buf.release() } @@ -127,7 +127,9 @@ final class ShuffleBlockFetcherIterator( val iter = results.iterator() while (iter.hasNext) { val result = iter.next() - result.buf.release() + if (!result.failed) { + result.buf.release() + } } } @@ -313,7 +315,7 @@ object ShuffleBlockFetcherIterator { * @param blocks Sequence of tuple, where the first element is the block id, * and the second element is the estimated size, used to calculate bytesInFlight. */ - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { + case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) { val size = blocks.map(_._2).sum } @@ -324,7 +326,8 @@ object ShuffleBlockFetcherIterator { * Note that this is NOT the exact bytes. -1 if failure is present. * @param buf [[ManagedBuffer]] for the content. null is error. */ - class FetchResult(val blockId: BlockId, val size: Long, val buf: ManagedBuffer) { + case class FetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) { def failed: Boolean = size == -1 + if (failed) assert(buf == null) else assert(buf != null) } } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index d4c5e7bc39b8..b4700f38a678 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,6 +17,11 @@ package org.apache.spark.storage +import java.util.concurrent.Semaphore + +import scala.concurrent.future +import scala.concurrent.ExecutionContext.Implicits.global + import org.mockito.Mockito._ import org.mockito.Matchers.{any, eq => meq} import org.mockito.invocation.InvocationOnMock @@ -30,80 +35,200 @@ import org.apache.spark.serializer.TestSerializer class ShuffleBlockFetcherIteratorSuite extends FunSuite { + // Some of the tests are quite tricky because we are testing the cleanup behavior + // in the presence of faults. - val conf = new SparkConf + /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ + private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { + val transfer = mock(classOf[BlockTransferService]) + when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val blocks = invocation.getArguments()(2).asInstanceOf[Seq[String]] + val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener] - test("handle successful local reads") { - val buf = mock(classOf[ManagedBuffer]) - val blockManager = mock(classOf[BlockManager]) - doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId + for (blockId <- blocks) { + if (data.contains(BlockId(blockId))) { + listener.onBlockFetchSuccess(blockId, data(BlockId(blockId))) + } else { + listener.onBlockFetchFailure(new BlockNotFoundException(blockId)) + } + } + } + }) + transfer + } - val blockIds = Array[BlockId]( - ShuffleBlockId(0, 0, 0), - ShuffleBlockId(0, 1, 0), - ShuffleBlockId(0, 2, 0), - ShuffleBlockId(0, 3, 0), - ShuffleBlockId(0, 4, 0)) + private val conf = new SparkConf - // All blocks should be fetched successfully - blockIds.foreach { blockId => + test("successful 3 local reads + 2 remote reads") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure blockManager.getBlockData would return the blocks + val localBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])) + localBlocks.foreach { case (blockId, buf) => doReturn(buf).when(blockManager).getBlockData(meq(blockId.toString)) } - val bmId = BlockManagerId("test-client", "test-client", 1) + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer]) + ) + + val transfer = createMockTransfer(remoteBlocks) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, blockIds.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) + (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq), + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) ) val iterator = new ShuffleBlockFetcherIterator( new TaskContext(0, 0, 0), - mock(classOf[BlockTransferService]), + transfer, blockManager, blocksByAddress, new TestSerializer, 48 * 1024 * 1024) - // Local blocks are fetched immediately. - verify(blockManager, times(5)).getBlockData(any()) + // 3 local blocks fetched in initialization + verify(blockManager, times(3)).getBlockData(any()) for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") - assert(iterator.next()._2.isDefined, + val (blockId, subIterator) = iterator.next() + assert(subIterator.isDefined, s"iterator should have 5 elements defined but actually has $i elements") + + // Make sure we release the buffer once the iterator is exhausted. + val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) + verify(mockBuf, times(0)).release() + subIterator.get.foreach(_ => Unit) // exhaust the iterator + verify(mockBuf, times(1)).release() } - // No more fetching of local blocks. - verify(blockManager, times(5)).getBlockData(any()) + + // 3 local blocks, and 2 remote blocks + // (but from the same block manager so one call to fetchBlocks) + verify(blockManager, times(3)).getBlockData(any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any()) } - test("handle remote fetch failures in BlockTransferService") { + test("release current unexhausted buffer in case the task completes early") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]) + ) + + // Semaphore to coordinate event sequence in two different threads. + val sem = new Semaphore(0) + val transfer = mock(classOf[BlockTransferService]) when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener] - listener.onBlockFetchFailure(new Exception("blah")) + future { + // Return the first two blocks, and wait till task completion before returning the 3rd one + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0))) + sem.acquire() + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0))) + } } }) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + + val taskContext = new TaskContext(0, 0, 0) + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + new TestSerializer, + 48 * 1024 * 1024) + + // Exhaust the first block, and then it should be released. + iterator.next()._2.get.foreach(_ => Unit) + verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() + + // Get the 2nd block but do not exhaust the iterator + val subIter = iterator.next()._2.get + + // Complete the task; then the 2nd block buffer should be exhausted + verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() + taskContext.markTaskCompleted() + verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release() + + // The 3rd block should not be retained because the iterator is already in zombie state + sem.release() + verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).retain() + verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release() + } + + test("fail all blocks if any of the remote request fails") { val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val blocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), + ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]) + ) - when(blockManager.blockManagerId).thenReturn(BlockManagerId("test-client", "test-client", 1)) + // Semaphore to coordinate event sequence in two different threads. + val sem = new Semaphore(0) + + val transfer = mock(classOf[BlockTransferService]) + when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener] + future { + // Return the first block, and then fail. + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchFailure(new BlockNotFoundException("blah")) + sem.release() + } + } + }) - val blId1 = ShuffleBlockId(0, 0, 0) - val blId2 = ShuffleBlockId(0, 1, 0) - val bmId = BlockManagerId("test-server", "test-server", 1) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (bmId, Seq((blId1, 1L), (blId2, 1L)))) + (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + val taskContext = new TaskContext(0, 0, 0) val iterator = new ShuffleBlockFetcherIterator( - new TaskContext(0, 0, 0), + taskContext, transfer, blockManager, blocksByAddress, new TestSerializer, 48 * 1024 * 1024) - iterator.foreach { case (_, iterOption) => - assert(!iterOption.isDefined) - } + // Continue only after the mock calls onBlockFetchFailure + sem.acquire() + + // The first block should be defined, and the last two are not defined (due to failure) + assert(iterator.next()._2.isDefined === true) + assert(iterator.next()._2.isDefined === false) + assert(iterator.next()._2.isDefined === false) } } From d135fa38801467b0dd870063c00103ddd45438c7 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 10 Sep 2014 19:55:54 -0700 Subject: [PATCH 09/28] Fixed style violation. --- .../main/scala/org/apache/spark/network/netty/BlockServer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index ceae31efac93..d95ab8dd8496 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -113,7 +113,7 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress] _port = addr.getPort - //_hostName = addr.getHostName + // _hostName = addr.getHostName _hostName = Utils.localHostName() } From 1e0d2770e6cfb24b82d8e0f64f0480ff80dfd0d6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 10 Sep 2014 21:04:56 -0700 Subject: [PATCH 10/28] Fixed BlockClientHandlerSuite --- .../spark/network/netty/BlockClientHandlerSuite.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala index f2ed404ed8d4..7ed3dc915bb7 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala @@ -31,8 +31,9 @@ import org.apache.spark.network._ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { private def sizeOfOutstandingRequests(handler: BlockClientHandler): Int = { - val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) - handler.invokePrivate(outstandingRequests()).size + val f = handler.getClass.getDeclaredField("outstandingRequests") + f.setAccessible(true) + f.get(handler).asInstanceOf[java.util.Map[_, _]].size } test("handling block data (successful fetch)") { @@ -56,8 +57,7 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { } ) - val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests) - assert(handler.invokePrivate(outstandingRequests()).size === 1) + assert(sizeOfOutstandingRequests(handler) === 1) val channel = new EmbeddedChannel(handler) val buf = ByteBuffer.allocate(blockData.size) // 4 bytes for the length field itself @@ -68,7 +68,7 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { assert(parsedBlockId === blockId) assert(parsedBlockData === blockData) - assert(handler.invokePrivate(outstandingRequests()).size === 0) + assert(sizeOfOutstandingRequests(handler) === 0) assert(channel.finish() === false) } From 55266d1f7f9e2bf95b310a1d4d603c0df9e7b996 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 11 Sep 2014 18:28:45 -0700 Subject: [PATCH 11/28] Incorporated feedback from Norman: - use same pool for boss and worker - remove ioratio - disable caching of byte buf allocator - childoption sendbuf/receivebuf - fire exception through pipeline In addition: - fire failure handler BlockFetchingListener at least once per block. - enabled a bunch of ignored tests --- .../spark/network/BlockFetchingListener.scala | 4 +- .../spark/network/BlockTransferService.scala | 2 +- .../spark/network/netty/BlockClient.scala | 2 +- .../network/netty/BlockClientFactory.scala | 30 +++++- .../network/netty/BlockClientHandler.scala | 47 +++++++--- .../spark/network/netty/BlockServer.scala | 21 ++--- .../network/nio/NioBlockTransferService.scala | 12 ++- .../storage/ShuffleBlockFetcherIterator.scala | 9 +- .../netty/BlockClientHandlerSuite.scala | 94 ++++++++----------- .../netty/ServerClientIntegrationSuite.scala | 7 +- .../ShuffleBlockFetcherIteratorSuite.scala | 5 +- 11 files changed, 129 insertions(+), 104 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala index 34acaa563ca5..83fe497ad744 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala @@ -31,7 +31,7 @@ trait BlockFetchingListener extends EventListener { def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit /** - * Called upon failures. For each failure, this is called only once (i.e. not once per block). + * Called at least once per block upon failures. */ - def onBlockFetchFailure(exception: Throwable): Unit + def onBlockFetchFailure(blockId: String, exception: Throwable): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 84d991fa6808..4833b8a6abf3 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -83,7 +83,7 @@ abstract class BlockTransferService { val lock = new Object @volatile var result: Either[ManagedBuffer, Throwable] = null fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener { - override def onBlockFetchFailure(exception: Throwable): Unit = { + override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { lock.synchronized { result = Right(exception) lock.notify() diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala index 9333fefa9295..6f67187adcb3 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala @@ -77,8 +77,8 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Loggin logError(errorMsg, future.cause) blockIds.foreach { blockId => handler.removeRequest(blockId) + listener.onBlockFetchFailure(blockId, new RuntimeException(errorMsg)) } - listener.onBlockFetchFailure(new RuntimeException(errorMsg)) } } }) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala index f05f1419ded1..1414d0966e3d 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala @@ -28,6 +28,7 @@ import io.netty.channel.oio.OioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioSocketChannel import io.netty.channel.socket.oio.OioSocketChannel +import io.netty.util.internal.PlatformDependent import org.apache.spark.SparkConf import org.apache.spark.util.Utils @@ -92,13 +93,14 @@ class BlockClientFactory(val conf: NettyConfig) { val bootstrap = new Bootstrap bootstrap.group(workerGroup) .channel(socketChannelClass) - // Use pooled buffers to reduce temporary buffer allocation - .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) // Disable Nagle's Algorithm since we don't want packets to wait .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectTimeoutMs) + // Use pooled buffers to reduce temporary buffer allocation + bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator()) + bootstrap.handler(new ChannelInitializer[SocketChannel] { override def initChannel(ch: SocketChannel): Unit = { ch.pipeline @@ -124,4 +126,28 @@ class BlockClientFactory(val conf: NettyConfig) { workerGroup.shutdownGracefully() } } + + /** + * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches + * are disabled because the ByteBufs are allocated by the event loop thread, but released by the + * executor thread rather than the event loop thread. Those thread-local caches actually delay + * the recycling of buffers, leading to larger memory usage. + */ + private def createPooledByteBufAllocator(): PooledByteBufAllocator = { + def getPrivateStaticField(name: String): Int = { + val f = PooledByteBufAllocator.DEFAULT.getClass.getDeclaredField(name) + f.setAccessible(true) + f.getInt(null) + } + new PooledByteBufAllocator( + PlatformDependent.directBufferPreferred(), + getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), + getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), + getPrivateStaticField("DEFAULT_PAGE_SIZE"), + getPrivateStaticField("DEFAULT_MAX_ORDER"), + 0, // tinyCacheSize + 0, // smallCacheSize + 0 // normalCacheSize + ) + } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala index 2a474cd71eab..1a74c6649f28 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala @@ -17,6 +17,8 @@ package org.apache.spark.network.netty +import java.util.concurrent.ConcurrentHashMap + import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.Logging @@ -33,9 +35,8 @@ private[netty] class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] with Logging { /** Tracks the list of outstanding requests and their listeners on success/failure. */ - private[this] val outstandingRequests = java.util.Collections.synchronizedMap { - new java.util.HashMap[String, BlockFetchingListener] - } + private[this] val outstandingRequests: java.util.Map[String, BlockFetchingListener] = + new ConcurrentHashMap[String, BlockFetchingListener] def addRequest(blockId: String, listener: BlockFetchingListener): Unit = { outstandingRequests.put(blockId, listener) @@ -45,20 +46,36 @@ class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] wit outstandingRequests.remove(blockId) } - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}" - logError(errorMsg, cause) + /** + * Fire the failure callback for all outstanding requests. This is called when we have an + * uncaught exception or pre-mature connection termination. + */ + private def failOutstandingRequests(cause: Throwable): Unit = { + val iter = outstandingRequests.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + entry.getValue.onBlockFetchFailure(entry.getKey, cause) + } + // TODO(rxin): Maybe we need to synchronize the access? Otherwise we could clear new requests + // as well. But I guess that is ok given the caller will fail as soon as any requests fail. + outstandingRequests.clear() + } - // Fire the failure callback for all outstanding blocks - outstandingRequests.synchronized { - val iter = outstandingRequests.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - entry.getValue.onBlockFetchFailure(cause) - } - outstandingRequests.clear() + override def channelUnregistered(ctx: ChannelHandlerContext): Unit = { + if (outstandingRequests.size() > 0) { + logError("Still have " + outstandingRequests.size() + " requests outstanding " + + s"when connection from ${ctx.channel.remoteAddress} is closed") + failOutstandingRequests(new RuntimeException( + s"Connection from ${ctx.channel.remoteAddress} closed")) } + } + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + if (outstandingRequests.size() > 0) { + logError( + s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}", cause) + failOutstandingRequests(cause) + } ctx.close() } @@ -80,7 +97,7 @@ class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] wit s"Got a response for block $blockId from $server ($errorMsg) but it is not outstanding") } else { outstandingRequests.remove(blockId) - listener.onBlockFetchFailure(new RuntimeException(errorMsg)) + listener.onBlockFetchFailure(blockId, new RuntimeException(errorMsg)) } } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index d95ab8dd8496..bd28d48c1a5e 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -54,25 +54,22 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log /** Initialize the server. */ private def init(): Unit = { bootstrap = new ServerBootstrap - val bossThreadFactory = Utils.namedThreadFactory("spark-netty-server-boss") - val workerThreadFactory = Utils.namedThreadFactory("spark-netty-server-worker") + val threadFactory = Utils.namedThreadFactory("spark-netty-server") // Use only one thread to accept connections, and 2 * num_cores for worker. def initNio(): Unit = { - val bossGroup = new NioEventLoopGroup(1, bossThreadFactory) - val workerGroup = new NioEventLoopGroup(0, workerThreadFactory) - workerGroup.setIoRatio(conf.ioRatio) + val bossGroup = new NioEventLoopGroup(0, threadFactory) + val workerGroup = bossGroup bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel]) } def initOio(): Unit = { - val bossGroup = new OioEventLoopGroup(1, bossThreadFactory) - val workerGroup = new OioEventLoopGroup(0, workerThreadFactory) + val bossGroup = new OioEventLoopGroup(0, threadFactory) + val workerGroup = bossGroup bootstrap.group(bossGroup, workerGroup).channel(classOf[OioServerSocketChannel]) } def initEpoll(): Unit = { - val bossGroup = new EpollEventLoopGroup(1, bossThreadFactory) - val workerGroup = new EpollEventLoopGroup(0, workerThreadFactory) - workerGroup.setIoRatio(conf.ioRatio) + val bossGroup = new EpollEventLoopGroup(0, threadFactory) + val workerGroup = bossGroup bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel]) } @@ -92,10 +89,10 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog) } conf.receiveBuf.foreach { receiveBuf => - bootstrap.option[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf) + bootstrap.childOption[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf) } conf.sendBuf.foreach { sendBuf => - bootstrap.option[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf) + bootstrap.childOption[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf) } bootstrap.childHandler(new ChannelInitializer[SocketChannel] { diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index 6f73d3fa930c..e7eac75a9b4e 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -96,10 +96,12 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa val bufferMessage = message.asInstanceOf[BufferMessage] val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - for (blockMessage <- blockMessageArray) { + for (blockMessage: BlockMessage <- blockMessageArray) { if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - listener.onBlockFetchFailure( - new SparkException(s"Unexpected message ${blockMessage.getType} received from $cmId")) + if (blockMessage.getId != null) { + listener.onBlockFetchFailure(blockMessage.getId.toString, + new SparkException(s"Unexpected message ${blockMessage.getType} received from $cmId")) + } } else { val blockId = blockMessage.getId val networkSize = blockMessage.getData.limit() @@ -110,7 +112,9 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa }(cm.futureExecContext) future.onFailure { case exception => - listener.onBlockFetchFailure(exception) + blockIds.foreach { blockId => + listener.onBlockFetchFailure(blockId, exception) + } }(cm.futureExecContext) } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 1d3e650cb059..c139ac206161 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -158,14 +158,9 @@ final class ShuffleBlockFetcherIterator( logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } - override def onBlockFetchFailure(e: Throwable): Unit = { + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { logError("Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - // Note that there is a chance that some blocks have been fetched successfully, but we - // still add them to the failed queue. This is fine because when the caller see a - // FetchFailedException, it is going to fail the entire task anyway. - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) - } + results.put(new FetchResult(BlockId(blockId), -1, null)) } } ) diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala index 7ed3dc915bb7..c470bff825ba 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala @@ -18,11 +18,13 @@ package org.apache.spark.network.netty import java.nio.ByteBuffer -import java.util.concurrent.atomic.AtomicInteger import io.netty.buffer.Unpooled import io.netty.channel.embedded.EmbeddedChannel +import org.mockito.Mockito._ +import org.mockito.Matchers.{any, eq => meq} + import org.scalatest.{FunSuite, PrivateMethodTester} import org.apache.spark.network._ @@ -31,7 +33,8 @@ import org.apache.spark.network._ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { private def sizeOfOutstandingRequests(handler: BlockClientHandler): Int = { - val f = handler.getClass.getDeclaredField("outstandingRequests") + val f = handler.getClass.getDeclaredField( + "org$apache$spark$network$netty$BlockClientHandler$$outstandingRequests") f.setAccessible(true) f.get(handler).asInstanceOf[java.util.Map[_, _]].size } @@ -39,24 +42,9 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { test("handling block data (successful fetch)") { val blockId = "test_block" val blockData = "blahblahblahblahblah" - - var parsedBlockId: String = "" - var parsedBlockData: String = "" val handler = new BlockClientHandler - handler.addRequest(blockId, - new BlockFetchingListener { - override def onBlockFetchFailure(exception: Throwable): Unit = { - throw new UnsupportedOperationException - } - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - parsedBlockId = blockId - val bytes = new Array[Byte](data.size.toInt) - data.nioByteBuffer().get(bytes) - parsedBlockData = new String(bytes) - } - } - ) - + val listener = mock(classOf[BlockFetchingListener]) + handler.addRequest(blockId, listener) assert(sizeOfOutstandingRequests(handler) === 1) val channel = new EmbeddedChannel(handler) @@ -65,54 +53,29 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { buf.flip() channel.writeInbound(BlockFetchSuccess(blockId, new NioManagedBuffer(buf))) - - assert(parsedBlockId === blockId) - assert(parsedBlockData === blockData) + verify(listener, times(1)).onBlockFetchSuccess(meq(blockId), any()) assert(sizeOfOutstandingRequests(handler) === 0) assert(channel.finish() === false) } test("handling error message (failed fetch)") { val blockId = "test_block" - val errorMsg = "error erro5r error err4or error3 error6 error erro1r" - - var parsedErrorMsg: String = "" val handler = new BlockClientHandler - handler.addRequest(blockId, - new BlockFetchingListener { - override def onBlockFetchFailure(exception: Throwable): Unit = { - parsedErrorMsg = exception.getMessage - } - - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - throw new UnsupportedOperationException - } - } - ) - + val listener = mock(classOf[BlockFetchingListener]) + handler.addRequest(blockId, listener) assert(sizeOfOutstandingRequests(handler) === 1) val channel = new EmbeddedChannel(handler) - channel.writeInbound(BlockFetchFailure(blockId, errorMsg)) - assert(parsedErrorMsg === errorMsg) + channel.writeInbound(BlockFetchFailure(blockId, "some error msg")) + verify(listener, times(0)).onBlockFetchSuccess(any(), any()) + verify(listener, times(1)).onBlockFetchFailure(meq(blockId), any()) assert(sizeOfOutstandingRequests(handler) === 0) assert(channel.finish() === false) } - ignore("clear all outstanding request upon connection close") { - val errorCount = new AtomicInteger(0) - val successCount = new AtomicInteger(0) + test("clear all outstanding request upon uncaught exception") { val handler = new BlockClientHandler - - val listener = new BlockFetchingListener { - override def onBlockFetchFailure(exception: Throwable): Unit = { - errorCount.incrementAndGet() - } - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - successCount.incrementAndGet() - } - } - + val listener = mock(classOf[BlockFetchingListener]) handler.addRequest("b1", listener) handler.addRequest("b2", listener) handler.addRequest("b3", listener) @@ -120,9 +83,30 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { val channel = new EmbeddedChannel(handler) channel.writeInbound(BlockFetchSuccess("b1", new NettyManagedBuffer(Unpooled.buffer()))) - // Need to figure out a way to generate an exception - assert(successCount.get() === 1) - assert(errorCount.get() === 2) + channel.pipeline().fireExceptionCaught(new Exception("duh duh duh")) + + // should fail both b2 and b3 + verify(listener, times(1)).onBlockFetchSuccess(any(), any()) + verify(listener, times(2)).onBlockFetchFailure(any(), any()) + assert(sizeOfOutstandingRequests(handler) === 0) + assert(channel.finish() === false) + } + + test("clear all outstanding request upon connection close") { + val handler = new BlockClientHandler + val listener = mock(classOf[BlockFetchingListener]) + handler.addRequest("c1", listener) + handler.addRequest("c2", listener) + handler.addRequest("c3", listener) + assert(sizeOfOutstandingRequests(handler) === 3) + + val channel = new EmbeddedChannel(handler) + channel.writeInbound(BlockFetchSuccess("c1", new NettyManagedBuffer(Unpooled.buffer()))) + channel.finish() + + // should fail both b2 and b3 + verify(listener, times(1)).onBlockFetchSuccess(any(), any()) + verify(listener, times(2)).onBlockFetchFailure(any(), any()) assert(sizeOfOutstandingRequests(handler) === 0) assert(channel.finish() === false) } diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index fa3512768d9a..e3f98ff173ad 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -103,7 +103,8 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { client.fetchBlocks( blockIds, new BlockFetchingListener { - override def onBlockFetchFailure(exception: Throwable): Unit = { + override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { + errorBlockIds.add(blockId) sem.release() } @@ -135,7 +136,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { assert(failBlockIds.isEmpty) } - ignore("fetch a non-existent block") { + test("fetch a non-existent block") { val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block")) assert(blockIds.isEmpty) assert(buffers.isEmpty) @@ -149,7 +150,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { assert(failBlockIds.isEmpty) } - ignore("fetch both ByteBuffer block and a non-existent block") { + test("fetch both ByteBuffer block and a non-existent block") { val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block")) assert(blockIds === Set(bufferBlockId)) assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index b4700f38a678..5a36614d1f59 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -50,7 +50,7 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { if (data.contains(BlockId(blockId))) { listener.onBlockFetchSuccess(blockId, data(BlockId(blockId))) } else { - listener.onBlockFetchFailure(new BlockNotFoundException(blockId)) + listener.onBlockFetchFailure(blockId, new BlockNotFoundException(blockId)) } } } @@ -205,7 +205,8 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { // Return the first block, and then fail. listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) - listener.onBlockFetchFailure(new BlockNotFoundException("blah")) + listener.onBlockFetchFailure( + ShuffleBlockId(0, 1, 0).toString, new BlockNotFoundException("blah")) sem.release() } } From f83611e9cc8cfffe60df218cc788ab680dca18d5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 11 Sep 2014 22:12:08 -0700 Subject: [PATCH 12/28] Added connection pooling. --- .../spark/network/netty/BlockClient.scala | 11 +-- .../network/netty/BlockClientFactory.scala | 42 +++++++-- .../netty/BlockClientFactorySuite.scala | 91 +++++++++++++++++++ .../netty/BlockClientHandlerSuite.scala | 1 + .../netty/ServerClientIntegrationSuite.scala | 9 ++ 5 files changed, 140 insertions(+), 14 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala index 6f67187adcb3..2768f98e9c1f 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala @@ -43,6 +43,8 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Loggin private[this] val serverAddr = cf.channel().remoteAddress().toString + def isActive: Boolean = cf.channel().isActive + /** * Ask the remote server for a sequence of blocks, and execute the callback. * @@ -55,7 +57,7 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Loggin def fetchBlocks(blockIds: Seq[String], listener: BlockFetchingListener): Unit = { var startTime: Long = 0 logTrace { - startTime = System.nanoTime + startTime = System.nanoTime() s"Sending request $blockIds to $serverAddr" } @@ -67,7 +69,7 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Loggin override def operationComplete(future: ChannelFuture): Unit = { if (future.isSuccess) { logTrace { - val timeTaken = (System.nanoTime - startTime).toDouble / 1000000 + val timeTaken = (System.nanoTime() - startTime).toDouble / 1000000 s"Sending request $blockIds to $serverAddr took $timeTaken ms" } } else { @@ -84,9 +86,6 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Loggin }) } - def waitForClose(): Unit = { - cf.channel().closeFuture().sync() - } - + /** Close the connection. This does NOT block till the connection is closed. */ def close(): Unit = cf.channel().close() } diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala index 1414d0966e3d..01fc73fe728a 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala @@ -17,7 +17,7 @@ package org.apache.spark.network.netty -import java.util.concurrent.TimeoutException +import java.util.concurrent.{ConcurrentHashMap, TimeoutException} import io.netty.bootstrap.Bootstrap import io.netty.buffer.PooledByteBufAllocator @@ -35,8 +35,10 @@ import org.apache.spark.util.Utils /** - * Factory for creating [[BlockClient]] by using createClient. This factory reuses - * the worker thread pool for Netty. + * Factory for creating [[BlockClient]] by using createClient. + * + * The factory maintains a connection pool to other hosts and should return the same [[BlockClient]] + * for the same remote host. It also shares a single worker thread pool for all [[BlockClient]]s. */ private[netty] class BlockClientFactory(val conf: NettyConfig) { @@ -44,11 +46,15 @@ class BlockClientFactory(val conf: NettyConfig) { def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) /** A thread factory so the threads are named (for debugging). */ - private[netty] val threadFactory = Utils.namedThreadFactory("spark-netty-client") + private[this] val threadFactory = Utils.namedThreadFactory("spark-netty-client") + + /** Socket channel type, initialized by [[init]] depending ioMode. */ + private[this] var socketChannelClass: Class[_ <: Channel] = _ - /** The following two are instantiated by the [[init]] method, depending ioMode. */ - private[netty] var socketChannelClass: Class[_ <: Channel] = _ - private[netty] var workerGroup: EventLoopGroup = _ + /** Thread pool shared by all clients. */ + private[this] var workerGroup: EventLoopGroup = _ + + private[this] val connectionPool = new ConcurrentHashMap[(String, Int), BlockClient] // The encoders are stateless and can be shared among multiple clients. private[this] val encoder = new ClientRequestEncoder @@ -88,6 +94,16 @@ class BlockClientFactory(val conf: NettyConfig) { * Concurrency: This method is safe to call from multiple threads. */ def createClient(remoteHost: String, remotePort: Int): BlockClient = { + // Get connection from the connection pool first. + // If it is not found or not active, create a new one. + val cachedClient = connectionPool.get((remoteHost, remotePort)) + if (cachedClient != null && cachedClient.isActive) { + return cachedClient + } + + // There is a chance two threads are creating two different clients connecting to the same host. + // But that's probably ok ... + val handler = new BlockClientHandler val bootstrap = new Bootstrap @@ -118,10 +134,20 @@ class BlockClientFactory(val conf: NettyConfig) { s"Connecting to $remoteHost:$remotePort timed out (${conf.connectTimeoutMs} ms)") } - new BlockClient(cf, handler) + val client = new BlockClient(cf, handler) + connectionPool.put((remoteHost, remotePort), client) + client } + /** Close all connections in the connection pool, and shutdown the worker thread pool. */ def stop(): Unit = { + val iter = connectionPool.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + entry.getValue.close() + connectionPool.remove(entry.getKey) + } + if (workerGroup != null) { workerGroup.shutdownGracefully() } diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala new file mode 100644 index 000000000000..b2dcebfc8cee --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import scala.concurrent.{Await, future} +import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext.Implicits.global + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.SparkConf + + +class BlockClientFactorySuite extends FunSuite with BeforeAndAfterAll { + + private val conf = new SparkConf + private var server1: BlockServer = _ + private var server2: BlockServer = _ + + override def beforeAll() { + server1 = new BlockServer(new NettyConfig(conf), null) + server2 = new BlockServer(new NettyConfig(conf), null) + } + + override def afterAll() { + if (server1 != null) { + server1.stop() + } + if (server2 != null) { + server2.stop() + } + } + + test("BlockClients created are active and reused") { + val factory = new BlockClientFactory(conf) + val c1 = factory.createClient(server1.hostName, server1.port) + val c2 = factory.createClient(server1.hostName, server1.port) + val c3 = factory.createClient(server2.hostName, server2.port) + assert(c1.isActive) + assert(c3.isActive) + assert(c1 === c2) + assert(c1 !== c3) + factory.stop() + } + + test("never return inactive clients") { + val factory = new BlockClientFactory(conf) + val c1 = factory.createClient(server1.hostName, server1.port) + c1.close() + + // Block until c1 is no longer active + val f = future { + while (c1.isActive) { + Thread.sleep(10) + } + } + Await.result(f, 3 seconds) + assert(!c1.isActive) + + // Create c2, which should be different from c1 + val c2 = factory.createClient(server1.hostName, server1.port) + assert(c1 !== c2) + factory.stop() + } + + test("BlockClients are close when BlockClientFactory is stopped") { + val factory = new BlockClientFactory(conf) + val c1 = factory.createClient(server1.hostName, server1.port) + val c2 = factory.createClient(server2.hostName, server2.port) + assert(c1.isActive) + assert(c2.isActive) + factory.stop() + assert(!c1.isActive) + assert(!c2.isActive) + } +} diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala index c470bff825ba..7b80fe6aa364 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.network._ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { + /** Helper method to get num. outstanding requests from a private field using reflection. */ private def sizeOfOutstandingRequests(handler: BlockClientHandler): Int = { val f = handler.getClass.getDeclaredField( "org$apache$spark$network$netty$BlockClientHandler$$outstandingRequests") diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index e3f98ff173ad..789df1f70dcd 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -27,6 +27,9 @@ import scala.collection.JavaConversions._ import io.netty.buffer.Unpooled import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.Span +import org.scalatest.time.Seconds import org.apache.spark.SparkConf import org.apache.spark.network._ @@ -156,4 +159,10 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) assert(failBlockIds === Set("random-block")) } + + test("shutting down server should also close client") { + val client = clientFactory.createClient(server.hostName, server.port) + server.stop() + eventually(timeout(Span(5, Seconds))) { assert(!client.isActive) } + } } From 6ddaa5d6fc893a759be81347a401296eef8c566c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 11 Sep 2014 22:13:02 -0700 Subject: [PATCH 13/28] Removed BlockManager.getLocalShuffleFromDisk. --- .../scala/org/apache/spark/storage/BlockManager.scala | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index caaf5cfccd06..e2a3576bb1eb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -333,17 +333,6 @@ private[spark] class BlockManager( locations } - /** - * A short-circuited method to get blocks directly from disk. This is used for getting - * shuffle blocks. It is safe to do so without a lock on block info since disk store - * never deletes (recent) items. - */ - def getLocalShuffleFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { - val buf = shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) - val is = wrapForCompression(blockId, buf.inputStream()) - Some(serializer.newInstance().deserializeStream(is).asIterator) - } - /** * Get block from local block manager. */ From 8295561a9befcfa3c2a56d8836e05a42ce4ab6b0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 12 Sep 2014 00:40:53 -0700 Subject: [PATCH 14/28] Fixed test hanging. --- .../apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 5a36614d1f59..7d4086313fcc 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -207,6 +207,8 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite { ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchFailure( ShuffleBlockId(0, 1, 0).toString, new BlockNotFoundException("blah")) + listener.onBlockFetchFailure( + ShuffleBlockId(0, 2, 0).toString, new BlockNotFoundException("blah")) sem.release() } } From d7d0aac3b5c67c875da1af655008472253ad51e3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 12 Sep 2014 14:18:58 -0700 Subject: [PATCH 15/28] Mark private package visibility and MimaExcludes. --- .../org/apache/spark/network/BlockDataManager.scala | 1 + .../apache/spark/network/BlockFetchingListener.scala | 1 + .../apache/spark/network/BlockTransferService.scala | 1 + .../scala/org/apache/spark/network/ManagedBuffer.scala | 4 ++++ .../network/netty/NettyBlockTransferService.scala | 1 + .../org/apache/spark/network/netty/protocol.scala | 10 ++++++++++ project/MimaExcludes.scala | 6 +++++- 7 files changed, 23 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 638e05f481f5..0eeffe0e7c5e 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.network import org.apache.spark.storage.StorageLevel +private[spark] trait BlockDataManager { /** diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala index 83fe497ad744..dd70e2664793 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala @@ -23,6 +23,7 @@ import java.util.EventListener /** * Listener callback interface for [[BlockTransferService.fetchBlocks]]. */ +private[spark] trait BlockFetchingListener extends EventListener { /** diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 4833b8a6abf3..d894eac374b7 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -23,6 +23,7 @@ import scala.concurrent.duration.Duration import org.apache.spark.storage.StorageLevel +private[spark] abstract class BlockTransferService { /** diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index 9e1c83197eab..aacb1b246a5b 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -41,6 +41,7 @@ import org.apache.spark.util.ByteBufferInputStream * In that case, if the buffer is going to be passed around to a different thread, retain/release * should be called. */ +private[spark] abstract class ManagedBuffer { // Note that all the methods are defined with parenthesis because their implementations can // have side effects (io operations). @@ -82,6 +83,7 @@ abstract class ManagedBuffer { /** * A [[ManagedBuffer]] backed by a segment in a file */ +private[spark] final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long) extends ManagedBuffer { @@ -112,6 +114,7 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt /** * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]]. */ +private[spark] final class NioManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { override def size: Long = buf.remaining() @@ -131,6 +134,7 @@ final class NioManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { /** * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]]. */ +private[spark] final class NettyManagedBuffer(buf: ByteBuf) extends ManagedBuffer { override def size: Long = buf.readableBytes() diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index fa8bdfc96e8b..30dc812c4e7d 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -29,6 +29,7 @@ import org.apache.spark.storage.StorageLevel * * See protocol.scala for the communication protocol between server and client */ +private[spark] final class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService { private[this] val nettyConf: NettyConfig = new NettyConfig(conf) diff --git a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala index ac9d2097c93e..6a14ad26dbf2 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala @@ -29,6 +29,7 @@ import org.apache.spark.network.{NettyManagedBuffer, ManagedBuffer} /** Messages from the client to the server. */ +private[netty] sealed trait ClientRequest { def id: Byte } @@ -37,6 +38,7 @@ sealed trait ClientRequest { * Request to fetch a sequence of blocks from the server. A single [[BlockFetchRequest]] can * correspond to multiple [[ServerResponse]]s. */ +private[netty] final case class BlockFetchRequest(blocks: Seq[String]) extends ClientRequest { override def id = 0 } @@ -44,6 +46,7 @@ final case class BlockFetchRequest(blocks: Seq[String]) extends ClientRequest { /** * Request to upload a block to the server. Currently the server does not ack the upload request. */ +private[netty] final case class BlockUploadRequest(blockId: String, data: ManagedBuffer) extends ClientRequest { require(blockId.length <= Byte.MaxValue) override def id = 1 @@ -51,17 +54,20 @@ final case class BlockUploadRequest(blockId: String, data: ManagedBuffer) extend /** Messages from server to client (usually in response to some [[ClientRequest]]. */ +private[netty] sealed trait ServerResponse { def id: Byte } /** Response to [[BlockFetchRequest]] when a block exists and has been successfully fetched. */ +private[netty] final case class BlockFetchSuccess(blockId: String, data: ManagedBuffer) extends ServerResponse { require(blockId.length <= Byte.MaxValue) override def id = 0 } /** Response to [[BlockFetchRequest]] when there is an error fetching the block. */ +private[netty] final case class BlockFetchFailure(blockId: String, error: String) extends ServerResponse { require(blockId.length <= Byte.MaxValue) override def id = 1 @@ -74,6 +80,7 @@ final case class BlockFetchFailure(blockId: String, error: String) extends Serve * This encoder is stateless so it is safe to be shared by multiple threads. */ @Sharable +private[netty] final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] { override def encode(ctx: ChannelHandlerContext, in: ClientRequest, out: JList[Object]): Unit = { in match { @@ -128,6 +135,7 @@ final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] * [[ProtocolUtils.createFrameDecoder()]]. */ @Sharable +private[netty] final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] { override protected def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: JList[AnyRef]): Unit = { @@ -155,6 +163,7 @@ final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] { * This encoder is stateless so it is safe to be shared by multiple threads. */ @Sharable +private[netty] final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse] with Logging { override def encode(ctx: ChannelHandlerContext, in: ServerResponse, out: JList[Object]): Unit = { in match { @@ -211,6 +220,7 @@ final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse * [[ProtocolUtils.createFrameDecoder()]]. */ @Sharable +private[netty] final class ServerResponseDecoder extends MessageToMessageDecoder[ByteBuf] { override def decode(ctx: ChannelHandlerContext, in: ByteBuf, out: JList[AnyRef]): Unit = { val msgId = in.readByte() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 46b78bd5c706..e4bd64ec05f5 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -43,7 +43,11 @@ object MimaExcludes { Seq( // This is @Experimental, but Mima still gives false-positives: ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.api.java.JavaRDDLike.foreachAsync") + "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.network.netty.PathResolver"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.network.netty.client.BlockClientListener") ) case v if v.startsWith("1.1") => Seq( From 29fe0cc9ea6ca1285d054b335d989d1131aa4b69 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 12 Sep 2014 22:42:32 -0700 Subject: [PATCH 16/28] Implement java.io.Closeable interface. --- .../apache/spark/network/BlockTransferService.scala | 6 ++++-- .../org/apache/spark/network/netty/BlockClient.scala | 3 ++- .../spark/network/netty/BlockClientFactory.scala | 5 +++-- .../org/apache/spark/network/netty/BlockServer.scala | 8 +++++--- .../network/netty/NettyBlockTransferService.scala | 6 +++--- .../spark/network/nio/NioBlockTransferService.scala | 2 +- .../scala/org/apache/spark/storage/BlockManager.scala | 2 +- .../spark/network/netty/BlockClientFactorySuite.scala | 10 +++++----- .../network/netty/ServerClientIntegrationSuite.scala | 6 +++--- 9 files changed, 27 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index d894eac374b7..a8379a207a18 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -17,6 +17,8 @@ package org.apache.spark.network +import java.io.Closeable + import scala.concurrent.{Await, Future} import scala.concurrent.duration.Duration @@ -24,7 +26,7 @@ import org.apache.spark.storage.StorageLevel private[spark] -abstract class BlockTransferService { +abstract class BlockTransferService extends Closeable { /** * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch @@ -35,7 +37,7 @@ abstract class BlockTransferService { /** * Tear down the transfer service. */ - def stop(): Unit + def close(): Unit /** * Port number the service is listening on, available only after [[init]] is invoked. diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala index 2768f98e9c1f..fb50b1547429 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala @@ -17,6 +17,7 @@ package org.apache.spark.network.netty +import java.io.Closeable import java.util.concurrent.TimeoutException import io.netty.channel.{ChannelFuture, ChannelFutureListener} @@ -39,7 +40,7 @@ import org.apache.spark.network.BlockFetchingListener */ @throws[TimeoutException] private[netty] -class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Logging { +class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Closeable with Logging { private[this] val serverAddr = cf.channel().remoteAddress().toString diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala index 01fc73fe728a..e264f91142ec 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala @@ -17,6 +17,7 @@ package org.apache.spark.network.netty +import java.io.Closeable import java.util.concurrent.{ConcurrentHashMap, TimeoutException} import io.netty.bootstrap.Bootstrap @@ -41,7 +42,7 @@ import org.apache.spark.util.Utils * for the same remote host. It also shares a single worker thread pool for all [[BlockClient]]s. */ private[netty] -class BlockClientFactory(val conf: NettyConfig) { +class BlockClientFactory(val conf: NettyConfig) extends Closeable { def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) @@ -140,7 +141,7 @@ class BlockClientFactory(val conf: NettyConfig) { } /** Close all connections in the connection pool, and shutdown the worker thread pool. */ - def stop(): Unit = { + override def close(): Unit = { val iter = connectionPool.entrySet().iterator() while (iter.hasNext) { val entry = iter.next() diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index bd28d48c1a5e..9a8ffabd04c8 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -17,6 +17,7 @@ package org.apache.spark.network.netty +import java.io.Closeable import java.net.InetSocketAddress import io.netty.bootstrap.ServerBootstrap @@ -29,7 +30,7 @@ import io.netty.channel.socket.nio.NioServerSocketChannel import io.netty.channel.socket.oio.OioServerSocketChannel import io.netty.channel.{ChannelInitializer, ChannelFuture, ChannelOption} -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.Logging import org.apache.spark.network.BlockDataManager import org.apache.spark.util.Utils @@ -38,7 +39,8 @@ import org.apache.spark.util.Utils * Server for the [[NettyBlockTransferService]]. */ private[netty] -class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Logging { +class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) + extends Closeable with Logging { def port: Int = _port @@ -115,7 +117,7 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log } /** Shutdown the server. */ - def stop(): Unit = { + def close(): Unit = { if (channelFuture != null) { channelFuture.channel().close().awaitUninterruptibly() channelFuture = null diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 30dc812c4e7d..14df5161cb0f 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -42,12 +42,12 @@ final class NettyBlockTransferService(conf: SparkConf) extends BlockTransferServ clientFactory = new BlockClientFactory(nettyConf) } - override def stop(): Unit = { + override def close(): Unit = { if (server != null) { - server.stop() + server.close() } if (clientFactory != null) { - clientFactory.stop() + clientFactory.close() } } diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index e7eac75a9b4e..5a3ca5ad902a 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -71,7 +71,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa /** * Tear down the transfer service. */ - override def stop(): Unit = { + override def close(): Unit = { if (cm != null) { cm.stop() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index e2a3576bb1eb..abef0e171a5c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1027,7 +1027,7 @@ private[spark] class BlockManager( } def stop(): Unit = { - blockTransferService.stop() + blockTransferService.close() diskBlockManager.stop() actorSystem.stop(slaveActor) blockInfo.clear() diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala index b2dcebfc8cee..5075688b1b27 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala @@ -39,10 +39,10 @@ class BlockClientFactorySuite extends FunSuite with BeforeAndAfterAll { override def afterAll() { if (server1 != null) { - server1.stop() + server1.close() } if (server2 != null) { - server2.stop() + server2.close() } } @@ -55,7 +55,7 @@ class BlockClientFactorySuite extends FunSuite with BeforeAndAfterAll { assert(c3.isActive) assert(c1 === c2) assert(c1 !== c3) - factory.stop() + factory.close() } test("never return inactive clients") { @@ -75,7 +75,7 @@ class BlockClientFactorySuite extends FunSuite with BeforeAndAfterAll { // Create c2, which should be different from c1 val c2 = factory.createClient(server1.hostName, server1.port) assert(c1 !== c2) - factory.stop() + factory.close() } test("BlockClients are close when BlockClientFactory is stopped") { @@ -84,7 +84,7 @@ class BlockClientFactorySuite extends FunSuite with BeforeAndAfterAll { val c2 = factory.createClient(server2.hostName, server2.port) assert(c1.isActive) assert(c2.isActive) - factory.stop() + factory.close() assert(!c1.isActive) assert(!c2.isActive) } diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index 789df1f70dcd..98e896221f91 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -86,8 +86,8 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { } override def afterAll() = { - server.stop() - clientFactory.stop() + server.close() + clientFactory.close() } /** A ByteBuf for buffer_block */ @@ -162,7 +162,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { test("shutting down server should also close client") { val client = clientFactory.createClient(server.hostName, server.port) - server.stop() + server.close() eventually(timeout(Span(5, Seconds))) { assert(!client.isActive) } } } From a79a25918a96171c4b20c5c9153e5815bc23698e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 16 Sep 2014 22:51:11 -0700 Subject: [PATCH 17/28] Added logging. --- .../org/apache/spark/network/netty/BlockClientFactory.scala | 6 ++++-- .../scala/org/apache/spark/network/netty/BlockServer.scala | 2 ++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala index e264f91142ec..6278e69c2200 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala @@ -31,7 +31,7 @@ import io.netty.channel.socket.nio.NioSocketChannel import io.netty.channel.socket.oio.OioSocketChannel import io.netty.util.internal.PlatformDependent -import org.apache.spark.SparkConf +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.util.Utils @@ -42,7 +42,7 @@ import org.apache.spark.util.Utils * for the same remote host. It also shares a single worker thread pool for all [[BlockClient]]s. */ private[netty] -class BlockClientFactory(val conf: NettyConfig) extends Closeable { +class BlockClientFactory(val conf: NettyConfig) extends Logging with Closeable { def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) @@ -102,6 +102,8 @@ class BlockClientFactory(val conf: NettyConfig) extends Closeable { return cachedClient } + logInfo(s"Creating new connection to $remoteHost:$remotePort") + // There is a chance two threads are creating two different clients connecting to the same host. // But that's probably ok ... diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index 9a8ffabd04c8..2611f2eacdb3 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -114,6 +114,8 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) _port = addr.getPort // _hostName = addr.getHostName _hostName = Utils.localHostName() + + logInfo(s"Server started ${_hostName}:${_port}") } /** Shutdown the server. */ From 088ed8ac46bc59bf3d13d4a0be1d4c616a22d698 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 16 Sep 2014 23:30:01 -0700 Subject: [PATCH 18/28] Fixed error message. (cherry picked from commit eacb82832de7e3c38ebfc22e57bdcff15f445ca3) Signed-off-by: Reynold Xin --- .../org/apache/spark/storage/ShuffleBlockFetcherIterator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index c139ac206161..38486c1ded9e 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -159,7 +159,7 @@ final class ShuffleBlockFetcherIterator( } override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { - logError("Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) results.put(new FetchResult(BlockId(blockId), -1, null)) } } From 323dfec9734111c923772b2b0203766b620b0153 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 11:13:44 -0700 Subject: [PATCH 19/28] Add more debug message. --- .../apache/spark/network/ManagedBuffer.scala | 39 +++++++++++++++++-- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index e990c1da6730..37b9939f90bf 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -17,11 +17,13 @@ package org.apache.spark.network -import java.io.{FileInputStream, RandomAccessFile, File, InputStream} +import java.io._ import java.nio.ByteBuffer import java.nio.channels.FileChannel import java.nio.channels.FileChannel.MapMode +import scala.util.Try + import com.google.common.io.ByteStreams import io.netty.buffer.{ByteBufInputStream, ByteBuf} @@ -71,6 +73,14 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt try { channel = new RandomAccessFile(file, "r").getChannel channel.map(MapMode.READ_ONLY, offset, length) + } catch { + case e: IOException => + Try(channel.size).toOption match { + case Some(fileLen) => + throw new IOException(s"Error in reading $this (actual file length $fileLen)", e) + case None => + throw new IOException(s"Error in opening $this", e) + } } finally { if (channel != null) { channel.close() @@ -79,10 +89,31 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt } override def inputStream(): InputStream = { - val is = new FileInputStream(file) - is.skip(offset) - ByteStreams.limit(is, length) + var is: FileInputStream = null + try { + is = new FileInputStream(file) + is.skip(offset) + ByteStreams.limit(is, length) + } catch { + case e: IOException => + if (is != null) { + is.close() + } + Try(file.length).toOption match { + case Some(fileLen) => + throw new IOException(s"Error in reading $this (actual file length $fileLen)", e) + case None => + throw new IOException(s"Error in opening $this", e) + } + case e: Throwable => + if (is != null) { + is.close() + } + throw e + } } + + override def toString: String = s"${getClass.getName}($file, $offset, $length)" } From 5814292ec2f9f438ef2eb822504ef3d9a5100f96 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 11:30:13 -0700 Subject: [PATCH 20/28] Logging close() in case close() fails. --- .../org/apache/spark/network/ManagedBuffer.scala | 8 ++++---- .../main/scala/org/apache/spark/util/Utils.scala | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index 37b9939f90bf..a4409181ec90 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -27,7 +27,7 @@ import scala.util.Try import com.google.common.io.ByteStreams import io.netty.buffer.{ByteBufInputStream, ByteBuf} -import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.{ByteBufferInputStream, Utils} /** @@ -83,7 +83,7 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt } } finally { if (channel != null) { - channel.close() + Utils.tryLog(channel.close()) } } } @@ -97,7 +97,7 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt } catch { case e: IOException => if (is != null) { - is.close() + Utils.tryLog(is.close()) } Try(file.length).toOption match { case Some(fileLen) => @@ -107,7 +107,7 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt } case e: Throwable => if (is != null) { - is.close() + Utils.tryLog(is.close()) } throw e } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2755887feeef..10d440828e32 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1304,6 +1304,20 @@ private[spark] object Utils extends Logging { } } + /** Executes the given block in a Try, logging any uncaught exceptions. */ + def tryLog[T](f: => T): Try[T] = { + try { + val res = f + scala.util.Success(res) + } catch { + case ct: ControlThrowable => + throw ct + case t: Throwable => + logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) + scala.util.Failure(t) + } + } + /** Returns true if the given exception was fatal. See docs for scala.util.control.NonFatal. */ def isFatalError(e: Throwable): Boolean = { e match { From ba8c44138a8bedcfd2fd716fce3de59b1c03cc55 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 12:07:53 -0700 Subject: [PATCH 21/28] Fixed tests. --- .../network/netty/BlockClientHandler.scala | 2 ++ .../storage/ShuffleBlockFetcherIterator.scala | 2 +- .../spark/network/netty/ProtocolSuite.scala | 25 +++++++++++++++++-- .../netty/ServerClientIntegrationSuite.scala | 6 +++++ 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala index 1a74c6649f28..466ece99b9b9 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala @@ -86,9 +86,11 @@ class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] wit val listener = outstandingRequests.get(blockId) if (listener == null) { logWarning(s"Got a response for block $blockId from $server but it is not outstanding") + buf.release() } else { outstandingRequests.remove(blockId) listener.onBlockFetchSuccess(blockId, buf) + buf.release() } case BlockFetchFailure(blockId, errorMsg) => val listener = outstandingRequests.get(blockId) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 9de5a0f91e90..4e69a5d9e3ec 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -155,7 +155,7 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.remoteBytesRead += buf.size shuffleMetrics.remoteBlocksFetched += 1 } - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { diff --git a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala index 72034634a5bd..46604ea1fb62 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala @@ -71,18 +71,39 @@ class ProtocolSuite extends FunSuite { assert(msg === serverChannel.readInbound()) } - test("server to client protocol") { + test("server to client protocol - BlockFetchSuccess(\"a1234\", new TestManagedBuffer(10))") { testServerToClient(BlockFetchSuccess("a1234", new TestManagedBuffer(10))) + } + + test("server to client protocol - BlockFetchSuccess(\"\", new TestManagedBuffer(0))") { testServerToClient(BlockFetchSuccess("", new TestManagedBuffer(0))) + } + + test("server to client protocol - BlockFetchFailure(\"abcd\", \"this is an error\")") { testServerToClient(BlockFetchFailure("abcd", "this is an error")) + } + + test("server to client protocol - BlockFetchFailure(\"\", \"\")") { testServerToClient(BlockFetchFailure("", "")) } - test("client to server protocol") { + test("client to server protocol - BlockFetchRequest(Seq.empty[String])") { testClientToServer(BlockFetchRequest(Seq.empty[String])) + } + + test("client to server protocol - BlockFetchRequest(Seq(\"b1\"))") { testClientToServer(BlockFetchRequest(Seq("b1"))) + } + + test("client to server protocol - BlockFetchRequest(Seq(\"b1\", \"b2\", \"b3\"))") { testClientToServer(BlockFetchRequest(Seq("b1", "b2", "b3"))) + } + + ignore("client to server protocol - BlockUploadRequest(\"\", new TestManagedBuffer(0))") { testClientToServer(BlockUploadRequest("", new TestManagedBuffer(0))) + } + + ignore("client to server protocol - BlockUploadRequest(\"b_upload\", new TestManagedBuffer(10))") { testClientToServer(BlockUploadRequest("b_upload", new TestManagedBuffer(10))) } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala index 98e896221f91..35ff90a2dabc 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala @@ -112,6 +112,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { } override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + data.retain() receivedBlockIds.add(blockId) receivedBuffers.add(data) sem.release() @@ -130,6 +131,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { assert(blockIds === Set(bufferBlockId)) assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) assert(failBlockIds.isEmpty) + buffers.foreach(_.release()) } test("fetch a FileSegment block via zero-copy send") { @@ -137,6 +139,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { assert(blockIds === Set(fileBlockId)) assert(buffers.map(_.convertToNetty()) === Set(fileBlockReference)) assert(failBlockIds.isEmpty) + buffers.foreach(_.release()) } test("fetch a non-existent block") { @@ -144,6 +147,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { assert(blockIds.isEmpty) assert(buffers.isEmpty) assert(failBlockIds === Set("random-block")) + buffers.foreach(_.release()) } test("fetch both ByteBuffer block and FileSegment block") { @@ -151,6 +155,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { assert(blockIds === Set(bufferBlockId, fileBlockId)) assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference, fileBlockReference)) assert(failBlockIds.isEmpty) + buffers.foreach(_.release()) } test("fetch both ByteBuffer block and a non-existent block") { @@ -158,6 +163,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll { assert(blockIds === Set(bufferBlockId)) assert(buffers.map(_.convertToNetty()) === Set(byteBufferBlockReference)) assert(failBlockIds === Set("random-block")) + buffers.foreach(_.release()) } test("shutting down server should also close client") { From dfc2c34b8951f0b3c3256b6ed744d1d8ece78583 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 12:22:01 -0700 Subject: [PATCH 22/28] Removed OIO and added num threads settings. --- .../network/netty/BlockClientFactory.scala | 13 +++---------- .../spark/network/netty/BlockServer.scala | 12 ++---------- .../spark/network/netty/NettyConfig.scala | 18 ++++++++++-------- 3 files changed, 15 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala index 6278e69c2200..8021cfdf42d1 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala @@ -25,10 +25,8 @@ import io.netty.buffer.PooledByteBufAllocator import io.netty.channel._ import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollSocketChannel} import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.oio.OioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioSocketChannel -import io.netty.channel.socket.oio.OioSocketChannel import io.netty.util.internal.PlatformDependent import org.apache.spark.{Logging, SparkConf} @@ -65,23 +63,18 @@ class BlockClientFactory(val conf: NettyConfig) extends Logging with Closeable { /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */ private def init(): Unit = { - def initOio(): Unit = { - socketChannelClass = classOf[OioSocketChannel] - workerGroup = new OioEventLoopGroup(0, threadFactory) - } def initNio(): Unit = { socketChannelClass = classOf[NioSocketChannel] - workerGroup = new NioEventLoopGroup(0, threadFactory) + workerGroup = new NioEventLoopGroup(conf.clientThreads, threadFactory) } def initEpoll(): Unit = { socketChannelClass = classOf[EpollSocketChannel] - workerGroup = new EpollEventLoopGroup(0, threadFactory) + workerGroup = new EpollEventLoopGroup(conf.clientThreads, threadFactory) } // For auto mode, first try epoll (only available on Linux), then nio. conf.ioMode match { case "nio" => initNio() - case "oio" => initOio() case "epoll" => initEpoll() case "auto" => if (Epoll.isAvailable) initEpoll() else initNio() } @@ -102,7 +95,7 @@ class BlockClientFactory(val conf: NettyConfig) extends Logging with Closeable { return cachedClient } - logInfo(s"Creating new connection to $remoteHost:$remotePort") + logDebug(s"Creating new connection to $remoteHost:$remotePort") // There is a chance two threads are creating two different clients connecting to the same host. // But that's probably ok ... diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala index 2611f2eacdb3..e2eb7c379f14 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala @@ -24,10 +24,8 @@ import io.netty.bootstrap.ServerBootstrap import io.netty.buffer.PooledByteBufAllocator import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollServerSocketChannel} import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.oio.OioEventLoopGroup import io.netty.channel.socket.SocketChannel import io.netty.channel.socket.nio.NioServerSocketChannel -import io.netty.channel.socket.oio.OioServerSocketChannel import io.netty.channel.{ChannelInitializer, ChannelFuture, ChannelOption} import org.apache.spark.Logging @@ -60,24 +58,18 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) // Use only one thread to accept connections, and 2 * num_cores for worker. def initNio(): Unit = { - val bossGroup = new NioEventLoopGroup(0, threadFactory) + val bossGroup = new NioEventLoopGroup(conf.serverThreads, threadFactory) val workerGroup = bossGroup bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel]) } - def initOio(): Unit = { - val bossGroup = new OioEventLoopGroup(0, threadFactory) - val workerGroup = bossGroup - bootstrap.group(bossGroup, workerGroup).channel(classOf[OioServerSocketChannel]) - } def initEpoll(): Unit = { - val bossGroup = new EpollEventLoopGroup(0, threadFactory) + val bossGroup = new EpollEventLoopGroup(conf.serverThreads, threadFactory) val workerGroup = bossGroup bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel]) } conf.ioMode match { case "nio" => initNio() - case "oio" => initOio() case "epoll" => initEpoll() case "auto" => if (Epoll.isAvailable) initEpoll() else initNio() } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala index b5870152c5a6..d5078e417d6d 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala @@ -31,18 +31,20 @@ class NettyConfig(conf: SparkConf) { /** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */ private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "nio").toLowerCase - /** Connect timeout in secs. Default 60 secs. */ - private[netty] val connectTimeoutMs = conf.getInt("spark.shuffle.io.connectionTimeout", 60) * 1000 - - /** - * Percentage of the desired amount of time spent for I/O in the child event loops. - * Only applicable in nio and epoll. - */ - private[netty] val ioRatio = conf.getInt("spark.shuffle.io.netty.ioRatio", 80) + /** Connect timeout in secs. Default 120 secs. */ + private[netty] val connectTimeoutMs = { + conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000 + } /** Requested maximum length of the queue of incoming connections. */ private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt) + /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */ + private[netty] val serverThreads: Int = conf.getInt("spark.shuffle.io.serverThreads", 0) + + /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ + private[netty] val clientThreads: Int = conf.getInt("spark.shuffle.io.clientThreads", 0) + /** * Receive buffer size (SO_RCVBUF). * Note: the optimal size for receive buffer and send buffer should be From 69f5d0a2434396abbbd98886e047bc08a9e65565 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 15:45:05 -0700 Subject: [PATCH 23/28] Copy the buffer in fetchBlockSync. --- .../org/apache/spark/network/BlockTransferService.scala | 5 ++++- .../apache/spark/network/netty/BlockClientFactorySuite.scala | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index a8379a207a18..c874bddcf4a6 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -18,6 +18,7 @@ package org.apache.spark.network import java.io.Closeable +import java.nio.ByteBuffer import scala.concurrent.{Await, Future} import scala.concurrent.duration.Duration @@ -94,7 +95,9 @@ abstract class BlockTransferService extends Closeable { } override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { lock.synchronized { - result = Left(data) + val ret = ByteBuffer.allocate(data.size.toInt) + ret.put(data.nioByteBuffer()) + result = Left(new NioManagedBuffer(ret)) lock.notify() } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala index 5075688b1b27..2d4baafcf03d 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientFactorySuite.scala @@ -69,7 +69,7 @@ class BlockClientFactorySuite extends FunSuite with BeforeAndAfterAll { Thread.sleep(10) } } - Await.result(f, 3 seconds) + Await.result(f, 3.seconds) assert(!c1.isActive) // Create c2, which should be different from c1 From bc9ed22d4d9fd36599a076a0a201a1809ea3a24c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 19:22:34 -0700 Subject: [PATCH 24/28] Implemented block uploads. --- .../spark/network/BlockTransferService.scala | 3 - .../org/apache/spark/network/exceptions.scala | 31 ++++++++ .../spark/network/netty/BlockClient.scala | 42 +++++++++-- .../network/netty/BlockClientHandler.scala | 68 +++++++++++++----- .../network/netty/BlockServerHandler.scala | 40 +++++++++-- .../netty/NettyBlockTransferService.scala | 6 +- .../apache/spark/network/netty/protocol.scala | 72 ++++++++++++++++--- .../apache/spark/storage/StorageLevel.scala | 3 +- .../netty/BlockClientHandlerSuite.scala | 18 ++--- .../spark/network/netty/ProtocolSuite.scala | 12 ++-- 10 files changed, 240 insertions(+), 55 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/network/exceptions.scala diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index c874bddcf4a6..2a0a1a0bc0a1 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -54,9 +54,6 @@ abstract class BlockTransferService extends Closeable { * Fetch a sequence of blocks from a remote node asynchronously, * available only after [[init]] is invoked. * - * Note that [[BlockFetchingListener.onBlockFetchSuccess]] is called once per block, - * while [[BlockFetchingListener.onBlockFetchFailure]] is called once per failure (not per block). - * * Note that this API takes a sequence so the implementation can batch requests, and does not * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. diff --git a/core/src/main/scala/org/apache/spark/network/exceptions.scala b/core/src/main/scala/org/apache/spark/network/exceptions.scala new file mode 100644 index 000000000000..d918d358c4ad --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/exceptions.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +class BlockFetchFailureException(blockId: String, errorMsg: String, cause: Throwable) + extends Exception(errorMsg, cause) { + + def this(blockId: String, errorMsg: String) = this(blockId, errorMsg, null) +} + + +class BlockUploadFailureException(blockId: String, cause: Throwable) + extends Exception(s"Failed to fetch block $blockId", cause) { + + def this(blockId: String) = this(blockId, null) +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala index fb50b1547429..c77a7ae1ccb0 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala @@ -20,10 +20,13 @@ package org.apache.spark.network.netty import java.io.Closeable import java.util.concurrent.TimeoutException +import scala.concurrent.{Future, promise} + import io.netty.channel.{ChannelFuture, ChannelFutureListener} import org.apache.spark.Logging -import org.apache.spark.network.BlockFetchingListener +import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener} +import org.apache.spark.storage.StorageLevel /** @@ -58,19 +61,19 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Closea def fetchBlocks(blockIds: Seq[String], listener: BlockFetchingListener): Unit = { var startTime: Long = 0 logTrace { - startTime = System.nanoTime() + startTime = System.currentTimeMillis() s"Sending request $blockIds to $serverAddr" } blockIds.foreach { blockId => - handler.addRequest(blockId, listener) + handler.addFetchRequest(blockId, listener) } cf.channel().writeAndFlush(BlockFetchRequest(blockIds)).addListener(new ChannelFutureListener { override def operationComplete(future: ChannelFuture): Unit = { if (future.isSuccess) { logTrace { - val timeTaken = (System.nanoTime() - startTime).toDouble / 1000000 + val timeTaken = System.currentTimeMillis() - startTime s"Sending request $blockIds to $serverAddr took $timeTaken ms" } } else { @@ -79,7 +82,7 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Closea s"Failed to send request $blockIds to $serverAddr: ${future.cause.getMessage}" logError(errorMsg, future.cause) blockIds.foreach { blockId => - handler.removeRequest(blockId) + handler.removeFetchRequest(blockId) listener.onBlockFetchFailure(blockId, new RuntimeException(errorMsg)) } } @@ -87,6 +90,35 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Closea }) } + def uploadBlock(blockId: String, data: ManagedBuffer, storageLevel: StorageLevel): Future[Unit] = { + var startTime: Long = 0 + logTrace { + startTime = System.currentTimeMillis() + s"Uploading block ($blockId) to $serverAddr" + } + val f = cf.channel().writeAndFlush(new BlockUploadRequest(blockId, data, storageLevel)) + + val p = promise[Unit]() + handler.addUploadRequest(blockId, p) + f.addListener(new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (future.isSuccess) { + logTrace { + val timeTaken = System.currentTimeMillis() - startTime + s"Uploading block ($blockId) to $serverAddr took $timeTaken ms" + } + } else { + // Fail all blocks. + val errorMsg = + s"Failed to upload block $blockId to $serverAddr: ${future.cause.getMessage}" + logError(errorMsg, future.cause) + } + } + }) + + p.future + } + /** Close the connection. This does NOT block till the connection is closed. */ def close(): Unit = cf.channel().close() } diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala index 466ece99b9b9..5e28a07a461f 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala @@ -19,10 +19,12 @@ package org.apache.spark.network.netty import java.util.concurrent.ConcurrentHashMap +import scala.concurrent.Promise + import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import org.apache.spark.Logging -import org.apache.spark.network.BlockFetchingListener +import org.apache.spark.network.{BlockFetchFailureException, BlockUploadFailureException, BlockFetchingListener} /** @@ -35,15 +37,22 @@ private[netty] class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] with Logging { /** Tracks the list of outstanding requests and their listeners on success/failure. */ - private[this] val outstandingRequests: java.util.Map[String, BlockFetchingListener] = + private[this] val outstandingFetches: java.util.Map[String, BlockFetchingListener] = new ConcurrentHashMap[String, BlockFetchingListener] - def addRequest(blockId: String, listener: BlockFetchingListener): Unit = { - outstandingRequests.put(blockId, listener) + private[this] val outstandingUploads: java.util.Map[String, Promise[Unit]] = + new ConcurrentHashMap[String, Promise[Unit]] + + def addFetchRequest(blockId: String, listener: BlockFetchingListener): Unit = { + outstandingFetches.put(blockId, listener) } - def removeRequest(blockId: String): Unit = { - outstandingRequests.remove(blockId) + def removeFetchRequest(blockId: String): Unit = { + outstandingFetches.remove(blockId) + } + + def addUploadRequest(blockId: String, promise: Promise[Unit]): Unit = { + outstandingUploads.put(blockId, promise) } /** @@ -51,19 +60,26 @@ class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] wit * uncaught exception or pre-mature connection termination. */ private def failOutstandingRequests(cause: Throwable): Unit = { - val iter = outstandingRequests.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() + val iter1 = outstandingFetches.entrySet().iterator() + while (iter1.hasNext) { + val entry = iter1.next() entry.getValue.onBlockFetchFailure(entry.getKey, cause) } // TODO(rxin): Maybe we need to synchronize the access? Otherwise we could clear new requests // as well. But I guess that is ok given the caller will fail as soon as any requests fail. - outstandingRequests.clear() + outstandingFetches.clear() + + val iter2 = outstandingUploads.entrySet().iterator() + while (iter2.hasNext) { + val entry = iter2.next() + entry.getValue.failure(new RuntimeException(s"Failed to upload block ${entry.getKey}")) + } + outstandingUploads.clear() } override def channelUnregistered(ctx: ChannelHandlerContext): Unit = { - if (outstandingRequests.size() > 0) { - logError("Still have " + outstandingRequests.size() + " requests outstanding " + + if (outstandingFetches.size() > 0) { + logError("Still have " + outstandingFetches.size() + " requests outstanding " + s"when connection from ${ctx.channel.remoteAddress} is closed") failOutstandingRequests(new RuntimeException( s"Connection from ${ctx.channel.remoteAddress} closed")) @@ -71,7 +87,7 @@ class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] wit } override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - if (outstandingRequests.size() > 0) { + if (outstandingFetches.size() > 0) { logError( s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}", cause) failOutstandingRequests(cause) @@ -83,23 +99,39 @@ class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] wit val server = ctx.channel.remoteAddress.toString response match { case BlockFetchSuccess(blockId, buf) => - val listener = outstandingRequests.get(blockId) + val listener = outstandingFetches.get(blockId) if (listener == null) { logWarning(s"Got a response for block $blockId from $server but it is not outstanding") buf.release() } else { - outstandingRequests.remove(blockId) + outstandingFetches.remove(blockId) listener.onBlockFetchSuccess(blockId, buf) buf.release() } case BlockFetchFailure(blockId, errorMsg) => - val listener = outstandingRequests.get(blockId) + val listener = outstandingFetches.get(blockId) if (listener == null) { logWarning( s"Got a response for block $blockId from $server ($errorMsg) but it is not outstanding") } else { - outstandingRequests.remove(blockId) - listener.onBlockFetchFailure(blockId, new RuntimeException(errorMsg)) + outstandingFetches.remove(blockId) + listener.onBlockFetchFailure(blockId, new BlockFetchFailureException(blockId, errorMsg)) + } + case BlockUploadSuccess(blockId) => + val p = outstandingUploads.get(blockId) + if (p == null) { + logWarning(s"Got a response for upload $blockId from $server but it is not outstanding") + } else { + outstandingUploads.remove(blockId) + p.success(Unit) + } + case BlockUploadFailure(blockId, error) => + val p = outstandingUploads.get(blockId) + if (p == null) { + logWarning(s"Got a response for upload $blockId from $server but it is not outstanding") + } else { + outstandingUploads.remove(blockId) + p.failure(new BlockUploadFailureException(blockId)) } } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala index c3b4d41829f4..44687f0b770e 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala @@ -21,6 +21,7 @@ import io.netty.channel._ import org.apache.spark.Logging import org.apache.spark.network.{ManagedBuffer, BlockDataManager} +import org.apache.spark.storage.StorageLevel /** @@ -39,13 +40,13 @@ private[netty] class BlockServerHandler(dataProvider: BlockDataManager) override def channelRead0(ctx: ChannelHandlerContext, request: ClientRequest): Unit = { request match { case BlockFetchRequest(blockIds) => - blockIds.foreach(processBlockRequest(ctx, _)) - case BlockUploadRequest(blockId, data) => - // TODO(rxin): handle upload. + blockIds.foreach(processFetchRequest(ctx, _)) + case BlockUploadRequest(blockId, data, level) => + processUploadRequest(ctx, blockId, data, level) } } // end of channelRead0 - private def processBlockRequest(ctx: ChannelHandlerContext, blockId: String): Unit = { + private def processFetchRequest(ctx: ChannelHandlerContext, blockId: String): Unit = { // A helper function to send error message back to the client. def client = ctx.channel.remoteAddress.toString @@ -90,4 +91,35 @@ private[netty] class BlockServerHandler(dataProvider: BlockDataManager) } ) } // end of processBlockRequest + + private def processUploadRequest( + ctx: ChannelHandlerContext, + blockId: String, + data: ManagedBuffer, + level: StorageLevel): Unit = { + // A helper function to send error message back to the client. + def client = ctx.channel.remoteAddress.toString + + try { + dataProvider.putBlockData(blockId, data, level) + ctx.writeAndFlush(BlockUploadSuccess(blockId)).addListener(new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (!future.isSuccess) { + logError(s"Error sending an ACK back to client $client") + } + } + }) + } catch { + case e: Throwable => + logError(s"Error processing uploaded block $blockId", e) + ctx.writeAndFlush(BlockUploadFailure(blockId, e.getMessage)).addListener( + new ChannelFutureListener { + override def operationComplete(future: ChannelFuture): Unit = { + if (!future.isSuccess) { + logError(s"Error sending an ACK back to client $client") + } + } + }) + } + } // end of processUploadRequest } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 14df5161cb0f..b7f979dccd0f 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -63,9 +63,9 @@ final class NettyBlockTransferService(conf: SparkConf) extends BlockTransferServ hostname: String, port: Int, blockId: String, - blockData: ManagedBuffer, level: StorageLevel): Future[Unit] = { - // TODO(rxin): Implement uploadBlock. - ??? + blockData: ManagedBuffer, + level: StorageLevel): Future[Unit] = { + clientFactory.createClient(hostName, port).uploadBlock(blockId, blockData, level) } override def hostName: String = { diff --git a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala index 6a14ad26dbf2..13942f3d0adc 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/protocol.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/protocol.scala @@ -17,6 +17,7 @@ package org.apache.spark.network.netty +import java.nio.ByteBuffer import java.util.{List => JList} import io.netty.buffer.ByteBuf @@ -25,7 +26,8 @@ import io.netty.channel.ChannelHandler.Sharable import io.netty.handler.codec._ import org.apache.spark.Logging -import org.apache.spark.network.{NettyManagedBuffer, ManagedBuffer} +import org.apache.spark.network.{NioManagedBuffer, NettyManagedBuffer, ManagedBuffer} +import org.apache.spark.storage.StorageLevel /** Messages from the client to the server. */ @@ -47,7 +49,11 @@ final case class BlockFetchRequest(blocks: Seq[String]) extends ClientRequest { * Request to upload a block to the server. Currently the server does not ack the upload request. */ private[netty] -final case class BlockUploadRequest(blockId: String, data: ManagedBuffer) extends ClientRequest { +final case class BlockUploadRequest( + blockId: String, + data: ManagedBuffer, + level: StorageLevel) + extends ClientRequest { require(blockId.length <= Byte.MaxValue) override def id = 1 } @@ -73,6 +79,20 @@ final case class BlockFetchFailure(blockId: String, error: String) extends Serve override def id = 1 } +/** Response to [[BlockUploadRequest]] when a block is successfully uploaded. */ +private[netty] +final case class BlockUploadSuccess(blockId: String) extends ServerResponse { + require(blockId.length <= Byte.MaxValue) + override def id = 2 +} + +/** Response to [[BlockUploadRequest]] when there is an error uploading the block. */ +private[netty] +final case class BlockUploadFailure(blockId: String, error: String) extends ServerResponse { + require(blockId.length <= Byte.MaxValue) + override def id = 3 +} + /** * Encoder for [[ClientRequest]] used in client side. @@ -102,12 +122,12 @@ final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] assert(buf.writableBytes() == 0) out.add(buf) - case BlockUploadRequest(blockId, data) => + case BlockUploadRequest(blockId, data, level) => // 8 bytes: frame size // 1 byte: msg id (BlockFetchRequest vs BlockUploadRequest) // 1 byte: blockId.length // data itself (length can be derived from: frame size - 1 - blockId.length) - val headerLength = 8 + 1 + 1 + blockId.length + val headerLength = 8 + 1 + 1 + blockId.length + 5 val frameLength = headerLength + data.size val header = ctx.alloc().buffer(headerLength) @@ -118,6 +138,8 @@ final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] header.writeLong(frameLength) header.writeByte(in.id) ProtocolUtils.writeBlockId(header, blockId) + header.writeInt(level.toInt) + header.writeByte(level.replication) assert(header.writableBytes() == 0) out.add(header) @@ -148,8 +170,12 @@ final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] { case 1 => // BlockUploadRequest val blockId = ProtocolUtils.readBlockId(in) - in.retain() // retain the bytebuf so we don't recycle it immediately. - BlockUploadRequest(blockId, new NettyManagedBuffer(in)) + val level = new StorageLevel(in.readInt(), in.readByte()) + + val ret = ByteBuffer.allocate(in.readableBytes()) + ret.put(in.nioBuffer()) + ret.flip() + BlockUploadRequest(blockId, new NioManagedBuffer(ret), level) } assert(decoded.id == msgTypeId) @@ -205,6 +231,27 @@ final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse ProtocolUtils.writeBlockId(buf, blockId) buf.writeBytes(error.getBytes) + assert(buf.writableBytes() == 0) + out.add(buf) + + case BlockUploadSuccess(blockId) => + val frameLength = 8 + 1 + 1 + blockId.length + val buf = ctx.alloc().buffer(frameLength) + buf.writeLong(frameLength) + buf.writeByte(in.id) + ProtocolUtils.writeBlockId(buf, blockId) + + assert(buf.writableBytes() == 0) + out.add(buf) + + case BlockUploadFailure(blockId, error) => + val frameLength = 8 + 1 + 1 + blockId.length + + error.length + val buf = ctx.alloc().buffer(frameLength) + buf.writeLong(frameLength) + buf.writeByte(in.id) + ProtocolUtils.writeBlockId(buf, blockId) + buf.writeBytes(error.getBytes) + assert(buf.writableBytes() == 0) out.add(buf) } @@ -228,13 +275,22 @@ final class ServerResponseDecoder extends MessageToMessageDecoder[ByteBuf] { case 0 => // BlockFetchSuccess val blockId = ProtocolUtils.readBlockId(in) in.retain() - new BlockFetchSuccess(blockId, new NettyManagedBuffer(in)) + BlockFetchSuccess(blockId, new NettyManagedBuffer(in)) case 1 => // BlockFetchFailure val blockId = ProtocolUtils.readBlockId(in) val errorBytes = new Array[Byte](in.readableBytes()) in.readBytes(errorBytes) - new BlockFetchFailure(blockId, new String(errorBytes)) + BlockFetchFailure(blockId, new String(errorBytes)) + + case 2 => // BlockUploadSuccess + BlockUploadSuccess(ProtocolUtils.readBlockId(in)) + + case 3 => // BlockUploadFailure + val blockId = ProtocolUtils.readBlockId(in) + val errorBytes = new Array[Byte](in.readableBytes()) + in.readBytes(errorBytes) + BlockUploadFailure(blockId, new String(errorBytes)) } assert(decoded.id == msgId) diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 1e35abaab535..2fc7c7d9b831 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -42,7 +42,7 @@ class StorageLevel private( extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. - private def this(flags: Int, replication: Int) { + private[spark] def this(flags: Int, replication: Int) { this((flags & 8) != 0, (flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } @@ -98,6 +98,7 @@ class StorageLevel private( } override def writeExternal(out: ObjectOutput) { + /* If the wire protocol changes, please also update [[ClientRequestEncoder]] */ out.writeByte(toInt) out.writeByte(_replication) } diff --git a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala index 7b80fe6aa364..4c3a64908157 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/BlockClientHandlerSuite.scala @@ -35,7 +35,7 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { /** Helper method to get num. outstanding requests from a private field using reflection. */ private def sizeOfOutstandingRequests(handler: BlockClientHandler): Int = { val f = handler.getClass.getDeclaredField( - "org$apache$spark$network$netty$BlockClientHandler$$outstandingRequests") + "org$apache$spark$network$netty$BlockClientHandler$$outstandingFetches") f.setAccessible(true) f.get(handler).asInstanceOf[java.util.Map[_, _]].size } @@ -45,7 +45,7 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { val blockData = "blahblahblahblahblah" val handler = new BlockClientHandler val listener = mock(classOf[BlockFetchingListener]) - handler.addRequest(blockId, listener) + handler.addFetchRequest(blockId, listener) assert(sizeOfOutstandingRequests(handler) === 1) val channel = new EmbeddedChannel(handler) @@ -63,7 +63,7 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { val blockId = "test_block" val handler = new BlockClientHandler val listener = mock(classOf[BlockFetchingListener]) - handler.addRequest(blockId, listener) + handler.addFetchRequest(blockId, listener) assert(sizeOfOutstandingRequests(handler) === 1) val channel = new EmbeddedChannel(handler) @@ -77,9 +77,9 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { test("clear all outstanding request upon uncaught exception") { val handler = new BlockClientHandler val listener = mock(classOf[BlockFetchingListener]) - handler.addRequest("b1", listener) - handler.addRequest("b2", listener) - handler.addRequest("b3", listener) + handler.addFetchRequest("b1", listener) + handler.addFetchRequest("b2", listener) + handler.addFetchRequest("b3", listener) assert(sizeOfOutstandingRequests(handler) === 3) val channel = new EmbeddedChannel(handler) @@ -96,9 +96,9 @@ class BlockClientHandlerSuite extends FunSuite with PrivateMethodTester { test("clear all outstanding request upon connection close") { val handler = new BlockClientHandler val listener = mock(classOf[BlockFetchingListener]) - handler.addRequest("c1", listener) - handler.addRequest("c2", listener) - handler.addRequest("c3", listener) + handler.addFetchRequest("c1", listener) + handler.addFetchRequest("c2", listener) + handler.addFetchRequest("c3", listener) assert(sizeOfOutstandingRequests(handler) === 3) val channel = new EmbeddedChannel(handler) diff --git a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala index 46604ea1fb62..8d1b7276f408 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/ProtocolSuite.scala @@ -21,6 +21,8 @@ import io.netty.channel.embedded.EmbeddedChannel import org.scalatest.FunSuite +import org.apache.spark.api.java.StorageLevels + /** * Test client/server encoder/decoder protocol. @@ -99,11 +101,13 @@ class ProtocolSuite extends FunSuite { testClientToServer(BlockFetchRequest(Seq("b1", "b2", "b3"))) } - ignore("client to server protocol - BlockUploadRequest(\"\", new TestManagedBuffer(0))") { - testClientToServer(BlockUploadRequest("", new TestManagedBuffer(0))) + test("client to server protocol - BlockUploadRequest(\"\", new TestManagedBuffer(0))") { + testClientToServer( + BlockUploadRequest("", new TestManagedBuffer(0), StorageLevels.MEMORY_AND_DISK)) } - ignore("client to server protocol - BlockUploadRequest(\"b_upload\", new TestManagedBuffer(10))") { - testClientToServer(BlockUploadRequest("b_upload", new TestManagedBuffer(10))) + test("client to server protocol - BlockUploadRequest(\"b_upload\", new TestManagedBuffer(10))") { + testClientToServer( + BlockUploadRequest("b_upload", new TestManagedBuffer(10), StorageLevels.MEMORY_AND_DISK_2)) } } From a3a09f6485950bc859b3724a20cea39fbee0be2b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 19:37:28 -0700 Subject: [PATCH 25/28] Fix style violation. --- .../scala/org/apache/spark/network/netty/BlockClient.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala index c77a7ae1ccb0..6bdbf88d337c 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala @@ -90,7 +90,8 @@ class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Closea }) } - def uploadBlock(blockId: String, data: ManagedBuffer, storageLevel: StorageLevel): Future[Unit] = { + def uploadBlock(blockId: String, data: ManagedBuffer, storageLevel: StorageLevel): Future[Unit] = + { var startTime: Long = 0 logTrace { startTime = System.currentTimeMillis() From 0dae31022fa26abc806db94e02fa7c15a031d1c1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 23:30:17 -0700 Subject: [PATCH 26/28] Merge with latest master. --- .../org/apache/spark/network/nio/NioBlockTransferService.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index 3d72155f8db8..e942b43d9cc4 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -201,10 +201,9 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa private def getBlock(blockId: String): ByteBuffer = { val startTimeMs = System.currentTimeMillis() logDebug("GetBlock " + blockId + " started from " + startTimeMs) - // TODO(rxin): propagate error back to the client? val buffer = blockDataManager.getBlockData(blockId) logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + " and got buffer " + buffer) - if (buffer == null) null else buffer.nioByteBuffer() + buffer.nioByteBuffer() } } From ad092361f649b82dff64c44a30b50af1e9cccc0c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 30 Sep 2014 00:56:32 -0700 Subject: [PATCH 27/28] Flip buffer. --- .../scala/org/apache/spark/network/BlockTransferService.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 2a0a1a0bc0a1..d3ed683c7e88 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -94,6 +94,7 @@ abstract class BlockTransferService extends Closeable { lock.synchronized { val ret = ByteBuffer.allocate(data.size.toInt) ret.put(data.nioByteBuffer()) + ret.flip() result = Left(new NioManagedBuffer(ret)) lock.notify() } From bdab2c74111c8bce382323f68732f87ca9b080a9 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 30 Sep 2014 12:28:21 -0700 Subject: [PATCH 28/28] Fixed spark.shuffle.io.receiveBuffer setting. --- .../main/scala/org/apache/spark/network/netty/NettyConfig.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala index d5078e417d6d..7c3074e93979 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala @@ -53,7 +53,7 @@ class NettyConfig(conf: SparkConf) { * buffer size should be ~ 1.25MB */ private[netty] val receiveBuf: Option[Int] = - conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt) + conf.getOption("spark.shuffle.io.receiveBuffer").map(_.toInt) /** Send buffer size (SO_SNDBUF). */ private[netty] val sendBuf: Option[Int] =