diff --git a/core/src/main/java/org/apache/spark/network/buffer/LargeByteBufferInputStream.java b/core/src/main/java/org/apache/spark/network/buffer/LargeByteBufferInputStream.java new file mode 100644 index 0000000000000..a4b1e2571af56 --- /dev/null +++ b/core/src/main/java/org/apache/spark/network/buffer/LargeByteBufferInputStream.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.buffer; + +import java.io.InputStream; + +import com.google.common.annotations.VisibleForTesting; + +/** + * Reads data from a LargeByteBuffer, and optionally cleans it up using buffer.dispose() + * when the stream is closed (e.g. to close a memory-mapped file). + */ +public class LargeByteBufferInputStream extends InputStream { + + private LargeByteBuffer buffer; + private final boolean dispose; + + public LargeByteBufferInputStream(LargeByteBuffer buffer, boolean dispose) { + this.buffer = buffer; + this.dispose = dispose; + } + + public LargeByteBufferInputStream(LargeByteBuffer buffer) { + this(buffer, false); + } + + @Override + public int read() { + if (buffer == null || buffer.remaining() == 0) { + return -1; + } else { + return buffer.get() & 0xFF; + } + } + + @Override + public int read(byte[] dest) { + return read(dest, 0, dest.length); + } + + @Override + public int read(byte[] dest, int offset, int length) { + if (buffer == null || buffer.remaining() == 0) { + return -1; + } else { + int amountToGet = (int) Math.min(buffer.remaining(), length); + buffer.get(dest, offset, amountToGet); + return amountToGet; + } + } + + @Override + public long skip(long toSkip) { + if (buffer != null) { + return buffer.skip(toSkip); + } else { + return 0L; + } + } + + // only for testing + @VisibleForTesting + boolean disposed = false; + + /** + * Clean up the buffer, and potentially dispose of it + */ + @Override + public void close() { + if (buffer != null) { + if (dispose) { + buffer.dispose(); + disposed = true; + } + buffer = null; + } + } +} diff --git a/core/src/main/java/org/apache/spark/network/buffer/LargeByteBufferOutputStream.java b/core/src/main/java/org/apache/spark/network/buffer/LargeByteBufferOutputStream.java new file mode 100644 index 0000000000000..975de7b10f65c --- /dev/null +++ b/core/src/main/java/org/apache/spark/network/buffer/LargeByteBufferOutputStream.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +import com.google.common.annotations.VisibleForTesting; +import org.apache.spark.util.io.ByteArrayChunkOutputStream; + +/** + * An OutputStream that will write all data to memory. It supports writing over 2GB + * and the resulting data can be retrieved as a + * {@link org.apache.spark.network.buffer.LargeByteBuffer} + */ +public class LargeByteBufferOutputStream extends OutputStream { + + private final ByteArrayChunkOutputStream output; + + /** + * Create a new LargeByteBufferOutputStream which writes to byte arrays of the given size. Note + * that chunkSize has no effect on the LargeByteBuffer returned by + * {@link #largeBuffer()}. + * + * @param chunkSize size of the byte arrays used by this output stream, in bytes + */ + public LargeByteBufferOutputStream(int chunkSize) { + output = new ByteArrayChunkOutputStream(chunkSize); + } + + @Override + public void write(int b) { + output.write(b); + } + + @Override + public void write(byte[] bytes, int off, int len) { + output.write(bytes, off, len); + } + + /** + * Get all of the data written to the stream so far as a LargeByteBuffer. This method can be + * called multiple times, and each returned buffer will be completely independent (the data + * is copied for each returned buffer). It does not close the stream. + * + * @return the data written to the stream as a LargeByteBuffer + */ + public LargeByteBuffer largeBuffer() { + return largeBuffer(LargeByteBufferHelper.MAX_CHUNK_SIZE); + } + + /** + * exposed for testing. You don't really ever want to call this method -- the returned + * buffer will not implement {{asByteBuffer}} correctly. + */ + @VisibleForTesting + LargeByteBuffer largeBuffer(int maxChunk) { + long totalSize = output.size(); + int chunksNeeded = (int) ((totalSize + maxChunk - 1) / maxChunk); + ByteBuffer[] chunks = new ByteBuffer[chunksNeeded]; + long remaining = totalSize; + long pos = 0; + for (int idx = 0; idx < chunksNeeded; idx++) { + int nextSize = (int) Math.min(maxChunk, remaining); + chunks[idx] = ByteBuffer.wrap(output.slice(pos, pos + nextSize)); + pos += nextSize; + remaining -= nextSize; + } + return new WrappedLargeByteBuffer(chunks, maxChunk); + } + + @Override + public void close() throws IOException { + output.close(); + } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index a0c9b5e63c744..8a516709dd2b8 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -18,7 +18,6 @@ package org.apache.spark.broadcast import java.io._ -import java.nio.ByteBuffer import scala.collection.JavaConversions.asJavaEnumeration import scala.reflect.ClassTag @@ -26,9 +25,10 @@ import scala.util.Random import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec +import org.apache.spark.network.buffer.{LargeByteBufferInputStream, LargeByteBufferHelper, LargeByteBuffer} import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} -import org.apache.spark.util.{ByteBufferInputStream, Utils} +import org.apache.spark.util.Utils import org.apache.spark.util.io.ByteArrayChunkOutputStream /** @@ -111,10 +111,10 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } /** Fetch torrent blocks from the driver and/or other executors. */ - private def readBlocks(): Array[ByteBuffer] = { + private def readBlocks(): Array[LargeByteBuffer] = { // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported // to the driver, so other executors can pull these chunks from this executor as well. - val blocks = new Array[ByteBuffer](numBlocks) + val blocks = new Array[LargeByteBuffer](numBlocks) val bm = SparkEnv.get.blockManager for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { @@ -123,8 +123,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) // First try getLocalBytes because there is a chance that previous attempts to fetch the // broadcast blocks have already fetched some of the blocks. In that case, some blocks // would be available locally (on this executor). - def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId) - def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block => + def getLocal: Option[LargeByteBuffer] = bm.getLocalBytes(pieceId) + def getRemote: Option[LargeByteBuffer] = bm.getRemoteBytes(pieceId).map { block => // If we found the block from remote executors/driver's BlockManager, put the block // in this executor's BlockManager. SparkEnv.get.blockManager.putBytes( @@ -134,7 +134,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) tellMaster = true) block } - val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse( + val block: LargeByteBuffer = getLocal.orElse(getRemote).getOrElse( throw new SparkException(s"Failed to get $pieceId of $broadcastId")) blocks(pid) = block } @@ -195,22 +195,22 @@ private object TorrentBroadcast extends Logging { obj: T, blockSize: Int, serializer: Serializer, - compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = { + compressionCodec: Option[CompressionCodec]): Array[LargeByteBuffer] = { val bos = new ByteArrayChunkOutputStream(blockSize) val out: OutputStream = compressionCodec.map(c => c.compressedOutputStream(bos)).getOrElse(bos) val ser = serializer.newInstance() val serOut = ser.serializeStream(out) serOut.writeObject[T](obj).close() - bos.toArrays.map(ByteBuffer.wrap) + bos.toArrays.map(LargeByteBufferHelper.asLargeByteBuffer) } def unBlockifyObject[T: ClassTag]( - blocks: Array[ByteBuffer], + blocks: Array[LargeByteBuffer], serializer: Serializer, compressionCodec: Option[CompressionCodec]): T = { require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") val is = new SequenceInputStream( - asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block)))) + asJavaEnumeration(blocks.iterator.map(block => new LargeByteBufferInputStream(block)))) val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) val ser = serializer.newInstance() val serIn = ser.deserializeStream(in) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 42a85e42ea2b6..17da70500f9df 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -23,6 +23,8 @@ import java.net.URL import java.nio.ByteBuffer import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import org.apache.spark.network.buffer.LargeByteBufferHelper + import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal @@ -266,7 +268,8 @@ private[spark] class Executor( } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( - blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) + blockId, LargeByteBufferHelper.asLargeByteBuffer(serializedDirectResult), + StorageLevel.MEMORY_AND_DISK_SER) logInfo( s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") ser.serialize(new IndirectTaskResult[Any](blockId, resultSize)) diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index dcbda5a8515dd..14462df32410f 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -24,9 +24,9 @@ import scala.concurrent.{Promise, Await, Future} import scala.concurrent.duration.Duration import org.apache.spark.Logging -import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer} +import org.apache.spark.network.buffer.{LargeByteBufferHelper, NioManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.{ShuffleClient, BlockFetchingListener} -import org.apache.spark.storage.{BlockManagerId, BlockId, StorageLevel} +import org.apache.spark.storage.{BlockId, StorageLevel} private[spark] abstract class BlockTransferService extends ShuffleClient with Closeable with Logging { diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index b089da8596e2b..c099f7ec7f79a 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -23,12 +23,12 @@ import scala.collection.JavaConversions._ import org.apache.spark.Logging import org.apache.spark.network.BlockDataManager -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.buffer.{BufferTooLargeException, ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock} import org.apache.spark.serializer.Serializer -import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.storage.{ShuffleRemoteBlockSizeLimitException, BlockId, StorageLevel} /** * Serves requests to open blocks by simply registering one chunk per block requested. @@ -53,11 +53,24 @@ class NettyBlockRpcServer( message match { case openBlocks: OpenBlocks => - val blocks: Seq[ManagedBuffer] = - openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) - val streamId = streamManager.registerStream(blocks.iterator) - logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") - responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) + try { + val blocks: Seq[ManagedBuffer] = + openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) + val streamId = streamManager.registerStream(blocks.iterator) + logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") + responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) + } catch { + // shouldn't ever happen, b/c we should prevent writing 2GB shuffle files, + // but just to be safe + case ex: BufferTooLargeException => + // throw & catch this helper exception, just to get full stack trace + try { + throw new ShuffleRemoteBlockSizeLimitException(ex) + } catch { + case ex2: ShuffleRemoteBlockSizeLimitException => + responseContext.onFailure(ex2) + } + } case uploadBlock: UploadBlock => // StorageLevel is serialized as bytes using our JavaSerializer. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 46a6f6537e2ee..2c0b3f085ca76 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -77,7 +77,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul return } val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]]( - serializedTaskResult.get) + serializedTaskResult.get.asByteBuffer()) sparkEnv.blockManager.master.removeBlock(blockId) (deserializedResult, size) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index 4342b0d598b16..81aea33ee41b4 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -18,6 +18,7 @@ package org.apache.spark.shuffle import java.nio.ByteBuffer + import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.storage.ShuffleBlockId diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index eedb27942e841..dd2a198714bac 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.io._ import java.nio.{ByteBuffer, MappedByteBuffer} +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{ExecutionContext, Await, Future} import scala.concurrent.duration._ @@ -31,7 +32,7 @@ import org.apache.spark._ import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics} import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.buffer._ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo @@ -42,7 +43,7 @@ import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.util._ private[spark] sealed trait BlockValues -private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues +private[spark] case class ByteBufferValues(buffer: LargeByteBuffer) extends BlockValues private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues private[spark] case class ArrayValues(buffer: Array[Any]) extends BlockValues @@ -300,10 +301,10 @@ private[spark] class BlockManager( shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { val blockBytesOpt = doGetLocal(blockId, asBlockResult = false) - .asInstanceOf[Option[ByteBuffer]] + .asInstanceOf[Option[LargeByteBuffer]] if (blockBytesOpt.isDefined) { val buffer = blockBytesOpt.get - new NioManagedBuffer(buffer) + new NioManagedBuffer(buffer.asByteBuffer()) } else { throw new BlockNotFoundException(blockId.toString) } @@ -314,7 +315,7 @@ private[spark] class BlockManager( * Put the block locally, using the given storage level. */ override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit = { - putBytes(blockId, data.nioByteBuffer(), level) + putBytes(blockId, LargeByteBufferHelper.asLargeByteBuffer(data.nioByteBuffer()), level) } /** @@ -432,7 +433,7 @@ private[spark] class BlockManager( /** * Get block from the local block manager as serialized bytes. */ - def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = { + def getLocalBytes(blockId: BlockId): Option[LargeByteBuffer] = { logDebug(s"Getting local block $blockId as bytes") // As an optimization for map output fetches, if the block is for a shuffle, return it // without acquiring a lock; the disk store never deletes (recent) items so this should work @@ -440,10 +441,10 @@ private[spark] class BlockManager( val shuffleBlockResolver = shuffleManager.shuffleBlockResolver // TODO: This should gracefully handle case where local block is not available. Currently // downstream code will throw an exception. - Option( - shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) + Option(LargeByteBufferHelper.asLargeByteBuffer( + shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())) } else { - doGetLocal(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] + doGetLocal(blockId, asBlockResult = false).asInstanceOf[Option[LargeByteBuffer]] } } @@ -509,13 +510,13 @@ private[spark] class BlockManager( // Look for block on disk, potentially storing it back in memory if required if (level.useDisk) { logDebug(s"Getting block $blockId from disk") - val bytes: ByteBuffer = diskStore.getBytes(blockId) match { + val bytes: LargeByteBuffer = diskStore.getBytes(blockId) match { case Some(b) => b case None => throw new BlockException( blockId, s"Block $blockId not found on disk, though it should be") } - assert(0 == bytes.position()) + assert(0L == bytes.position()) if (!level.useMemory) { // If the block shouldn't be stored in memory, we can just return it @@ -531,13 +532,12 @@ private[spark] class BlockManager( /* We'll store the bytes in memory if the block's storage level includes * "memory serialized", or if it should be cached as objects in memory * but we only requested its serialized bytes. */ - memoryStore.putBytes(blockId, bytes.limit, () => { + memoryStore.putBytes(blockId, bytes.size(), () => { // https://issues.apache.org/jira/browse/SPARK-6076 // If the file size is bigger than the free memory, OOM will happen. So if we cannot // put it into MemoryStore, copyForMemory should not be created. That's why this // action is put into a `() => ByteBuffer` and created lazily. - val copyForMemory = ByteBuffer.allocate(bytes.limit) - copyForMemory.put(bytes) + bytes.deepCopy() }) bytes.rewind() } @@ -582,9 +582,9 @@ private[spark] class BlockManager( /** * Get block from remote block managers as serialized bytes. */ - def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = { + def getRemoteBytes(blockId: BlockId): Option[LargeByteBuffer] = { logDebug(s"Getting remote block $blockId as bytes") - doGetRemote(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] + doGetRemote(blockId, asBlockResult = false).asInstanceOf[Option[LargeByteBuffer]] } private def doGetRemote(blockId: BlockId, asBlockResult: Boolean): Option[Any] = { @@ -592,17 +592,18 @@ private[spark] class BlockManager( val locations = Random.shuffle(master.getLocations(blockId)) for (loc <- locations) { logDebug(s"Getting remote block $blockId from $loc") - val data = blockTransferService.fetchBlockSync( + // the fetch will always be one byte buffer till we fix SPARK-5928 + val data: ByteBuffer = blockTransferService.fetchBlockSync( loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() if (data != null) { if (asBlockResult) { return Some(new BlockResult( - dataDeserialize(blockId, data), + dataDeserialize(blockId, LargeByteBufferHelper.asLargeByteBuffer(data)), DataReadMethod.Network, data.limit())) } else { - return Some(data) + return Some(LargeByteBufferHelper.asLargeByteBuffer(data)) } } logDebug(s"The value of block $blockId is null") @@ -675,7 +676,7 @@ private[spark] class BlockManager( */ def putBytes( blockId: BlockId, - bytes: ByteBuffer, + bytes: LargeByteBuffer, level: StorageLevel, tellMaster: Boolean = true, effectiveStorageLevel: Option[StorageLevel] = None): Seq[(BlockId, BlockStatus)] = { @@ -737,7 +738,7 @@ private[spark] class BlockManager( var valuesAfterPut: Iterator[Any] = null // Ditto for the bytes after the put - var bytesAfterPut: ByteBuffer = null + var bytesAfterPut: LargeByteBuffer = null // Size of the block in bytes var size = 0L @@ -745,119 +746,135 @@ private[spark] class BlockManager( // The level we actually use to put the block val putLevel = effectiveStorageLevel.getOrElse(level) - // If we're storing bytes, then initiate the replication before storing them locally. - // This is faster as data is already serialized and ready to send. - val replicationFuture = data match { - case b: ByteBufferValues if putLevel.replication > 1 => - // Duplicate doesn't copy the bytes, but just creates a wrapper - val bufferView = b.buffer.duplicate() - Future { - // This is a blocking action and should run in futureExecutionContext which is a cached - // thread pool - replicate(blockId, bufferView, putLevel) - }(futureExecutionContext) - case _ => null - } - - putBlockInfo.synchronized { - logTrace("Put for block %s took %s to get into synchronized block" - .format(blockId, Utils.getUsedTimeMs(startTimeMs))) + try { + // If we're storing bytes, then initiate the replication before storing them locally. + // This is faster as data is already serialized and ready to send. + val replicationFuture = data match { + case b: ByteBufferValues if putLevel.replication > 1 => + // Duplicate doesn't copy the bytes, but just creates a wrapper + val bufferView = try { + b.buffer.asByteBuffer() + } catch { + case ex: BufferTooLargeException => + throw new ReplicationBlockSizeLimitException(ex) + } + Future { + // This is a blocking action and should run in futureExecutionContext which is a cached + // thread pool + replicate(blockId, bufferView, putLevel) + }(futureExecutionContext) + case _ => null + } - var marked = false - try { - // returnValues - Whether to return the values put - // blockStore - The type of storage to put these values into - val (returnValues, blockStore: BlockStore) = { - if (putLevel.useMemory) { - // Put it in memory first, even if it also has useDisk set to true; - // We will drop it to disk later if the memory store can't hold it. - (true, memoryStore) - } else if (putLevel.useOffHeap) { - // Use external block store - (false, externalBlockStore) - } else if (putLevel.useDisk) { - // Don't get back the bytes from put unless we replicate them - (putLevel.replication > 1, diskStore) - } else { - assert(putLevel == StorageLevel.NONE) - throw new BlockException( - blockId, s"Attempted to put block $blockId without specifying storage level!") + putBlockInfo.synchronized { + logTrace("Put for block %s took %s to get into synchronized block" + .format(blockId, Utils.getUsedTimeMs(startTimeMs))) + + var marked = false + try { + // returnValues - Whether to return the values put + // blockStore - The type of storage to put these values into + val (returnValues, blockStore: BlockStore) = { + if (putLevel.useMemory) { + // Put it in memory first, even if it also has useDisk set to true; + // We will drop it to disk later if the memory store can't hold it. + (true, memoryStore) + } else if (putLevel.useOffHeap) { + // Use external block store + (false, externalBlockStore) + } else if (putLevel.useDisk) { + // Don't get back the bytes from put unless we replicate them + (putLevel.replication > 1, diskStore) + } else { + assert(putLevel == StorageLevel.NONE) + throw new BlockException( + blockId, s"Attempted to put block $blockId without specifying storage level!") + } } - } - // Actually put the values - val result = data match { - case IteratorValues(iterator) => - blockStore.putIterator(blockId, iterator, putLevel, returnValues) - case ArrayValues(array) => - blockStore.putArray(blockId, array, putLevel, returnValues) - case ByteBufferValues(bytes) => - bytes.rewind() - blockStore.putBytes(blockId, bytes, putLevel) - } - size = result.size - result.data match { - case Left (newIterator) if putLevel.useMemory => valuesAfterPut = newIterator - case Right (newBytes) => bytesAfterPut = newBytes - case _ => - } + // Actually put the values + val result = data match { + case IteratorValues(iterator) => + blockStore.putIterator(blockId, iterator, putLevel, returnValues) + case ArrayValues(array) => + blockStore.putArray(blockId, array, putLevel, returnValues) + case ByteBufferValues(bytes) => + bytes.rewind() + blockStore.putBytes(blockId, bytes, putLevel) + } + size = result.size + result.data match { + case Left(newIterator) if putLevel.useMemory => valuesAfterPut = newIterator + case Right(newBytes) => bytesAfterPut = newBytes + case _ => + } - // Keep track of which blocks are dropped from memory - if (putLevel.useMemory) { - result.droppedBlocks.foreach { updatedBlocks += _ } - } + // Keep track of which blocks are dropped from memory + if (putLevel.useMemory) { + result.droppedBlocks.foreach { + updatedBlocks += _ + } + } - val putBlockStatus = getCurrentBlockStatus(blockId, putBlockInfo) - if (putBlockStatus.storageLevel != StorageLevel.NONE) { - // Now that the block is in either the memory, externalBlockStore, or disk store, - // let other threads read it, and tell the master about it. - marked = true - putBlockInfo.markReady(size) - if (tellMaster) { - reportBlockStatus(blockId, putBlockInfo, putBlockStatus) + val putBlockStatus = getCurrentBlockStatus(blockId, putBlockInfo) + if (putBlockStatus.storageLevel != StorageLevel.NONE) { + // Now that the block is in either the memory, externalBlockStore, or disk store, + // let other threads read it, and tell the master about it. + marked = true + putBlockInfo.markReady(size) + if (tellMaster) { + reportBlockStatus(blockId, putBlockInfo, putBlockStatus) + } + updatedBlocks += ((blockId, putBlockStatus)) + } + } finally { + // If we failed in putting the block to memory/disk, notify other possible readers + // that it has failed, and then remove it from the block info map. + if (!marked) { + // Note that the remove must happen before markFailure otherwise another thread + // could've inserted a new BlockInfo before we remove it. + blockInfo.remove(blockId) + putBlockInfo.markFailure() + logWarning(s"Putting block $blockId failed") } - updatedBlocks += ((blockId, putBlockStatus)) - } - } finally { - // If we failed in putting the block to memory/disk, notify other possible readers - // that it has failed, and then remove it from the block info map. - if (!marked) { - // Note that the remove must happen before markFailure otherwise another thread - // could've inserted a new BlockInfo before we remove it. - blockInfo.remove(blockId) - putBlockInfo.markFailure() - logWarning(s"Putting block $blockId failed") } } - } - logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs))) + logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs))) - // Either we're storing bytes and we asynchronously started replication, or we're storing - // values and need to serialize and replicate them now: - if (putLevel.replication > 1) { - data match { - case ByteBufferValues(bytes) => - if (replicationFuture != null) { - Await.ready(replicationFuture, Duration.Inf) - } - case _ => - val remoteStartTime = System.currentTimeMillis - // Serialize the block if not already done - if (bytesAfterPut == null) { - if (valuesAfterPut == null) { - throw new SparkException( - "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") + // Either we're storing bytes and we asynchronously started replication, or we're storing + // values and need to serialize and replicate them now: + if (putLevel.replication > 1) { + data match { + case ByteBufferValues(bytes) => + if (replicationFuture != null) { + Await.ready(replicationFuture, Duration.Inf) } - bytesAfterPut = dataSerialize(blockId, valuesAfterPut) - } - replicate(blockId, bytesAfterPut, putLevel) - logDebug("Put block %s remotely took %s" - .format(blockId, Utils.getUsedTimeMs(remoteStartTime))) + case _ => + val remoteStartTime = System.currentTimeMillis + // Serialize the block if not already done + if (bytesAfterPut == null) { + if (valuesAfterPut == null) { + throw new SparkException( + "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.") + } + bytesAfterPut = dataSerialize(blockId, valuesAfterPut) + } + try { + replicate(blockId, bytesAfterPut.asByteBuffer(), putLevel) + } catch { + case ex: BufferTooLargeException => + throw new ReplicationBlockSizeLimitException(ex) + } + logDebug("Put block %s remotely took %s" + .format(blockId, Utils.getUsedTimeMs(remoteStartTime))) + } + } + } finally { + if (bytesAfterPut != null) { + bytesAfterPut.dispose() } } - BlockManager.dispose(bytesAfterPut) - if (putLevel.replication > 1) { logDebug("Putting block %s with replication took %s" .format(blockId, Utils.getUsedTimeMs(startTimeMs))) @@ -998,7 +1015,7 @@ private[spark] class BlockManager( def dropFromMemory( blockId: BlockId, - data: Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { + data: Either[Array[Any], LargeByteBuffer]): Option[BlockStatus] = { dropFromMemory(blockId, () => data) } @@ -1012,7 +1029,7 @@ private[spark] class BlockManager( */ def dropFromMemory( blockId: BlockId, - data: () => Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { + data: () => Either[Array[Any], LargeByteBuffer]): Option[BlockStatus] = { logInfo(s"Dropping block $blockId from memory") val info = blockInfo.get(blockId).orNull @@ -1194,10 +1211,10 @@ private[spark] class BlockManager( def dataSerialize( blockId: BlockId, values: Iterator[Any], - serializer: Serializer = defaultSerializer): ByteBuffer = { - val byteStream = new ByteArrayOutputStream(4096) + serializer: Serializer = defaultSerializer): LargeByteBuffer = { + val byteStream = new LargeByteBufferOutputStream(65536) dataSerializeStream(blockId, byteStream, values, serializer) - ByteBuffer.wrap(byteStream.toByteArray) + byteStream.largeBuffer } /** @@ -1206,10 +1223,10 @@ private[spark] class BlockManager( */ def dataDeserialize( blockId: BlockId, - bytes: ByteBuffer, + bytes: LargeByteBuffer, serializer: Serializer = defaultSerializer): Iterator[Any] = { bytes.rewind() - dataDeserializeStream(blockId, new ByteBufferInputStream(bytes, true), serializer) + dataDeserializeStream(blockId, new LargeByteBufferInputStream(bytes, true), serializer) } /** @@ -1291,3 +1308,51 @@ private[spark] object BlockManager extends Logging { blockManagers.toMap } } + + +abstract class BlockSizeLimitException(msg: String, cause: BufferTooLargeException) + extends SparkException(msg, cause) + +object BlockSizeLimitException { + def sizeMsg(cause: BufferTooLargeException): String = { + s"that was ${Utils.bytesToString(cause.actualSize)} (too " + + s"large by ${Utils.bytesToString(cause.extra)} / " + + s"${cause.actualSize.toDouble / LargeByteBufferHelper.MAX_CHUNK_SIZE}x)." + } + + def sizeMsgAndAdvice(cause: BufferTooLargeException): String = { + sizeMsg(cause) + + " You should figure out which stage created these partitions, then increase the number of " + + "partitions used by that stage. That way, you will have less data per partition. You may " + + "want to make the number of partitions an easily configurable parameter so you can " + + "continue to update it as needed." + } + +} + +class ReplicationBlockSizeLimitException(cause: BufferTooLargeException) + extends BlockSizeLimitException("Spark cannot replicate partitions that are greater than 2GB. " + + "You tried to replicate a partition " + BlockSizeLimitException.sizeMsgAndAdvice(cause) + + " Or, you can turn off replication.", cause) + +class TachyonBlockSizeLimitException(cause: BufferTooLargeException) + extends BlockSizeLimitException("Spark cannot store partitions greater than 2GB in tachyon. " + + "You tried to store a partition " + BlockSizeLimitException.sizeMsgAndAdvice(cause) + + " Or, you can use a different storage mechanism.", cause) + +class ShuffleBlockSizeLimitException(size: Long) + extends SparkException("Spark cannot shuffle partitions that are greater than 2GB. " + + "You tried to shuffle a block that was at least " + Utils.bytesToString(size) + ". " + + "You should try to increase the number of partitions of this shuffle, and / or " + + "figure out which stage created the partitions before the shuffle, and increase the number " + + "of partitions for that stage. You may want to make both of these numbers easily " + + "configurable parameters so you can continue to update as needed.") + +class ShuffleRemoteBlockSizeLimitException(cause: BufferTooLargeException) + extends BlockSizeLimitException("Spark cannot shuffle partitions that are greater than 2GB. " + + "You tried to shuffle a block that was at least " + BlockSizeLimitException.sizeMsg(cause) + + "You should try to increase the number of partitions of this shuffle, and / or increase the " + + "figure out which stage created the partitions before the shuffle, and increase the number " + + "of partitions for that stage. You may want to make both of these numbers easily " + + "configurable parameters so you can continue to update as needed.", cause) + diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala index 69985c9759e2d..ebbe06fde68ae 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala @@ -17,18 +17,15 @@ package org.apache.spark.storage -import java.nio.ByteBuffer - -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.Logging +import org.apache.spark.network.buffer.LargeByteBuffer /** * Abstract class to store blocks. */ private[spark] abstract class BlockStore(val blockManager: BlockManager) extends Logging { - def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult + def putBytes(blockId: BlockId, bytes: LargeByteBuffer, level: StorageLevel) : PutResult /** * Put in a block and, possibly, also return its content as either bytes or another Iterator. @@ -54,7 +51,7 @@ private[spark] abstract class BlockStore(val blockManager: BlockManager) extends */ def getSize(blockId: BlockId): Long - def getBytes(blockId: BlockId): Option[ByteBuffer] + def getBytes(blockId: BlockId): Option[LargeByteBuffer] def getValues(blockId: BlockId): Option[Iterator[Any]] diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index 49d9154f95a5b..ee35281c93ee7 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -21,6 +21,7 @@ import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream} import java.nio.channels.FileChannel import org.apache.spark.Logging +import org.apache.spark.network.buffer.LargeByteBufferHelper import org.apache.spark.serializer.{SerializerInstance, SerializationStream} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.util.Utils @@ -108,6 +109,12 @@ private[spark] class DiskBlockObjectWriter( objOut.close() } + finalPosition = file.length() + val length = finalPosition - initialPosition + if (length > LargeByteBufferHelper.MAX_CHUNK_SIZE) { + throw new ShuffleBlockSizeLimitException(length) + } + channel = null bs = null fos = null @@ -201,6 +208,10 @@ private[spark] class DiskBlockObjectWriter( if (numRecordsWritten % 32 == 0) { updateBytesWritten() + val length = reportedPosition - initialPosition + if (length > LargeByteBufferHelper.MAX_CHUNK_SIZE) { + throw new ShuffleBlockSizeLimitException(length) + } } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 1f45956282166..1d4b023e6f24e 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -22,6 +22,7 @@ import java.nio.ByteBuffer import java.nio.channels.FileChannel.MapMode import org.apache.spark.Logging +import org.apache.spark.network.buffer.{LargeByteBufferHelper, LargeByteBuffer} import org.apache.spark.serializer.Serializer import org.apache.spark.util.Utils @@ -37,7 +38,10 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc diskManager.getFile(blockId.name).length } - override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel): PutResult = { + override def putBytes( + blockId: BlockId, + _bytes: LargeByteBuffer, + level: StorageLevel): PutResult = { // So that we do not modify the input offsets ! // duplicate does not copy buffer, so inexpensive val bytes = _bytes.duplicate() @@ -46,16 +50,14 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc val file = diskManager.getFile(blockId) val channel = new FileOutputStream(file).getChannel Utils.tryWithSafeFinally { - while (bytes.remaining > 0) { - channel.write(bytes) - } + bytes.writeTo(channel) } { channel.close() } val finishTime = System.currentTimeMillis logDebug("Block %s stored as %s file on disk in %d ms".format( - file.getName, Utils.bytesToString(bytes.limit), finishTime - startTime)) - PutResult(bytes.limit(), Right(bytes.duplicate())) + file.getName, Utils.bytesToString(bytes.size()), finishTime - startTime)) + PutResult(bytes.size(), Right(bytes.duplicate())) } override def putArray( @@ -106,7 +108,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc } } - private def getBytes(file: File, offset: Long, length: Long): Option[ByteBuffer] = { + private def getBytes(file: File, offset: Long, length: Long): Option[LargeByteBuffer] = { val channel = new RandomAccessFile(file, "r").getChannel Utils.tryWithSafeFinally { // For small files, directly read rather than memory map @@ -120,21 +122,22 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc } } buf.flip() - Some(buf) + Some(LargeByteBufferHelper.asLargeByteBuffer(buf)) } else { - Some(channel.map(MapMode.READ_ONLY, offset, length)) + logTrace(s"mapping file: $file:$offset+$length") + Some(LargeByteBufferHelper.mapFile(channel, MapMode.READ_ONLY, offset, length)) } } { channel.close() } } - override def getBytes(blockId: BlockId): Option[ByteBuffer] = { + override def getBytes(blockId: BlockId): Option[LargeByteBuffer] = { val file = diskManager.getFile(blockId.name) getBytes(file, 0, file.length) } - def getBytes(segment: FileSegment): Option[ByteBuffer] = { + def getBytes(segment: FileSegment): Option[LargeByteBuffer] = { getBytes(segment.file, segment.offset, segment.length) } diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala index f39325a12d244..d71253539e914 100644 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala @@ -19,6 +19,8 @@ package org.apache.spark.storage import java.nio.ByteBuffer +import org.apache.spark.network.buffer.LargeByteBuffer + /** * An abstract class that the concrete external block manager has to inherit. * The class has to have a no-argument constructor, and will be initialized by init, @@ -75,7 +77,7 @@ private[spark] abstract class ExternalBlockManager { * * @throws java.io.IOException if there is any file system failure in putting the block. */ - def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit + def putBytes(blockId: BlockId, bytes: LargeByteBuffer): Unit def putValues(blockId: BlockId, values: Iterator[_]): Unit = { val bytes = blockManager.dataSerialize(blockId, values) @@ -89,7 +91,7 @@ private[spark] abstract class ExternalBlockManager { * * @throws java.io.IOException if there is any file system failure in getting the block. */ - def getBytes(blockId: BlockId): Option[ByteBuffer] + def getBytes(blockId: BlockId): Option[LargeByteBuffer] /** * Retrieve the block data. diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala index db965d54bafd6..c2271bccc8f4c 100644 --- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala @@ -22,6 +22,7 @@ import java.nio.ByteBuffer import scala.util.control.NonFatal import org.apache.spark.Logging +import org.apache.spark.network.buffer.LargeByteBuffer import org.apache.spark.util.Utils @@ -47,7 +48,10 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: } } - override def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult = { + override def putBytes( + blockId: BlockId, + bytes: LargeByteBuffer, + level: StorageLevel): PutResult = { putIntoExternalBlockStore(blockId, bytes, returnValues = true) } @@ -100,7 +104,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: private def putIntoExternalBlockStore( blockId: BlockId, - bytes: ByteBuffer, + bytes: LargeByteBuffer, returnValues: Boolean): PutResult = { logTrace(s"Attempting to put block $blockId into ExternalBlockStore") // we should never hit here if externalBlockManager is None. Handle it anyway for safety. @@ -110,7 +114,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: val byteBuffer = bytes.duplicate() byteBuffer.rewind() externalBlockManager.get.putBytes(blockId, byteBuffer) - val size = bytes.limit() + val size = bytes.size() val data = if (returnValues) { Right(bytes) } else { @@ -152,7 +156,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId: } } - override def getBytes(blockId: BlockId): Option[ByteBuffer] = { + override def getBytes(blockId: BlockId): Option[LargeByteBuffer] = { try { externalBlockManager.flatMap(_.getBytes(blockId)) } catch { diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 6f27f00307f8c..157c49ebdc919 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -17,9 +17,10 @@ package org.apache.spark.storage -import java.nio.ByteBuffer import java.util.LinkedHashMap +import org.apache.spark.network.buffer.LargeByteBuffer + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -86,7 +87,10 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel): PutResult = { + override def putBytes( + blockId: BlockId, + _bytes: LargeByteBuffer, + level: StorageLevel): PutResult = { // Work on a duplicate - since the original input might be used elsewhere. val bytes = _bytes.duplicate() bytes.rewind() @@ -94,8 +98,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val values = blockManager.dataDeserialize(blockId, bytes) putIterator(blockId, values, level, returnValues = true) } else { - val putAttempt = tryToPut(blockId, bytes, bytes.limit, deserialized = false) - PutResult(bytes.limit(), Right(bytes.duplicate()), putAttempt.droppedBlocks) + val putAttempt = tryToPut(blockId, bytes, bytes.size(), deserialized = false) + PutResult(bytes.size(), Right(bytes.duplicate()), putAttempt.droppedBlocks) } } @@ -105,13 +109,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * * The caller should guarantee that `size` is correct. */ - def putBytes(blockId: BlockId, size: Long, _bytes: () => ByteBuffer): PutResult = { + def putBytes( + blockId: BlockId, + size: Long, + _bytes: () => LargeByteBuffer): PutResult = { // Work on a duplicate - since the original input might be used elsewhere. - lazy val bytes = _bytes().duplicate().rewind().asInstanceOf[ByteBuffer] + lazy val bytes = _bytes().duplicate().rewind() val putAttempt = tryToPut(blockId, () => bytes, size, deserialized = false) val data = if (putAttempt.success) { - assert(bytes.limit == size) + assert(bytes.size() == size) Right(bytes.duplicate()) } else { null @@ -130,8 +137,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) PutResult(sizeEstimate, Left(values.iterator), putAttempt.droppedBlocks) } else { val bytes = blockManager.dataSerialize(blockId, values.iterator) - val putAttempt = tryToPut(blockId, bytes, bytes.limit, deserialized = false) - PutResult(bytes.limit(), Right(bytes.duplicate()), putAttempt.droppedBlocks) + val putAttempt = tryToPut(blockId, bytes, bytes.size(), deserialized = false) + PutResult(bytes.size(), Right(bytes.duplicate()), putAttempt.droppedBlocks) } } @@ -181,7 +188,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } } - override def getBytes(blockId: BlockId): Option[ByteBuffer] = { + override def getBytes(blockId: BlockId): Option[LargeByteBuffer] = { val entry = entries.synchronized { entries.get(blockId) } @@ -190,7 +197,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } else if (entry.deserialized) { Some(blockManager.dataSerialize(blockId, entry.value.asInstanceOf[Array[Any]].iterator)) } else { - Some(entry.value.asInstanceOf[ByteBuffer].duplicate()) // Doesn't actually copy the data + Some(entry.value.asInstanceOf[LargeByteBuffer].duplicate()) // Doesn't actually copy the data } } @@ -203,7 +210,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } else if (entry.deserialized) { Some(entry.value.asInstanceOf[Array[Any]].iterator) } else { - val buffer = entry.value.asInstanceOf[ByteBuffer].duplicate() // Doesn't actually copy data + // Doesn't actually copy data + val buffer = entry.value.asInstanceOf[LargeByteBuffer].duplicate() Some(blockManager.dataDeserialize(blockId, buffer)) } } @@ -392,7 +400,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) lazy val data = if (deserialized) { Left(value().asInstanceOf[Array[Any]]) } else { - Right(value().asInstanceOf[ByteBuffer].duplicate()) + Right(value().asInstanceOf[LargeByteBuffer].duplicate()) } val droppedBlockStatus = blockManager.dropFromMemory(blockId, () => data) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } @@ -463,7 +471,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val data = if (entry.deserialized) { Left(entry.value.asInstanceOf[Array[Any]]) } else { - Right(entry.value.asInstanceOf[ByteBuffer].duplicate()) + Right(entry.value.asInstanceOf[LargeByteBuffer].duplicate()) } val droppedBlockStatus = blockManager.dropFromMemory(blockId, data) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } diff --git a/core/src/main/scala/org/apache/spark/storage/PutResult.scala b/core/src/main/scala/org/apache/spark/storage/PutResult.scala index f0eac7594ecf6..aa9176791b319 100644 --- a/core/src/main/scala/org/apache/spark/storage/PutResult.scala +++ b/core/src/main/scala/org/apache/spark/storage/PutResult.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.nio.ByteBuffer +import org.apache.spark.network.buffer.LargeByteBuffer /** * Result of adding a block into a BlockStore. This case class contains a few things: @@ -28,5 +28,5 @@ import java.nio.ByteBuffer */ private[spark] case class PutResult( size: Long, - data: Either[Iterator[_], ByteBuffer], + data: Either[Iterator[_], LargeByteBuffer], droppedBlocks: Seq[(BlockId, BlockStatus)] = Seq.empty) diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index 22878783fca67..f806407c73ead 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -18,7 +18,6 @@ package org.apache.spark.storage import java.io.IOException -import java.nio.ByteBuffer import java.text.SimpleDateFormat import java.util.{Date, Random} @@ -32,6 +31,7 @@ import tachyon.TachyonURI import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode +import org.apache.spark.network.buffer.{BufferTooLargeException, LargeByteBufferHelper, LargeByteBuffer} import org.apache.spark.util.{ShutdownHookManager, Utils} @@ -99,12 +99,14 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log fileExists(file) } - override def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit = { + override def putBytes(blockId: BlockId, bytes: LargeByteBuffer): Unit = { val file = getFile(blockId) val os = file.getOutStream(WriteType.TRY_CACHE) try { - os.write(bytes.array()) + os.write(bytes.asByteBuffer().array()) } catch { + case tooLarge: BufferTooLargeException => + throw new TachyonBlockSizeLimitException(tooLarge) case NonFatal(e) => logWarning(s"Failed to put bytes of block $blockId into Tachyon", e) os.cancel() @@ -127,7 +129,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log } } - override def getBytes(blockId: BlockId): Option[ByteBuffer] = { + override def getBytes(blockId: BlockId): Option[LargeByteBuffer] = { val file = getFile(blockId) if (file == null || file.getLocationHosts.size == 0) { return None @@ -135,9 +137,10 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log val is = file.getInStream(ReadType.CACHE) try { val size = file.length + // TODO get tachyon to support large blocks val bs = new Array[Byte](size.asInstanceOf[Int]) ByteStreams.readFully(is, bs) - Some(ByteBuffer.wrap(bs)) + Some(LargeByteBufferHelper.asLargeByteBuffer(bs)) } catch { case NonFatal(e) => logWarning(s"Failed to get bytes of block $blockId from Tachyon", e) diff --git a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala index daac6f971eb20..d48eb2f330321 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala @@ -21,7 +21,6 @@ import java.io.OutputStream import scala.collection.mutable.ArrayBuffer - /** * An OutputStream that writes to fixed-size chunks of byte arrays. * @@ -43,10 +42,13 @@ class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream { */ private var position = chunkSize + private[spark] var size: Long = 0L + override def write(b: Int): Unit = { allocateNewChunkIfNeeded() chunks(lastChunkIndex)(position) = b.toByte position += 1 + size += 1 } override def write(bytes: Array[Byte], off: Int, len: Int): Unit = { @@ -58,6 +60,7 @@ class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream { written += thisBatch position += thisBatch } + size += len } @inline @@ -91,4 +94,44 @@ class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream { ret } } + + /** + * Get a copy of the data between the two endpoints, start <= idx < until. Always returns + * an array of size (until - start). Throws an IllegalArgumentException unless + * 0 <= start <= until <= size + */ + def slice(start: Long, until: Long): Array[Byte] = { + require((until - start) < Integer.MAX_VALUE, "max slice length = Integer.MAX_VALUE") + require(start >= 0 && start <= until, s"start ($start) must be >= 0 and <= until ($until)") + require(until >= start && until <= size, + s"until ($until) must be >= start ($start) and <= size ($size)") + var chunkStart = 0L + var chunkIdx = 0 + val length = (until - start).toInt + var foundStart = false + val result = new Array[Byte](length) + while (!foundStart) { + val nextChunkStart = chunkStart + chunks(chunkIdx).size + if (nextChunkStart > start) { + foundStart = true + } else { + chunkStart = nextChunkStart + chunkIdx += 1 + } + } + + var remaining = length + var pos = 0 + var offsetInChunk = (start - chunkStart).toInt + while (remaining > 0) { + val lenToCopy = math.min(remaining, chunks(chunkIdx).size - offsetInChunk) + System.arraycopy(chunks(chunkIdx), offsetInChunk, result, pos, lenToCopy) + chunkIdx += 1 + offsetInChunk = 0 + pos += lenToCopy + remaining -= lenToCopy + } + result + } + } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 600c1403b0344..4d488a0b78f5b 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.Matchers import org.scalatest.time.{Millis, Span} +import org.apache.spark.network.buffer.LargeByteBufferHelper import org.apache.spark.storage.{RDDBlockId, StorageLevel} class NotSerializableClass @@ -196,7 +197,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex blockManager.master.getLocations(blockId).foreach { cmId => val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, blockId.toString) - val deserialized = blockManager.dataDeserialize(blockId, bytes.nioByteBuffer()) + val deserialized = blockManager.dataDeserialize(blockId, + LargeByteBufferHelper.asLargeByteBuffer(bytes.nioByteBuffer())) .asInstanceOf[Iterator[Int]].toList assert(deserialized === (1 to 100).toList) } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index d91b799ecfc08..e28818098b918 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.ShuffleSuite.NonJavaSerializableClass import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD} import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId} +import org.apache.spark.storage.{ShuffleBlockSizeLimitException, BlockSizeLimitException, ShuffleDataBlockId, ShuffleBlockId} import org.apache.spark.util.MutablePair abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext { @@ -283,6 +283,39 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC rdd.count() } + ignore("shuffle total > 2GB ok if each block is small") { + sc = new SparkContext("local", "test", conf) + val rdd = sc.parallelize(1 to 1e6.toInt, 1).map{ i => + val n = 3e3.toInt + val arr = new Array[Byte](n) + // need to make sure the array doesn't compress to something small + scala.util.Random.nextBytes(arr) + (i, arr) + } + rdd.partitionBy(new HashPartitioner(100)).count() + } + + ignore("shuffle blocks > 2GB fail with sane exception") { + // note that this *could* succeed in local mode, b/c local shuffles actually don't + // have a limit at 2GB. BUT, we make them fail in any case, b/c its better to have + // a consistent failure, and not have success depend on where tasks get scheduled + + sc = new SparkContext("local", "test", conf) + val rdd = sc.parallelize(1 to 1e6.toInt, 1).map{ i => + val n = 3e3.toInt + val arr = new Array[Byte](n) + // need to make sure the array doesn't compress to something small + scala.util.Random.nextBytes(arr) + (2 * i, arr) + } + + val exc = intercept[SparkException] { + rdd.partitionBy(new org.apache.spark.HashPartitioner(2)).count() + } + + exc.getCause shouldBe a[ShuffleBlockSizeLimitException] + } + test("metrics for shuffle without aggregation") { sc = new SparkContext("local", "test", conf.clone()) val numRecords = 10000 diff --git a/core/src/test/scala/org/apache/spark/network/buffer/LargeByteBufferInputStreamSuite.scala b/core/src/test/scala/org/apache/spark/network/buffer/LargeByteBufferInputStreamSuite.scala new file mode 100644 index 0000000000000..d8e48db32f78c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/buffer/LargeByteBufferInputStreamSuite.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.buffer + +import java.io.{File, FileInputStream, FileOutputStream, OutputStream} +import java.nio.channels.FileChannel.MapMode + +import org.junit.Assert._ +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite + +class LargeByteBufferInputStreamSuite extends SparkFunSuite with Matchers { + + test("read from large mapped file") { + val testFile = File.createTempFile("large-buffer-input-stream-test", ".bin") + + try { + val out: OutputStream = new FileOutputStream(testFile) + val buffer: Array[Byte] = new Array[Byte](1 << 16) + val len: Long = buffer.length.toLong + Integer.MAX_VALUE + 1 + (0 until buffer.length).foreach { idx => + buffer(idx) = idx.toByte + } + (0 until (len / buffer.length).toInt).foreach { idx => + out.write(buffer) + } + out.close + + val channel = new FileInputStream(testFile).getChannel + val buf = LargeByteBufferHelper.mapFile(channel, MapMode.READ_ONLY, 0, len) + val in = new LargeByteBufferInputStream(buf, true) + + val read = new Array[Byte](buffer.length) + (0 until (len / buffer.length).toInt).foreach { idx => + in.disposed should be(false) + in.read(read) should be(read.length) + (0 until buffer.length).foreach { arrIdx => + assertEquals(buffer(arrIdx), read(arrIdx)) + } + } + in.disposed should be(false) + in.read(read) should be(-1) + in.disposed should be(false) + in.close() + in.disposed should be(true) + } finally { + testFile.delete() + } + } + + test("dispose on close") { + // don't need to read to the end -- dispose anytime we close + val data = new Array[Byte](10) + val in = new LargeByteBufferInputStream(LargeByteBufferHelper.asLargeByteBuffer(data), true) + in.disposed should be (false) + in.close() + in.disposed should be (true) + } + + test("io stream roundtrip") { + val out = new LargeByteBufferOutputStream(128) + (0 until 200).foreach { idx => out.write(idx) } + out.close() + + val lb = out.largeBuffer(128) + // just make sure that we test reading from multiple chunks + lb.asInstanceOf[WrappedLargeByteBuffer].underlying.size should be > 1 + + val rawIn = new LargeByteBufferInputStream(lb) + val arr = new Array[Byte](500) + val nRead = rawIn.read(arr, 0, 500) + nRead should be (200) + (0 until 200).foreach { idx => + arr(idx) should be (idx.toByte) + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/network/buffer/LargeByteBufferOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/network/buffer/LargeByteBufferOutputStreamSuite.scala new file mode 100644 index 0000000000000..72c98b7feacab --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/buffer/LargeByteBufferOutputStreamSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.buffer + +import scala.util.Random + +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite + +class LargeByteBufferOutputStreamSuite extends SparkFunSuite with Matchers { + + test("merged buffers for < 2GB") { + val out = new LargeByteBufferOutputStream(10) + val bytes = new Array[Byte](100) + Random.nextBytes(bytes) + out.write(bytes) + + val buffer = out.largeBuffer + buffer.position() should be (0) + buffer.size() should be (100) + val nioBuffer = buffer.asByteBuffer() + nioBuffer.position() should be (0) + nioBuffer.capacity() should be (100) + nioBuffer.limit() should be (100) + + val read = new Array[Byte](100) + buffer.get(read, 0, 100) + read should be (bytes) + + buffer.rewind() + nioBuffer.get(read) + read should be (bytes) + } + + test("chunking") { + val out = new LargeByteBufferOutputStream(10) + val bytes = new Array[Byte](100) + Random.nextBytes(bytes) + out.write(bytes) + + (10 to 100 by 10).foreach { chunkSize => + val buffer = out.largeBuffer(chunkSize).asInstanceOf[WrappedLargeByteBuffer] + buffer.position() should be (0) + buffer.size() should be (100) + val read = new Array[Byte](100) + buffer.get(read, 0, 100) + read should be (bytes) + } + + } + +} diff --git a/core/src/test/scala/org/apache/spark/network/buffer/LargeByteBufferTestHelper.scala b/core/src/test/scala/org/apache/spark/network/buffer/LargeByteBufferTestHelper.scala new file mode 100644 index 0000000000000..a04bb41fae366 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/network/buffer/LargeByteBufferTestHelper.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer + +import java.nio.ByteBuffer +import java.util.{List => JList} + +/** + * cheat to access package-protected members in test + */ +object LargeByteBufferTestHelper { + def nioBuffers(wbb: WrappedLargeByteBuffer): JList[ByteBuffer] = { + wbb.nioBuffers() + } +} diff --git a/core/src/test/scala/org/apache/spark/rdd/LargePartitionCachingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/LargePartitionCachingSuite.scala new file mode 100644 index 0000000000000..dfcc90df32a18 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/LargePartitionCachingSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rdd + +import org.apache.spark._ +import org.apache.spark.storage.{ReplicationBlockSizeLimitException, StorageLevel} +import org.scalatest.Matchers + +class LargePartitionCachingSuite extends SparkFunSuite with SharedSparkContext with Matchers { + + def largePartitionRdd: RDD[Array[Byte]] = { + sc.parallelize(1 to 1e6.toInt, 1).map{i => new Array[Byte](2.2e3.toInt)} + } + + // just don't want to kill the test server + ignore("memory serialized cache large partitions") { + largePartitionRdd.persist(StorageLevel.MEMORY_ONLY_SER).count() should be (1e6.toInt) + } + + ignore("disk cache large partitions") { + largePartitionRdd.persist(StorageLevel.DISK_ONLY).count() should be (1e6.toInt) + } + + ignore("disk cache large partitions with replications") { + val conf = new SparkConf() + .setMaster("local-cluster[2, 1, 1024]") + .setAppName("test-cluster") + .set("spark.task.maxFailures", "1") + .set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data + val clusterSc = new SparkContext(conf) + try { + val exc = intercept[SparkException]{ + val myRDD = clusterSc.parallelize(1 to 1e6.toInt, 1).map{i => new Array[Byte](2.2e3.toInt)} + .persist(StorageLevel.DISK_ONLY_2) + myRDD.count() + } + exc.getCause() shouldBe a [ReplicationBlockSizeLimitException] + } finally { + clusterSc.stop() + } + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index f480fd107a0c2..d1d6778180a4c 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays +import org.apache.spark.network.buffer.{LargeByteBufferTestHelper, WrappedLargeByteBuffer, LargeByteBufferHelper, LargeByteBuffer} + import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.implicitConversions @@ -169,8 +171,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(master.getLocations("a3").size === 0, "master was told about a3") // Drop a1 and a2 from memory; this should be reported back to the master - store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer]) - store.dropFromMemory("a2", null: Either[Array[Any], ByteBuffer]) + store.dropFromMemory("a1", null: Either[Array[Any], LargeByteBuffer]) + store.dropFromMemory("a2", null: Either[Array[Any], LargeByteBuffer]) assert(store.getSingle("a1") === None, "a1 not removed from store") assert(store.getSingle("a2") === None, "a2 not removed from store") assert(master.getLocations("a1").size === 0, "master did not remove a1") @@ -410,8 +412,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE t2.join() t3.join() - store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer]) - store.dropFromMemory("a2", null: Either[Array[Any], ByteBuffer]) + store.dropFromMemory("a1", null: Either[Array[Any], LargeByteBuffer]) + store.dropFromMemory("a2", null: Either[Array[Any], LargeByteBuffer]) store.waitForAsyncReregister() } } @@ -807,7 +809,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE var counter = 0.toByte def incr: Byte = {counter = (counter + 1).toByte; counter;} val bytes = Array.fill[Byte](1000)(incr) - val byteBuffer = ByteBuffer.wrap(bytes) + val byteBuffer = LargeByteBufferHelper.asLargeByteBuffer(bytes) val blockId = BlockId("rdd_1_2") @@ -820,21 +822,22 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val diskStoreMapped = new DiskStore(blockManager, diskBlockManager) diskStoreMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY) - val mapped = diskStoreMapped.getBytes(blockId).get + val mapped = diskStoreMapped.getBytes(blockId).get.asInstanceOf[WrappedLargeByteBuffer] when(blockManager.conf).thenReturn(conf.clone.set(confKey, "1m")) val diskStoreNotMapped = new DiskStore(blockManager, diskBlockManager) diskStoreNotMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY) - val notMapped = diskStoreNotMapped.getBytes(blockId).get + val notMapped = diskStoreNotMapped.getBytes(blockId).get.asInstanceOf[WrappedLargeByteBuffer] // Not possible to do isInstanceOf due to visibility of HeapByteBuffer - assert(notMapped.getClass.getName.endsWith("HeapByteBuffer"), - "Expected HeapByteBuffer for un-mapped read") - assert(mapped.isInstanceOf[MappedByteBuffer], "Expected MappedByteBuffer for mapped read") - - def arrayFromByteBuffer(in: ByteBuffer): Array[Byte] = { - val array = new Array[Byte](in.remaining()) - in.get(array) + assert(LargeByteBufferTestHelper.nioBuffers(notMapped).get(0).getClass.getName + .endsWith("HeapByteBuffer"), "Expected HeapByteBuffer for un-mapped read") + assert(LargeByteBufferTestHelper.nioBuffers(mapped).get(0).isInstanceOf[MappedByteBuffer], + "Expected MappedByteBuffer for mapped read") + + def arrayFromByteBuffer(in: LargeByteBuffer): Array[Byte] = { + val array = new Array[Byte](in.remaining().toInt) + in.get(array, 0, in.remaining().toInt) array } @@ -1242,13 +1245,23 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store = makeBlockManager(12000) val memoryStore = store.memoryStore val blockId = BlockId("rdd_3_10") - var bytes: ByteBuffer = null + var bytes: LargeByteBuffer = null val result = memoryStore.putBytes(blockId, 10000, () => { - bytes = ByteBuffer.allocate(10000) + bytes = LargeByteBufferHelper.allocate(10000) bytes }) assert(result.size === 10000) - assert(result.data === Right(bytes)) + assert(result.data.isRight) + assertEquivalentByteBufs(result.data.right.get, bytes) assert(result.droppedBlocks === Nil) } + + def assertEquivalentByteBufs(exp: LargeByteBuffer, act: LargeByteBuffer): Unit = { + assert(exp.size() === act.size()) + val expBytes = new Array[Byte](exp.size().toInt) + exp.get(expBytes, 0, exp.size().toInt) + val actBytes = new Array[Byte](act.size().toInt) + act.get(actBytes, 0, act.size().toInt) + assert(expBytes === actBytes) + } } diff --git a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala index 361ec95654f47..38bc24528f3a7 100644 --- a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala @@ -21,7 +21,6 @@ import scala.util.Random import org.apache.spark.SparkFunSuite - class ByteArrayChunkOutputStreamSuite extends SparkFunSuite { test("empty output") { @@ -106,4 +105,30 @@ class ByteArrayChunkOutputStreamSuite extends SparkFunSuite { assert(arrays(1).toSeq === ref.slice(10, 20)) assert(arrays(2).toSeq === ref.slice(20, 30)) } + + test("slice") { + val ref = new Array[Byte](30) + Random.nextBytes(ref) + val o = new ByteArrayChunkOutputStream(5) + o.write(ref) + + for { + start <- (0 until 30) + end <- (start to 30) + } { + withClue(s"start = $start; end = $end") { + try { + assert(o.slice(start, end).toSeq === ref.slice(start, end)) + } catch { + case ex => fail(ex) + } + } + } + + // errors on bad bounds + intercept[IllegalArgumentException]{o.slice(31, 31)} + intercept[IllegalArgumentException]{o.slice(-1, 10)} + intercept[IllegalArgumentException]{o.slice(10, 5)} + intercept[IllegalArgumentException]{o.slice(10, 35)} + } } diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/BufferTooLargeException.java b/network/common/src/main/java/org/apache/spark/network/buffer/BufferTooLargeException.java new file mode 100644 index 0000000000000..4e1a85ba1f126 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/BufferTooLargeException.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.buffer; + +import java.io.IOException; + +public class BufferTooLargeException extends IOException { + public final long actualSize; + public final long extra; + public final long maxSize; + + public BufferTooLargeException(long actualSize, long maxSize) { + super(String.format("LargeByteBuffer is too large to convert. Size: %d; Size Limit: %d (%d " + + "too big)", actualSize, maxSize, + actualSize - maxSize)); + this.extra = actualSize - maxSize; + this.actualSize = actualSize; + this.maxSize = maxSize; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index 844eff4f4c701..2d534f12abd62 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -73,6 +73,9 @@ public ByteBuffer nioByteBuffer() throws IOException { buf.flip(); return buf; } else { + if (length > LargeByteBufferHelper.MAX_CHUNK_SIZE) { + throw new BufferTooLargeException(length, LargeByteBufferHelper.MAX_CHUNK_SIZE); + } return channel.map(FileChannel.MapMode.READ_ONLY, offset, length); } } catch (IOException e) { diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/LargeByteBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/LargeByteBuffer.java new file mode 100644 index 0000000000000..beeb007e2197e --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/LargeByteBuffer.java @@ -0,0 +1,148 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; + +/** + * A byte buffer which can hold over 2GB. + *

+ * This is roughly similar {@link java.nio.ByteBuffer}, with a limited set of operations relevant + * to use in Spark, and without the capacity restrictions of a ByteBuffer. + *

+ * Unlike ByteBuffers, this is read-only, and only supports reading bytes (with both single and bulk + * get methods). It supports random access via skip to move around the + * buffer. + *

+ * In general, implementations are expected to support O(1) random access. Furthermore, + * neighboring locations in the buffer are likely to be neighboring in memory, so sequential access + * will avoid cache-misses. However, these are only rough guidelines which may differ in + * implementations. + *

+ * Any code which expects a ByteBuffer can obtain one via {@link #asByteBuffer} when possible -- see + * that method for a full description of its limitations. + *

+ * Instances of this class can be created with + * {@link org.apache.spark.network.buffer.LargeByteBufferHelper}, + * with a LargeByteBufferOutputStream, + * or directly from the implementation + * {@link org.apache.spark.network.buffer.WrappedLargeByteBuffer}. + */ +public interface LargeByteBuffer { + public byte get(); + + + /** + * Bulk copy data from this buffer into the given array. First checks there is sufficient + * data in this buffer; if not, throws a {@link java.nio.BufferUnderflowException}. Behaves + * in the exact same way as get(dst, 0, dst.length) + * + * @param dst the destination array + * @return this buffer + */ + public LargeByteBuffer get(byte[] dst); + + /** + * Bulk copy data from this buffer into the given array. First checks there is sufficient + * data in this buffer; if not, throws a {@link java.nio.BufferUnderflowException}. + * + * @param dst the destination array + * @param offset the offset within the destination array to write to + * @param length how many bytes to write + * @return this buffer + */ + public LargeByteBuffer get(byte[] dst, int offset, int length); + + + public LargeByteBuffer rewind(); + + /** + * Return a deep copy of this buffer. + * The returned buffer will have position == 0. The position + * of this buffer will not change as a result of copying. + * + * @return a new buffer with a full copy of this buffer's data + */ + public LargeByteBuffer deepCopy(); + + /** + * Advance the position in this buffer by up to n bytes. n may be + * positive or negative. It will move the full n unless that moves + * it past the end (or beginning) of the buffer, in which case it will move to the end + * (or beginning). + * + * @return the number of bytes moved forward (can be negative if n is negative) + */ + public long skip(long n); + + public long position(); + + /** + * Creates a new byte buffer that shares this buffer's content. + *

+ * The content of the new buffer will be that of this buffer. Changes + * to this buffer's content will be visible in the new buffer, and vice + * versa; the two buffers' positions will be independent. + *

+ * The new buffer's position will be identical to those of this buffer + */ + public LargeByteBuffer duplicate(); + + public long remaining(); + + /** + * Total number of bytes in this buffer + */ + public long size(); + + /** + * Writes the data from the current position() to the end of this buffer + * to the given channel. The position() will be moved to the end of + * the buffer after this. + *

+ * Note that this method will continually attempt to push data to the given channel. If the + * channel cannot accept more data, this will continuously retry until the channel accepts + * the data. + * + * @param channel + * @return the number of bytes written to the channel + * @throws IOException + */ + public long writeTo(WritableByteChannel channel) throws IOException; + + /** + * Get the entire contents of this as one ByteBuffer, if possible. The returned ByteBuffer + * will always have the position set to 0, and the limit set to the end of the data. Each + * call will return a new ByteBuffer, but will not require copying the data (eg., it will + * use ByteBuffer#duplicate()). The returned byte buffer will share data with this buffer. The + * returned buffers will never be larger than + * {@link org.apache.spark.network.buffer.LargeByteBufferHelper#MAX_CHUNK_SIZE} + * + * @throws BufferTooLargeException if this buffer is too large to fit in one {@link ByteBuffer} + */ + public ByteBuffer asByteBuffer() throws BufferTooLargeException; + + /** + * Attempt to clean this up if it is memory-mapped. This uses an *unsafe* Sun API that + * might cause errors if one attempts to read from the unmapped buffer, but it's better than + * waiting for the GC to find it because that could lead to huge numbers of open files. There's + * unfortunately no standard API to do this. + */ + public void dispose(); +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/LargeByteBufferHelper.java b/network/common/src/main/java/org/apache/spark/network/buffer/LargeByteBufferHelper.java new file mode 100644 index 0000000000000..4941ed6559ea9 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/LargeByteBufferHelper.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; + +import com.google.common.annotations.VisibleForTesting; + +/** + * Utils for creating {@link org.apache.spark.network.buffer.LargeByteBuffer}s, either from + * pre-allocated byte arrays, ByteBuffers, or by memory mapping a file. + */ +public class LargeByteBufferHelper { + + // netty can't quite send msgs that are a full 2GB -- they need to be slightly smaller + // not sure what the exact limit is, but 200 seems OK. + /** + * The maximum size of any ByteBuffer. + * {@link org.apache.spark.network.buffer.LargeByteBuffer#asByteBuffer} will never return a + * ByteBuffer larger than this. This is close to the max ByteBuffer size (2GB), minus a small + * amount for message overhead. + */ + public static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 200; + + public static LargeByteBuffer asLargeByteBuffer(ByteBuffer buffer) { + return new WrappedLargeByteBuffer(new ByteBuffer[] { buffer } ); + } + + public static LargeByteBuffer asLargeByteBuffer(byte[] bytes) { + return asLargeByteBuffer(ByteBuffer.wrap(bytes)); + } + + public static LargeByteBuffer allocate(long size) { + return allocate(size, MAX_CHUNK_SIZE); + } + + @VisibleForTesting + static LargeByteBuffer allocate(long size, int maxChunk) { + int chunksNeeded = (int) ((size + maxChunk - 1) / maxChunk); + ByteBuffer[] chunks = new ByteBuffer[chunksNeeded]; + long remaining = size; + for (int i = 0; i < chunksNeeded; i++) { + int nextSize = (int) Math.min(remaining, maxChunk); + ByteBuffer next = ByteBuffer.allocate(nextSize); + remaining -= nextSize; + chunks[i] = next; + } + if (remaining != 0) { + throw new IllegalStateException("remaining = " + remaining); + } + return new WrappedLargeByteBuffer(chunks, maxChunk); + } + + public static LargeByteBuffer mapFile( + FileChannel channel, + FileChannel.MapMode mode, + long offset, + long length + ) throws IOException { + int chunksNeeded = (int) ((length - 1) / MAX_CHUNK_SIZE) + 1; + ByteBuffer[] chunks = new ByteBuffer[chunksNeeded]; + long curPos = offset; + long end = offset + length; + for (int i = 0; i < chunksNeeded; i++) { + long nextPos = Math.min(curPos + MAX_CHUNK_SIZE, end); + chunks[i] = channel.map(mode, curPos, nextPos - curPos); + curPos = nextPos; + } + return new WrappedLargeByteBuffer(chunks); + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/WrappedLargeByteBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/WrappedLargeByteBuffer.java new file mode 100644 index 0000000000000..58a621249386f --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/WrappedLargeByteBuffer.java @@ -0,0 +1,292 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.Arrays; +import java.util.List; + +import com.google.common.annotations.VisibleForTesting; +import sun.nio.ch.DirectBuffer; + +/** + * A {@link org.apache.spark.network.buffer.LargeByteBuffer} which may contain multiple + * {@link java.nio.ByteBuffer}s. In order to support asByteBuffer, all + * of the underlying ByteBuffers must have size equal to + * {@link org.apache.spark.network.buffer.LargeByteBufferHelper#MAX_CHUNK_SIZE} (except that last + * one). The underlying ByteBuffers may be on-heap, direct, or memory-mapped. + */ +public class WrappedLargeByteBuffer implements LargeByteBuffer { + + @VisibleForTesting + final ByteBuffer[] underlying; + + private final long size; + /** + * each sub-ByteBuffer (except for the last one) must be exactly this size. Note that this + * class *really* expects this to be LargeByteBufferHelper.MAX_CHUNK_SIZE. The only reason it isn't + * is so that we can do tests without creating ginormous buffers. Public methods force it to + * be LargeByteBufferHelper.MAX_CHUNK_SIZE + */ + private final int subBufferSize; + private long _pos; + @VisibleForTesting + int currentBufferIdx; + @VisibleForTesting + ByteBuffer currentBuffer; + + /** + * Construct a WrappedLargeByteBuffer from the given ByteBuffers. Each of the ByteBuffers must + * have size equal to {@link org.apache.spark.network.buffer.LargeByteBufferHelper#MAX_CHUNK_SIZE} + * except for the final one. The buffers are duplicated, so the position of the + * given buffers and the returned buffer will be independent, though the underlying data will be + * shared. The constructed buffer will always have position == 0. + */ + public WrappedLargeByteBuffer(ByteBuffer[] underlying) { + this(underlying, LargeByteBufferHelper.MAX_CHUNK_SIZE); + } + + /** + * you do **not** want to call this version. It leads to a buffer which doesn't properly + * support {@link #asByteBuffer}. The only reason it exists is to we can have tests which + * don't require 2GB of memory + * + * @param underlying + * @param subBufferSize + */ + @VisibleForTesting + WrappedLargeByteBuffer(ByteBuffer[] underlying, int subBufferSize) { + if (underlying.length == 0) { + throw new IllegalArgumentException("must wrap at least one ByteBuffer"); + } + this.underlying = new ByteBuffer[underlying.length]; + this.subBufferSize = subBufferSize; + long sum = 0L; + + for (int i = 0; i < underlying.length; i++) { + ByteBuffer b = underlying[i].duplicate(); + b.position(0); + this.underlying[i] = b; + if (i != underlying.length - 1 && b.capacity() != subBufferSize) { + // this is to make sure that asByteBuffer() is implemented correctly. We need the first + // subBuffer to be LargeByteBufferHelper.MAX_CHUNK_SIZE. We don't *have* to check all the + // subBuffers, but I figure its makes it more consistent this way. (Also, this check + // really only serves a purpose when using the public constructor -- subBufferSize is a + // a parameter just to allow small tests.) + throw new IllegalArgumentException("All buffers, except for the final one, must have " + + "size = " + subBufferSize); + } + sum += b.capacity(); + } + _pos = 0; + currentBufferIdx = 0; + currentBuffer = this.underlying[0]; + size = sum; + } + + + @Override + public WrappedLargeByteBuffer get(byte[] dest) { + return get(dest, 0, dest.length); + } + + @Override + public WrappedLargeByteBuffer get(byte[] dest, int offset, int length) { + if (length > remaining()) { + throw new BufferUnderflowException(); + } + int moved = 0; + while (moved < length) { + int toRead = Math.min(length - moved, currentBuffer.remaining()); + currentBuffer.get(dest, offset + moved, toRead); + moved += toRead; + updateCurrentBufferIfNeeded(); + } + _pos += moved; + return this; + } + + @Override + public LargeByteBuffer rewind() { + if (currentBuffer != null) { + currentBuffer.rewind(); + } + while (currentBufferIdx > 0) { + currentBufferIdx -= 1; + currentBuffer = underlying[currentBufferIdx]; + currentBuffer.rewind(); + } + _pos = 0; + return this; + } + + @Override + public WrappedLargeByteBuffer deepCopy() { + ByteBuffer[] dataCopy = new ByteBuffer[underlying.length]; + for (int i = 0; i < underlying.length; i++) { + ByteBuffer b = underlying[i]; + dataCopy[i] = ByteBuffer.allocate(b.capacity()); + int originalPosition = b.position(); + b.rewind(); + dataCopy[i].put(b); + dataCopy[i].position(0); + b.position(originalPosition); + } + return new WrappedLargeByteBuffer(dataCopy, subBufferSize); + } + + @Override + public byte get() { + if (remaining() < 1L) { + throw new BufferUnderflowException(); + } + byte r = currentBuffer.get(); + _pos += 1; + updateCurrentBufferIfNeeded(); + return r; + } + + /** + * If we've read to the end of the current buffer, move on to the next one. Safe to call + * even if we haven't moved to the next buffer + */ + private void updateCurrentBufferIfNeeded() { + while (currentBuffer != null && !currentBuffer.hasRemaining()) { + currentBufferIdx += 1; + currentBuffer = currentBufferIdx < underlying.length ? underlying[currentBufferIdx] : null; + } + } + + @Override + public long position() { + return _pos; + } + + @Override + public long skip(long n) { + if (n < 0) { + final long moveTotal = Math.min(-n, _pos); + long toMove = moveTotal; + // move backwards and update the position of every buffer as we go + if (currentBuffer != null) { + currentBufferIdx += 1; + } + while (toMove > 0) { + currentBufferIdx -= 1; + currentBuffer = underlying[currentBufferIdx]; + int thisMove = (int) Math.min(toMove, currentBuffer.position()); + currentBuffer.position(currentBuffer.position() - thisMove); + toMove -= thisMove; + } + _pos -= moveTotal; + return -moveTotal; + } else if (n > 0) { + final long moveTotal = Math.min(n, remaining()); + long toMove = moveTotal; + // move forwards and update the position of every buffer as we go + currentBufferIdx -= 1; + while (toMove > 0) { + currentBufferIdx += 1; + currentBuffer = underlying[currentBufferIdx]; + int thisMove = (int) Math.min(toMove, currentBuffer.remaining()); + currentBuffer.position(currentBuffer.position() + thisMove); + toMove -= thisMove; + } + _pos += moveTotal; + return moveTotal; + } else { + return 0; + } + } + + @Override + public long remaining() { + return size - _pos; + } + + @Override + public WrappedLargeByteBuffer duplicate() { + // the constructor will duplicate the underlying buffers for us + WrappedLargeByteBuffer dup = new WrappedLargeByteBuffer(underlying, subBufferSize); + dup.skip(position()); + return dup; + } + + @Override + public long size() { + return size; + } + + @Override + public long writeTo(WritableByteChannel channel) throws IOException { + long written = 0L; + for (; currentBufferIdx < underlying.length; currentBufferIdx++) { + currentBuffer = underlying[currentBufferIdx]; + written += currentBuffer.remaining(); + while (currentBuffer.hasRemaining()) + channel.write(currentBuffer); + } + _pos = size(); + return written; + } + + @Override + public ByteBuffer asByteBuffer() throws BufferTooLargeException { + if (underlying.length == 1) { + ByteBuffer b = underlying[0].duplicate(); + b.rewind(); + return b; + } else { + // NOTE: if subBufferSize != LargeByteBufferHelper.MAX_CAPACITY, in theory + // we could copy the data into a new buffer. But we don't want to do any copying. + // The only reason we allow smaller subBufferSize is so that we can have tests which + // don't require 2GB of memory + throw new BufferTooLargeException(size(), underlying[0].capacity()); + } + } + + @VisibleForTesting + List nioBuffers() { + return Arrays.asList(underlying); + } + + /** + * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that + * might cause errors if one attempts to read from the unmapped buffer, but it's better than + * waiting for the GC to find it because that could lead to huge numbers of open files. There's + * unfortunately no standard API to do this. + */ + private static void dispose(ByteBuffer buffer) { + if (buffer != null && buffer instanceof DirectBuffer) { + DirectBuffer db = (DirectBuffer) buffer; + if (db.cleaner() != null) { + db.cleaner().clean(); + } + } + } + + @Override + public void dispose() { + for (ByteBuffer bb : underlying) { + dispose(bb); + } + } + +} diff --git a/network/common/src/test/java/org/apache/spark/network/buffer/LargeByteBufferHelperSuite.java b/network/common/src/test/java/org/apache/spark/network/buffer/LargeByteBufferHelperSuite.java new file mode 100644 index 0000000000000..9e636fc032928 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/buffer/LargeByteBufferHelperSuite.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.buffer; + +import java.io.*; +import java.nio.channels.FileChannel; +import java.util.Random; + +import org.junit.Test; + +import static org.junit.Assert.*; + +public class LargeByteBufferHelperSuite { + + @Test + public void testMapFile() throws IOException { + File testFile = File.createTempFile("large-byte-buffer-test", ".bin"); + try { + testFile.deleteOnExit(); + OutputStream out = new FileOutputStream(testFile); + byte[] buffer = new byte[1 << 16]; + Random rng = new XORShiftRandom(0L); + long len = ((long)buffer.length) + Integer.MAX_VALUE + 1; + for (int i = 0; i < len / buffer.length; i++) { + rng.nextBytes(buffer); + out.write(buffer); + } + out.close(); + + FileChannel in = new FileInputStream(testFile).getChannel(); + + //fail quickly on bad bounds + try { + LargeByteBufferHelper.mapFile(in, FileChannel.MapMode.READ_ONLY, 0, len + 1); + fail("expected exception"); + } catch (IOException ioe) { + } + try { + LargeByteBufferHelper.mapFile(in, FileChannel.MapMode.READ_ONLY, -1, 10); + fail("expected exception"); + } catch (IllegalArgumentException iae) { + } + + //now try to read from the buffer + LargeByteBuffer buf = LargeByteBufferHelper.mapFile(in, FileChannel.MapMode.READ_ONLY, 0, len); + assertEquals(len, buf.size()); + byte[] read = new byte[buffer.length]; + byte[] expected = new byte[buffer.length]; + Random rngExpected = new XORShiftRandom(0L); + for (int i = 0; i < len / buffer.length; i++) { + buf.get(read, 0, buffer.length); + // assertArrayEquals() is really slow + rngExpected.nextBytes(expected); + for (int j = 0; j < buffer.length; j++) { + if (read[j] != expected[j]) + fail("bad byte at (i,j) = (" + i + "," + j + ")"); + } + } + } finally { + testFile.delete(); + } + } + + @Test + public void testAllocate() { + WrappedLargeByteBuffer buf = (WrappedLargeByteBuffer) LargeByteBufferHelper.allocate(95,10); + assertEquals(10, buf.underlying.length); + for (int i = 0 ; i < 9; i++) { + assertEquals(10, buf.underlying[i].capacity()); + } + assertEquals(5, buf.underlying[9].capacity()); + } + + + private class XORShiftRandom extends Random { + + XORShiftRandom(long init) { + super(init); + seed = new Random(init).nextLong(); + } + + long seed; + + // we need to just override next - this will be called by nextInt, nextDouble, + // nextGaussian, nextLong, etc. + @Override + protected int next(int bits) { + long nextSeed = seed ^ (seed << 21); + nextSeed ^= (nextSeed >>> 35); + nextSeed ^= (nextSeed << 4); + seed = nextSeed; + return (int) (nextSeed & ((1L << bits) -1)); + } + } + +} diff --git a/network/common/src/test/java/org/apache/spark/network/buffer/WrappedLargeByteBufferSuite.java b/network/common/src/test/java/org/apache/spark/network/buffer/WrappedLargeByteBufferSuite.java new file mode 100644 index 0000000000000..3cbd2d8710304 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/buffer/WrappedLargeByteBufferSuite.java @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.buffer; + +import java.io.*; +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.util.Arrays; +import java.util.Random; + +import org.junit.Test; + +import static org.junit.Assert.*; + +public class WrappedLargeByteBufferSuite { + + private byte[] data = new byte[500]; + { + new Random(1234).nextBytes(data); + } + + private WrappedLargeByteBuffer testDataBuf() { + ByteBuffer[] bufs = new ByteBuffer[10]; + for (int i = 0; i < 10; i++) { + byte[] b = new byte[50]; + System.arraycopy(data, i * 50, b, 0, 50); + bufs[i] = ByteBuffer.wrap(b); + } + return new WrappedLargeByteBuffer(bufs, 50); + } + + @Test + public void asByteBuffer() throws BufferTooLargeException { + // test that it works when buffer is small + LargeByteBuffer buf = LargeByteBufferHelper.asLargeByteBuffer(new byte[100]); + ByteBuffer nioBuf = buf.asByteBuffer(); + assertEquals(0, nioBuf.position()); + assertEquals(100, nioBuf.remaining()); + // if we move the large byte buffer, the nio.ByteBuffer we have doesn't change + buf.skip(10); + assertEquals(0, nioBuf.position()); + assertEquals(100, nioBuf.remaining()); + // if we grab another byte buffer while the large byte buffer's position != 0, + // the returned buffer still has position 0 + ByteBuffer nioBuf2 = buf.asByteBuffer(); + assertEquals(0, nioBuf2.position()); + assertEquals(100, nioBuf2.remaining()); + // the two byte buffers we grabbed are independent + nioBuf2.position(20); + assertEquals(0, nioBuf.position()); + assertEquals(100, nioBuf.remaining()); + assertEquals(20, nioBuf2.position()); + assertEquals(80, nioBuf2.remaining()); + + // the right error when the buffer is too big + try { + WrappedLargeByteBuffer buf2 = new WrappedLargeByteBuffer( + new ByteBuffer[]{ByteBuffer.allocate(10), ByteBuffer.allocate(10)}, 10); + // you really shouldn't ever construct a WrappedLargeByteBuffer with + // multiple small chunks, so this is somewhat contrived + buf2.asByteBuffer(); + fail("expected an exception"); + } catch (BufferTooLargeException btl) { + } + } + + @Test + public void checkSizesOfInternalBuffers() { + errorOnBuffersSized(10, new int[]{9,10}); + errorOnBuffersSized(10, new int[]{10,10,0,10}); + errorOnBuffersSized(20, new int[]{10,10,10,10}); + } + + private void errorOnBuffersSized(int chunkSize, int[] sizes) { + ByteBuffer[] bufs = new ByteBuffer[sizes.length]; + for (int i = 0; i < sizes.length; i++) { + bufs[i] = ByteBuffer.allocate(sizes[i]); + } + try { + new WrappedLargeByteBuffer(bufs, chunkSize); + fail("expected exception"); + } catch (IllegalArgumentException iae) { + } + } + + @Test + public void deepCopy() { + WrappedLargeByteBuffer b = testDataBuf(); + //intentionally move around sporadically + for (int initialPosition: new int[]{10,475, 0, 19, 58, 499, 498, 32, 234, 378}) { + b.rewind(); + b.skip(initialPosition); + WrappedLargeByteBuffer copy = b.deepCopy(); + assertEquals(0, copy.position()); + assertConsistent(copy); + assertConsistent(b); + assertEquals(b.size(), copy.size()); + assertEquals(initialPosition, b.position()); + byte[] copyData = new byte[500]; + copy.get(copyData, 0, 500); + assertArrayEquals(data, copyData); + } + } + + @Test + public void skipAndGet() { + WrappedLargeByteBuffer b = testDataBuf(); + int position = 0; + for (int move: new int[]{20, 50, 100, 0, -80, 0, 200, -175, 500, 0, -1000, 0}) { + long moved = b.skip(move); + assertConsistent(b); + long expMoved = move > 0 ? Math.min(move, 500 - position) : Math.max(move, -position); + position += moved; + assertEquals(expMoved, moved); + assertEquals(position, b.position()); + byte[] copyData = new byte[500 - position]; + b.get(copyData, 0, 500 - position); + assertConsistent(b); + byte[] dataSubset = new byte[500 - position]; + System.arraycopy(data, position, dataSubset, 0, 500 - position); + assertArrayEquals(dataSubset, copyData); + b.rewind(); + assertConsistent(b); + b.skip(position); + assertConsistent(b); + + int copy2Length = Math.min(20, 500 - position); + byte[] copy2 = new byte[copy2Length]; + b.rewind(); + b.skip(position); + b.get(copy2); + assertSubArrayEquals(data, position, copy2, 0, copy2Length); + + b.rewind(); + b.skip(position); + } + } + + @Test + public void get() { + WrappedLargeByteBuffer b = testDataBuf(); + byte[] into = new byte[500]; + for (int[] offsetAndLength: new int[][]{{0, 200}, {10,10}, {300, 20}, {30, 100}}) { + int offset = offsetAndLength[0]; + int length = offsetAndLength[1]; + b.rewind(); + b.get(into, offset, length); + assertConsistent(b); + assertSubArrayEquals(data, 0, into, offset, length); + + byte[] into2 = new byte[length]; + b.rewind(); + b.get(into2); + assertConsistent(b); + assertSubArrayEquals(data, 0, into2, 0, length); + } + + try { + b.rewind(); + b.skip(400); + b.get(into, 0, 500); + fail("expected exception"); + } catch (BufferUnderflowException bue) { + } + + try { + b.rewind(); + b.skip(1); + b.get(into); + fail("expected exception"); + } catch (BufferUnderflowException bue) { + } + + b.rewind(); + b.skip(495); + assertEquals(data[495], b.get()); + assertEquals(data[496], b.get()); + assertEquals(data[497], b.get()); + assertEquals(data[498], b.get()); + assertEquals(data[499], b.get()); + try { + b.get(); + fail("expected exception"); + } catch (BufferUnderflowException bue) { + } + } + + @Test + public void writeTo() throws IOException { + for (int initialPosition: new int[]{0,20, 400}) { + File testFile = File.createTempFile("WrappedLargeByteBuffer-writeTo-" + initialPosition,".bin"); + testFile.deleteOnExit(); + FileChannel channel = new FileOutputStream(testFile).getChannel(); + WrappedLargeByteBuffer buf = testDataBuf(); + buf.skip(initialPosition); + assertEquals(initialPosition, buf.position()); + int expN = 500 - initialPosition; + long bytesWritten = buf.writeTo(channel); + assertEquals(expN, bytesWritten); + channel.close(); + + byte[] fileBytes = new byte[expN]; + FileInputStream in = new FileInputStream(testFile); + int n = 0; + while (n < expN) { + n += in.read(fileBytes, n, expN - n); + } + assertEquals(-1, in.read()); + byte[] dataSlice = Arrays.copyOfRange(data, initialPosition, 500); + assertArrayEquals(dataSlice, fileBytes); + assertEquals(0, buf.remaining()); + assertEquals(500, buf.position()); + } + } + + @Test + public void duplicate() { + for (int initialPosition: new int[]{0,20, 400}) { + WrappedLargeByteBuffer buf = testDataBuf(); + buf.skip(initialPosition); + + WrappedLargeByteBuffer dup = buf.duplicate(); + assertEquals(initialPosition, buf.position()); + assertEquals(initialPosition, dup.position()); + assertEquals(500, buf.size()); + assertEquals(500, dup.size()); + assertEquals(500 - initialPosition, buf.remaining()); + assertEquals(500 - initialPosition, dup.remaining()); + assertConsistent(buf); + assertConsistent(dup); + + // check positions of both buffers are independent + buf.skip(20); + assertEquals(initialPosition + 20, buf.position()); + assertEquals(initialPosition, dup.position()); + assertConsistent(buf); + assertConsistent(dup); + } + } + + @Test(expected=IllegalArgumentException.class) + public void testRequireAtLeastOneBuffer() { + new WrappedLargeByteBuffer( new ByteBuffer[0]); + } + + @Test + public void positionIndependentOfInitialBuffers() { + ByteBuffer[] byteBufs = testDataBuf().underlying; + byteBufs[0].position(50); + for (int initialPosition: new int[]{0,20, 400}) { + WrappedLargeByteBuffer buf = new WrappedLargeByteBuffer(byteBufs, 50); + assertEquals(0L, buf.position()); + assertEquals(50, byteBufs[0].position()); + buf.skip(initialPosition); + assertEquals(initialPosition, buf.position()); + assertEquals(50, byteBufs[0].position()); + } + } + + private void assertConsistent(WrappedLargeByteBuffer buffer) { + long pos = buffer.position(); + long bufferStartPos = 0; + if (buffer.currentBufferIdx < buffer.underlying.length) { + assertEquals(buffer.currentBuffer, buffer.underlying[buffer.currentBufferIdx]); + } else { + assertNull(buffer.currentBuffer); + } + for (ByteBuffer p: buffer.nioBuffers()) { + if (pos < bufferStartPos) { + assertEquals(0, p.position()); + } else if (pos < bufferStartPos + p.capacity()) { + assertEquals(pos - bufferStartPos, p.position()); + } else { + assertEquals(p.capacity(), p.position()); + } + bufferStartPos += p.capacity(); + } + } + + private void assertSubArrayEquals( + byte[] exp, + int expOffset, + byte[] act, + int actOffset, + int length) { + byte[] expCopy = new byte[length]; + byte[] actCopy = new byte[length]; + System.arraycopy(exp, expOffset, expCopy, 0, length); + System.arraycopy(act, actOffset, actCopy, 0, length); + assertArrayEquals(expCopy, actCopy); + } + +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index 620b8a36a2baf..80331f3378230 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -24,6 +24,7 @@ import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark._ +import org.apache.spark.network.buffer.LargeByteBufferHelper import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.streaming.util._ @@ -156,11 +157,13 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( logInfo(s"Read partition data of $this from write ahead log, record handle " + partition.walRecordHandle) if (storeInBlockManager) { - blockManager.putBytes(blockId, dataRead, storageLevel) + blockManager.putBytes(blockId, LargeByteBufferHelper.asLargeByteBuffer(dataRead), + storageLevel) logDebug(s"Stored partition data of $this into block manager with level $storageLevel") dataRead.rewind() } - blockManager.dataDeserialize(blockId, dataRead).asInstanceOf[Iterator[T]] + blockManager.dataDeserialize(blockId, LargeByteBufferHelper.asLargeByteBuffer(dataRead)) + .asInstanceOf[Iterator[T]] } if (partition.isBlockIdValid) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index c8dd6e06812dc..791d97ba598e0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -17,13 +17,15 @@ package org.apache.spark.streaming.receiver -import scala.concurrent.duration._ + import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ import scala.language.{existentials, postfixOps} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.network.buffer.LargeByteBufferHelper import org.apache.spark.storage._ import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} @@ -83,7 +85,8 @@ private[streaming] class BlockManagerBasedBlockHandler( numRecords = countIterator.count putResult case ByteBufferBlock(byteBuffer) => - blockManager.putBytes(blockId, byteBuffer, storageLevel, tellMaster = true) + blockManager.putBytes(blockId, LargeByteBufferHelper.asLargeByteBuffer(byteBuffer), + storageLevel, tellMaster = true) case o => throw new SparkException( s"Could not store $blockId to block manager, unexpected block type ${o.getClass.getName}") @@ -177,7 +180,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( numRecords = countIterator.count serializedBlock case ByteBufferBlock(byteBuffer) => - byteBuffer + LargeByteBufferHelper.asLargeByteBuffer(byteBuffer) case _ => throw new Exception(s"Could not push $blockId to block manager, unexpected block type") } @@ -194,7 +197,7 @@ private[streaming] class WriteAheadLogBasedBlockHandler( // Store the block in write ahead log val storeInWriteAheadLogFuture = Future { - writeAheadLog.write(serializedBlock, clock.getTimeMillis()) + writeAheadLog.write(serializedBlock.asByteBuffer(), clock.getTimeMillis()) } // Combine the futures, wait for both to complete, and return the write ahead log record handle diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 6c0c926755c20..1772cbd3a8de9 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.streaming import java.io.File import java.nio.ByteBuffer +import org.apache.spark.network.buffer.LargeByteBufferHelper + import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps @@ -142,7 +144,7 @@ class ReceivedBlockHandlerSuite val loggedData = walSegments.flatMap { walSegment => val fileSegment = walSegment.asInstanceOf[FileBasedWriteAheadLogSegment] val reader = new FileBasedWriteAheadLogRandomReader(fileSegment.path, hadoopConf) - val bytes = reader.read(fileSegment) + val bytes = LargeByteBufferHelper.asLargeByteBuffer(reader.read(fileSegment)) reader.close() blockManager.dataDeserialize(generateBlockId(), bytes).toList } @@ -326,7 +328,7 @@ class ReceivedBlockHandlerSuite storeAndVerify(blocks.map { b => IteratorBlock(b.toIterator) }) storeAndVerify(blocks.map { b => ArrayBufferBlock(new ArrayBuffer ++= b) }) - storeAndVerify(blocks.map { b => ByteBufferBlock(dataToByteBuffer(b)) }) + storeAndVerify(blocks.map { b => ByteBufferBlock(dataToByteBuffer(b).asByteBuffer()) }) } /** Test error handling when blocks that cannot be stored */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index cb017b798b2a4..11867a06a0042 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -213,7 +213,7 @@ class WriteAheadLogBackedBlockRDDSuite require(blockData.size === blockIds.size) val writer = new FileBasedWriteAheadLogWriter(new File(dir, "logFile").toString, hadoopConf) val segments = blockData.zip(blockIds).map { case (data, id) => - writer.write(blockManager.dataSerialize(id, data.iterator)) + writer.write(blockManager.dataSerialize(id, data.iterator).asByteBuffer()) } writer.close() segments