1717
1818package org .apache .spark .storage .memory
1919
20+ import java .io .OutputStream
21+ import java .nio .ByteBuffer
2022import java .util .LinkedHashMap
2123
2224import scala .collection .mutable
2325import scala .collection .mutable .ArrayBuffer
2426import scala .reflect .ClassTag
2527
28+ import com .google .common .io .ByteStreams
29+
2630import org .apache .spark .{SparkConf , TaskContext }
2731import org .apache .spark .internal .Logging
2832import org .apache .spark .memory .MemoryManager
29- import org .apache .spark .serializer .SerializerManager
33+ import org .apache .spark .serializer .{ SerializationStream , SerializerManager }
3034import org .apache .spark .storage .{BlockId , BlockInfoManager , StorageLevel }
3135import org .apache .spark .util .{CompletionIterator , SizeEstimator , Utils }
3236import org .apache .spark .util .collection .SizeTrackingVector
33- import org .apache .spark .util .io .ChunkedByteBuffer
37+ import org .apache .spark .util .io .{ ByteArrayChunkOutputStream , ChunkedByteBuffer }
3438
3539private sealed trait MemoryEntry [T ] {
3640 def size : Long
@@ -42,8 +46,9 @@ private case class DeserializedMemoryEntry[T](
4246 classTag : ClassTag [T ]) extends MemoryEntry [T ]
4347private case class SerializedMemoryEntry [T ](
4448 buffer : ChunkedByteBuffer ,
45- size : Long ,
46- classTag : ClassTag [T ]) extends MemoryEntry [T ]
49+ classTag : ClassTag [T ]) extends MemoryEntry [T ] {
50+ def size : Long = buffer.size
51+ }
4752
4853private [storage] trait BlockEvictionHandler {
4954 /**
@@ -132,7 +137,7 @@ private[spark] class MemoryStore(
132137 // We acquired enough memory for the block, so go ahead and put it
133138 val bytes = _bytes()
134139 assert(bytes.size == size)
135- val entry = new SerializedMemoryEntry [T ](bytes, size, implicitly[ClassTag [T ]])
140+ val entry = new SerializedMemoryEntry [T ](bytes, implicitly[ClassTag [T ]])
136141 entries.synchronized {
137142 entries.put(blockId, entry)
138143 }
@@ -145,7 +150,7 @@ private[spark] class MemoryStore(
145150 }
146151
147152 /**
148- * Attempt to put the given block in memory store.
153+ * Attempt to put the given block in memory store as values .
149154 *
150155 * It's possible that the iterator is too large to materialize and store in memory. To avoid
151156 * OOM exceptions, this method will gradually unroll the iterator while periodically checking
@@ -160,10 +165,9 @@ private[spark] class MemoryStore(
160165 * iterator or call `close()` on it in order to free the storage memory consumed by the
161166 * partially-unrolled block.
162167 */
163- private [storage] def putIterator [T ](
168+ private [storage] def putIteratorAsValues [T ](
164169 blockId : BlockId ,
165170 values : Iterator [T ],
166- level : StorageLevel ,
167171 classTag : ClassTag [T ]): Either [PartiallyUnrolledIterator [T ], Long ] = {
168172
169173 require(! contains(blockId), s " Block $blockId is already present in the MemoryStore " )
@@ -218,12 +222,8 @@ private[spark] class MemoryStore(
218222 // We successfully unrolled the entirety of this block
219223 val arrayValues = vector.toArray
220224 vector = null
221- val entry = if (level.deserialized) {
225+ val entry =
222226 new DeserializedMemoryEntry [T ](arrayValues, SizeEstimator .estimate(arrayValues), classTag)
223- } else {
224- val bytes = serializerManager.dataSerialize(blockId, arrayValues.iterator)(classTag)
225- new SerializedMemoryEntry [T ](bytes, bytes.size, classTag)
226- }
227227 val size = entry.size
228228 def transferUnrollToStorage (amount : Long ): Unit = {
229229 // Synchronize so that transfer is atomic
@@ -255,12 +255,8 @@ private[spark] class MemoryStore(
255255 entries.synchronized {
256256 entries.put(blockId, entry)
257257 }
258- val bytesOrValues = if (level.deserialized) " values" else " bytes"
259- logInfo(" Block %s stored as %s in memory (estimated size %s, free %s)" .format(
260- blockId,
261- bytesOrValues,
262- Utils .bytesToString(size),
263- Utils .bytesToString(maxMemory - blocksMemoryUsed)))
258+ logInfo(" Block %s stored as values in memory (estimated size %s, free %s)" .format(
259+ blockId, Utils .bytesToString(size), Utils .bytesToString(maxMemory - blocksMemoryUsed)))
264260 Right (size)
265261 } else {
266262 assert(currentUnrollMemoryForThisTask >= currentUnrollMemoryForThisTask,
@@ -279,13 +275,117 @@ private[spark] class MemoryStore(
279275 }
280276 }
281277
278+ /**
279+ * Attempt to put the given block in memory store as bytes.
280+ *
281+ * It's possible that the iterator is too large to materialize and store in memory. To avoid
282+ * OOM exceptions, this method will gradually unroll the iterator while periodically checking
283+ * whether there is enough free memory. If the block is successfully materialized, then the
284+ * temporary unroll memory used during the materialization is "transferred" to storage memory,
285+ * so we won't acquire more memory than is actually needed to store the block.
286+ *
287+ * @return in case of success, the estimated the estimated size of the stored data. In case of
288+ * failure, return a handle which allows the caller to either finish the serialization
289+ * by spilling to disk or to deserialize the partially-serialized block and reconstruct
290+ * the original input iterator. The caller must either fully consume this result
291+ * iterator or call `discard()` on it in order to free the storage memory consumed by the
292+ * partially-unrolled block.
293+ */
294+ private [storage] def putIteratorAsBytes [T ](
295+ blockId : BlockId ,
296+ values : Iterator [T ],
297+ classTag : ClassTag [T ]): Either [PartiallySerializedBlock [T ], Long ] = {
298+
299+ require(! contains(blockId), s " Block $blockId is already present in the MemoryStore " )
300+
301+ // Whether there is still enough memory for us to continue unrolling this block
302+ var keepUnrolling = true
303+ // Initial per-task memory to request for unrolling blocks (bytes).
304+ val initialMemoryThreshold = unrollMemoryThreshold
305+ // Keep track of unroll memory used by this particular block / putIterator() operation
306+ var unrollMemoryUsedByThisBlock = 0L
307+ // Underlying buffer for unrolling the block
308+ val redirectableStream = new RedirectableOutputStream
309+ val byteArrayChunkOutputStream = new ByteArrayChunkOutputStream (initialMemoryThreshold.toInt)
310+ redirectableStream.setOutputStream(byteArrayChunkOutputStream)
311+ val serializationStream : SerializationStream = {
312+ val ser = serializerManager.getSerializer(classTag).newInstance()
313+ ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream))
314+ }
315+
316+ // Request enough memory to begin unrolling
317+ keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold)
318+
319+ if (! keepUnrolling) {
320+ logWarning(s " Failed to reserve initial memory threshold of " +
321+ s " ${Utils .bytesToString(initialMemoryThreshold)} for computing block $blockId in memory. " )
322+ } else {
323+ unrollMemoryUsedByThisBlock += initialMemoryThreshold
324+ }
325+
326+ def reserveAdditionalMemoryIfNecessary (): Unit = {
327+ if (byteArrayChunkOutputStream.size > unrollMemoryUsedByThisBlock) {
328+ val amountToRequest = byteArrayChunkOutputStream.size - unrollMemoryUsedByThisBlock
329+ keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest)
330+ if (keepUnrolling) {
331+ unrollMemoryUsedByThisBlock += amountToRequest
332+ }
333+ }
334+ }
335+
336+ // Unroll this block safely, checking whether we have exceeded our threshold
337+ while (values.hasNext && keepUnrolling) {
338+ serializationStream.writeObject(values.next())(classTag)
339+ reserveAdditionalMemoryIfNecessary()
340+ }
341+
342+ // Make sure that we have enough memory to store the block. By this point, it is possible that
343+ // the block's actual memory usage has exceeded the unroll memory by a small amount, so we
344+ // perform one final call to attempt to allocate additional memory if necessary.
345+ if (keepUnrolling) {
346+ serializationStream.close()
347+ reserveAdditionalMemoryIfNecessary()
348+ }
349+
350+ if (keepUnrolling) {
351+ val entry = SerializedMemoryEntry [T ](
352+ new ChunkedByteBuffer (byteArrayChunkOutputStream.toArrays.map(ByteBuffer .wrap)), classTag)
353+ // Synchronize so that transfer is atomic
354+ memoryManager.synchronized {
355+ releaseUnrollMemoryForThisTask(unrollMemoryUsedByThisBlock)
356+ val success = memoryManager.acquireStorageMemory(blockId, entry.size)
357+ assert(success, " transferring unroll memory to storage memory failed" )
358+ }
359+ entries.synchronized {
360+ entries.put(blockId, entry)
361+ }
362+ logInfo(" Block %s stored as bytes in memory (estimated size %s, free %s)" .format(
363+ blockId, Utils .bytesToString(entry.size), Utils .bytesToString(blocksMemoryUsed)))
364+ Right (entry.size)
365+ } else {
366+ // We ran out of space while unrolling the values for this block
367+ logUnrollFailureMessage(blockId, byteArrayChunkOutputStream.size)
368+ Left (
369+ new PartiallySerializedBlock (
370+ this ,
371+ serializerManager,
372+ blockId,
373+ serializationStream,
374+ redirectableStream,
375+ unrollMemoryUsedByThisBlock,
376+ new ChunkedByteBuffer (byteArrayChunkOutputStream.toArrays.map(ByteBuffer .wrap)),
377+ values,
378+ classTag))
379+ }
380+ }
381+
282382 def getBytes (blockId : BlockId ): Option [ChunkedByteBuffer ] = {
283383 val entry = entries.synchronized { entries.get(blockId) }
284384 entry match {
285385 case null => None
286386 case e : DeserializedMemoryEntry [_] =>
287387 throw new IllegalArgumentException (" should only call getBytes on serialized blocks" )
288- case SerializedMemoryEntry (bytes, _, _ ) => Some (bytes)
388+ case SerializedMemoryEntry (bytes, _) => Some (bytes)
289389 }
290390 }
291391
@@ -373,7 +473,7 @@ private[spark] class MemoryStore(
373473 def dropBlock [T ](blockId : BlockId , entry : MemoryEntry [T ]): Unit = {
374474 val data = entry match {
375475 case DeserializedMemoryEntry (values, _, _) => Left (values)
376- case SerializedMemoryEntry (buffer, _, _ ) => Right (buffer)
476+ case SerializedMemoryEntry (buffer, _) => Right (buffer)
377477 }
378478 val newEffectiveStorageLevel =
379479 blockEvictionHandler.dropFromMemory(blockId, () => data)(entry.classTag)
@@ -507,12 +607,13 @@ private[spark] class MemoryStore(
507607}
508608
509609/**
510- * The result of a failed [[MemoryStore.putIterator () ]] call.
610+ * The result of a failed [[MemoryStore.putIteratorAsValues () ]] call.
511611 *
512- * @param memoryStore the memoryStore, used for freeing memory.
612+ * @param memoryStore the memoryStore, used for freeing memory.
513613 * @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
514- * @param unrolled an iterator for the partially-unrolled values.
515- * @param rest the rest of the original iterator passed to [[MemoryStore.putIterator() ]].
614+ * @param unrolled an iterator for the partially-unrolled values.
615+ * @param rest the rest of the original iterator passed to
616+ * [[MemoryStore.putIteratorAsValues() ]].
516617 */
517618private [storage] class PartiallyUnrolledIterator [T ](
518619 memoryStore : MemoryStore ,
@@ -544,3 +645,81 @@ private[storage] class PartiallyUnrolledIterator[T](
544645 iter = null
545646 }
546647}
648+
649+ /**
650+ * A wrapper which allows an open [[OutputStream ]] to be redirected to a different sink.
651+ */
652+ private class RedirectableOutputStream extends OutputStream {
653+ private [this ] var os : OutputStream = _
654+ def setOutputStream (s : OutputStream ): Unit = { os = s }
655+ override def write (b : Int ): Unit = os.write(b)
656+ override def write (b : Array [Byte ]): Unit = os.write(b)
657+ override def write (b : Array [Byte ], off : Int , len : Int ): Unit = os.write(b, off, len)
658+ override def flush (): Unit = os.flush()
659+ override def close (): Unit = os.close()
660+ }
661+
662+ /**
663+ * The result of a failed [[MemoryStore.putIteratorAsBytes() ]] call.
664+ *
665+ * @param memoryStore the MemoryStore, used for freeing memory.
666+ * @param serializerManager the SerializerManager, used for deserializing values.
667+ * @param blockId the block id.
668+ * @param serializationStream a serialization stream which writes to [[redirectableOutputStream ]].
669+ * @param redirectableOutputStream an OutputStream which can be redirected to a different sink.
670+ * @param unrollMemory the amount of unroll memory used by the values in `unrolled`.
671+ * @param unrolled a byte buffer containing the partially-serialized values.
672+ * @param rest the rest of the original iterator passed to
673+ * [[MemoryStore.putIteratorAsValues() ]].
674+ * @param classTag the [[ClassTag ]] for the block.
675+ */
676+ private [storage] class PartiallySerializedBlock [T ](
677+ memoryStore : MemoryStore ,
678+ serializerManager : SerializerManager ,
679+ blockId : BlockId ,
680+ serializationStream : SerializationStream ,
681+ redirectableOutputStream : RedirectableOutputStream ,
682+ unrollMemory : Long ,
683+ unrolled : ChunkedByteBuffer ,
684+ rest : Iterator [T ],
685+ classTag : ClassTag [T ]) {
686+
687+ /**
688+ * Called to dispose of this block and free its memory.
689+ */
690+ def discard (): Unit = {
691+ try {
692+ serializationStream.close()
693+ } finally {
694+ memoryStore.releaseUnrollMemoryForThisTask(unrollMemory)
695+ }
696+ }
697+
698+ /**
699+ * Finish writing this block to the given output stream by first writing the serialized values
700+ * and then serializing the values from the original input iterator.
701+ */
702+ def finishWritingToStream (os : OutputStream ): Unit = {
703+ ByteStreams .copy(unrolled.toInputStream(), os)
704+ redirectableOutputStream.setOutputStream(os)
705+ while (rest.hasNext) {
706+ serializationStream.writeObject(rest.next())(classTag)
707+ }
708+ discard()
709+ }
710+
711+ /**
712+ * Returns an iterator over the values in this block by first deserializing the serialized
713+ * values and then consuming the rest of the original input iterator.
714+ *
715+ * If the caller does not plan to fully consume the resulting iterator then they must call
716+ * `close()` on it to free its resources.
717+ */
718+ def valuesIterator : PartiallyUnrolledIterator [T ] = {
719+ new PartiallyUnrolledIterator (
720+ memoryStore,
721+ unrollMemory,
722+ unrolled = serializerManager.dataDeserialize(blockId, unrolled)(classTag),
723+ rest = rest)
724+ }
725+ }
0 commit comments