Skip to content

Commit fdd460f

Browse files
committed
[SPARK-13980] Incrementally serialize blocks while unrolling them in MemoryStore
When a block is persisted in the MemoryStore at a serialized storage level, the current MemoryStore.putIterator() code will unroll the entire iterator as Java objects in memory, then will turn around and serialize an iterator obtained from the unrolled array. This is inefficient and doubles our peak memory requirements. Instead, I think that we should incrementally serialize blocks while unrolling them. A downside to incremental serialization is the fact that we will need to deserialize the partially-unrolled data in case there is not enough space to unroll the block and the block cannot be dropped to disk. However, I'm hoping that the memory efficiency improvements will outweigh any performance losses as a result of extra serialization in that hopefully-rare case. Author: Josh Rosen <[email protected]> Closes apache#11791 from JoshRosen/serialize-incrementally.
1 parent 2cf46d5 commit fdd460f

File tree

6 files changed

+392
-77
lines changed

6 files changed

+392
-77
lines changed

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ private[spark] class BlockManager(
746746
// We will drop it to disk later if the memory store can't hold it.
747747
val putSucceeded = if (level.deserialized) {
748748
val values = serializerManager.dataDeserialize(blockId, bytes)(classTag)
749-
memoryStore.putIterator(blockId, values, level, classTag) match {
749+
memoryStore.putIteratorAsValues(blockId, values, classTag) match {
750750
case Right(_) => true
751751
case Left(iter) =>
752752
// If putting deserialized values in memory failed, we will put the bytes directly to
@@ -876,21 +876,40 @@ private[spark] class BlockManager(
876876
if (level.useMemory) {
877877
// Put it in memory first, even if it also has useDisk set to true;
878878
// We will drop it to disk later if the memory store can't hold it.
879-
memoryStore.putIterator(blockId, iterator(), level, classTag) match {
880-
case Right(s) =>
881-
size = s
882-
case Left(iter) =>
883-
// Not enough space to unroll this block; drop to disk if applicable
884-
if (level.useDisk) {
885-
logWarning(s"Persisting block $blockId to disk instead.")
886-
diskStore.put(blockId) { fileOutputStream =>
887-
serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag)
879+
if (level.deserialized) {
880+
memoryStore.putIteratorAsValues(blockId, iterator(), classTag) match {
881+
case Right(s) =>
882+
size = s
883+
case Left(iter) =>
884+
// Not enough space to unroll this block; drop to disk if applicable
885+
if (level.useDisk) {
886+
logWarning(s"Persisting block $blockId to disk instead.")
887+
diskStore.put(blockId) { fileOutputStream =>
888+
serializerManager.dataSerializeStream(blockId, fileOutputStream, iter)(classTag)
889+
}
890+
size = diskStore.getSize(blockId)
891+
} else {
892+
iteratorFromFailedMemoryStorePut = Some(iter)
888893
}
889-
size = diskStore.getSize(blockId)
890-
} else {
891-
iteratorFromFailedMemoryStorePut = Some(iter)
892-
}
894+
}
895+
} else { // !level.deserialized
896+
memoryStore.putIteratorAsBytes(blockId, iterator(), classTag) match {
897+
case Right(s) =>
898+
size = s
899+
case Left(partiallySerializedValues) =>
900+
// Not enough space to unroll this block; drop to disk if applicable
901+
if (level.useDisk) {
902+
logWarning(s"Persisting block $blockId to disk instead.")
903+
diskStore.put(blockId) { fileOutputStream =>
904+
partiallySerializedValues.finishWritingToStream(fileOutputStream)
905+
}
906+
size = diskStore.getSize(blockId)
907+
} else {
908+
iteratorFromFailedMemoryStorePut = Some(partiallySerializedValues.valuesIterator)
909+
}
910+
}
893911
}
912+
894913
} else if (level.useDisk) {
895914
diskStore.put(blockId) { fileOutputStream =>
896915
serializerManager.dataSerializeStream(blockId, fileOutputStream, iterator())(classTag)
@@ -991,7 +1010,7 @@ private[spark] class BlockManager(
9911010
// Note: if we had a means to discard the disk iterator, we would do that here.
9921011
memoryStore.getValues(blockId).get
9931012
} else {
994-
memoryStore.putIterator(blockId, diskIterator, level, classTag) match {
1013+
memoryStore.putIteratorAsValues(blockId, diskIterator, classTag) match {
9951014
case Left(iter) =>
9961015
// The memory store put() failed, so it returned the iterator back to us:
9971016
iter

core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala

Lines changed: 204 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,24 @@
1717

1818
package org.apache.spark.storage.memory
1919

20+
import java.io.OutputStream
21+
import java.nio.ByteBuffer
2022
import java.util.LinkedHashMap
2123

2224
import scala.collection.mutable
2325
import scala.collection.mutable.ArrayBuffer
2426
import scala.reflect.ClassTag
2527

28+
import com.google.common.io.ByteStreams
29+
2630
import org.apache.spark.{SparkConf, TaskContext}
2731
import org.apache.spark.internal.Logging
2832
import org.apache.spark.memory.MemoryManager
29-
import org.apache.spark.serializer.SerializerManager
33+
import org.apache.spark.serializer.{SerializationStream, SerializerManager}
3034
import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel}
3135
import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils}
3236
import org.apache.spark.util.collection.SizeTrackingVector
33-
import org.apache.spark.util.io.ChunkedByteBuffer
37+
import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer}
3438

3539
private sealed trait MemoryEntry[T] {
3640
def size: Long
@@ -42,8 +46,9 @@ private case class DeserializedMemoryEntry[T](
4246
classTag: ClassTag[T]) extends MemoryEntry[T]
4347
private case class SerializedMemoryEntry[T](
4448
buffer: ChunkedByteBuffer,
45-
size: Long,
46-
classTag: ClassTag[T]) extends MemoryEntry[T]
49+
classTag: ClassTag[T]) extends MemoryEntry[T] {
50+
def size: Long = buffer.size
51+
}
4752

4853
private[storage] trait BlockEvictionHandler {
4954
/**
@@ -132,7 +137,7 @@ private[spark] class MemoryStore(
132137
// We acquired enough memory for the block, so go ahead and put it
133138
val bytes = _bytes()
134139
assert(bytes.size == size)
135-
val entry = new SerializedMemoryEntry[T](bytes, size, implicitly[ClassTag[T]])
140+
val entry = new SerializedMemoryEntry[T](bytes, implicitly[ClassTag[T]])
136141
entries.synchronized {
137142
entries.put(blockId, entry)
138143
}
@@ -145,7 +150,7 @@ private[spark] class MemoryStore(
145150
}
146151

147152
/**
148-
* Attempt to put the given block in memory store.
153+
* Attempt to put the given block in memory store as values.
149154
*
150155
* It's possible that the iterator is too large to materialize and store in memory. To avoid
151156
* OOM exceptions, this method will gradually unroll the iterator while periodically checking
@@ -160,10 +165,9 @@ private[spark] class MemoryStore(
160165
* iterator or call `close()` on it in order to free the storage memory consumed by the
161166
* partially-unrolled block.
162167
*/
163-
private[storage] def putIterator[T](
168+
private[storage] def putIteratorAsValues[T](
164169
blockId: BlockId,
165170
values: Iterator[T],
166-
level: StorageLevel,
167171
classTag: ClassTag[T]): Either[PartiallyUnrolledIterator[T], Long] = {
168172

169173
require(!contains(blockId), s"Block $blockId is already present in the MemoryStore")
@@ -218,12 +222,8 @@ private[spark] class MemoryStore(
218222
// We successfully unrolled the entirety of this block
219223
val arrayValues = vector.toArray
220224
vector = null
221-
val entry = if (level.deserialized) {
225+
val entry =
222226
new DeserializedMemoryEntry[T](arrayValues, SizeEstimator.estimate(arrayValues), classTag)
223-
} else {
224-
val bytes = serializerManager.dataSerialize(blockId, arrayValues.iterator)(classTag)
225-
new SerializedMemoryEntry[T](bytes, bytes.size, classTag)
226-
}
227227
val size = entry.size
228228
def transferUnrollToStorage(amount: Long): Unit = {
229229
// Synchronize so that transfer is atomic
@@ -255,12 +255,8 @@ private[spark] class MemoryStore(
255255
entries.synchronized {
256256
entries.put(blockId, entry)
257257
}
258-
val bytesOrValues = if (level.deserialized) "values" else "bytes"
259-
logInfo("Block %s stored as %s in memory (estimated size %s, free %s)".format(
260-
blockId,
261-
bytesOrValues,
262-
Utils.bytesToString(size),
263-
Utils.bytesToString(maxMemory - blocksMemoryUsed)))
258+
logInfo("Block %s stored as values in memory (estimated size %s, free %s)".format(
259+
blockId, Utils.bytesToString(size), Utils.bytesToString(maxMemory - blocksMemoryUsed)))
264260
Right(size)
265261
} else {
266262
assert(currentUnrollMemoryForThisTask >= currentUnrollMemoryForThisTask,
@@ -279,13 +275,117 @@ private[spark] class MemoryStore(
279275
}
280276
}
281277

278+
/**
279+
* Attempt to put the given block in memory store as bytes.
280+
*
281+
* It's possible that the iterator is too large to materialize and store in memory. To avoid
282+
* OOM exceptions, this method will gradually unroll the iterator while periodically checking
283+
* whether there is enough free memory. If the block is successfully materialized, then the
284+
* temporary unroll memory used during the materialization is "transferred" to storage memory,
285+
* so we won't acquire more memory than is actually needed to store the block.
286+
*
287+
* @return in case of success, the estimated the estimated size of the stored data. In case of
288+
* failure, return a handle which allows the caller to either finish the serialization
289+
* by spilling to disk or to deserialize the partially-serialized block and reconstruct
290+
* the original input iterator. The caller must either fully consume this result
291+
* iterator or call `discard()` on it in order to free the storage memory consumed by the
292+
* partially-unrolled block.
293+
*/
294+
private[storage] def putIteratorAsBytes[T](
295+
blockId: BlockId,
296+
values: Iterator[T],
297+
classTag: ClassTag[T]): Either[PartiallySerializedBlock[T], Long] = {
298+
299+
require(!contains(blockId), s"Block $blockId is already present in the MemoryStore")
300+
301+
// Whether there is still enough memory for us to continue unrolling this block
302+
var keepUnrolling = true
303+
// Initial per-task memory to request for unrolling blocks (bytes).
304+
val initialMemoryThreshold = unrollMemoryThreshold
305+
// Keep track of unroll memory used by this particular block / putIterator() operation
306+
var unrollMemoryUsedByThisBlock = 0L
307+
// Underlying buffer for unrolling the block
308+
val redirectableStream = new RedirectableOutputStream
309+
val byteArrayChunkOutputStream = new ByteArrayChunkOutputStream(initialMemoryThreshold.toInt)
310+
redirectableStream.setOutputStream(byteArrayChunkOutputStream)
311+
val serializationStream: SerializationStream = {
312+
val ser = serializerManager.getSerializer(classTag).newInstance()
313+
ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream))
314+
}
315+
316+
// Request enough memory to begin unrolling
317+
keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold)
318+
319+
if (!keepUnrolling) {
320+
logWarning(s"Failed to reserve initial memory threshold of " +
321+
s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.")
322+
} else {
323+
unrollMemoryUsedByThisBlock += initialMemoryThreshold
324+
}
325+
326+
def reserveAdditionalMemoryIfNecessary(): Unit = {
327+
if (byteArrayChunkOutputStream.size > unrollMemoryUsedByThisBlock) {
328+
val amountToRequest = byteArrayChunkOutputStream.size - unrollMemoryUsedByThisBlock
329+
keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest)
330+
if (keepUnrolling) {
331+
unrollMemoryUsedByThisBlock += amountToRequest
332+
}
333+
}
334+
}
335+
336+
// Unroll this block safely, checking whether we have exceeded our threshold
337+
while (values.hasNext && keepUnrolling) {
338+
serializationStream.writeObject(values.next())(classTag)
339+
reserveAdditionalMemoryIfNecessary()
340+
}
341+
342+
// Make sure that we have enough memory to store the block. By this point, it is possible that
343+
// the block's actual memory usage has exceeded the unroll memory by a small amount, so we
344+
// perform one final call to attempt to allocate additional memory if necessary.
345+
if (keepUnrolling) {
346+
serializationStream.close()
347+
reserveAdditionalMemoryIfNecessary()
348+
}
349+
350+
if (keepUnrolling) {
351+
val entry = SerializedMemoryEntry[T](
352+
new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap)), classTag)
353+
// Synchronize so that transfer is atomic
354+
memoryManager.synchronized {
355+
releaseUnrollMemoryForThisTask(unrollMemoryUsedByThisBlock)
356+
val success = memoryManager.acquireStorageMemory(blockId, entry.size)
357+
assert(success, "transferring unroll memory to storage memory failed")
358+
}
359+
entries.synchronized {
360+
entries.put(blockId, entry)
361+
}
362+
logInfo("Block %s stored as bytes in memory (estimated size %s, free %s)".format(
363+
blockId, Utils.bytesToString(entry.size), Utils.bytesToString(blocksMemoryUsed)))
364+
Right(entry.size)
365+
} else {
366+
// We ran out of space while unrolling the values for this block
367+
logUnrollFailureMessage(blockId, byteArrayChunkOutputStream.size)
368+
Left(
369+
new PartiallySerializedBlock(
370+
this,
371+
serializerManager,
372+
blockId,
373+
serializationStream,
374+
redirectableStream,
375+
unrollMemoryUsedByThisBlock,
376+
new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap)),
377+
values,
378+
classTag))
379+
}
380+
}
381+
282382
def getBytes(blockId: BlockId): Option[ChunkedByteBuffer] = {
283383
val entry = entries.synchronized { entries.get(blockId) }
284384
entry match {
285385
case null => None
286386
case e: DeserializedMemoryEntry[_] =>
287387
throw new IllegalArgumentException("should only call getBytes on serialized blocks")
288-
case SerializedMemoryEntry(bytes, _, _) => Some(bytes)
388+
case SerializedMemoryEntry(bytes, _) => Some(bytes)
289389
}
290390
}
291391

@@ -373,7 +473,7 @@ private[spark] class MemoryStore(
373473
def dropBlock[T](blockId: BlockId, entry: MemoryEntry[T]): Unit = {
374474
val data = entry match {
375475
case DeserializedMemoryEntry(values, _, _) => Left(values)
376-
case SerializedMemoryEntry(buffer, _, _) => Right(buffer)
476+
case SerializedMemoryEntry(buffer, _) => Right(buffer)
377477
}
378478
val newEffectiveStorageLevel =
379479
blockEvictionHandler.dropFromMemory(blockId, () => data)(entry.classTag)
@@ -507,12 +607,13 @@ private[spark] class MemoryStore(
507607
}
508608

509609
/**
510-
* The result of a failed [[MemoryStore.putIterator()]] call.
610+
* The result of a failed [[MemoryStore.putIteratorAsValues()]] call.
511611
*
512-
* @param memoryStore the memoryStore, used for freeing memory.
612+
* @param memoryStore the memoryStore, used for freeing memory.
513613
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
514-
* @param unrolled an iterator for the partially-unrolled values.
515-
* @param rest the rest of the original iterator passed to [[MemoryStore.putIterator()]].
614+
* @param unrolled an iterator for the partially-unrolled values.
615+
* @param rest the rest of the original iterator passed to
616+
* [[MemoryStore.putIteratorAsValues()]].
516617
*/
517618
private[storage] class PartiallyUnrolledIterator[T](
518619
memoryStore: MemoryStore,
@@ -544,3 +645,81 @@ private[storage] class PartiallyUnrolledIterator[T](
544645
iter = null
545646
}
546647
}
648+
649+
/**
650+
* A wrapper which allows an open [[OutputStream]] to be redirected to a different sink.
651+
*/
652+
private class RedirectableOutputStream extends OutputStream {
653+
private[this] var os: OutputStream = _
654+
def setOutputStream(s: OutputStream): Unit = { os = s }
655+
override def write(b: Int): Unit = os.write(b)
656+
override def write(b: Array[Byte]): Unit = os.write(b)
657+
override def write(b: Array[Byte], off: Int, len: Int): Unit = os.write(b, off, len)
658+
override def flush(): Unit = os.flush()
659+
override def close(): Unit = os.close()
660+
}
661+
662+
/**
663+
* The result of a failed [[MemoryStore.putIteratorAsBytes()]] call.
664+
*
665+
* @param memoryStore the MemoryStore, used for freeing memory.
666+
* @param serializerManager the SerializerManager, used for deserializing values.
667+
* @param blockId the block id.
668+
* @param serializationStream a serialization stream which writes to [[redirectableOutputStream]].
669+
* @param redirectableOutputStream an OutputStream which can be redirected to a different sink.
670+
* @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
671+
* @param unrolled a byte buffer containing the partially-serialized values.
672+
* @param rest the rest of the original iterator passed to
673+
* [[MemoryStore.putIteratorAsValues()]].
674+
* @param classTag the [[ClassTag]] for the block.
675+
*/
676+
private[storage] class PartiallySerializedBlock[T](
677+
memoryStore: MemoryStore,
678+
serializerManager: SerializerManager,
679+
blockId: BlockId,
680+
serializationStream: SerializationStream,
681+
redirectableOutputStream: RedirectableOutputStream,
682+
unrollMemory: Long,
683+
unrolled: ChunkedByteBuffer,
684+
rest: Iterator[T],
685+
classTag: ClassTag[T]) {
686+
687+
/**
688+
* Called to dispose of this block and free its memory.
689+
*/
690+
def discard(): Unit = {
691+
try {
692+
serializationStream.close()
693+
} finally {
694+
memoryStore.releaseUnrollMemoryForThisTask(unrollMemory)
695+
}
696+
}
697+
698+
/**
699+
* Finish writing this block to the given output stream by first writing the serialized values
700+
* and then serializing the values from the original input iterator.
701+
*/
702+
def finishWritingToStream(os: OutputStream): Unit = {
703+
ByteStreams.copy(unrolled.toInputStream(), os)
704+
redirectableOutputStream.setOutputStream(os)
705+
while (rest.hasNext) {
706+
serializationStream.writeObject(rest.next())(classTag)
707+
}
708+
discard()
709+
}
710+
711+
/**
712+
* Returns an iterator over the values in this block by first deserializing the serialized
713+
* values and then consuming the rest of the original input iterator.
714+
*
715+
* If the caller does not plan to fully consume the resulting iterator then they must call
716+
* `close()` on it to free its resources.
717+
*/
718+
def valuesIterator: PartiallyUnrolledIterator[T] = {
719+
new PartiallyUnrolledIterator(
720+
memoryStore,
721+
unrollMemory,
722+
unrolled = serializerManager.dataDeserialize(blockId, unrolled)(classTag),
723+
rest = rest)
724+
}
725+
}

0 commit comments

Comments
 (0)