Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,15 @@ package object config {
.booleanConf
.createWithDefault(true)

private[spark] val SHUFFLE_DETECT_CORRUPT_MEMORY =
ConfigBuilder("spark.shuffle.detectCorrupt.useExtraMemory")
.doc("If enabled, part of a compressed/encrypted stream will be de-compressed/de-crypted " +
"by using extra memory to detect early corruption. Any IOException thrown will cause " +
"the task to be retried once and if it fails again with same exception, then " +
"FetchFailedException will be thrown to retry previous stage")
.booleanConf
.createWithDefault(false)

private[spark] val SHUFFLE_SYNC =
ConfigBuilder("spark.shuffle.sync")
.doc("Whether to force outstanding writes to disk.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
readMetrics).toCompletionIterator

val serializerInstance = dep.serializer.newInstance()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,23 @@

package org.apache.spark.storage

import java.io.{InputStream, IOException}
import java.io.{InputStream, IOException, SequenceInputStream}
import java.nio.ByteBuffer
import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}

import org.apache.commons.io.IOUtils

import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle._
import org.apache.spark.network.util.TransportConf
import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter}
import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils}
import org.apache.spark.util.io.ChunkedByteBufferOutputStream

/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
Expand Down Expand Up @@ -73,6 +74,7 @@ final class ShuffleBlockFetcherIterator(
maxBlocksInFlightPerAddress: Int,
maxReqSizeShuffleToMem: Long,
detectCorrupt: Boolean,
detectCorruptUseExtraMemory: Boolean,
shuffleMetrics: ShuffleReadMetricsReporter)
extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging {

Expand Down Expand Up @@ -406,6 +408,7 @@ final class ShuffleBlockFetcherIterator(

var result: FetchResult = null
var input: InputStream = null
var streamCompressedOrEncrypted: Boolean = false
// Take the next fetched result and try to decompress it to detect data corruption,
// then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
// is also corrupt, so the previous stage could be retried.
Expand Down Expand Up @@ -463,25 +466,22 @@ final class ShuffleBlockFetcherIterator(
buf.release()
throwFetchFailedException(blockId, address, e)
}
var isStreamCopied: Boolean = false
try {
input = streamWrapper(blockId, in)
// Only copy the stream if it's wrapped by compression or encryption, also the size of
// block is small (the decompressed block is smaller than maxBytesInFlight)
if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) {
isStreamCopied = true
val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
// Decompress the whole block at once to detect any corruption, which could increase
// the memory usage tne potential increase the chance of OOM.
// If the stream is compressed or wrapped, then we optionally decompress/unwrap the
// first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
// of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
// the corruption is later, we'll still detect the corruption later in the stream.
streamCompressedOrEncrypted = !input.eq(in)
if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
// TODO: manage the memory used here, and spill it into disk in case of OOM.
Utils.copyStream(input, out, closeStreams = true)
input = out.toChunkedByteBuffer.toInputStream(dispose = true)
input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
}
} catch {
case e: IOException =>
buf.release()
if (buf.isInstanceOf[FileSegmentManagedBuffer]
|| corruptedBlocks.contains(blockId)) {
|| corruptedBlocks.contains(blockId)) {
throwFetchFailedException(blockId, address, e)
} else {
logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
Expand All @@ -491,7 +491,9 @@ final class ShuffleBlockFetcherIterator(
}
} finally {
// TODO: release the buf here to free memory earlier
if (isStreamCopied) {
if (input == null) {
// Close the underlying stream if there was an issue in wrapping the stream using
// streamWrapper
in.close()
}
}
Expand All @@ -508,7 +510,13 @@ final class ShuffleBlockFetcherIterator(
throw new NoSuchElementException()
}
currentResult = result.asInstanceOf[SuccessFetchResult]
(currentResult.blockId, new BufferReleasingInputStream(input, this))
(currentResult.blockId,
new BufferReleasingInputStream(
input,
this,
currentResult.blockId,
currentResult.address,
detectCorrupt && streamCompressedOrEncrypted))
}

def toCompletionIterator: Iterator[(BlockId, InputStream)] = {
Expand Down Expand Up @@ -571,7 +579,10 @@ final class ShuffleBlockFetcherIterator(
}
}

private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = {
private[storage] def throwFetchFailedException(
blockId: BlockId,
address: BlockManagerId,
e: Throwable) = {
blockId match {
case ShuffleBlockId(shufId, mapId, reduceId) =>
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
Expand All @@ -583,15 +594,28 @@ final class ShuffleBlockFetcherIterator(
}

/**
* Helper class that ensures a ManagedBuffer is released upon InputStream.close()
* Helper class that ensures a ManagedBuffer is released upon InputStream.close() and
* also detects stream corruption if streamCompressedOrEncrypted is true
*/
private class BufferReleasingInputStream(
private val delegate: InputStream,
private val iterator: ShuffleBlockFetcherIterator)
// This is visible for testing
private[storage] val delegate: InputStream,
private val iterator: ShuffleBlockFetcherIterator,
private val blockId: BlockId,
private val address: BlockManagerId,
private val detectCorruption: Boolean)
extends InputStream {
private[this] var closed = false

override def read(): Int = delegate.read()
override def read(): Int = {
try {
delegate.read()
} catch {
case e: IOException if detectCorruption =>
IOUtils.closeQuietly(this)
iterator.throwFetchFailedException(blockId, address, e)
}
}

override def close(): Unit = {
if (!closed) {
Expand All @@ -605,13 +629,37 @@ private class BufferReleasingInputStream(

override def mark(readlimit: Int): Unit = delegate.mark(readlimit)

override def skip(n: Long): Long = delegate.skip(n)
override def skip(n: Long): Long = {
try {
delegate.skip(n)
} catch {
case e: IOException if detectCorruption =>
IOUtils.closeQuietly(this)
iterator.throwFetchFailedException(blockId, address, e)
}
}

override def markSupported(): Boolean = delegate.markSupported()

override def read(b: Array[Byte]): Int = delegate.read(b)
override def read(b: Array[Byte]): Int = {
try {
delegate.read(b)
} catch {
case e: IOException if detectCorruption =>
IOUtils.closeQuietly(this)
iterator.throwFetchFailedException(blockId, address, e)
}
}

override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len)
override def read(b: Array[Byte], off: Int, len: Int): Int = {
try {
delegate.read(b, off, len)
} catch {
case e: IOException if detectCorruption =>
IOUtils.closeQuietly(this)
iterator.throwFetchFailedException(blockId, address, e)
}
}

override def reset(): Unit = delegate.reset()
}
Expand Down
45 changes: 45 additions & 0 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
import org.apache.spark.status.api.v1.{StackTrace, ThreadStackTrace}
import org.apache.spark.util.io.ChunkedByteBufferOutputStream

/** CallSite represents a place in user code. It can have a short and a long form. */
private[spark] case class CallSite(shortForm: String, longForm: String)
Expand Down Expand Up @@ -337,6 +338,50 @@ private[spark] object Utils extends Logging {
}
}

/**
* Copy the first `maxSize` bytes of data from the InputStream to an in-memory
* buffer, primarily to check for corruption.
*
* This returns a new InputStream which contains the same data as the original input stream.
* It may be entirely on in-memory buffer, or it may be a combination of in-memory data, and then
* continue to read from the original stream. The only real use of this is if the original input
* stream will potentially detect corruption while the data is being read (eg. from compression).
* This allows for an eager check of corruption in the first maxSize bytes of data.
*
* @return An InputStream which includes all data from the original stream (combining buffered
* data and remaining data in the original stream)
*/
def copyStreamUpTo(in: InputStream, maxSize: Long): InputStream = {
var count = 0L
val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
val fullyCopied = tryWithSafeFinally {
val bufSize = Math.min(8192L, maxSize)
val buf = new Array[Byte](bufSize.toInt)
var n = 0
while (n != -1 && count < maxSize) {
n = in.read(buf, 0, Math.min(maxSize - count, bufSize).toInt)
if (n != -1) {
out.write(buf, 0, n)
count += n
}
}
count < maxSize
} {
try {
if (count < maxSize) {
in.close()
}
} finally {
out.close()
}
}
if (fullyCopied) {
out.toChunkedByteBuffer.toInputStream(dispose = true)
} else {
new SequenceInputStream( out.toChunkedByteBuffer.toInputStream(dispose = true), in)
}
}

def copyFileStreamNIO(
input: FileChannel,
output: FileChannel,
Expand Down
Loading