Skip to content

Commit 688b0c0

Browse files
ankuriitgsquito
authored andcommitted
[SPARK-26089][CORE] Handle corruption in large shuffle blocks
## What changes were proposed in this pull request? SPARK-4105 added corruption detection in shuffle blocks but that was limited to blocks which are smaller than maxBytesInFlight/3. This commit adds upon that by adding corruption check for large blocks. There are two changes/improvements that are made in this commit: 1. Large blocks are checked upto maxBytesInFlight/3 size in a similar way as smaller blocks, so if a large block is corrupt in the starting, that block will be re-fetched and if that also fails, FetchFailureException will be thrown. 2. If large blocks are corrupt after size maxBytesInFlight/3, then any IOException thrown while reading the stream will be converted to FetchFailureException. This is slightly more aggressive than was originally intended but since the consumer of the stream may have already read some records and processed them, we can't just re-fetch the block, we need to fail the whole task. Additionally, we also thought about maybe adding a new type of TaskEndReason, which would re-try the task couple of times before failing the previous stage, but given the complexity involved in that solution we decided to not proceed in that direction. Thanks to squito for direction and support. ## How was this patch tested? Changed the junit test for big blocks to check for corruption. Closes #23453 from ankuriitg/ankurgupta/SPARK-26089. Authored-by: ankurgupta <[email protected]> Signed-off-by: Imran Rashid <[email protected]>
1 parent 3f9247d commit 688b0c0

File tree

6 files changed

+306
-45
lines changed

6 files changed

+306
-45
lines changed

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,15 @@ package object config {
928928
.booleanConf
929929
.createWithDefault(true)
930930

931+
private[spark] val SHUFFLE_DETECT_CORRUPT_MEMORY =
932+
ConfigBuilder("spark.shuffle.detectCorrupt.useExtraMemory")
933+
.doc("If enabled, part of a compressed/encrypted stream will be de-compressed/de-crypted " +
934+
"by using extra memory to detect early corruption. Any IOException thrown will cause " +
935+
"the task to be retried once and if it fails again with same exception, then " +
936+
"FetchFailedException will be thrown to retry previous stage")
937+
.booleanConf
938+
.createWithDefault(false)
939+
931940
private[spark] val SHUFFLE_SYNC =
932941
ConfigBuilder("spark.shuffle.sync")
933942
.doc("Whether to force outstanding writes to disk.")

core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
5555
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
5656
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
5757
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
58+
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
5859
readMetrics).toCompletionIterator
5960

6061
val serializerInstance = dep.serializer.newInstance()

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 71 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,23 @@
1717

1818
package org.apache.spark.storage
1919

20-
import java.io.{InputStream, IOException}
20+
import java.io.{InputStream, IOException, SequenceInputStream}
2121
import java.nio.ByteBuffer
2222
import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
2323
import javax.annotation.concurrent.GuardedBy
2424

2525
import scala.collection.mutable
2626
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
2727

28+
import org.apache.commons.io.IOUtils
29+
2830
import org.apache.spark.{SparkException, TaskContext}
2931
import org.apache.spark.internal.Logging
3032
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
3133
import org.apache.spark.network.shuffle._
3234
import org.apache.spark.network.util.TransportConf
3335
import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter}
3436
import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils}
35-
import org.apache.spark.util.io.ChunkedByteBufferOutputStream
3637

3738
/**
3839
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
@@ -73,6 +74,7 @@ final class ShuffleBlockFetcherIterator(
7374
maxBlocksInFlightPerAddress: Int,
7475
maxReqSizeShuffleToMem: Long,
7576
detectCorrupt: Boolean,
77+
detectCorruptUseExtraMemory: Boolean,
7678
shuffleMetrics: ShuffleReadMetricsReporter)
7779
extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging {
7880

@@ -406,6 +408,7 @@ final class ShuffleBlockFetcherIterator(
406408

407409
var result: FetchResult = null
408410
var input: InputStream = null
411+
var streamCompressedOrEncrypted: Boolean = false
409412
// Take the next fetched result and try to decompress it to detect data corruption,
410413
// then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
411414
// is also corrupt, so the previous stage could be retried.
@@ -463,25 +466,22 @@ final class ShuffleBlockFetcherIterator(
463466
buf.release()
464467
throwFetchFailedException(blockId, address, e)
465468
}
466-
var isStreamCopied: Boolean = false
467469
try {
468470
input = streamWrapper(blockId, in)
469-
// Only copy the stream if it's wrapped by compression or encryption, also the size of
470-
// block is small (the decompressed block is smaller than maxBytesInFlight)
471-
if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) {
472-
isStreamCopied = true
473-
val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
474-
// Decompress the whole block at once to detect any corruption, which could increase
475-
// the memory usage tne potential increase the chance of OOM.
471+
// If the stream is compressed or wrapped, then we optionally decompress/unwrap the
472+
// first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
473+
// of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
474+
// the corruption is later, we'll still detect the corruption later in the stream.
475+
streamCompressedOrEncrypted = !input.eq(in)
476+
if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
476477
// TODO: manage the memory used here, and spill it into disk in case of OOM.
477-
Utils.copyStream(input, out, closeStreams = true)
478-
input = out.toChunkedByteBuffer.toInputStream(dispose = true)
478+
input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
479479
}
480480
} catch {
481481
case e: IOException =>
482482
buf.release()
483483
if (buf.isInstanceOf[FileSegmentManagedBuffer]
484-
|| corruptedBlocks.contains(blockId)) {
484+
|| corruptedBlocks.contains(blockId)) {
485485
throwFetchFailedException(blockId, address, e)
486486
} else {
487487
logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
@@ -491,7 +491,9 @@ final class ShuffleBlockFetcherIterator(
491491
}
492492
} finally {
493493
// TODO: release the buf here to free memory earlier
494-
if (isStreamCopied) {
494+
if (input == null) {
495+
// Close the underlying stream if there was an issue in wrapping the stream using
496+
// streamWrapper
495497
in.close()
496498
}
497499
}
@@ -508,7 +510,13 @@ final class ShuffleBlockFetcherIterator(
508510
throw new NoSuchElementException()
509511
}
510512
currentResult = result.asInstanceOf[SuccessFetchResult]
511-
(currentResult.blockId, new BufferReleasingInputStream(input, this))
513+
(currentResult.blockId,
514+
new BufferReleasingInputStream(
515+
input,
516+
this,
517+
currentResult.blockId,
518+
currentResult.address,
519+
detectCorrupt && streamCompressedOrEncrypted))
512520
}
513521

514522
def toCompletionIterator: Iterator[(BlockId, InputStream)] = {
@@ -571,7 +579,10 @@ final class ShuffleBlockFetcherIterator(
571579
}
572580
}
573581

574-
private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = {
582+
private[storage] def throwFetchFailedException(
583+
blockId: BlockId,
584+
address: BlockManagerId,
585+
e: Throwable) = {
575586
blockId match {
576587
case ShuffleBlockId(shufId, mapId, reduceId) =>
577588
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
@@ -583,15 +594,28 @@ final class ShuffleBlockFetcherIterator(
583594
}
584595

585596
/**
586-
* Helper class that ensures a ManagedBuffer is released upon InputStream.close()
597+
* Helper class that ensures a ManagedBuffer is released upon InputStream.close() and
598+
* also detects stream corruption if streamCompressedOrEncrypted is true
587599
*/
588600
private class BufferReleasingInputStream(
589-
private val delegate: InputStream,
590-
private val iterator: ShuffleBlockFetcherIterator)
601+
// This is visible for testing
602+
private[storage] val delegate: InputStream,
603+
private val iterator: ShuffleBlockFetcherIterator,
604+
private val blockId: BlockId,
605+
private val address: BlockManagerId,
606+
private val detectCorruption: Boolean)
591607
extends InputStream {
592608
private[this] var closed = false
593609

594-
override def read(): Int = delegate.read()
610+
override def read(): Int = {
611+
try {
612+
delegate.read()
613+
} catch {
614+
case e: IOException if detectCorruption =>
615+
IOUtils.closeQuietly(this)
616+
iterator.throwFetchFailedException(blockId, address, e)
617+
}
618+
}
595619

596620
override def close(): Unit = {
597621
if (!closed) {
@@ -605,13 +629,37 @@ private class BufferReleasingInputStream(
605629

606630
override def mark(readlimit: Int): Unit = delegate.mark(readlimit)
607631

608-
override def skip(n: Long): Long = delegate.skip(n)
632+
override def skip(n: Long): Long = {
633+
try {
634+
delegate.skip(n)
635+
} catch {
636+
case e: IOException if detectCorruption =>
637+
IOUtils.closeQuietly(this)
638+
iterator.throwFetchFailedException(blockId, address, e)
639+
}
640+
}
609641

610642
override def markSupported(): Boolean = delegate.markSupported()
611643

612-
override def read(b: Array[Byte]): Int = delegate.read(b)
644+
override def read(b: Array[Byte]): Int = {
645+
try {
646+
delegate.read(b)
647+
} catch {
648+
case e: IOException if detectCorruption =>
649+
IOUtils.closeQuietly(this)
650+
iterator.throwFetchFailedException(blockId, address, e)
651+
}
652+
}
613653

614-
override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len)
654+
override def read(b: Array[Byte], off: Int, len: Int): Int = {
655+
try {
656+
delegate.read(b, off, len)
657+
} catch {
658+
case e: IOException if detectCorruption =>
659+
IOUtils.closeQuietly(this)
660+
iterator.throwFetchFailedException(blockId, address, e)
661+
}
662+
}
615663

616664
override def reset(): Unit = delegate.reset()
617665
}

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ import org.apache.spark.launcher.SparkLauncher
6767
import org.apache.spark.network.util.JavaUtils
6868
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
6969
import org.apache.spark.status.api.v1.{StackTrace, ThreadStackTrace}
70+
import org.apache.spark.util.io.ChunkedByteBufferOutputStream
7071

7172
/** CallSite represents a place in user code. It can have a short and a long form. */
7273
private[spark] case class CallSite(shortForm: String, longForm: String)
@@ -337,6 +338,50 @@ private[spark] object Utils extends Logging {
337338
}
338339
}
339340

341+
/**
342+
* Copy the first `maxSize` bytes of data from the InputStream to an in-memory
343+
* buffer, primarily to check for corruption.
344+
*
345+
* This returns a new InputStream which contains the same data as the original input stream.
346+
* It may be entirely on in-memory buffer, or it may be a combination of in-memory data, and then
347+
* continue to read from the original stream. The only real use of this is if the original input
348+
* stream will potentially detect corruption while the data is being read (eg. from compression).
349+
* This allows for an eager check of corruption in the first maxSize bytes of data.
350+
*
351+
* @return An InputStream which includes all data from the original stream (combining buffered
352+
* data and remaining data in the original stream)
353+
*/
354+
def copyStreamUpTo(in: InputStream, maxSize: Long): InputStream = {
355+
var count = 0L
356+
val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
357+
val fullyCopied = tryWithSafeFinally {
358+
val bufSize = Math.min(8192L, maxSize)
359+
val buf = new Array[Byte](bufSize.toInt)
360+
var n = 0
361+
while (n != -1 && count < maxSize) {
362+
n = in.read(buf, 0, Math.min(maxSize - count, bufSize).toInt)
363+
if (n != -1) {
364+
out.write(buf, 0, n)
365+
count += n
366+
}
367+
}
368+
count < maxSize
369+
} {
370+
try {
371+
if (count < maxSize) {
372+
in.close()
373+
}
374+
} finally {
375+
out.close()
376+
}
377+
}
378+
if (fullyCopied) {
379+
out.toChunkedByteBuffer.toInputStream(dispose = true)
380+
} else {
381+
new SequenceInputStream( out.toChunkedByteBuffer.toInputStream(dispose = true), in)
382+
}
383+
}
384+
340385
def copyFileStreamNIO(
341386
input: FileChannel,
342387
output: FileChannel,

0 commit comments

Comments
 (0)