1717
1818package org .apache .spark .storage
1919
20- import java .io .{InputStream , IOException }
20+ import java .io .{InputStream , IOException , SequenceInputStream }
2121import java .nio .ByteBuffer
2222import java .util .concurrent .{LinkedBlockingQueue , TimeUnit }
2323import javax .annotation .concurrent .GuardedBy
2424
2525import scala .collection .mutable
2626import scala .collection .mutable .{ArrayBuffer , HashMap , HashSet , Queue }
2727
28+ import org .apache .commons .io .IOUtils
29+
2830import org .apache .spark .{SparkException , TaskContext }
2931import org .apache .spark .internal .Logging
3032import org .apache .spark .network .buffer .{FileSegmentManagedBuffer , ManagedBuffer }
3133import org .apache .spark .network .shuffle ._
3234import org .apache .spark .network .util .TransportConf
3335import org .apache .spark .shuffle .{FetchFailedException , ShuffleReadMetricsReporter }
3436import org .apache .spark .util .{CompletionIterator , TaskCompletionListener , Utils }
35- import org .apache .spark .util .io .ChunkedByteBufferOutputStream
3637
3738/**
3839 * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
@@ -73,6 +74,7 @@ final class ShuffleBlockFetcherIterator(
7374 maxBlocksInFlightPerAddress : Int ,
7475 maxReqSizeShuffleToMem : Long ,
7576 detectCorrupt : Boolean ,
77+ detectCorruptUseExtraMemory : Boolean ,
7678 shuffleMetrics : ShuffleReadMetricsReporter )
7779 extends Iterator [(BlockId , InputStream )] with DownloadFileManager with Logging {
7880
@@ -406,6 +408,7 @@ final class ShuffleBlockFetcherIterator(
406408
407409 var result : FetchResult = null
408410 var input : InputStream = null
411+ var streamCompressedOrEncrypted : Boolean = false
409412 // Take the next fetched result and try to decompress it to detect data corruption,
410413 // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
411414 // is also corrupt, so the previous stage could be retried.
@@ -463,25 +466,22 @@ final class ShuffleBlockFetcherIterator(
463466 buf.release()
464467 throwFetchFailedException(blockId, address, e)
465468 }
466- var isStreamCopied : Boolean = false
467469 try {
468470 input = streamWrapper(blockId, in)
469- // Only copy the stream if it's wrapped by compression or encryption, also the size of
470- // block is small (the decompressed block is smaller than maxBytesInFlight)
471- if (detectCorrupt && ! input.eq(in) && size < maxBytesInFlight / 3 ) {
472- isStreamCopied = true
473- val out = new ChunkedByteBufferOutputStream (64 * 1024 , ByteBuffer .allocate)
474- // Decompress the whole block at once to detect any corruption, which could increase
475- // the memory usage tne potential increase the chance of OOM.
471+ // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
472+ // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
473+ // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
474+ // the corruption is later, we'll still detect the corruption later in the stream.
475+ streamCompressedOrEncrypted = ! input.eq(in)
476+ if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
476477 // TODO: manage the memory used here, and spill it into disk in case of OOM.
477- Utils .copyStream(input, out, closeStreams = true )
478- input = out.toChunkedByteBuffer.toInputStream(dispose = true )
478+ input = Utils .copyStreamUpTo(input, maxBytesInFlight / 3 )
479479 }
480480 } catch {
481481 case e : IOException =>
482482 buf.release()
483483 if (buf.isInstanceOf [FileSegmentManagedBuffer ]
484- || corruptedBlocks.contains(blockId)) {
484+ || corruptedBlocks.contains(blockId)) {
485485 throwFetchFailedException(blockId, address, e)
486486 } else {
487487 logWarning(s " got an corrupted block $blockId from $address, fetch again " , e)
@@ -491,7 +491,9 @@ final class ShuffleBlockFetcherIterator(
491491 }
492492 } finally {
493493 // TODO: release the buf here to free memory earlier
494- if (isStreamCopied) {
494+ if (input == null ) {
495+ // Close the underlying stream if there was an issue in wrapping the stream using
496+ // streamWrapper
495497 in.close()
496498 }
497499 }
@@ -508,7 +510,13 @@ final class ShuffleBlockFetcherIterator(
508510 throw new NoSuchElementException ()
509511 }
510512 currentResult = result.asInstanceOf [SuccessFetchResult ]
511- (currentResult.blockId, new BufferReleasingInputStream (input, this ))
513+ (currentResult.blockId,
514+ new BufferReleasingInputStream (
515+ input,
516+ this ,
517+ currentResult.blockId,
518+ currentResult.address,
519+ detectCorrupt && streamCompressedOrEncrypted))
512520 }
513521
514522 def toCompletionIterator : Iterator [(BlockId , InputStream )] = {
@@ -571,7 +579,10 @@ final class ShuffleBlockFetcherIterator(
571579 }
572580 }
573581
574- private def throwFetchFailedException (blockId : BlockId , address : BlockManagerId , e : Throwable ) = {
582+ private [storage] def throwFetchFailedException (
583+ blockId : BlockId ,
584+ address : BlockManagerId ,
585+ e : Throwable ) = {
575586 blockId match {
576587 case ShuffleBlockId (shufId, mapId, reduceId) =>
577588 throw new FetchFailedException (address, shufId.toInt, mapId.toInt, reduceId, e)
@@ -583,15 +594,28 @@ final class ShuffleBlockFetcherIterator(
583594}
584595
585596/**
586- * Helper class that ensures a ManagedBuffer is released upon InputStream.close()
597+ * Helper class that ensures a ManagedBuffer is released upon InputStream.close() and
598+ * also detects stream corruption if streamCompressedOrEncrypted is true
587599 */
588600private class BufferReleasingInputStream (
589- private val delegate : InputStream ,
590- private val iterator : ShuffleBlockFetcherIterator )
601+ // This is visible for testing
602+ private [storage] val delegate : InputStream ,
603+ private val iterator : ShuffleBlockFetcherIterator ,
604+ private val blockId : BlockId ,
605+ private val address : BlockManagerId ,
606+ private val detectCorruption : Boolean )
591607 extends InputStream {
592608 private [this ] var closed = false
593609
594- override def read (): Int = delegate.read()
610+ override def read (): Int = {
611+ try {
612+ delegate.read()
613+ } catch {
614+ case e : IOException if detectCorruption =>
615+ IOUtils .closeQuietly(this )
616+ iterator.throwFetchFailedException(blockId, address, e)
617+ }
618+ }
595619
596620 override def close (): Unit = {
597621 if (! closed) {
@@ -605,13 +629,37 @@ private class BufferReleasingInputStream(
605629
606630 override def mark (readlimit : Int ): Unit = delegate.mark(readlimit)
607631
608- override def skip (n : Long ): Long = delegate.skip(n)
632+ override def skip (n : Long ): Long = {
633+ try {
634+ delegate.skip(n)
635+ } catch {
636+ case e : IOException if detectCorruption =>
637+ IOUtils .closeQuietly(this )
638+ iterator.throwFetchFailedException(blockId, address, e)
639+ }
640+ }
609641
610642 override def markSupported (): Boolean = delegate.markSupported()
611643
612- override def read (b : Array [Byte ]): Int = delegate.read(b)
644+ override def read (b : Array [Byte ]): Int = {
645+ try {
646+ delegate.read(b)
647+ } catch {
648+ case e : IOException if detectCorruption =>
649+ IOUtils .closeQuietly(this )
650+ iterator.throwFetchFailedException(blockId, address, e)
651+ }
652+ }
613653
614- override def read (b : Array [Byte ], off : Int , len : Int ): Int = delegate.read(b, off, len)
654+ override def read (b : Array [Byte ], off : Int , len : Int ): Int = {
655+ try {
656+ delegate.read(b, off, len)
657+ } catch {
658+ case e : IOException if detectCorruption =>
659+ IOUtils .closeQuietly(this )
660+ iterator.throwFetchFailedException(blockId, address, e)
661+ }
662+ }
615663
616664 override def reset (): Unit = delegate.reset()
617665}
0 commit comments