diff --git a/core/src/main/java/org/apache/spark/network/buffer/LargeByteBufferInputStream.java b/core/src/main/java/org/apache/spark/network/buffer/LargeByteBufferInputStream.java
new file mode 100644
index 0000000000000..a4b1e2571af56
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/network/buffer/LargeByteBufferInputStream.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.buffer;
+
+import java.io.InputStream;
+
+import com.google.common.annotations.VisibleForTesting;
+
+/**
+ * Reads data from a LargeByteBuffer, and optionally cleans it up using buffer.dispose()
+ * when the stream is closed (e.g. to close a memory-mapped file).
+ */
+public class LargeByteBufferInputStream extends InputStream {
+
+ private LargeByteBuffer buffer;
+ private final boolean dispose;
+
+ public LargeByteBufferInputStream(LargeByteBuffer buffer, boolean dispose) {
+ this.buffer = buffer;
+ this.dispose = dispose;
+ }
+
+ public LargeByteBufferInputStream(LargeByteBuffer buffer) {
+ this(buffer, false);
+ }
+
+ @Override
+ public int read() {
+ if (buffer == null || buffer.remaining() == 0) {
+ return -1;
+ } else {
+ return buffer.get() & 0xFF;
+ }
+ }
+
+ @Override
+ public int read(byte[] dest) {
+ return read(dest, 0, dest.length);
+ }
+
+ @Override
+ public int read(byte[] dest, int offset, int length) {
+ if (buffer == null || buffer.remaining() == 0) {
+ return -1;
+ } else {
+ int amountToGet = (int) Math.min(buffer.remaining(), length);
+ buffer.get(dest, offset, amountToGet);
+ return amountToGet;
+ }
+ }
+
+ @Override
+ public long skip(long toSkip) {
+ if (buffer != null) {
+ return buffer.skip(toSkip);
+ } else {
+ return 0L;
+ }
+ }
+
+ // only for testing
+ @VisibleForTesting
+ boolean disposed = false;
+
+ /**
+ * Clean up the buffer, and potentially dispose of it
+ */
+ @Override
+ public void close() {
+ if (buffer != null) {
+ if (dispose) {
+ buffer.dispose();
+ disposed = true;
+ }
+ buffer = null;
+ }
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/network/buffer/LargeByteBufferOutputStream.java b/core/src/main/java/org/apache/spark/network/buffer/LargeByteBufferOutputStream.java
new file mode 100644
index 0000000000000..975de7b10f65c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/network/buffer/LargeByteBufferOutputStream.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.apache.spark.util.io.ByteArrayChunkOutputStream;
+
+/**
+ * An OutputStream that will write all data to memory. It supports writing over 2GB
+ * and the resulting data can be retrieved as a
+ * {@link org.apache.spark.network.buffer.LargeByteBuffer}
+ */
+public class LargeByteBufferOutputStream extends OutputStream {
+
+ private final ByteArrayChunkOutputStream output;
+
+ /**
+ * Create a new LargeByteBufferOutputStream which writes to byte arrays of the given size. Note
+ * that chunkSize has no effect on the LargeByteBuffer returned by
+ * {@link #largeBuffer()}.
+ *
+ * @param chunkSize size of the byte arrays used by this output stream, in bytes
+ */
+ public LargeByteBufferOutputStream(int chunkSize) {
+ output = new ByteArrayChunkOutputStream(chunkSize);
+ }
+
+ @Override
+ public void write(int b) {
+ output.write(b);
+ }
+
+ @Override
+ public void write(byte[] bytes, int off, int len) {
+ output.write(bytes, off, len);
+ }
+
+ /**
+ * Get all of the data written to the stream so far as a LargeByteBuffer. This method can be
+ * called multiple times, and each returned buffer will be completely independent (the data
+ * is copied for each returned buffer). It does not close the stream.
+ *
+ * @return the data written to the stream as a LargeByteBuffer
+ */
+ public LargeByteBuffer largeBuffer() {
+ return largeBuffer(LargeByteBufferHelper.MAX_CHUNK_SIZE);
+ }
+
+ /**
+ * exposed for testing. You don't really ever want to call this method -- the returned
+ * buffer will not implement {{asByteBuffer}} correctly.
+ */
+ @VisibleForTesting
+ LargeByteBuffer largeBuffer(int maxChunk) {
+ long totalSize = output.size();
+ int chunksNeeded = (int) ((totalSize + maxChunk - 1) / maxChunk);
+ ByteBuffer[] chunks = new ByteBuffer[chunksNeeded];
+ long remaining = totalSize;
+ long pos = 0;
+ for (int idx = 0; idx < chunksNeeded; idx++) {
+ int nextSize = (int) Math.min(maxChunk, remaining);
+ chunks[idx] = ByteBuffer.wrap(output.slice(pos, pos + nextSize));
+ pos += nextSize;
+ remaining -= nextSize;
+ }
+ return new WrappedLargeByteBuffer(chunks, maxChunk);
+ }
+
+ @Override
+ public void close() throws IOException {
+ output.close();
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index a0c9b5e63c744..8a516709dd2b8 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -18,7 +18,6 @@
package org.apache.spark.broadcast
import java.io._
-import java.nio.ByteBuffer
import scala.collection.JavaConversions.asJavaEnumeration
import scala.reflect.ClassTag
@@ -26,9 +25,10 @@ import scala.util.Random
import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
+import org.apache.spark.network.buffer.{LargeByteBufferInputStream, LargeByteBufferHelper, LargeByteBuffer}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
-import org.apache.spark.util.{ByteBufferInputStream, Utils}
+import org.apache.spark.util.Utils
import org.apache.spark.util.io.ByteArrayChunkOutputStream
/**
@@ -111,10 +111,10 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
}
/** Fetch torrent blocks from the driver and/or other executors. */
- private def readBlocks(): Array[ByteBuffer] = {
+ private def readBlocks(): Array[LargeByteBuffer] = {
// Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
// to the driver, so other executors can pull these chunks from this executor as well.
- val blocks = new Array[ByteBuffer](numBlocks)
+ val blocks = new Array[LargeByteBuffer](numBlocks)
val bm = SparkEnv.get.blockManager
for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
@@ -123,8 +123,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
// First try getLocalBytes because there is a chance that previous attempts to fetch the
// broadcast blocks have already fetched some of the blocks. In that case, some blocks
// would be available locally (on this executor).
- def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId)
- def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block =>
+ def getLocal: Option[LargeByteBuffer] = bm.getLocalBytes(pieceId)
+ def getRemote: Option[LargeByteBuffer] = bm.getRemoteBytes(pieceId).map { block =>
// If we found the block from remote executors/driver's BlockManager, put the block
// in this executor's BlockManager.
SparkEnv.get.blockManager.putBytes(
@@ -134,7 +134,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
tellMaster = true)
block
}
- val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse(
+ val block: LargeByteBuffer = getLocal.orElse(getRemote).getOrElse(
throw new SparkException(s"Failed to get $pieceId of $broadcastId"))
blocks(pid) = block
}
@@ -195,22 +195,22 @@ private object TorrentBroadcast extends Logging {
obj: T,
blockSize: Int,
serializer: Serializer,
- compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = {
+ compressionCodec: Option[CompressionCodec]): Array[LargeByteBuffer] = {
val bos = new ByteArrayChunkOutputStream(blockSize)
val out: OutputStream = compressionCodec.map(c => c.compressedOutputStream(bos)).getOrElse(bos)
val ser = serializer.newInstance()
val serOut = ser.serializeStream(out)
serOut.writeObject[T](obj).close()
- bos.toArrays.map(ByteBuffer.wrap)
+ bos.toArrays.map(LargeByteBufferHelper.asLargeByteBuffer)
}
def unBlockifyObject[T: ClassTag](
- blocks: Array[ByteBuffer],
+ blocks: Array[LargeByteBuffer],
serializer: Serializer,
compressionCodec: Option[CompressionCodec]): T = {
require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks")
val is = new SequenceInputStream(
- asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block))))
+ asJavaEnumeration(blocks.iterator.map(block => new LargeByteBufferInputStream(block))))
val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
val ser = serializer.newInstance()
val serIn = ser.deserializeStream(in)
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 42a85e42ea2b6..17da70500f9df 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -23,6 +23,8 @@ import java.net.URL
import java.nio.ByteBuffer
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+import org.apache.spark.network.buffer.LargeByteBufferHelper
+
import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal
@@ -266,7 +268,8 @@ private[spark] class Executor(
} else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
- blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
+ blockId, LargeByteBufferHelper.asLargeByteBuffer(serializedDirectResult),
+ StorageLevel.MEMORY_AND_DISK_SER)
logInfo(
s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
index dcbda5a8515dd..14462df32410f 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
@@ -24,9 +24,9 @@ import scala.concurrent.{Promise, Await, Future}
import scala.concurrent.duration.Duration
import org.apache.spark.Logging
-import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.buffer.{LargeByteBufferHelper, NioManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle.{ShuffleClient, BlockFetchingListener}
-import org.apache.spark.storage.{BlockManagerId, BlockId, StorageLevel}
+import org.apache.spark.storage.{BlockId, StorageLevel}
private[spark]
abstract class BlockTransferService extends ShuffleClient with Closeable with Logging {
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
index b089da8596e2b..c099f7ec7f79a 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -23,12 +23,12 @@ import scala.collection.JavaConversions._
import org.apache.spark.Logging
import org.apache.spark.network.BlockDataManager
-import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.buffer.{BufferTooLargeException, ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager}
import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock}
import org.apache.spark.serializer.Serializer
-import org.apache.spark.storage.{BlockId, StorageLevel}
+import org.apache.spark.storage.{ShuffleRemoteBlockSizeLimitException, BlockId, StorageLevel}
/**
* Serves requests to open blocks by simply registering one chunk per block requested.
@@ -53,11 +53,24 @@ class NettyBlockRpcServer(
message match {
case openBlocks: OpenBlocks =>
- val blocks: Seq[ManagedBuffer] =
- openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
- val streamId = streamManager.registerStream(blocks.iterator)
- logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
- responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)
+ try {
+ val blocks: Seq[ManagedBuffer] =
+ openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
+ val streamId = streamManager.registerStream(blocks.iterator)
+ logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
+ responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)
+ } catch {
+ // shouldn't ever happen, b/c we should prevent writing 2GB shuffle files,
+ // but just to be safe
+ case ex: BufferTooLargeException =>
+ // throw & catch this helper exception, just to get full stack trace
+ try {
+ throw new ShuffleRemoteBlockSizeLimitException(ex)
+ } catch {
+ case ex2: ShuffleRemoteBlockSizeLimitException =>
+ responseContext.onFailure(ex2)
+ }
+ }
case uploadBlock: UploadBlock =>
// StorageLevel is serialized as bytes using our JavaSerializer.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 46a6f6537e2ee..2c0b3f085ca76 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -77,7 +77,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
return
}
val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
- serializedTaskResult.get)
+ serializedTaskResult.get.asByteBuffer())
sparkEnv.blockManager.master.removeBlock(blockId)
(deserializedResult, size)
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala
index 4342b0d598b16..81aea33ee41b4 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala
@@ -18,6 +18,7 @@
package org.apache.spark.shuffle
import java.nio.ByteBuffer
+
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.storage.ShuffleBlockId
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index eedb27942e841..dd2a198714bac 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -20,6 +20,7 @@ package org.apache.spark.storage
import java.io._
import java.nio.{ByteBuffer, MappedByteBuffer}
+import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.{ExecutionContext, Await, Future}
import scala.concurrent.duration._
@@ -31,7 +32,7 @@ import org.apache.spark._
import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
-import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.buffer._
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExternalShuffleClient
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
@@ -42,7 +43,7 @@ import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.util._
private[spark] sealed trait BlockValues
-private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues
+private[spark] case class ByteBufferValues(buffer: LargeByteBuffer) extends BlockValues
private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues
private[spark] case class ArrayValues(buffer: Array[Any]) extends BlockValues
@@ -300,10 +301,10 @@ private[spark] class BlockManager(
shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
} else {
val blockBytesOpt = doGetLocal(blockId, asBlockResult = false)
- .asInstanceOf[Option[ByteBuffer]]
+ .asInstanceOf[Option[LargeByteBuffer]]
if (blockBytesOpt.isDefined) {
val buffer = blockBytesOpt.get
- new NioManagedBuffer(buffer)
+ new NioManagedBuffer(buffer.asByteBuffer())
} else {
throw new BlockNotFoundException(blockId.toString)
}
@@ -314,7 +315,7 @@ private[spark] class BlockManager(
* Put the block locally, using the given storage level.
*/
override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit = {
- putBytes(blockId, data.nioByteBuffer(), level)
+ putBytes(blockId, LargeByteBufferHelper.asLargeByteBuffer(data.nioByteBuffer()), level)
}
/**
@@ -432,7 +433,7 @@ private[spark] class BlockManager(
/**
* Get block from the local block manager as serialized bytes.
*/
- def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = {
+ def getLocalBytes(blockId: BlockId): Option[LargeByteBuffer] = {
logDebug(s"Getting local block $blockId as bytes")
// As an optimization for map output fetches, if the block is for a shuffle, return it
// without acquiring a lock; the disk store never deletes (recent) items so this should work
@@ -440,10 +441,10 @@ private[spark] class BlockManager(
val shuffleBlockResolver = shuffleManager.shuffleBlockResolver
// TODO: This should gracefully handle case where local block is not available. Currently
// downstream code will throw an exception.
- Option(
- shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())
+ Option(LargeByteBufferHelper.asLargeByteBuffer(
+ shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()))
} else {
- doGetLocal(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]]
+ doGetLocal(blockId, asBlockResult = false).asInstanceOf[Option[LargeByteBuffer]]
}
}
@@ -509,13 +510,13 @@ private[spark] class BlockManager(
// Look for block on disk, potentially storing it back in memory if required
if (level.useDisk) {
logDebug(s"Getting block $blockId from disk")
- val bytes: ByteBuffer = diskStore.getBytes(blockId) match {
+ val bytes: LargeByteBuffer = diskStore.getBytes(blockId) match {
case Some(b) => b
case None =>
throw new BlockException(
blockId, s"Block $blockId not found on disk, though it should be")
}
- assert(0 == bytes.position())
+ assert(0L == bytes.position())
if (!level.useMemory) {
// If the block shouldn't be stored in memory, we can just return it
@@ -531,13 +532,12 @@ private[spark] class BlockManager(
/* We'll store the bytes in memory if the block's storage level includes
* "memory serialized", or if it should be cached as objects in memory
* but we only requested its serialized bytes. */
- memoryStore.putBytes(blockId, bytes.limit, () => {
+ memoryStore.putBytes(blockId, bytes.size(), () => {
// https://issues.apache.org/jira/browse/SPARK-6076
// If the file size is bigger than the free memory, OOM will happen. So if we cannot
// put it into MemoryStore, copyForMemory should not be created. That's why this
// action is put into a `() => ByteBuffer` and created lazily.
- val copyForMemory = ByteBuffer.allocate(bytes.limit)
- copyForMemory.put(bytes)
+ bytes.deepCopy()
})
bytes.rewind()
}
@@ -582,9 +582,9 @@ private[spark] class BlockManager(
/**
* Get block from remote block managers as serialized bytes.
*/
- def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = {
+ def getRemoteBytes(blockId: BlockId): Option[LargeByteBuffer] = {
logDebug(s"Getting remote block $blockId as bytes")
- doGetRemote(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]]
+ doGetRemote(blockId, asBlockResult = false).asInstanceOf[Option[LargeByteBuffer]]
}
private def doGetRemote(blockId: BlockId, asBlockResult: Boolean): Option[Any] = {
@@ -592,17 +592,18 @@ private[spark] class BlockManager(
val locations = Random.shuffle(master.getLocations(blockId))
for (loc <- locations) {
logDebug(s"Getting remote block $blockId from $loc")
- val data = blockTransferService.fetchBlockSync(
+ // the fetch will always be one byte buffer till we fix SPARK-5928
+ val data: ByteBuffer = blockTransferService.fetchBlockSync(
loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer()
if (data != null) {
if (asBlockResult) {
return Some(new BlockResult(
- dataDeserialize(blockId, data),
+ dataDeserialize(blockId, LargeByteBufferHelper.asLargeByteBuffer(data)),
DataReadMethod.Network,
data.limit()))
} else {
- return Some(data)
+ return Some(LargeByteBufferHelper.asLargeByteBuffer(data))
}
}
logDebug(s"The value of block $blockId is null")
@@ -675,7 +676,7 @@ private[spark] class BlockManager(
*/
def putBytes(
blockId: BlockId,
- bytes: ByteBuffer,
+ bytes: LargeByteBuffer,
level: StorageLevel,
tellMaster: Boolean = true,
effectiveStorageLevel: Option[StorageLevel] = None): Seq[(BlockId, BlockStatus)] = {
@@ -737,7 +738,7 @@ private[spark] class BlockManager(
var valuesAfterPut: Iterator[Any] = null
// Ditto for the bytes after the put
- var bytesAfterPut: ByteBuffer = null
+ var bytesAfterPut: LargeByteBuffer = null
// Size of the block in bytes
var size = 0L
@@ -745,119 +746,135 @@ private[spark] class BlockManager(
// The level we actually use to put the block
val putLevel = effectiveStorageLevel.getOrElse(level)
- // If we're storing bytes, then initiate the replication before storing them locally.
- // This is faster as data is already serialized and ready to send.
- val replicationFuture = data match {
- case b: ByteBufferValues if putLevel.replication > 1 =>
- // Duplicate doesn't copy the bytes, but just creates a wrapper
- val bufferView = b.buffer.duplicate()
- Future {
- // This is a blocking action and should run in futureExecutionContext which is a cached
- // thread pool
- replicate(blockId, bufferView, putLevel)
- }(futureExecutionContext)
- case _ => null
- }
-
- putBlockInfo.synchronized {
- logTrace("Put for block %s took %s to get into synchronized block"
- .format(blockId, Utils.getUsedTimeMs(startTimeMs)))
+ try {
+ // If we're storing bytes, then initiate the replication before storing them locally.
+ // This is faster as data is already serialized and ready to send.
+ val replicationFuture = data match {
+ case b: ByteBufferValues if putLevel.replication > 1 =>
+ // Duplicate doesn't copy the bytes, but just creates a wrapper
+ val bufferView = try {
+ b.buffer.asByteBuffer()
+ } catch {
+ case ex: BufferTooLargeException =>
+ throw new ReplicationBlockSizeLimitException(ex)
+ }
+ Future {
+ // This is a blocking action and should run in futureExecutionContext which is a cached
+ // thread pool
+ replicate(blockId, bufferView, putLevel)
+ }(futureExecutionContext)
+ case _ => null
+ }
- var marked = false
- try {
- // returnValues - Whether to return the values put
- // blockStore - The type of storage to put these values into
- val (returnValues, blockStore: BlockStore) = {
- if (putLevel.useMemory) {
- // Put it in memory first, even if it also has useDisk set to true;
- // We will drop it to disk later if the memory store can't hold it.
- (true, memoryStore)
- } else if (putLevel.useOffHeap) {
- // Use external block store
- (false, externalBlockStore)
- } else if (putLevel.useDisk) {
- // Don't get back the bytes from put unless we replicate them
- (putLevel.replication > 1, diskStore)
- } else {
- assert(putLevel == StorageLevel.NONE)
- throw new BlockException(
- blockId, s"Attempted to put block $blockId without specifying storage level!")
+ putBlockInfo.synchronized {
+ logTrace("Put for block %s took %s to get into synchronized block"
+ .format(blockId, Utils.getUsedTimeMs(startTimeMs)))
+
+ var marked = false
+ try {
+ // returnValues - Whether to return the values put
+ // blockStore - The type of storage to put these values into
+ val (returnValues, blockStore: BlockStore) = {
+ if (putLevel.useMemory) {
+ // Put it in memory first, even if it also has useDisk set to true;
+ // We will drop it to disk later if the memory store can't hold it.
+ (true, memoryStore)
+ } else if (putLevel.useOffHeap) {
+ // Use external block store
+ (false, externalBlockStore)
+ } else if (putLevel.useDisk) {
+ // Don't get back the bytes from put unless we replicate them
+ (putLevel.replication > 1, diskStore)
+ } else {
+ assert(putLevel == StorageLevel.NONE)
+ throw new BlockException(
+ blockId, s"Attempted to put block $blockId without specifying storage level!")
+ }
}
- }
- // Actually put the values
- val result = data match {
- case IteratorValues(iterator) =>
- blockStore.putIterator(blockId, iterator, putLevel, returnValues)
- case ArrayValues(array) =>
- blockStore.putArray(blockId, array, putLevel, returnValues)
- case ByteBufferValues(bytes) =>
- bytes.rewind()
- blockStore.putBytes(blockId, bytes, putLevel)
- }
- size = result.size
- result.data match {
- case Left (newIterator) if putLevel.useMemory => valuesAfterPut = newIterator
- case Right (newBytes) => bytesAfterPut = newBytes
- case _ =>
- }
+ // Actually put the values
+ val result = data match {
+ case IteratorValues(iterator) =>
+ blockStore.putIterator(blockId, iterator, putLevel, returnValues)
+ case ArrayValues(array) =>
+ blockStore.putArray(blockId, array, putLevel, returnValues)
+ case ByteBufferValues(bytes) =>
+ bytes.rewind()
+ blockStore.putBytes(blockId, bytes, putLevel)
+ }
+ size = result.size
+ result.data match {
+ case Left(newIterator) if putLevel.useMemory => valuesAfterPut = newIterator
+ case Right(newBytes) => bytesAfterPut = newBytes
+ case _ =>
+ }
- // Keep track of which blocks are dropped from memory
- if (putLevel.useMemory) {
- result.droppedBlocks.foreach { updatedBlocks += _ }
- }
+ // Keep track of which blocks are dropped from memory
+ if (putLevel.useMemory) {
+ result.droppedBlocks.foreach {
+ updatedBlocks += _
+ }
+ }
- val putBlockStatus = getCurrentBlockStatus(blockId, putBlockInfo)
- if (putBlockStatus.storageLevel != StorageLevel.NONE) {
- // Now that the block is in either the memory, externalBlockStore, or disk store,
- // let other threads read it, and tell the master about it.
- marked = true
- putBlockInfo.markReady(size)
- if (tellMaster) {
- reportBlockStatus(blockId, putBlockInfo, putBlockStatus)
+ val putBlockStatus = getCurrentBlockStatus(blockId, putBlockInfo)
+ if (putBlockStatus.storageLevel != StorageLevel.NONE) {
+ // Now that the block is in either the memory, externalBlockStore, or disk store,
+ // let other threads read it, and tell the master about it.
+ marked = true
+ putBlockInfo.markReady(size)
+ if (tellMaster) {
+ reportBlockStatus(blockId, putBlockInfo, putBlockStatus)
+ }
+ updatedBlocks += ((blockId, putBlockStatus))
+ }
+ } finally {
+ // If we failed in putting the block to memory/disk, notify other possible readers
+ // that it has failed, and then remove it from the block info map.
+ if (!marked) {
+ // Note that the remove must happen before markFailure otherwise another thread
+ // could've inserted a new BlockInfo before we remove it.
+ blockInfo.remove(blockId)
+ putBlockInfo.markFailure()
+ logWarning(s"Putting block $blockId failed")
}
- updatedBlocks += ((blockId, putBlockStatus))
- }
- } finally {
- // If we failed in putting the block to memory/disk, notify other possible readers
- // that it has failed, and then remove it from the block info map.
- if (!marked) {
- // Note that the remove must happen before markFailure otherwise another thread
- // could've inserted a new BlockInfo before we remove it.
- blockInfo.remove(blockId)
- putBlockInfo.markFailure()
- logWarning(s"Putting block $blockId failed")
}
}
- }
- logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs)))
+ logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs)))
- // Either we're storing bytes and we asynchronously started replication, or we're storing
- // values and need to serialize and replicate them now:
- if (putLevel.replication > 1) {
- data match {
- case ByteBufferValues(bytes) =>
- if (replicationFuture != null) {
- Await.ready(replicationFuture, Duration.Inf)
- }
- case _ =>
- val remoteStartTime = System.currentTimeMillis
- // Serialize the block if not already done
- if (bytesAfterPut == null) {
- if (valuesAfterPut == null) {
- throw new SparkException(
- "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
+ // Either we're storing bytes and we asynchronously started replication, or we're storing
+ // values and need to serialize and replicate them now:
+ if (putLevel.replication > 1) {
+ data match {
+ case ByteBufferValues(bytes) =>
+ if (replicationFuture != null) {
+ Await.ready(replicationFuture, Duration.Inf)
}
- bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
- }
- replicate(blockId, bytesAfterPut, putLevel)
- logDebug("Put block %s remotely took %s"
- .format(blockId, Utils.getUsedTimeMs(remoteStartTime)))
+ case _ =>
+ val remoteStartTime = System.currentTimeMillis
+ // Serialize the block if not already done
+ if (bytesAfterPut == null) {
+ if (valuesAfterPut == null) {
+ throw new SparkException(
+ "Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
+ }
+ bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
+ }
+ try {
+ replicate(blockId, bytesAfterPut.asByteBuffer(), putLevel)
+ } catch {
+ case ex: BufferTooLargeException =>
+ throw new ReplicationBlockSizeLimitException(ex)
+ }
+ logDebug("Put block %s remotely took %s"
+ .format(blockId, Utils.getUsedTimeMs(remoteStartTime)))
+ }
+ }
+ } finally {
+ if (bytesAfterPut != null) {
+ bytesAfterPut.dispose()
}
}
- BlockManager.dispose(bytesAfterPut)
-
if (putLevel.replication > 1) {
logDebug("Putting block %s with replication took %s"
.format(blockId, Utils.getUsedTimeMs(startTimeMs)))
@@ -998,7 +1015,7 @@ private[spark] class BlockManager(
def dropFromMemory(
blockId: BlockId,
- data: Either[Array[Any], ByteBuffer]): Option[BlockStatus] = {
+ data: Either[Array[Any], LargeByteBuffer]): Option[BlockStatus] = {
dropFromMemory(blockId, () => data)
}
@@ -1012,7 +1029,7 @@ private[spark] class BlockManager(
*/
def dropFromMemory(
blockId: BlockId,
- data: () => Either[Array[Any], ByteBuffer]): Option[BlockStatus] = {
+ data: () => Either[Array[Any], LargeByteBuffer]): Option[BlockStatus] = {
logInfo(s"Dropping block $blockId from memory")
val info = blockInfo.get(blockId).orNull
@@ -1194,10 +1211,10 @@ private[spark] class BlockManager(
def dataSerialize(
blockId: BlockId,
values: Iterator[Any],
- serializer: Serializer = defaultSerializer): ByteBuffer = {
- val byteStream = new ByteArrayOutputStream(4096)
+ serializer: Serializer = defaultSerializer): LargeByteBuffer = {
+ val byteStream = new LargeByteBufferOutputStream(65536)
dataSerializeStream(blockId, byteStream, values, serializer)
- ByteBuffer.wrap(byteStream.toByteArray)
+ byteStream.largeBuffer
}
/**
@@ -1206,10 +1223,10 @@ private[spark] class BlockManager(
*/
def dataDeserialize(
blockId: BlockId,
- bytes: ByteBuffer,
+ bytes: LargeByteBuffer,
serializer: Serializer = defaultSerializer): Iterator[Any] = {
bytes.rewind()
- dataDeserializeStream(blockId, new ByteBufferInputStream(bytes, true), serializer)
+ dataDeserializeStream(blockId, new LargeByteBufferInputStream(bytes, true), serializer)
}
/**
@@ -1291,3 +1308,51 @@ private[spark] object BlockManager extends Logging {
blockManagers.toMap
}
}
+
+
+abstract class BlockSizeLimitException(msg: String, cause: BufferTooLargeException)
+ extends SparkException(msg, cause)
+
+object BlockSizeLimitException {
+ def sizeMsg(cause: BufferTooLargeException): String = {
+ s"that was ${Utils.bytesToString(cause.actualSize)} (too " +
+ s"large by ${Utils.bytesToString(cause.extra)} / " +
+ s"${cause.actualSize.toDouble / LargeByteBufferHelper.MAX_CHUNK_SIZE}x)."
+ }
+
+ def sizeMsgAndAdvice(cause: BufferTooLargeException): String = {
+ sizeMsg(cause) +
+ " You should figure out which stage created these partitions, then increase the number of " +
+ "partitions used by that stage. That way, you will have less data per partition. You may " +
+ "want to make the number of partitions an easily configurable parameter so you can " +
+ "continue to update it as needed."
+ }
+
+}
+
+class ReplicationBlockSizeLimitException(cause: BufferTooLargeException)
+ extends BlockSizeLimitException("Spark cannot replicate partitions that are greater than 2GB. " +
+ "You tried to replicate a partition " + BlockSizeLimitException.sizeMsgAndAdvice(cause) +
+ " Or, you can turn off replication.", cause)
+
+class TachyonBlockSizeLimitException(cause: BufferTooLargeException)
+ extends BlockSizeLimitException("Spark cannot store partitions greater than 2GB in tachyon. " +
+ "You tried to store a partition " + BlockSizeLimitException.sizeMsgAndAdvice(cause) +
+ " Or, you can use a different storage mechanism.", cause)
+
+class ShuffleBlockSizeLimitException(size: Long)
+ extends SparkException("Spark cannot shuffle partitions that are greater than 2GB. " +
+ "You tried to shuffle a block that was at least " + Utils.bytesToString(size) + ". " +
+ "You should try to increase the number of partitions of this shuffle, and / or " +
+ "figure out which stage created the partitions before the shuffle, and increase the number " +
+ "of partitions for that stage. You may want to make both of these numbers easily " +
+ "configurable parameters so you can continue to update as needed.")
+
+class ShuffleRemoteBlockSizeLimitException(cause: BufferTooLargeException)
+ extends BlockSizeLimitException("Spark cannot shuffle partitions that are greater than 2GB. " +
+ "You tried to shuffle a block that was at least " + BlockSizeLimitException.sizeMsg(cause) +
+ "You should try to increase the number of partitions of this shuffle, and / or increase the " +
+ "figure out which stage created the partitions before the shuffle, and increase the number " +
+ "of partitions for that stage. You may want to make both of these numbers easily " +
+ "configurable parameters so you can continue to update as needed.", cause)
+
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
index 69985c9759e2d..ebbe06fde68ae 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockStore.scala
@@ -17,18 +17,15 @@
package org.apache.spark.storage
-import java.nio.ByteBuffer
-
-import scala.collection.mutable.ArrayBuffer
-
import org.apache.spark.Logging
+import org.apache.spark.network.buffer.LargeByteBuffer
/**
* Abstract class to store blocks.
*/
private[spark] abstract class BlockStore(val blockManager: BlockManager) extends Logging {
- def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult
+ def putBytes(blockId: BlockId, bytes: LargeByteBuffer, level: StorageLevel) : PutResult
/**
* Put in a block and, possibly, also return its content as either bytes or another Iterator.
@@ -54,7 +51,7 @@ private[spark] abstract class BlockStore(val blockManager: BlockManager) extends
*/
def getSize(blockId: BlockId): Long
- def getBytes(blockId: BlockId): Option[ByteBuffer]
+ def getBytes(blockId: BlockId): Option[LargeByteBuffer]
def getValues(blockId: BlockId): Option[Iterator[Any]]
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index 49d9154f95a5b..ee35281c93ee7 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -21,6 +21,7 @@ import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream}
import java.nio.channels.FileChannel
import org.apache.spark.Logging
+import org.apache.spark.network.buffer.LargeByteBufferHelper
import org.apache.spark.serializer.{SerializerInstance, SerializationStream}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.util.Utils
@@ -108,6 +109,12 @@ private[spark] class DiskBlockObjectWriter(
objOut.close()
}
+ finalPosition = file.length()
+ val length = finalPosition - initialPosition
+ if (length > LargeByteBufferHelper.MAX_CHUNK_SIZE) {
+ throw new ShuffleBlockSizeLimitException(length)
+ }
+
channel = null
bs = null
fos = null
@@ -201,6 +208,10 @@ private[spark] class DiskBlockObjectWriter(
if (numRecordsWritten % 32 == 0) {
updateBytesWritten()
+ val length = reportedPosition - initialPosition
+ if (length > LargeByteBufferHelper.MAX_CHUNK_SIZE) {
+ throw new ShuffleBlockSizeLimitException(length)
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index 1f45956282166..1d4b023e6f24e 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -22,6 +22,7 @@ import java.nio.ByteBuffer
import java.nio.channels.FileChannel.MapMode
import org.apache.spark.Logging
+import org.apache.spark.network.buffer.{LargeByteBufferHelper, LargeByteBuffer}
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.Utils
@@ -37,7 +38,10 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc
diskManager.getFile(blockId.name).length
}
- override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel): PutResult = {
+ override def putBytes(
+ blockId: BlockId,
+ _bytes: LargeByteBuffer,
+ level: StorageLevel): PutResult = {
// So that we do not modify the input offsets !
// duplicate does not copy buffer, so inexpensive
val bytes = _bytes.duplicate()
@@ -46,16 +50,14 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc
val file = diskManager.getFile(blockId)
val channel = new FileOutputStream(file).getChannel
Utils.tryWithSafeFinally {
- while (bytes.remaining > 0) {
- channel.write(bytes)
- }
+ bytes.writeTo(channel)
} {
channel.close()
}
val finishTime = System.currentTimeMillis
logDebug("Block %s stored as %s file on disk in %d ms".format(
- file.getName, Utils.bytesToString(bytes.limit), finishTime - startTime))
- PutResult(bytes.limit(), Right(bytes.duplicate()))
+ file.getName, Utils.bytesToString(bytes.size()), finishTime - startTime))
+ PutResult(bytes.size(), Right(bytes.duplicate()))
}
override def putArray(
@@ -106,7 +108,7 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc
}
}
- private def getBytes(file: File, offset: Long, length: Long): Option[ByteBuffer] = {
+ private def getBytes(file: File, offset: Long, length: Long): Option[LargeByteBuffer] = {
val channel = new RandomAccessFile(file, "r").getChannel
Utils.tryWithSafeFinally {
// For small files, directly read rather than memory map
@@ -120,21 +122,22 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc
}
}
buf.flip()
- Some(buf)
+ Some(LargeByteBufferHelper.asLargeByteBuffer(buf))
} else {
- Some(channel.map(MapMode.READ_ONLY, offset, length))
+ logTrace(s"mapping file: $file:$offset+$length")
+ Some(LargeByteBufferHelper.mapFile(channel, MapMode.READ_ONLY, offset, length))
}
} {
channel.close()
}
}
- override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
+ override def getBytes(blockId: BlockId): Option[LargeByteBuffer] = {
val file = diskManager.getFile(blockId.name)
getBytes(file, 0, file.length)
}
- def getBytes(segment: FileSegment): Option[ByteBuffer] = {
+ def getBytes(segment: FileSegment): Option[LargeByteBuffer] = {
getBytes(segment.file, segment.offset, segment.length)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala
index f39325a12d244..d71253539e914 100644
--- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockManager.scala
@@ -19,6 +19,8 @@ package org.apache.spark.storage
import java.nio.ByteBuffer
+import org.apache.spark.network.buffer.LargeByteBuffer
+
/**
* An abstract class that the concrete external block manager has to inherit.
* The class has to have a no-argument constructor, and will be initialized by init,
@@ -75,7 +77,7 @@ private[spark] abstract class ExternalBlockManager {
*
* @throws java.io.IOException if there is any file system failure in putting the block.
*/
- def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit
+ def putBytes(blockId: BlockId, bytes: LargeByteBuffer): Unit
def putValues(blockId: BlockId, values: Iterator[_]): Unit = {
val bytes = blockManager.dataSerialize(blockId, values)
@@ -89,7 +91,7 @@ private[spark] abstract class ExternalBlockManager {
*
* @throws java.io.IOException if there is any file system failure in getting the block.
*/
- def getBytes(blockId: BlockId): Option[ByteBuffer]
+ def getBytes(blockId: BlockId): Option[LargeByteBuffer]
/**
* Retrieve the block data.
diff --git a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala
index db965d54bafd6..c2271bccc8f4c 100644
--- a/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ExternalBlockStore.scala
@@ -22,6 +22,7 @@ import java.nio.ByteBuffer
import scala.util.control.NonFatal
import org.apache.spark.Logging
+import org.apache.spark.network.buffer.LargeByteBuffer
import org.apache.spark.util.Utils
@@ -47,7 +48,10 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId:
}
}
- override def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult = {
+ override def putBytes(
+ blockId: BlockId,
+ bytes: LargeByteBuffer,
+ level: StorageLevel): PutResult = {
putIntoExternalBlockStore(blockId, bytes, returnValues = true)
}
@@ -100,7 +104,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId:
private def putIntoExternalBlockStore(
blockId: BlockId,
- bytes: ByteBuffer,
+ bytes: LargeByteBuffer,
returnValues: Boolean): PutResult = {
logTrace(s"Attempting to put block $blockId into ExternalBlockStore")
// we should never hit here if externalBlockManager is None. Handle it anyway for safety.
@@ -110,7 +114,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId:
val byteBuffer = bytes.duplicate()
byteBuffer.rewind()
externalBlockManager.get.putBytes(blockId, byteBuffer)
- val size = bytes.limit()
+ val size = bytes.size()
val data = if (returnValues) {
Right(bytes)
} else {
@@ -152,7 +156,7 @@ private[spark] class ExternalBlockStore(blockManager: BlockManager, executorId:
}
}
- override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
+ override def getBytes(blockId: BlockId): Option[LargeByteBuffer] = {
try {
externalBlockManager.flatMap(_.getBytes(blockId))
} catch {
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index 6f27f00307f8c..157c49ebdc919 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -17,9 +17,10 @@
package org.apache.spark.storage
-import java.nio.ByteBuffer
import java.util.LinkedHashMap
+import org.apache.spark.network.buffer.LargeByteBuffer
+
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
@@ -86,7 +87,10 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel): PutResult = {
+ override def putBytes(
+ blockId: BlockId,
+ _bytes: LargeByteBuffer,
+ level: StorageLevel): PutResult = {
// Work on a duplicate - since the original input might be used elsewhere.
val bytes = _bytes.duplicate()
bytes.rewind()
@@ -94,8 +98,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
val values = blockManager.dataDeserialize(blockId, bytes)
putIterator(blockId, values, level, returnValues = true)
} else {
- val putAttempt = tryToPut(blockId, bytes, bytes.limit, deserialized = false)
- PutResult(bytes.limit(), Right(bytes.duplicate()), putAttempt.droppedBlocks)
+ val putAttempt = tryToPut(blockId, bytes, bytes.size(), deserialized = false)
+ PutResult(bytes.size(), Right(bytes.duplicate()), putAttempt.droppedBlocks)
}
}
@@ -105,13 +109,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
*
* The caller should guarantee that `size` is correct.
*/
- def putBytes(blockId: BlockId, size: Long, _bytes: () => ByteBuffer): PutResult = {
+ def putBytes(
+ blockId: BlockId,
+ size: Long,
+ _bytes: () => LargeByteBuffer): PutResult = {
// Work on a duplicate - since the original input might be used elsewhere.
- lazy val bytes = _bytes().duplicate().rewind().asInstanceOf[ByteBuffer]
+ lazy val bytes = _bytes().duplicate().rewind()
val putAttempt = tryToPut(blockId, () => bytes, size, deserialized = false)
val data =
if (putAttempt.success) {
- assert(bytes.limit == size)
+ assert(bytes.size() == size)
Right(bytes.duplicate())
} else {
null
@@ -130,8 +137,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
PutResult(sizeEstimate, Left(values.iterator), putAttempt.droppedBlocks)
} else {
val bytes = blockManager.dataSerialize(blockId, values.iterator)
- val putAttempt = tryToPut(blockId, bytes, bytes.limit, deserialized = false)
- PutResult(bytes.limit(), Right(bytes.duplicate()), putAttempt.droppedBlocks)
+ val putAttempt = tryToPut(blockId, bytes, bytes.size(), deserialized = false)
+ PutResult(bytes.size(), Right(bytes.duplicate()), putAttempt.droppedBlocks)
}
}
@@ -181,7 +188,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
- override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
+ override def getBytes(blockId: BlockId): Option[LargeByteBuffer] = {
val entry = entries.synchronized {
entries.get(blockId)
}
@@ -190,7 +197,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
} else if (entry.deserialized) {
Some(blockManager.dataSerialize(blockId, entry.value.asInstanceOf[Array[Any]].iterator))
} else {
- Some(entry.value.asInstanceOf[ByteBuffer].duplicate()) // Doesn't actually copy the data
+ Some(entry.value.asInstanceOf[LargeByteBuffer].duplicate()) // Doesn't actually copy the data
}
}
@@ -203,7 +210,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
} else if (entry.deserialized) {
Some(entry.value.asInstanceOf[Array[Any]].iterator)
} else {
- val buffer = entry.value.asInstanceOf[ByteBuffer].duplicate() // Doesn't actually copy data
+ // Doesn't actually copy data
+ val buffer = entry.value.asInstanceOf[LargeByteBuffer].duplicate()
Some(blockManager.dataDeserialize(blockId, buffer))
}
}
@@ -392,7 +400,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
lazy val data = if (deserialized) {
Left(value().asInstanceOf[Array[Any]])
} else {
- Right(value().asInstanceOf[ByteBuffer].duplicate())
+ Right(value().asInstanceOf[LargeByteBuffer].duplicate())
}
val droppedBlockStatus = blockManager.dropFromMemory(blockId, () => data)
droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) }
@@ -463,7 +471,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
val data = if (entry.deserialized) {
Left(entry.value.asInstanceOf[Array[Any]])
} else {
- Right(entry.value.asInstanceOf[ByteBuffer].duplicate())
+ Right(entry.value.asInstanceOf[LargeByteBuffer].duplicate())
}
val droppedBlockStatus = blockManager.dropFromMemory(blockId, data)
droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) }
diff --git a/core/src/main/scala/org/apache/spark/storage/PutResult.scala b/core/src/main/scala/org/apache/spark/storage/PutResult.scala
index f0eac7594ecf6..aa9176791b319 100644
--- a/core/src/main/scala/org/apache/spark/storage/PutResult.scala
+++ b/core/src/main/scala/org/apache/spark/storage/PutResult.scala
@@ -17,7 +17,7 @@
package org.apache.spark.storage
-import java.nio.ByteBuffer
+import org.apache.spark.network.buffer.LargeByteBuffer
/**
* Result of adding a block into a BlockStore. This case class contains a few things:
@@ -28,5 +28,5 @@ import java.nio.ByteBuffer
*/
private[spark] case class PutResult(
size: Long,
- data: Either[Iterator[_], ByteBuffer],
+ data: Either[Iterator[_], LargeByteBuffer],
droppedBlocks: Seq[(BlockId, BlockStatus)] = Seq.empty)
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
index 22878783fca67..f806407c73ead 100644
--- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
@@ -18,7 +18,6 @@
package org.apache.spark.storage
import java.io.IOException
-import java.nio.ByteBuffer
import java.text.SimpleDateFormat
import java.util.{Date, Random}
@@ -32,6 +31,7 @@ import tachyon.TachyonURI
import org.apache.spark.Logging
import org.apache.spark.executor.ExecutorExitCode
+import org.apache.spark.network.buffer.{BufferTooLargeException, LargeByteBufferHelper, LargeByteBuffer}
import org.apache.spark.util.{ShutdownHookManager, Utils}
@@ -99,12 +99,14 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log
fileExists(file)
}
- override def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit = {
+ override def putBytes(blockId: BlockId, bytes: LargeByteBuffer): Unit = {
val file = getFile(blockId)
val os = file.getOutStream(WriteType.TRY_CACHE)
try {
- os.write(bytes.array())
+ os.write(bytes.asByteBuffer().array())
} catch {
+ case tooLarge: BufferTooLargeException =>
+ throw new TachyonBlockSizeLimitException(tooLarge)
case NonFatal(e) =>
logWarning(s"Failed to put bytes of block $blockId into Tachyon", e)
os.cancel()
@@ -127,7 +129,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log
}
}
- override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
+ override def getBytes(blockId: BlockId): Option[LargeByteBuffer] = {
val file = getFile(blockId)
if (file == null || file.getLocationHosts.size == 0) {
return None
@@ -135,9 +137,10 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log
val is = file.getInStream(ReadType.CACHE)
try {
val size = file.length
+ // TODO get tachyon to support large blocks
val bs = new Array[Byte](size.asInstanceOf[Int])
ByteStreams.readFully(is, bs)
- Some(ByteBuffer.wrap(bs))
+ Some(LargeByteBufferHelper.asLargeByteBuffer(bs))
} catch {
case NonFatal(e) =>
logWarning(s"Failed to get bytes of block $blockId from Tachyon", e)
diff --git a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala
index daac6f971eb20..d48eb2f330321 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala
@@ -21,7 +21,6 @@ import java.io.OutputStream
import scala.collection.mutable.ArrayBuffer
-
/**
* An OutputStream that writes to fixed-size chunks of byte arrays.
*
@@ -43,10 +42,13 @@ class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream {
*/
private var position = chunkSize
+ private[spark] var size: Long = 0L
+
override def write(b: Int): Unit = {
allocateNewChunkIfNeeded()
chunks(lastChunkIndex)(position) = b.toByte
position += 1
+ size += 1
}
override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
@@ -58,6 +60,7 @@ class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream {
written += thisBatch
position += thisBatch
}
+ size += len
}
@inline
@@ -91,4 +94,44 @@ class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream {
ret
}
}
+
+ /**
+ * Get a copy of the data between the two endpoints, start <= idx < until. Always returns
+ * an array of size (until - start). Throws an IllegalArgumentException unless
+ * 0 <= start <= until <= size
+ */
+ def slice(start: Long, until: Long): Array[Byte] = {
+ require((until - start) < Integer.MAX_VALUE, "max slice length = Integer.MAX_VALUE")
+ require(start >= 0 && start <= until, s"start ($start) must be >= 0 and <= until ($until)")
+ require(until >= start && until <= size,
+ s"until ($until) must be >= start ($start) and <= size ($size)")
+ var chunkStart = 0L
+ var chunkIdx = 0
+ val length = (until - start).toInt
+ var foundStart = false
+ val result = new Array[Byte](length)
+ while (!foundStart) {
+ val nextChunkStart = chunkStart + chunks(chunkIdx).size
+ if (nextChunkStart > start) {
+ foundStart = true
+ } else {
+ chunkStart = nextChunkStart
+ chunkIdx += 1
+ }
+ }
+
+ var remaining = length
+ var pos = 0
+ var offsetInChunk = (start - chunkStart).toInt
+ while (remaining > 0) {
+ val lenToCopy = math.min(remaining, chunks(chunkIdx).size - offsetInChunk)
+ System.arraycopy(chunks(chunkIdx), offsetInChunk, result, pos, lenToCopy)
+ chunkIdx += 1
+ offsetInChunk = 0
+ pos += lenToCopy
+ remaining -= lenToCopy
+ }
+ result
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 600c1403b0344..4d488a0b78f5b 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -21,6 +21,7 @@ import org.scalatest.concurrent.Timeouts._
import org.scalatest.Matchers
import org.scalatest.time.{Millis, Span}
+import org.apache.spark.network.buffer.LargeByteBufferHelper
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
class NotSerializableClass
@@ -196,7 +197,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
blockManager.master.getLocations(blockId).foreach { cmId =>
val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId,
blockId.toString)
- val deserialized = blockManager.dataDeserialize(blockId, bytes.nioByteBuffer())
+ val deserialized = blockManager.dataDeserialize(blockId,
+ LargeByteBufferHelper.asLargeByteBuffer(bytes.nioByteBuffer()))
.asInstanceOf[Iterator[Int]].toList
assert(deserialized === (1 to 100).toList)
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index d91b799ecfc08..e28818098b918 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ShuffleSuite.NonJavaSerializableClass
import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD}
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.serializer.KryoSerializer
-import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId}
+import org.apache.spark.storage.{ShuffleBlockSizeLimitException, BlockSizeLimitException, ShuffleDataBlockId, ShuffleBlockId}
import org.apache.spark.util.MutablePair
abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext {
@@ -283,6 +283,39 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
rdd.count()
}
+ ignore("shuffle total > 2GB ok if each block is small") {
+ sc = new SparkContext("local", "test", conf)
+ val rdd = sc.parallelize(1 to 1e6.toInt, 1).map{ i =>
+ val n = 3e3.toInt
+ val arr = new Array[Byte](n)
+ // need to make sure the array doesn't compress to something small
+ scala.util.Random.nextBytes(arr)
+ (i, arr)
+ }
+ rdd.partitionBy(new HashPartitioner(100)).count()
+ }
+
+ ignore("shuffle blocks > 2GB fail with sane exception") {
+ // note that this *could* succeed in local mode, b/c local shuffles actually don't
+ // have a limit at 2GB. BUT, we make them fail in any case, b/c its better to have
+ // a consistent failure, and not have success depend on where tasks get scheduled
+
+ sc = new SparkContext("local", "test", conf)
+ val rdd = sc.parallelize(1 to 1e6.toInt, 1).map{ i =>
+ val n = 3e3.toInt
+ val arr = new Array[Byte](n)
+ // need to make sure the array doesn't compress to something small
+ scala.util.Random.nextBytes(arr)
+ (2 * i, arr)
+ }
+
+ val exc = intercept[SparkException] {
+ rdd.partitionBy(new org.apache.spark.HashPartitioner(2)).count()
+ }
+
+ exc.getCause shouldBe a[ShuffleBlockSizeLimitException]
+ }
+
test("metrics for shuffle without aggregation") {
sc = new SparkContext("local", "test", conf.clone())
val numRecords = 10000
diff --git a/core/src/test/scala/org/apache/spark/network/buffer/LargeByteBufferInputStreamSuite.scala b/core/src/test/scala/org/apache/spark/network/buffer/LargeByteBufferInputStreamSuite.scala
new file mode 100644
index 0000000000000..d8e48db32f78c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/buffer/LargeByteBufferInputStreamSuite.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.buffer
+
+import java.io.{File, FileInputStream, FileOutputStream, OutputStream}
+import java.nio.channels.FileChannel.MapMode
+
+import org.junit.Assert._
+import org.scalatest.Matchers
+
+import org.apache.spark.SparkFunSuite
+
+class LargeByteBufferInputStreamSuite extends SparkFunSuite with Matchers {
+
+ test("read from large mapped file") {
+ val testFile = File.createTempFile("large-buffer-input-stream-test", ".bin")
+
+ try {
+ val out: OutputStream = new FileOutputStream(testFile)
+ val buffer: Array[Byte] = new Array[Byte](1 << 16)
+ val len: Long = buffer.length.toLong + Integer.MAX_VALUE + 1
+ (0 until buffer.length).foreach { idx =>
+ buffer(idx) = idx.toByte
+ }
+ (0 until (len / buffer.length).toInt).foreach { idx =>
+ out.write(buffer)
+ }
+ out.close
+
+ val channel = new FileInputStream(testFile).getChannel
+ val buf = LargeByteBufferHelper.mapFile(channel, MapMode.READ_ONLY, 0, len)
+ val in = new LargeByteBufferInputStream(buf, true)
+
+ val read = new Array[Byte](buffer.length)
+ (0 until (len / buffer.length).toInt).foreach { idx =>
+ in.disposed should be(false)
+ in.read(read) should be(read.length)
+ (0 until buffer.length).foreach { arrIdx =>
+ assertEquals(buffer(arrIdx), read(arrIdx))
+ }
+ }
+ in.disposed should be(false)
+ in.read(read) should be(-1)
+ in.disposed should be(false)
+ in.close()
+ in.disposed should be(true)
+ } finally {
+ testFile.delete()
+ }
+ }
+
+ test("dispose on close") {
+ // don't need to read to the end -- dispose anytime we close
+ val data = new Array[Byte](10)
+ val in = new LargeByteBufferInputStream(LargeByteBufferHelper.asLargeByteBuffer(data), true)
+ in.disposed should be (false)
+ in.close()
+ in.disposed should be (true)
+ }
+
+ test("io stream roundtrip") {
+ val out = new LargeByteBufferOutputStream(128)
+ (0 until 200).foreach { idx => out.write(idx) }
+ out.close()
+
+ val lb = out.largeBuffer(128)
+ // just make sure that we test reading from multiple chunks
+ lb.asInstanceOf[WrappedLargeByteBuffer].underlying.size should be > 1
+
+ val rawIn = new LargeByteBufferInputStream(lb)
+ val arr = new Array[Byte](500)
+ val nRead = rawIn.read(arr, 0, 500)
+ nRead should be (200)
+ (0 until 200).foreach { idx =>
+ arr(idx) should be (idx.toByte)
+ }
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/network/buffer/LargeByteBufferOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/network/buffer/LargeByteBufferOutputStreamSuite.scala
new file mode 100644
index 0000000000000..72c98b7feacab
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/buffer/LargeByteBufferOutputStreamSuite.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.buffer
+
+import scala.util.Random
+
+import org.scalatest.Matchers
+
+import org.apache.spark.SparkFunSuite
+
+class LargeByteBufferOutputStreamSuite extends SparkFunSuite with Matchers {
+
+ test("merged buffers for < 2GB") {
+ val out = new LargeByteBufferOutputStream(10)
+ val bytes = new Array[Byte](100)
+ Random.nextBytes(bytes)
+ out.write(bytes)
+
+ val buffer = out.largeBuffer
+ buffer.position() should be (0)
+ buffer.size() should be (100)
+ val nioBuffer = buffer.asByteBuffer()
+ nioBuffer.position() should be (0)
+ nioBuffer.capacity() should be (100)
+ nioBuffer.limit() should be (100)
+
+ val read = new Array[Byte](100)
+ buffer.get(read, 0, 100)
+ read should be (bytes)
+
+ buffer.rewind()
+ nioBuffer.get(read)
+ read should be (bytes)
+ }
+
+ test("chunking") {
+ val out = new LargeByteBufferOutputStream(10)
+ val bytes = new Array[Byte](100)
+ Random.nextBytes(bytes)
+ out.write(bytes)
+
+ (10 to 100 by 10).foreach { chunkSize =>
+ val buffer = out.largeBuffer(chunkSize).asInstanceOf[WrappedLargeByteBuffer]
+ buffer.position() should be (0)
+ buffer.size() should be (100)
+ val read = new Array[Byte](100)
+ buffer.get(read, 0, 100)
+ read should be (bytes)
+ }
+
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/network/buffer/LargeByteBufferTestHelper.scala b/core/src/test/scala/org/apache/spark/network/buffer/LargeByteBufferTestHelper.scala
new file mode 100644
index 0000000000000..a04bb41fae366
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/network/buffer/LargeByteBufferTestHelper.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer
+
+import java.nio.ByteBuffer
+import java.util.{List => JList}
+
+/**
+ * cheat to access package-protected members in test
+ */
+object LargeByteBufferTestHelper {
+ def nioBuffers(wbb: WrappedLargeByteBuffer): JList[ByteBuffer] = {
+ wbb.nioBuffers()
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/LargePartitionCachingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/LargePartitionCachingSuite.scala
new file mode 100644
index 0000000000000..dfcc90df32a18
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/LargePartitionCachingSuite.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.rdd
+
+import org.apache.spark._
+import org.apache.spark.storage.{ReplicationBlockSizeLimitException, StorageLevel}
+import org.scalatest.Matchers
+
+class LargePartitionCachingSuite extends SparkFunSuite with SharedSparkContext with Matchers {
+
+ def largePartitionRdd: RDD[Array[Byte]] = {
+ sc.parallelize(1 to 1e6.toInt, 1).map{i => new Array[Byte](2.2e3.toInt)}
+ }
+
+ // just don't want to kill the test server
+ ignore("memory serialized cache large partitions") {
+ largePartitionRdd.persist(StorageLevel.MEMORY_ONLY_SER).count() should be (1e6.toInt)
+ }
+
+ ignore("disk cache large partitions") {
+ largePartitionRdd.persist(StorageLevel.DISK_ONLY).count() should be (1e6.toInt)
+ }
+
+ ignore("disk cache large partitions with replications") {
+ val conf = new SparkConf()
+ .setMaster("local-cluster[2, 1, 1024]")
+ .setAppName("test-cluster")
+ .set("spark.task.maxFailures", "1")
+ .set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data
+ val clusterSc = new SparkContext(conf)
+ try {
+ val exc = intercept[SparkException]{
+ val myRDD = clusterSc.parallelize(1 to 1e6.toInt, 1).map{i => new Array[Byte](2.2e3.toInt)}
+ .persist(StorageLevel.DISK_ONLY_2)
+ myRDD.count()
+ }
+ exc.getCause() shouldBe a [ReplicationBlockSizeLimitException]
+ } finally {
+ clusterSc.stop()
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index f480fd107a0c2..d1d6778180a4c 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -20,6 +20,8 @@ package org.apache.spark.storage
import java.nio.{ByteBuffer, MappedByteBuffer}
import java.util.Arrays
+import org.apache.spark.network.buffer.{LargeByteBufferTestHelper, WrappedLargeByteBuffer, LargeByteBufferHelper, LargeByteBuffer}
+
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
import scala.language.implicitConversions
@@ -169,8 +171,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(master.getLocations("a3").size === 0, "master was told about a3")
// Drop a1 and a2 from memory; this should be reported back to the master
- store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer])
- store.dropFromMemory("a2", null: Either[Array[Any], ByteBuffer])
+ store.dropFromMemory("a1", null: Either[Array[Any], LargeByteBuffer])
+ store.dropFromMemory("a2", null: Either[Array[Any], LargeByteBuffer])
assert(store.getSingle("a1") === None, "a1 not removed from store")
assert(store.getSingle("a2") === None, "a2 not removed from store")
assert(master.getLocations("a1").size === 0, "master did not remove a1")
@@ -410,8 +412,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
t2.join()
t3.join()
- store.dropFromMemory("a1", null: Either[Array[Any], ByteBuffer])
- store.dropFromMemory("a2", null: Either[Array[Any], ByteBuffer])
+ store.dropFromMemory("a1", null: Either[Array[Any], LargeByteBuffer])
+ store.dropFromMemory("a2", null: Either[Array[Any], LargeByteBuffer])
store.waitForAsyncReregister()
}
}
@@ -807,7 +809,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
var counter = 0.toByte
def incr: Byte = {counter = (counter + 1).toByte; counter;}
val bytes = Array.fill[Byte](1000)(incr)
- val byteBuffer = ByteBuffer.wrap(bytes)
+ val byteBuffer = LargeByteBufferHelper.asLargeByteBuffer(bytes)
val blockId = BlockId("rdd_1_2")
@@ -820,21 +822,22 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
val diskStoreMapped = new DiskStore(blockManager, diskBlockManager)
diskStoreMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY)
- val mapped = diskStoreMapped.getBytes(blockId).get
+ val mapped = diskStoreMapped.getBytes(blockId).get.asInstanceOf[WrappedLargeByteBuffer]
when(blockManager.conf).thenReturn(conf.clone.set(confKey, "1m"))
val diskStoreNotMapped = new DiskStore(blockManager, diskBlockManager)
diskStoreNotMapped.putBytes(blockId, byteBuffer, StorageLevel.DISK_ONLY)
- val notMapped = diskStoreNotMapped.getBytes(blockId).get
+ val notMapped = diskStoreNotMapped.getBytes(blockId).get.asInstanceOf[WrappedLargeByteBuffer]
// Not possible to do isInstanceOf due to visibility of HeapByteBuffer
- assert(notMapped.getClass.getName.endsWith("HeapByteBuffer"),
- "Expected HeapByteBuffer for un-mapped read")
- assert(mapped.isInstanceOf[MappedByteBuffer], "Expected MappedByteBuffer for mapped read")
-
- def arrayFromByteBuffer(in: ByteBuffer): Array[Byte] = {
- val array = new Array[Byte](in.remaining())
- in.get(array)
+ assert(LargeByteBufferTestHelper.nioBuffers(notMapped).get(0).getClass.getName
+ .endsWith("HeapByteBuffer"), "Expected HeapByteBuffer for un-mapped read")
+ assert(LargeByteBufferTestHelper.nioBuffers(mapped).get(0).isInstanceOf[MappedByteBuffer],
+ "Expected MappedByteBuffer for mapped read")
+
+ def arrayFromByteBuffer(in: LargeByteBuffer): Array[Byte] = {
+ val array = new Array[Byte](in.remaining().toInt)
+ in.get(array, 0, in.remaining().toInt)
array
}
@@ -1242,13 +1245,23 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
store = makeBlockManager(12000)
val memoryStore = store.memoryStore
val blockId = BlockId("rdd_3_10")
- var bytes: ByteBuffer = null
+ var bytes: LargeByteBuffer = null
val result = memoryStore.putBytes(blockId, 10000, () => {
- bytes = ByteBuffer.allocate(10000)
+ bytes = LargeByteBufferHelper.allocate(10000)
bytes
})
assert(result.size === 10000)
- assert(result.data === Right(bytes))
+ assert(result.data.isRight)
+ assertEquivalentByteBufs(result.data.right.get, bytes)
assert(result.droppedBlocks === Nil)
}
+
+ def assertEquivalentByteBufs(exp: LargeByteBuffer, act: LargeByteBuffer): Unit = {
+ assert(exp.size() === act.size())
+ val expBytes = new Array[Byte](exp.size().toInt)
+ exp.get(expBytes, 0, exp.size().toInt)
+ val actBytes = new Array[Byte](act.size().toInt)
+ act.get(actBytes, 0, act.size().toInt)
+ assert(expBytes === actBytes)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala
index 361ec95654f47..38bc24528f3a7 100644
--- a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala
@@ -21,7 +21,6 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
-
class ByteArrayChunkOutputStreamSuite extends SparkFunSuite {
test("empty output") {
@@ -106,4 +105,30 @@ class ByteArrayChunkOutputStreamSuite extends SparkFunSuite {
assert(arrays(1).toSeq === ref.slice(10, 20))
assert(arrays(2).toSeq === ref.slice(20, 30))
}
+
+ test("slice") {
+ val ref = new Array[Byte](30)
+ Random.nextBytes(ref)
+ val o = new ByteArrayChunkOutputStream(5)
+ o.write(ref)
+
+ for {
+ start <- (0 until 30)
+ end <- (start to 30)
+ } {
+ withClue(s"start = $start; end = $end") {
+ try {
+ assert(o.slice(start, end).toSeq === ref.slice(start, end))
+ } catch {
+ case ex => fail(ex)
+ }
+ }
+ }
+
+ // errors on bad bounds
+ intercept[IllegalArgumentException]{o.slice(31, 31)}
+ intercept[IllegalArgumentException]{o.slice(-1, 10)}
+ intercept[IllegalArgumentException]{o.slice(10, 5)}
+ intercept[IllegalArgumentException]{o.slice(10, 35)}
+ }
}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/BufferTooLargeException.java b/network/common/src/main/java/org/apache/spark/network/buffer/BufferTooLargeException.java
new file mode 100644
index 0000000000000..4e1a85ba1f126
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/BufferTooLargeException.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+
+public class BufferTooLargeException extends IOException {
+ public final long actualSize;
+ public final long extra;
+ public final long maxSize;
+
+ public BufferTooLargeException(long actualSize, long maxSize) {
+ super(String.format("LargeByteBuffer is too large to convert. Size: %d; Size Limit: %d (%d " +
+ "too big)", actualSize, maxSize,
+ actualSize - maxSize));
+ this.extra = actualSize - maxSize;
+ this.actualSize = actualSize;
+ this.maxSize = maxSize;
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
index 844eff4f4c701..2d534f12abd62 100644
--- a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
@@ -73,6 +73,9 @@ public ByteBuffer nioByteBuffer() throws IOException {
buf.flip();
return buf;
} else {
+ if (length > LargeByteBufferHelper.MAX_CHUNK_SIZE) {
+ throw new BufferTooLargeException(length, LargeByteBufferHelper.MAX_CHUNK_SIZE);
+ }
return channel.map(FileChannel.MapMode.READ_ONLY, offset, length);
}
} catch (IOException e) {
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/LargeByteBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/LargeByteBuffer.java
new file mode 100644
index 0000000000000..beeb007e2197e
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/LargeByteBuffer.java
@@ -0,0 +1,148 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+
+/**
+ * A byte buffer which can hold over 2GB.
+ *
get methods). It supports random access via skip to move around the
+ * buffer.
+ *
+ * In general, implementations are expected to support O(1) random access. Furthermore,
+ * neighboring locations in the buffer are likely to be neighboring in memory, so sequential access
+ * will avoid cache-misses. However, these are only rough guidelines which may differ in
+ * implementations.
+ *
+ * Any code which expects a ByteBuffer can obtain one via {@link #asByteBuffer} when possible -- see
+ * that method for a full description of its limitations.
+ *
+ * Instances of this class can be created with
+ * {@link org.apache.spark.network.buffer.LargeByteBufferHelper},
+ * with a LargeByteBufferOutputStream,
+ * or directly from the implementation
+ * {@link org.apache.spark.network.buffer.WrappedLargeByteBuffer}.
+ */
+public interface LargeByteBuffer {
+ public byte get();
+
+
+ /**
+ * Bulk copy data from this buffer into the given array. First checks there is sufficient
+ * data in this buffer; if not, throws a {@link java.nio.BufferUnderflowException}. Behaves
+ * in the exact same way as get(dst, 0, dst.length)
+ *
+ * @param dst the destination array
+ * @return this buffer
+ */
+ public LargeByteBuffer get(byte[] dst);
+
+ /**
+ * Bulk copy data from this buffer into the given array. First checks there is sufficient
+ * data in this buffer; if not, throws a {@link java.nio.BufferUnderflowException}.
+ *
+ * @param dst the destination array
+ * @param offset the offset within the destination array to write to
+ * @param length how many bytes to write
+ * @return this buffer
+ */
+ public LargeByteBuffer get(byte[] dst, int offset, int length);
+
+
+ public LargeByteBuffer rewind();
+
+ /**
+ * Return a deep copy of this buffer.
+ * The returned buffer will have position == 0. The position
+ * of this buffer will not change as a result of copying.
+ *
+ * @return a new buffer with a full copy of this buffer's data
+ */
+ public LargeByteBuffer deepCopy();
+
+ /**
+ * Advance the position in this buffer by up to n bytes. n may be
+ * positive or negative. It will move the full n unless that moves
+ * it past the end (or beginning) of the buffer, in which case it will move to the end
+ * (or beginning).
+ *
+ * @return the number of bytes moved forward (can be negative if n is negative)
+ */
+ public long skip(long n);
+
+ public long position();
+
+ /**
+ * Creates a new byte buffer that shares this buffer's content.
+ *
+ * The content of the new buffer will be that of this buffer. Changes
+ * to this buffer's content will be visible in the new buffer, and vice
+ * versa; the two buffers' positions will be independent.
+ *
+ * The new buffer's position will be identical to those of this buffer
+ */
+ public LargeByteBuffer duplicate();
+
+ public long remaining();
+
+ /**
+ * Total number of bytes in this buffer
+ */
+ public long size();
+
+ /**
+ * Writes the data from the current position() to the end of this buffer
+ * to the given channel. The position() will be moved to the end of
+ * the buffer after this.
+ *
+ * Note that this method will continually attempt to push data to the given channel. If the
+ * channel cannot accept more data, this will continuously retry until the channel accepts
+ * the data.
+ *
+ * @param channel
+ * @return the number of bytes written to the channel
+ * @throws IOException
+ */
+ public long writeTo(WritableByteChannel channel) throws IOException;
+
+ /**
+ * Get the entire contents of this as one ByteBuffer, if possible. The returned ByteBuffer
+ * will always have the position set to 0, and the limit set to the end of the data. Each
+ * call will return a new ByteBuffer, but will not require copying the data (eg., it will
+ * use ByteBuffer#duplicate()). The returned byte buffer will share data with this buffer. The
+ * returned buffers will never be larger than
+ * {@link org.apache.spark.network.buffer.LargeByteBufferHelper#MAX_CHUNK_SIZE}
+ *
+ * @throws BufferTooLargeException if this buffer is too large to fit in one {@link ByteBuffer}
+ */
+ public ByteBuffer asByteBuffer() throws BufferTooLargeException;
+
+ /**
+ * Attempt to clean this up if it is memory-mapped. This uses an *unsafe* Sun API that
+ * might cause errors if one attempts to read from the unmapped buffer, but it's better than
+ * waiting for the GC to find it because that could lead to huge numbers of open files. There's
+ * unfortunately no standard API to do this.
+ */
+ public void dispose();
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/LargeByteBufferHelper.java b/network/common/src/main/java/org/apache/spark/network/buffer/LargeByteBufferHelper.java
new file mode 100644
index 0000000000000..4941ed6559ea9
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/LargeByteBufferHelper.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.FileChannel;
+
+import com.google.common.annotations.VisibleForTesting;
+
+/**
+ * Utils for creating {@link org.apache.spark.network.buffer.LargeByteBuffer}s, either from
+ * pre-allocated byte arrays, ByteBuffers, or by memory mapping a file.
+ */
+public class LargeByteBufferHelper {
+
+ // netty can't quite send msgs that are a full 2GB -- they need to be slightly smaller
+ // not sure what the exact limit is, but 200 seems OK.
+ /**
+ * The maximum size of any ByteBuffer.
+ * {@link org.apache.spark.network.buffer.LargeByteBuffer#asByteBuffer} will never return a
+ * ByteBuffer larger than this. This is close to the max ByteBuffer size (2GB), minus a small
+ * amount for message overhead.
+ */
+ public static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 200;
+
+ public static LargeByteBuffer asLargeByteBuffer(ByteBuffer buffer) {
+ return new WrappedLargeByteBuffer(new ByteBuffer[] { buffer } );
+ }
+
+ public static LargeByteBuffer asLargeByteBuffer(byte[] bytes) {
+ return asLargeByteBuffer(ByteBuffer.wrap(bytes));
+ }
+
+ public static LargeByteBuffer allocate(long size) {
+ return allocate(size, MAX_CHUNK_SIZE);
+ }
+
+ @VisibleForTesting
+ static LargeByteBuffer allocate(long size, int maxChunk) {
+ int chunksNeeded = (int) ((size + maxChunk - 1) / maxChunk);
+ ByteBuffer[] chunks = new ByteBuffer[chunksNeeded];
+ long remaining = size;
+ for (int i = 0; i < chunksNeeded; i++) {
+ int nextSize = (int) Math.min(remaining, maxChunk);
+ ByteBuffer next = ByteBuffer.allocate(nextSize);
+ remaining -= nextSize;
+ chunks[i] = next;
+ }
+ if (remaining != 0) {
+ throw new IllegalStateException("remaining = " + remaining);
+ }
+ return new WrappedLargeByteBuffer(chunks, maxChunk);
+ }
+
+ public static LargeByteBuffer mapFile(
+ FileChannel channel,
+ FileChannel.MapMode mode,
+ long offset,
+ long length
+ ) throws IOException {
+ int chunksNeeded = (int) ((length - 1) / MAX_CHUNK_SIZE) + 1;
+ ByteBuffer[] chunks = new ByteBuffer[chunksNeeded];
+ long curPos = offset;
+ long end = offset + length;
+ for (int i = 0; i < chunksNeeded; i++) {
+ long nextPos = Math.min(curPos + MAX_CHUNK_SIZE, end);
+ chunks[i] = channel.map(mode, curPos, nextPos - curPos);
+ curPos = nextPos;
+ }
+ return new WrappedLargeByteBuffer(chunks);
+ }
+
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/WrappedLargeByteBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/WrappedLargeByteBuffer.java
new file mode 100644
index 0000000000000..58a621249386f
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/WrappedLargeByteBuffer.java
@@ -0,0 +1,292 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.nio.BufferUnderflowException;
+import java.nio.ByteBuffer;
+import java.nio.channels.WritableByteChannel;
+import java.util.Arrays;
+import java.util.List;
+
+import com.google.common.annotations.VisibleForTesting;
+import sun.nio.ch.DirectBuffer;
+
+/**
+ * A {@link org.apache.spark.network.buffer.LargeByteBuffer} which may contain multiple
+ * {@link java.nio.ByteBuffer}s. In order to support asByteBuffer, all
+ * of the underlying ByteBuffers must have size equal to
+ * {@link org.apache.spark.network.buffer.LargeByteBufferHelper#MAX_CHUNK_SIZE} (except that last
+ * one). The underlying ByteBuffers may be on-heap, direct, or memory-mapped.
+ */
+public class WrappedLargeByteBuffer implements LargeByteBuffer {
+
+ @VisibleForTesting
+ final ByteBuffer[] underlying;
+
+ private final long size;
+ /**
+ * each sub-ByteBuffer (except for the last one) must be exactly this size. Note that this
+ * class *really* expects this to be LargeByteBufferHelper.MAX_CHUNK_SIZE. The only reason it isn't
+ * is so that we can do tests without creating ginormous buffers. Public methods force it to
+ * be LargeByteBufferHelper.MAX_CHUNK_SIZE
+ */
+ private final int subBufferSize;
+ private long _pos;
+ @VisibleForTesting
+ int currentBufferIdx;
+ @VisibleForTesting
+ ByteBuffer currentBuffer;
+
+ /**
+ * Construct a WrappedLargeByteBuffer from the given ByteBuffers. Each of the ByteBuffers must
+ * have size equal to {@link org.apache.spark.network.buffer.LargeByteBufferHelper#MAX_CHUNK_SIZE}
+ * except for the final one. The buffers are duplicated, so the position of the
+ * given buffers and the returned buffer will be independent, though the underlying data will be
+ * shared. The constructed buffer will always have position == 0.
+ */
+ public WrappedLargeByteBuffer(ByteBuffer[] underlying) {
+ this(underlying, LargeByteBufferHelper.MAX_CHUNK_SIZE);
+ }
+
+ /**
+ * you do **not** want to call this version. It leads to a buffer which doesn't properly
+ * support {@link #asByteBuffer}. The only reason it exists is to we can have tests which
+ * don't require 2GB of memory
+ *
+ * @param underlying
+ * @param subBufferSize
+ */
+ @VisibleForTesting
+ WrappedLargeByteBuffer(ByteBuffer[] underlying, int subBufferSize) {
+ if (underlying.length == 0) {
+ throw new IllegalArgumentException("must wrap at least one ByteBuffer");
+ }
+ this.underlying = new ByteBuffer[underlying.length];
+ this.subBufferSize = subBufferSize;
+ long sum = 0L;
+
+ for (int i = 0; i < underlying.length; i++) {
+ ByteBuffer b = underlying[i].duplicate();
+ b.position(0);
+ this.underlying[i] = b;
+ if (i != underlying.length - 1 && b.capacity() != subBufferSize) {
+ // this is to make sure that asByteBuffer() is implemented correctly. We need the first
+ // subBuffer to be LargeByteBufferHelper.MAX_CHUNK_SIZE. We don't *have* to check all the
+ // subBuffers, but I figure its makes it more consistent this way. (Also, this check
+ // really only serves a purpose when using the public constructor -- subBufferSize is a
+ // a parameter just to allow small tests.)
+ throw new IllegalArgumentException("All buffers, except for the final one, must have " +
+ "size = " + subBufferSize);
+ }
+ sum += b.capacity();
+ }
+ _pos = 0;
+ currentBufferIdx = 0;
+ currentBuffer = this.underlying[0];
+ size = sum;
+ }
+
+
+ @Override
+ public WrappedLargeByteBuffer get(byte[] dest) {
+ return get(dest, 0, dest.length);
+ }
+
+ @Override
+ public WrappedLargeByteBuffer get(byte[] dest, int offset, int length) {
+ if (length > remaining()) {
+ throw new BufferUnderflowException();
+ }
+ int moved = 0;
+ while (moved < length) {
+ int toRead = Math.min(length - moved, currentBuffer.remaining());
+ currentBuffer.get(dest, offset + moved, toRead);
+ moved += toRead;
+ updateCurrentBufferIfNeeded();
+ }
+ _pos += moved;
+ return this;
+ }
+
+ @Override
+ public LargeByteBuffer rewind() {
+ if (currentBuffer != null) {
+ currentBuffer.rewind();
+ }
+ while (currentBufferIdx > 0) {
+ currentBufferIdx -= 1;
+ currentBuffer = underlying[currentBufferIdx];
+ currentBuffer.rewind();
+ }
+ _pos = 0;
+ return this;
+ }
+
+ @Override
+ public WrappedLargeByteBuffer deepCopy() {
+ ByteBuffer[] dataCopy = new ByteBuffer[underlying.length];
+ for (int i = 0; i < underlying.length; i++) {
+ ByteBuffer b = underlying[i];
+ dataCopy[i] = ByteBuffer.allocate(b.capacity());
+ int originalPosition = b.position();
+ b.rewind();
+ dataCopy[i].put(b);
+ dataCopy[i].position(0);
+ b.position(originalPosition);
+ }
+ return new WrappedLargeByteBuffer(dataCopy, subBufferSize);
+ }
+
+ @Override
+ public byte get() {
+ if (remaining() < 1L) {
+ throw new BufferUnderflowException();
+ }
+ byte r = currentBuffer.get();
+ _pos += 1;
+ updateCurrentBufferIfNeeded();
+ return r;
+ }
+
+ /**
+ * If we've read to the end of the current buffer, move on to the next one. Safe to call
+ * even if we haven't moved to the next buffer
+ */
+ private void updateCurrentBufferIfNeeded() {
+ while (currentBuffer != null && !currentBuffer.hasRemaining()) {
+ currentBufferIdx += 1;
+ currentBuffer = currentBufferIdx < underlying.length ? underlying[currentBufferIdx] : null;
+ }
+ }
+
+ @Override
+ public long position() {
+ return _pos;
+ }
+
+ @Override
+ public long skip(long n) {
+ if (n < 0) {
+ final long moveTotal = Math.min(-n, _pos);
+ long toMove = moveTotal;
+ // move backwards and update the position of every buffer as we go
+ if (currentBuffer != null) {
+ currentBufferIdx += 1;
+ }
+ while (toMove > 0) {
+ currentBufferIdx -= 1;
+ currentBuffer = underlying[currentBufferIdx];
+ int thisMove = (int) Math.min(toMove, currentBuffer.position());
+ currentBuffer.position(currentBuffer.position() - thisMove);
+ toMove -= thisMove;
+ }
+ _pos -= moveTotal;
+ return -moveTotal;
+ } else if (n > 0) {
+ final long moveTotal = Math.min(n, remaining());
+ long toMove = moveTotal;
+ // move forwards and update the position of every buffer as we go
+ currentBufferIdx -= 1;
+ while (toMove > 0) {
+ currentBufferIdx += 1;
+ currentBuffer = underlying[currentBufferIdx];
+ int thisMove = (int) Math.min(toMove, currentBuffer.remaining());
+ currentBuffer.position(currentBuffer.position() + thisMove);
+ toMove -= thisMove;
+ }
+ _pos += moveTotal;
+ return moveTotal;
+ } else {
+ return 0;
+ }
+ }
+
+ @Override
+ public long remaining() {
+ return size - _pos;
+ }
+
+ @Override
+ public WrappedLargeByteBuffer duplicate() {
+ // the constructor will duplicate the underlying buffers for us
+ WrappedLargeByteBuffer dup = new WrappedLargeByteBuffer(underlying, subBufferSize);
+ dup.skip(position());
+ return dup;
+ }
+
+ @Override
+ public long size() {
+ return size;
+ }
+
+ @Override
+ public long writeTo(WritableByteChannel channel) throws IOException {
+ long written = 0L;
+ for (; currentBufferIdx < underlying.length; currentBufferIdx++) {
+ currentBuffer = underlying[currentBufferIdx];
+ written += currentBuffer.remaining();
+ while (currentBuffer.hasRemaining())
+ channel.write(currentBuffer);
+ }
+ _pos = size();
+ return written;
+ }
+
+ @Override
+ public ByteBuffer asByteBuffer() throws BufferTooLargeException {
+ if (underlying.length == 1) {
+ ByteBuffer b = underlying[0].duplicate();
+ b.rewind();
+ return b;
+ } else {
+ // NOTE: if subBufferSize != LargeByteBufferHelper.MAX_CAPACITY, in theory
+ // we could copy the data into a new buffer. But we don't want to do any copying.
+ // The only reason we allow smaller subBufferSize is so that we can have tests which
+ // don't require 2GB of memory
+ throw new BufferTooLargeException(size(), underlying[0].capacity());
+ }
+ }
+
+ @VisibleForTesting
+ List