From f1623a25e14950283776c244496507e96420d1c8 Mon Sep 17 00:00:00 2001 From: ankurgupta Date: Fri, 4 Jan 2019 09:28:20 -0800 Subject: [PATCH 01/12] [SPARK-26089][CORE] Handle corruption in large shuffle blocks SPARK-4105 added corruption detection in shuffle blocks but that was limited to blocks which are smaller than maxBytesInFlight/3. This commit adds upon that by adding corruption check for large blocks. There are two changes/improvements that are made in this commit: 1. Large blocks are checked upto maxBytesInFlight/3 size in a similar way as smaller blocks, so if a large block is corrupt in the starting, that block will be re-fetched and if that also fails, FetchFailureException will be thrown. 2. If large blocks are corrupt after size maxBytesInFlight/3, then any IOException thrown while reading the stream will be converted to FetchFailureException. This is slightly more aggressive than was originally intended but since the consumer of the stream may have already read some records and processed them, we can't just re-fetch the block, we need to fail the whole task. Additionally, we also thought about maybe adding a new type of TaskEndReason, which would re-try the task couple of times before failing the previous stage, but given the complexity involved in that solution we decided to not proceed in that direction. Thanks to @squito for direction and support. Testing Done: Changed the junit test for big blocks to check for corruption. --- .../storage/ShuffleBlockFetcherIterator.scala | 58 +++++++++++++++---- .../scala/org/apache/spark/util/Utils.scala | 35 +++++++++++ .../ShuffleBlockFetcherIteratorSuite.scala | 47 +++++++++++---- 3 files changed, 118 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index c75b20906918..1bb9a05c9de8 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{InputStream, IOException} +import java.io.{InputStream, IOException, SequenceInputStream} import java.nio.ByteBuffer import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} import javax.annotation.concurrent.GuardedBy @@ -406,6 +406,7 @@ final class ShuffleBlockFetcherIterator( var result: FetchResult = null var input: InputStream = null + var streamCompressedOrEncrypted: Boolean = false // Take the next fetched result and try to decompress it to detect data corruption, // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch // is also corrupt, so the previous stage could be retried. @@ -468,14 +469,21 @@ final class ShuffleBlockFetcherIterator( input = streamWrapper(blockId, in) // Only copy the stream if it's wrapped by compression or encryption, also the size of // block is small (the decompressed block is smaller than maxBytesInFlight) - if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) { + if (detectCorrupt && !input.eq(in)) { isStreamCopied = true + streamCompressedOrEncrypted = true val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) // Decompress the whole block at once to detect any corruption, which could increase // the memory usage tne potential increase the chance of OOM. // TODO: manage the memory used here, and spill it into disk in case of OOM. - Utils.copyStream(input, out, closeStreams = true) - input = out.toChunkedByteBuffer.toInputStream(dispose = true) + isStreamCopied = Utils.copyStreamUpto( + input, out, maxBytesInFlight / 3, closeStreams = true) + if (isStreamCopied) { + input = out.toChunkedByteBuffer.toInputStream(dispose = true) + } else { + input = new SequenceInputStream( + out.toChunkedByteBuffer.toInputStream(dispose = true), input) + } } } catch { case e: IOException => @@ -508,7 +516,9 @@ final class ShuffleBlockFetcherIterator( throw new NoSuchElementException() } currentResult = result.asInstanceOf[SuccessFetchResult] - (currentResult.blockId, new BufferReleasingInputStream(input, this)) + (currentResult.blockId, + new BufferReleasingInputStream( + input, this, currentResult.blockId, currentResult.address, streamCompressedOrEncrypted)) } def toCompletionIterator: Iterator[(BlockId, InputStream)] = { @@ -571,7 +581,8 @@ final class ShuffleBlockFetcherIterator( } } - private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = { + private[storage] def throwFetchFailedException( + blockId: BlockId, address: BlockManagerId, e: Throwable) = { blockId match { case ShuffleBlockId(shufId, mapId, reduceId) => throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) @@ -583,15 +594,26 @@ final class ShuffleBlockFetcherIterator( } /** - * Helper class that ensures a ManagedBuffer is released upon InputStream.close() + * Helper class that ensures a ManagedBuffer is released upon InputStream.close() and + * also detects stream corruption if detectCorruption is true */ private class BufferReleasingInputStream( private val delegate: InputStream, - private val iterator: ShuffleBlockFetcherIterator) + private val iterator: ShuffleBlockFetcherIterator, + private val blockId: BlockId, + private val address: BlockManagerId, + private val detectCorruption: Boolean) extends InputStream { private[this] var closed = false - override def read(): Int = delegate.read() + override def read(): Int = { + try { + delegate.read() + } catch { + case e: IOException if detectCorruption => + iterator.throwFetchFailedException(blockId, address, e) + } + } override def close(): Unit = { if (!closed) { @@ -609,9 +631,23 @@ private class BufferReleasingInputStream( override def markSupported(): Boolean = delegate.markSupported() - override def read(b: Array[Byte]): Int = delegate.read(b) + override def read(b: Array[Byte]): Int = { + try { + delegate.read(b) + } catch { + case e: IOException if detectCorruption => + iterator.throwFetchFailedException(blockId, address, e) + } + } - override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len) + override def read(b: Array[Byte], off: Int, len: Int): Int = { + try { + delegate.read(b, off, len) + } catch { + case e: IOException if detectCorruption => + iterator.throwFetchFailedException(blockId, address, e) + } + } override def reset(): Unit = delegate.reset() } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index cade0dd88fc7..6e34966e068c 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -337,6 +337,41 @@ private[spark] object Utils extends Logging { } } + /** + * Copy all data from an InputStream to an OutputStream upto maxSize and + * closes the input stream if closeStreams is true and all data is read. + */ + def copyStreamUpto( + in: InputStream, + out: OutputStream, + maxSize: Long, + closeStreams: Boolean = false): Boolean = { + var count = 0L + tryWithSafeFinally { + val bufSize = 8192 + val buf = new Array[Byte](bufSize) + var n = 0 + while (n != -1 && count < maxSize) { + n = in.read(buf, 0, Math.min(maxSize - count, bufSize.toLong).toInt) + if (n != -1) { + out.write(buf, 0, n) + count += n + } + } + count < maxSize + } { + if (closeStreams) { + try { + if (count < maxSize) { + in.close() + } + } finally { + out.close() + } + } + } + } + def copyFileStreamNIO( input: FileChannel, output: FileChannel, diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 98fe9663b621..b652f9803ba0 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -337,9 +337,26 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } - private def mockCorruptBuffer(size: Long = 1L): ManagedBuffer = { + private def mockCorruptBuffer(size: Long = 1L, corruptInStart: Boolean = true): ManagedBuffer = { val corruptStream = mock(classOf[InputStream]) - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + if (size < 8 * 1024 || corruptInStart) { + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + } else { + when(corruptStream.read(any(), any(), any(classOf[Int]))).thenAnswer(new Answer[Int] { + override def answer(invocationOnMock: InvocationOnMock): Int = { + val bufSize = invocationOnMock.getArgumentAt(2, classOf[Int]) + // This condition is needed as we don't throw exception until we read the stream + // less than maxBytesInFlight/3 + if (bufSize < 8 * 1024) { + return bufSize + } else { + throw new IOException("corrupt") + } + } + }) + when(corruptStream.read()).thenThrow(new IOException("corrupt")) + when(corruptStream.read(any())).thenThrow(new IOException("corrupt")) + } val corruptBuffer = mock(classOf[ManagedBuffer]) when(corruptBuffer.size()).thenReturn(size) when(corruptBuffer.createInputStream()).thenReturn(corruptStream) @@ -425,24 +442,25 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } - test("big blocks are not checked for corruption") { - val corruptBuffer = mockCorruptBuffer(10000L) + test("big blocks are also checked for corruption") { + val corruptBuffer1 = mockCorruptBuffer(10000L, true) val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) doReturn(localBmId).when(blockManager).blockManagerId - doReturn(corruptBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) + doReturn(corruptBuffer1).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) val localBlockLengths = Seq[Tuple2[BlockId, Long]]( - ShuffleBlockId(0, 0, 0) -> corruptBuffer.size() + ShuffleBlockId(0, 0, 0) -> corruptBuffer1.size() ) + val corruptBuffer2 = mockCorruptBuffer(10000L, false) val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val remoteBlockLengths = Seq[Tuple2[BlockId, Long]]( - ShuffleBlockId(0, 1, 0) -> corruptBuffer.size() + ShuffleBlockId(0, 1, 0) -> corruptBuffer2.size() ) val transfer = createMockTransfer( - Map(ShuffleBlockId(0, 0, 0) -> corruptBuffer, ShuffleBlockId(0, 1, 0) -> corruptBuffer)) + Map(ShuffleBlockId(0, 0, 0) -> corruptBuffer1, ShuffleBlockId(0, 1, 0) -> corruptBuffer2)) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (localBmId, localBlockLengths), @@ -462,9 +480,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, true, taskContext.taskMetrics.createTempShuffleReadMetrics()) - // Blocks should be returned without exceptions. - assert(Set(iterator.next()._1, iterator.next()._1) === - Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) + // Only one block should be returned which has corruption after maxBytesInFlight/3 + val (id, st) = iterator.next() + assert(id === ShuffleBlockId(0, 1, 0)) + intercept[FetchFailedException] { iterator.next(); iterator.next() } + // Following will succeed as it reads the first part of the stream which is not corrupt + st.read(new Array[Byte](8 * 1024), 0, 8 * 1024) + // Following will fail as it reads the remaining part of the stream which is corrupt + intercept[FetchFailedException] { st.read() } + intercept[FetchFailedException] { st.read(new Array[Byte](8 * 1024)) } + intercept[FetchFailedException] { st.read(new Array[Byte](8 * 1024), 0, 8 * 1024) } } test("retry corrupt blocks (disabled)") { From 365b27dcf21fcd73e42fdadcbdb39a186557bf0b Mon Sep 17 00:00:00 2001 From: ankurgupta Date: Fri, 4 Jan 2019 10:08:00 -0800 Subject: [PATCH 02/12] Correct indentation --- core/src/main/scala/org/apache/spark/util/Utils.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 6e34966e068c..de829d02f2e9 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -338,9 +338,9 @@ private[spark] object Utils extends Logging { } /** - * Copy all data from an InputStream to an OutputStream upto maxSize and - * closes the input stream if closeStreams is true and all data is read. - */ + * Copy all data from an InputStream to an OutputStream upto maxSize and + * closes the input stream if closeStreams is true and all data is read. + */ def copyStreamUpto( in: InputStream, out: OutputStream, From e130c6fdb129b386a7d39780bfafc88b7e65fcb8 Mon Sep 17 00:00:00 2001 From: ankurgupta Date: Fri, 4 Jan 2019 15:09:12 -0800 Subject: [PATCH 03/12] Review Comments: Part 1 1. Updated comments in the code 2. If IOException is thrown while reading from a stream, it will always be converted to a FetchFailureException, even when detectCorruption is false 3. Added a junit test which verifies that data can be read from concatenated stream --- .../storage/ShuffleBlockFetcherIterator.scala | 28 +++++----- .../scala/org/apache/spark/util/Utils.scala | 2 +- .../ShuffleBlockFetcherIteratorSuite.scala | 52 ++++++++++++++++++- 3 files changed, 67 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 1bb9a05c9de8..c99621af9320 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -467,16 +467,17 @@ final class ShuffleBlockFetcherIterator( var isStreamCopied: Boolean = false try { input = streamWrapper(blockId, in) - // Only copy the stream if it's wrapped by compression or encryption, also the size of - // block is small (the decompressed block is smaller than maxBytesInFlight) - if (detectCorrupt && !input.eq(in)) { + // Only copy the stream if it's wrapped by compression or encryption upto a size of + // maxBytesInFlight/3. If stream is longer, then corruption will be caught while reading + // the stream. + streamCompressedOrEncrypted = !input.eq(in) + if (detectCorrupt && streamCompressedOrEncrypted) { isStreamCopied = true - streamCompressedOrEncrypted = true val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) - // Decompress the whole block at once to detect any corruption, which could increase - // the memory usage tne potential increase the chance of OOM. + // Decompress the block upto maxBytesInFlight/3 at once to detect any corruption which + // could increase the memory usage and potentially increase the chance of OOM. // TODO: manage the memory used here, and spill it into disk in case of OOM. - isStreamCopied = Utils.copyStreamUpto( + isStreamCopied = Utils.copyStreamUpTo( input, out, maxBytesInFlight / 3, closeStreams = true) if (isStreamCopied) { input = out.toChunkedByteBuffer.toInputStream(dispose = true) @@ -595,14 +596,15 @@ final class ShuffleBlockFetcherIterator( /** * Helper class that ensures a ManagedBuffer is released upon InputStream.close() and - * also detects stream corruption if detectCorruption is true + * also detects stream corruption if streamCompressedOrEncrypted is true */ private class BufferReleasingInputStream( - private val delegate: InputStream, + // This is visible for testing + private[storage] val delegate: InputStream, private val iterator: ShuffleBlockFetcherIterator, private val blockId: BlockId, private val address: BlockManagerId, - private val detectCorruption: Boolean) + private val streamCompressedOrEncrypted: Boolean) extends InputStream { private[this] var closed = false @@ -610,7 +612,7 @@ private class BufferReleasingInputStream( try { delegate.read() } catch { - case e: IOException if detectCorruption => + case e: IOException if streamCompressedOrEncrypted => iterator.throwFetchFailedException(blockId, address, e) } } @@ -635,7 +637,7 @@ private class BufferReleasingInputStream( try { delegate.read(b) } catch { - case e: IOException if detectCorruption => + case e: IOException if streamCompressedOrEncrypted => iterator.throwFetchFailedException(blockId, address, e) } } @@ -644,7 +646,7 @@ private class BufferReleasingInputStream( try { delegate.read(b, off, len) } catch { - case e: IOException if detectCorruption => + case e: IOException if streamCompressedOrEncrypted => iterator.throwFetchFailedException(blockId, address, e) } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index de829d02f2e9..8a5152165e8c 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -341,7 +341,7 @@ private[spark] object Utils extends Logging { * Copy all data from an InputStream to an OutputStream upto maxSize and * closes the input stream if closeStreams is true and all data is read. */ - def copyStreamUpto( + def copyStreamUpTo( in: InputStream, out: OutputStream, maxSize: Long, diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index b652f9803ba0..fbfef853bb46 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.storage -import java.io.{File, InputStream, IOException} +import java.io._ +import java.nio.ByteBuffer import java.util.UUID import java.util.concurrent.Semaphore @@ -492,6 +493,55 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { st.read(new Array[Byte](8 * 1024), 0, 8 * 1024) } } + test("ensure big blocks available as a concatenated stream can be read") { + val tmpDir = Utils.createTempDir() + val tmpFile = new File(tmpDir, "someFile.txt") + val os = new FileOutputStream(tmpFile) + val buf = ByteBuffer.allocate(10000) + for (i <- 1 to 2500) { + buf.putInt(i) + } + os.write(buf.array()) + os.close() + val managedBuffer = new FileSegmentManagedBuffer(null, tmpFile, 0, 10000) + + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + doReturn(managedBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) + val localBlockLengths = Seq[Tuple2[BlockId, Long]]( + ShuffleBlockId(0, 0, 0) -> 10000 + ) + val transfer = createMockTransfer(Map(ShuffleBlockId(0, 0, 0) -> managedBuffer)) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (localBmId, localBlockLengths) + ).toIterator + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => new LimitedInputStream(in, 10000), + 2048, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true, + taskContext.taskMetrics.createTempShuffleReadMetrics()) + val (id, st) = iterator.next() + // The returned stream is a concatenated stream + assert (st.asInstanceOf[BufferReleasingInputStream].delegate.isInstanceOf[SequenceInputStream]) + + val dst = new DataInputStream(st) + for (i <- 1 to 2500) { + assert(i === dst.readInt()) + } + assert(dst.available() === 0) + dst.close() + } + test("retry corrupt blocks (disabled)") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) From 2ad65479bf4d6e48f44fe2df718583ecfd41c529 Mon Sep 17 00:00:00 2001 From: ankurgupta Date: Mon, 7 Jan 2019 09:55:48 -0800 Subject: [PATCH 04/12] Replaced 'getArgumentAt()' with 'getArguments()' because of changes in Mockito api across versions --- .../apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index fbfef853bb46..5e033ebb0cb6 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -345,7 +345,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } else { when(corruptStream.read(any(), any(), any(classOf[Int]))).thenAnswer(new Answer[Int] { override def answer(invocationOnMock: InvocationOnMock): Int = { - val bufSize = invocationOnMock.getArgumentAt(2, classOf[Int]) + val bufSize = invocationOnMock.getArguments()(2).asInstanceOf[Int] // This condition is needed as we don't throw exception until we read the stream // less than maxBytesInFlight/3 if (bufSize < 8 * 1024) { From 15ee096fd77dcff540221e63f1b1985b0991f55c Mon Sep 17 00:00:00 2001 From: ankurgupta Date: Mon, 28 Jan 2019 11:44:35 -0800 Subject: [PATCH 05/12] Review comments: Part 2 1. Minor changes 2. Added a new config for detecting corruption by using extra memory with default set to false 3. Added test cases for copyStreamUpTo --- .../spark/internal/config/package.scala | 16 +++++++ .../shuffle/BlockStoreShuffleReader.scala | 1 + .../storage/ShuffleBlockFetcherIterator.scala | 19 ++++---- .../scala/org/apache/spark/util/Utils.scala | 19 +++++--- .../ShuffleBlockFetcherIteratorSuite.scala | 13 ++++- .../org/apache/spark/util/UtilsSuite.scala | 48 ++++++++++++++++++- 6 files changed, 98 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index d6a359db66f4..d31b02236339 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -955,6 +955,22 @@ package object config { .checkValue(v => v > 0, "The value should be a positive integer.") .createWithDefault(2000) + private[spark] val SHUFFLE_DETECT_CORRUPT = + ConfigBuilder("spark.shuffle.detectCorrupt") + .doc("If enabled, IOException thrown while reading a compressed/encrypted stream will be " + + "converted to a FetchFailedException, to ensure that previous stage is retried") + .booleanConf + .createWithDefault(true) + + private[spark] val SHUFFLE_DETECT_CORRUPT_MEMORY = + ConfigBuilder("spark.shuffle.detectCorrupt.useExtraMemory") + .doc("If enabled, part of a compressed/encrypted stream will be de-compressed/de-crypted " + + "by using extra memory to detect early corruption. Any IOException thrown will cause " + + "the task to be retried once and if it fails again with same exception, then " + + "FetchFailedException will be thrown to retry previous stage") + .booleanConf + .createWithDefault(false) + private[spark] val MEMORY_MAP_LIMIT_FOR_TESTS = ConfigBuilder("spark.storage.memoryMapLimitForTests") .internal() diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index c5eefc7c5c04..c7843710413d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -55,6 +55,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), readMetrics).toCompletionIterator val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index c99621af9320..db4646000186 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -73,6 +73,7 @@ final class ShuffleBlockFetcherIterator( maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean, + detectCorruptUseExtraMemory: Boolean, shuffleMetrics: ShuffleReadMetricsReporter) extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { @@ -471,20 +472,16 @@ final class ShuffleBlockFetcherIterator( // maxBytesInFlight/3. If stream is longer, then corruption will be caught while reading // the stream. streamCompressedOrEncrypted = !input.eq(in) - if (detectCorrupt && streamCompressedOrEncrypted) { + if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { isStreamCopied = true val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) // Decompress the block upto maxBytesInFlight/3 at once to detect any corruption which // could increase the memory usage and potentially increase the chance of OOM. // TODO: manage the memory used here, and spill it into disk in case of OOM. - isStreamCopied = Utils.copyStreamUpTo( + val (completeStreamCopied: Boolean, newStream: InputStream) = Utils.copyStreamUpTo( input, out, maxBytesInFlight / 3, closeStreams = true) - if (isStreamCopied) { - input = out.toChunkedByteBuffer.toInputStream(dispose = true) - } else { - input = new SequenceInputStream( - out.toChunkedByteBuffer.toInputStream(dispose = true), input) - } + isStreamCopied = completeStreamCopied + input = newStream } } catch { case e: IOException => @@ -519,7 +516,11 @@ final class ShuffleBlockFetcherIterator( currentResult = result.asInstanceOf[SuccessFetchResult] (currentResult.blockId, new BufferReleasingInputStream( - input, this, currentResult.blockId, currentResult.address, streamCompressedOrEncrypted)) + input, + this, + currentResult.blockId, + currentResult.address, + detectCorrupt && streamCompressedOrEncrypted)) } def toCompletionIterator: Iterator[(BlockId, InputStream)] = { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 8a5152165e8c..b0cf36334e62 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -67,6 +67,7 @@ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} import org.apache.spark.status.api.v1.{StackTrace, ThreadStackTrace} +import org.apache.spark.util.io.ChunkedByteBufferOutputStream /** CallSite represents a place in user code. It can have a short and a long form. */ private[spark] case class CallSite(shortForm: String, longForm: String) @@ -343,16 +344,16 @@ private[spark] object Utils extends Logging { */ def copyStreamUpTo( in: InputStream, - out: OutputStream, + out: ChunkedByteBufferOutputStream, maxSize: Long, - closeStreams: Boolean = false): Boolean = { + closeStreams: Boolean = false): (Boolean, InputStream) = { var count = 0L - tryWithSafeFinally { - val bufSize = 8192 - val buf = new Array[Byte](bufSize) + val streamCopied = tryWithSafeFinally { + val bufSize = Math.min(8192L, maxSize) + val buf = new Array[Byte](bufSize.toInt) var n = 0 while (n != -1 && count < maxSize) { - n = in.read(buf, 0, Math.min(maxSize - count, bufSize.toLong).toInt) + n = in.read(buf, 0, Math.min(maxSize - count, bufSize).toInt) if (n != -1) { out.write(buf, 0, n) count += n @@ -370,6 +371,12 @@ private[spark] object Utils extends Logging { } } } + if (streamCopied) { + (streamCopied, out.toChunkedByteBuffer.toInputStream(dispose = true)) + } else { + (streamCopied, new SequenceInputStream( + out.toChunkedByteBuffer.toInputStream(dispose = true), in)) + } } def copyFileStreamNIO( diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 5e033ebb0cb6..7fef0a57cd93 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -119,6 +119,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, + false, metrics) // 3 local blocks fetched in initialization @@ -198,6 +199,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, + false, taskContext.taskMetrics.createTempShuffleReadMetrics()) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() @@ -326,6 +328,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, + false, taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure @@ -414,6 +417,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, + true, taskContext.taskMetrics.createTempShuffleReadMetrics()) // Continue only after the mock calls onBlockFetchFailure @@ -480,6 +484,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, + true, taskContext.taskMetrics.createTempShuffleReadMetrics()) // Only one block should be returned which has corruption after maxBytesInFlight/3 val (id, st) = iterator.next() @@ -529,16 +534,17 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, + true, taskContext.taskMetrics.createTempShuffleReadMetrics()) val (id, st) = iterator.next() - // The returned stream is a concatenated stream + // Check that the test setup is correct -- make sure we have a concatenated stream. assert (st.asInstanceOf[BufferReleasingInputStream].delegate.isInstanceOf[SequenceInputStream]) val dst = new DataInputStream(st) for (i <- 1 to 2500) { assert(i === dst.readInt()) } - assert(dst.available() === 0) + assert(dst.read() === -1) dst.close() } @@ -590,6 +596,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, Int.MaxValue, + true, false, taskContext.taskMetrics.createTempShuffleReadMetrics()) @@ -653,6 +660,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxBlocksInFlightPerAddress = Int.MaxValue, maxReqSizeShuffleToMem = 200, detectCorrupt = true, + false, taskContext.taskMetrics.createTempShuffleReadMetrics()) } @@ -700,6 +708,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, + false, taskContext.taskMetrics.createTempShuffleReadMetrics()) // All blocks fetched return zero length and should trigger a receive-side error: diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 188e3f6907da..10f90f634c71 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.util import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataOutput, DataOutputStream, File, - FileOutputStream, PrintStream} + FileOutputStream, InputStream, PrintStream} import java.lang.{Double => JDouble, Float => JFloat} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} @@ -43,6 +43,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.SparkListener +import org.apache.spark.util.io.ChunkedByteBufferOutputStream class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { @@ -211,6 +212,51 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(os.toByteArray.toList.equals(bytes.toList)) } + test("copyStreamUpTo") { + // input array initialization + val bytes = Array.ofDim[Byte](1200) + Random.nextBytes(bytes) + + var os: ChunkedByteBufferOutputStream = null + var in: InputStream = null + var copiedStream: InputStream = null + try { + os = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) + in = new ByteArrayInputStream(bytes.take(900)) + val (cp1: Boolean, input1: InputStream) = Utils.copyStreamUpTo(in, os, 1000, true) + copiedStream = input1 + assert(cp1) + assert(in.read() === -1) + IOUtils.closeQuietly(copiedStream) + IOUtils.closeQuietly(in) + IOUtils.closeQuietly(os) + + os = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) + in = new ByteArrayInputStream(bytes.take(1000)) + val (cp2: Boolean, input2: InputStream) = Utils.copyStreamUpTo(in, os, 1000, true) + copiedStream = input2 + assert(!cp2) + assert(in.read() === -1) + IOUtils.closeQuietly(copiedStream) + IOUtils.closeQuietly(in) + IOUtils.closeQuietly(os) + + os = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) + in = new ByteArrayInputStream(bytes.take(1100)) + val (cp3: Boolean, input3: InputStream) = Utils.copyStreamUpTo(in, os, 1000, true) + copiedStream = input3 + assert(!cp3) + assert(in.read() != -1) + IOUtils.closeQuietly(copiedStream) + IOUtils.closeQuietly(in) + IOUtils.closeQuietly(os) + } finally { + IOUtils.closeQuietly(copiedStream) + IOUtils.closeQuietly(in) + IOUtils.closeQuietly(os) + } + } + test("memoryStringToMb") { assert(Utils.memoryStringToMb("1") === 0) assert(Utils.memoryStringToMb("1048575") === 0) From ede5178ea3815270dc4a9309656f46d0c4d07017 Mon Sep 17 00:00:00 2001 From: ankurgupta Date: Tue, 29 Jan 2019 10:51:51 -0800 Subject: [PATCH 06/12] Fixed compilation errors after merging with master --- .../spark/internal/config/package.scala | 25 +++++++------------ .../ShuffleBlockFetcherIteratorSuite.scala | 1 + 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index d31b02236339..850d6845684b 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -928,6 +928,15 @@ package object config { .booleanConf .createWithDefault(true) + private[spark] val SHUFFLE_DETECT_CORRUPT_MEMORY = + ConfigBuilder("spark.shuffle.detectCorrupt.useExtraMemory") + .doc("If enabled, part of a compressed/encrypted stream will be de-compressed/de-crypted " + + "by using extra memory to detect early corruption. Any IOException thrown will cause " + + "the task to be retried once and if it fails again with same exception, then " + + "FetchFailedException will be thrown to retry previous stage") + .booleanConf + .createWithDefault(false) + private[spark] val SHUFFLE_SYNC = ConfigBuilder("spark.shuffle.sync") .doc("Whether to force outstanding writes to disk.") @@ -955,22 +964,6 @@ package object config { .checkValue(v => v > 0, "The value should be a positive integer.") .createWithDefault(2000) - private[spark] val SHUFFLE_DETECT_CORRUPT = - ConfigBuilder("spark.shuffle.detectCorrupt") - .doc("If enabled, IOException thrown while reading a compressed/encrypted stream will be " + - "converted to a FetchFailedException, to ensure that previous stage is retried") - .booleanConf - .createWithDefault(true) - - private[spark] val SHUFFLE_DETECT_CORRUPT_MEMORY = - ConfigBuilder("spark.shuffle.detectCorrupt.useExtraMemory") - .doc("If enabled, part of a compressed/encrypted stream will be de-compressed/de-crypted " + - "by using extra memory to detect early corruption. Any IOException thrown will cause " + - "the task to be retried once and if it fails again with same exception, then " + - "FetchFailedException will be thrown to retry previous stage") - .booleanConf - .createWithDefault(false) - private[spark] val MEMORY_MAP_LIMIT_FOR_TESTS = ConfigBuilder("spark.storage.memoryMapLimitForTests") .internal() diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 7fef0a57cd93..4a59720f1c0d 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -268,6 +268,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, + false, taskContext.taskMetrics.createTempShuffleReadMetrics()) From 5b4e74f7fa9ab00c6d4745bf69f90db9f6fe1bbb Mon Sep 17 00:00:00 2001 From: ankurgupta Date: Fri, 1 Feb 2019 08:57:23 -0800 Subject: [PATCH 07/12] Review comments: Part 3 1. Changed test to also compare the contents of stream 2. Other minor refactoring --- .../storage/ShuffleBlockFetcherIterator.scala | 9 +-- .../scala/org/apache/spark/util/Utils.scala | 2 +- .../org/apache/spark/util/UtilsSuite.scala | 77 +++++++++---------- 3 files changed, 43 insertions(+), 45 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index db4646000186..63d6ab2a3c9b 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -474,14 +474,13 @@ final class ShuffleBlockFetcherIterator( streamCompressedOrEncrypted = !input.eq(in) if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { isStreamCopied = true - val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) // Decompress the block upto maxBytesInFlight/3 at once to detect any corruption which // could increase the memory usage and potentially increase the chance of OOM. // TODO: manage the memory used here, and spill it into disk in case of OOM. - val (completeStreamCopied: Boolean, newStream: InputStream) = Utils.copyStreamUpTo( - input, out, maxBytesInFlight / 3, closeStreams = true) - isStreamCopied = completeStreamCopied - input = newStream + val (fullyCopied: Boolean, mergedStream: InputStream) = Utils.copyStreamUpTo( + input, maxBytesInFlight / 3, closeStreams = true) + isStreamCopied = fullyCopied + input = mergedStream } } catch { case e: IOException => diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b0cf36334e62..4879d00aa179 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -344,10 +344,10 @@ private[spark] object Utils extends Logging { */ def copyStreamUpTo( in: InputStream, - out: ChunkedByteBufferOutputStream, maxSize: Long, closeStreams: Boolean = false): (Boolean, InputStream) = { var count = 0L + val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) val streamCopied = tryWithSafeFinally { val bufSize = Math.min(8192L, maxSize) val buf = new Array[Byte](bufSize.toInt) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 10f90f634c71..f864a8c31069 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.util import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataOutput, DataOutputStream, File, - FileOutputStream, InputStream, PrintStream} + FileOutputStream, InputStream, PrintStream, SequenceInputStream} import java.lang.{Double => JDouble, Float => JFloat} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} @@ -29,6 +29,7 @@ import java.util.concurrent.TimeUnit import java.util.zip.GZIPOutputStream import scala.collection.mutable.ListBuffer +import scala.reflect.runtime.universe import scala.util.Random import com.google.common.io.Files @@ -43,7 +44,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.SparkListener -import org.apache.spark.util.io.ChunkedByteBufferOutputStream +import org.apache.spark.util.io.ChunkedByteBufferInputStream class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { @@ -217,46 +218,44 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { val bytes = Array.ofDim[Byte](1200) Random.nextBytes(bytes) - var os: ChunkedByteBufferOutputStream = null - var in: InputStream = null - var copiedStream: InputStream = null - try { - os = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) - in = new ByteArrayInputStream(bytes.take(900)) - val (cp1: Boolean, input1: InputStream) = Utils.copyStreamUpTo(in, os, 1000, true) - copiedStream = input1 - assert(cp1) - assert(in.read() === -1) - IOUtils.closeQuietly(copiedStream) - IOUtils.closeQuietly(in) - IOUtils.closeQuietly(os) - - os = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) - in = new ByteArrayInputStream(bytes.take(1000)) - val (cp2: Boolean, input2: InputStream) = Utils.copyStreamUpTo(in, os, 1000, true) - copiedStream = input2 - assert(!cp2) - assert(in.read() === -1) - IOUtils.closeQuietly(copiedStream) - IOUtils.closeQuietly(in) - IOUtils.closeQuietly(os) - - os = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) - in = new ByteArrayInputStream(bytes.take(1100)) - val (cp3: Boolean, input3: InputStream) = Utils.copyStreamUpTo(in, os, 1000, true) - copiedStream = input3 - assert(!cp3) - assert(in.read() != -1) - IOUtils.closeQuietly(copiedStream) - IOUtils.closeQuietly(in) - IOUtils.closeQuietly(os) - } finally { - IOUtils.closeQuietly(copiedStream) - IOUtils.closeQuietly(in) - IOUtils.closeQuietly(os) + val limit = 1000 + // testing for inputLength less than, equal to and greater than limit + List(900, 1000, 1100).foreach { inputLength => + val in = new ByteArrayInputStream(bytes.take(inputLength)) + val (fullyCopied: Boolean, mergedStream: InputStream) = Utils.copyStreamUpTo(in, limit, true) + try { + val byteBufferInputStream = if (mergedStream.isInstanceOf[ChunkedByteBufferInputStream]) { + mergedStream.asInstanceOf[ChunkedByteBufferInputStream] + } else { + val sequenceStream = mergedStream.asInstanceOf[SequenceInputStream] + val fieldValue = getFieldValue(sequenceStream, "in") + assert(fieldValue.isInstanceOf[ChunkedByteBufferInputStream]) + fieldValue.asInstanceOf[ChunkedByteBufferInputStream] + } + assert(fullyCopied === (inputLength < limit)) + (0 until inputLength).foreach { idx => + assert(bytes(idx) === mergedStream.read().asInstanceOf[Byte]) + if (idx == limit) { + assert(byteBufferInputStream.chunkedByteBuffer === null) + } + } + assert(mergedStream.read() === -1) + assert(byteBufferInputStream.chunkedByteBuffer === null) + } finally { + IOUtils.closeQuietly(mergedStream) + IOUtils.closeQuietly(in) + } } } + private def getFieldValue(obj: AnyRef, fieldName: String): Any = { + val mirror = universe.runtimeMirror(obj.getClass().getClassLoader()) + val field = mirror.classSymbol(obj.getClass()).info.decl(universe.TermName(fieldName)).asTerm + val instanceMirror = mirror.reflect(obj) + val fieldMirror = instanceMirror.reflectField(field) + fieldMirror.get + } + test("memoryStringToMb") { assert(Utils.memoryStringToMb("1") === 0) assert(Utils.memoryStringToMb("1048575") === 0) From 1870ff27c50d8c58385ba547e909ce3d4c3dba1f Mon Sep 17 00:00:00 2001 From: ankurgupta Date: Fri, 1 Feb 2019 10:32:22 -0800 Subject: [PATCH 08/12] Review comments: Part 4 --- .../ShuffleBlockFetcherIteratorSuite.scala | 2 +- .../org/apache/spark/util/UtilsSuite.scala | 20 ++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 4a59720f1c0d..55c859a4e754 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -490,7 +490,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Only one block should be returned which has corruption after maxBytesInFlight/3 val (id, st) = iterator.next() assert(id === ShuffleBlockId(0, 1, 0)) - intercept[FetchFailedException] { iterator.next(); iterator.next() } + intercept[FetchFailedException] { iterator.next() } // Following will succeed as it reads the first part of the stream which is not corrupt st.read(new Array[Byte](8 * 1024), 0, 8 * 1024) // Following will fail as it reads the remaining part of the stream which is corrupt diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index f864a8c31069..44ed452e3841 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.util import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataOutput, DataOutputStream, File, FileOutputStream, InputStream, PrintStream, SequenceInputStream} import java.lang.{Double => JDouble, Float => JFloat} +import java.lang.reflect.Field import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} import java.nio.charset.StandardCharsets @@ -29,7 +30,6 @@ import java.util.concurrent.TimeUnit import java.util.zip.GZIPOutputStream import scala.collection.mutable.ListBuffer -import scala.reflect.runtime.universe import scala.util.Random import com.google.common.io.Files @@ -220,10 +220,12 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { val limit = 1000 // testing for inputLength less than, equal to and greater than limit - List(900, 1000, 1100).foreach { inputLength => + List(998, 999, 1000, 1001, 1002).foreach { inputLength => val in = new ByteArrayInputStream(bytes.take(inputLength)) val (fullyCopied: Boolean, mergedStream: InputStream) = Utils.copyStreamUpTo(in, limit, true) try { + // Get a handle on the buffered data, to make sure memory gets freed once we read past the + // end of it. Need to use reflection to get handle on inner structures for this check val byteBufferInputStream = if (mergedStream.isInstanceOf[ChunkedByteBufferInputStream]) { mergedStream.asInstanceOf[ChunkedByteBufferInputStream] } else { @@ -249,11 +251,15 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { } private def getFieldValue(obj: AnyRef, fieldName: String): Any = { - val mirror = universe.runtimeMirror(obj.getClass().getClassLoader()) - val field = mirror.classSymbol(obj.getClass()).info.decl(universe.TermName(fieldName)).asTerm - val instanceMirror = mirror.reflect(obj) - val fieldMirror = instanceMirror.reflectField(field) - fieldMirror.get + val field: Field = obj.getClass().getDeclaredField(fieldName) + if (field.isAccessible()) { + field.get(obj) + } else { + field.setAccessible(true) + val result = field.get(obj) + field.setAccessible(false) + result + } } test("memoryStringToMb") { From 980f2bcb0774a12878c83eb9623664c090a4aba1 Mon Sep 17 00:00:00 2001 From: ankurgupta Date: Wed, 6 Feb 2019 09:43:31 -0800 Subject: [PATCH 09/12] Review comments: Part 5 Changes to unit test case --- .../storage/ShuffleBlockFetcherIterator.scala | 11 +- .../scala/org/apache/spark/util/Utils.scala | 20 ++-- .../ShuffleBlockFetcherIteratorSuite.scala | 100 +++++++++++------- .../org/apache/spark/util/UtilsSuite.scala | 2 +- 4 files changed, 80 insertions(+), 53 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 63d6ab2a3c9b..b47a7d79b60a 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -478,7 +478,7 @@ final class ShuffleBlockFetcherIterator( // could increase the memory usage and potentially increase the chance of OOM. // TODO: manage the memory used here, and spill it into disk in case of OOM. val (fullyCopied: Boolean, mergedStream: InputStream) = Utils.copyStreamUpTo( - input, maxBytesInFlight / 3, closeStreams = true) + input, maxBytesInFlight / 3) isStreamCopied = fullyCopied input = mergedStream } @@ -629,7 +629,14 @@ private class BufferReleasingInputStream( override def mark(readlimit: Int): Unit = delegate.mark(readlimit) - override def skip(n: Long): Long = delegate.skip(n) + override def skip(n: Long): Long = { + try { + delegate.skip(n) + } catch { + case e: IOException if streamCompressedOrEncrypted => + iterator.throwFetchFailedException(blockId, address, e) + } + } override def markSupported(): Boolean = delegate.markSupported() diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 4879d00aa179..6fed9e9e9755 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -340,12 +340,10 @@ private[spark] object Utils extends Logging { /** * Copy all data from an InputStream to an OutputStream upto maxSize and - * closes the input stream if closeStreams is true and all data is read. + * close the input stream if all data is read. + * @return A combined stream of read data and any remaining data */ - def copyStreamUpTo( - in: InputStream, - maxSize: Long, - closeStreams: Boolean = false): (Boolean, InputStream) = { + def copyStreamUpTo(in: InputStream, maxSize: Long): (Boolean, InputStream) = { var count = 0L val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) val streamCopied = tryWithSafeFinally { @@ -361,14 +359,12 @@ private[spark] object Utils extends Logging { } count < maxSize } { - if (closeStreams) { - try { - if (count < maxSize) { - in.close() - } - } finally { - out.close() + try { + if (count < maxSize) { + in.close() } + } finally { + out.close() } } if (streamCopied) { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 55c859a4e754..aa40aba94eab 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -25,6 +25,7 @@ import java.util.concurrent.Semaphore import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future +import org.apache.commons.io.IOUtils import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.{mock, times, verify, when} import org.mockito.invocation.InvocationOnMock @@ -342,32 +343,34 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } - private def mockCorruptBuffer(size: Long = 1L, corruptInStart: Boolean = true): ManagedBuffer = { - val corruptStream = mock(classOf[InputStream]) - if (size < 8 * 1024 || corruptInStart) { - when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) - } else { - when(corruptStream.read(any(), any(), any(classOf[Int]))).thenAnswer(new Answer[Int] { - override def answer(invocationOnMock: InvocationOnMock): Int = { - val bufSize = invocationOnMock.getArguments()(2).asInstanceOf[Int] - // This condition is needed as we don't throw exception until we read the stream - // less than maxBytesInFlight/3 - if (bufSize < 8 * 1024) { - return bufSize - } else { - throw new IOException("corrupt") - } - } - }) - when(corruptStream.read()).thenThrow(new IOException("corrupt")) - when(corruptStream.read(any())).thenThrow(new IOException("corrupt")) - } + private def mockCorruptBuffer(size: Long = 1L, corruptAt: Int = 0): ManagedBuffer = { + val corruptStream = new CorruptStream(corruptAt) val corruptBuffer = mock(classOf[ManagedBuffer]) when(corruptBuffer.size()).thenReturn(size) when(corruptBuffer.createInputStream()).thenReturn(corruptStream) corruptBuffer } + private class CorruptStream(corruptAt: Long = 0L) extends InputStream { + var pos = 0 + var closed = false + + override def read(): Int = { + if (pos >= corruptAt) { + throw new IOException("corrupt") + } else { + pos += 1 + pos + } + } + + override def read(dest: Array[Byte], off: Int, len: Int): Int = { + super.read(dest, off, len) + } + + override def close(): Unit = { closed = true } + } + test("retry corrupt blocks") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -449,54 +452,75 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("big blocks are also checked for corruption") { - val corruptBuffer1 = mockCorruptBuffer(10000L, true) - + val streamLength = 10000L val blockManager = mock(classOf[BlockManager]) + + // This stream will throw IOException when the first byte is read + val localBuffer = mockCorruptBuffer(streamLength, 0) val localBmId = BlockManagerId("test-client", "test-client", 1) doReturn(localBmId).when(blockManager).blockManagerId - doReturn(corruptBuffer1).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) + doReturn(localBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) + val localShuffleBlockId = ShuffleBlockId(0, 0, 0) val localBlockLengths = Seq[Tuple2[BlockId, Long]]( - ShuffleBlockId(0, 0, 0) -> corruptBuffer1.size() + localShuffleBlockId -> localBuffer.size() ) - val corruptBuffer2 = mockCorruptBuffer(10000L, false) + val streamNotCorruptTill = 8 * 1024 + // This stream will throw exception after streamNotCorruptTill bytes are read + val remoteBuffer = mockCorruptBuffer(streamLength, streamNotCorruptTill) val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteShuffleBlockId = ShuffleBlockId(0, 1, 0) val remoteBlockLengths = Seq[Tuple2[BlockId, Long]]( - ShuffleBlockId(0, 1, 0) -> corruptBuffer2.size() + remoteShuffleBlockId -> remoteBuffer.size() ) val transfer = createMockTransfer( - Map(ShuffleBlockId(0, 0, 0) -> corruptBuffer1, ShuffleBlockId(0, 1, 0) -> corruptBuffer2)) - + Map(localShuffleBlockId -> localBuffer, remoteShuffleBlockId -> remoteBuffer)) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (localBmId, localBlockLengths), (remoteBmId, remoteBlockLengths) ).toIterator - val taskContext = TaskContext.empty() + val maxBytesInFlight = 3 * 1024 val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, blockManager, blocksByAddress, - (_, in) => new LimitedInputStream(in, 10000), - 2048, + (_, in) => new LimitedInputStream(in, streamLength), + maxBytesInFlight, Int.MaxValue, Int.MaxValue, Int.MaxValue, true, true, taskContext.taskMetrics.createTempShuffleReadMetrics()) - // Only one block should be returned which has corruption after maxBytesInFlight/3 + + // Only one block should be returned which has corruption after maxBytesInFlight/3 because the + // other block will be re-fetched val (id, st) = iterator.next() - assert(id === ShuffleBlockId(0, 1, 0)) - intercept[FetchFailedException] { iterator.next() } - // Following will succeed as it reads the first part of the stream which is not corrupt - st.read(new Array[Byte](8 * 1024), 0, 8 * 1024) + assert(id === remoteShuffleBlockId) + + // The other block will throw a FetchFailedException + intercept[FetchFailedException] { + iterator.next() + } + + // Following will succeed as it reads part of the stream which is not corrupt. This will read + // maxBytesInFlight/3 bytes from first stream and remaining from the second stream + new DataInputStream(st).readFully( + new Array[Byte](streamNotCorruptTill), 0, streamNotCorruptTill) + // Following will fail as it reads the remaining part of the stream which is corrupt intercept[FetchFailedException] { st.read() } - intercept[FetchFailedException] { st.read(new Array[Byte](8 * 1024)) } - intercept[FetchFailedException] { st.read(new Array[Byte](8 * 1024), 0, 8 * 1024) } + intercept[FetchFailedException] { st.read(new Array[Byte](1024)) } + intercept[FetchFailedException] { st.read(new Array[Byte](1024), 0, 1024) } + intercept[FetchFailedException] { st.skip(1024) } + + IOUtils.closeQuietly(st) + // Buffers are mocked and they return the original input corrupt streams + assert(localBuffer.createInputStream().asInstanceOf[CorruptStream].closed) + assert(remoteBuffer.createInputStream().asInstanceOf[CorruptStream].closed) } test("ensure big blocks available as a concatenated stream can be read") { diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 44ed452e3841..9fd7b9cfcb9b 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -222,7 +222,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { // testing for inputLength less than, equal to and greater than limit List(998, 999, 1000, 1001, 1002).foreach { inputLength => val in = new ByteArrayInputStream(bytes.take(inputLength)) - val (fullyCopied: Boolean, mergedStream: InputStream) = Utils.copyStreamUpTo(in, limit, true) + val (fullyCopied: Boolean, mergedStream: InputStream) = Utils.copyStreamUpTo(in, limit) try { // Get a handle on the buffered data, to make sure memory gets freed once we read past the // end of it. Need to use reflection to get handle on inner structures for this check From e7d98947ce4f98c856e1beaf45cd57df2a8c6709 Mon Sep 17 00:00:00 2001 From: ankurgupta Date: Mon, 11 Feb 2019 15:04:24 -0800 Subject: [PATCH 10/12] Review comments: Part 6 --- .../storage/ShuffleBlockFetcherIterator.scala | 7 ++- .../scala/org/apache/spark/util/Utils.scala | 3 +- .../ShuffleBlockFetcherIteratorSuite.scala | 44 +++++++++---------- 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index b47a7d79b60a..4021360caaab 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -25,6 +25,8 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import org.apache.commons.io.IOUtils + import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} @@ -32,7 +34,6 @@ import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.TransportConf import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} -import org.apache.spark.util.io.ChunkedByteBufferOutputStream /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block @@ -613,6 +614,7 @@ private class BufferReleasingInputStream( delegate.read() } catch { case e: IOException if streamCompressedOrEncrypted => + IOUtils.closeQuietly(this) iterator.throwFetchFailedException(blockId, address, e) } } @@ -634,6 +636,7 @@ private class BufferReleasingInputStream( delegate.skip(n) } catch { case e: IOException if streamCompressedOrEncrypted => + IOUtils.closeQuietly(this) iterator.throwFetchFailedException(blockId, address, e) } } @@ -645,6 +648,7 @@ private class BufferReleasingInputStream( delegate.read(b) } catch { case e: IOException if streamCompressedOrEncrypted => + IOUtils.closeQuietly(this) iterator.throwFetchFailedException(blockId, address, e) } } @@ -654,6 +658,7 @@ private class BufferReleasingInputStream( delegate.read(b, off, len) } catch { case e: IOException if streamCompressedOrEncrypted => + IOUtils.closeQuietly(this) iterator.throwFetchFailedException(blockId, address, e) } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 6fed9e9e9755..6af671dc7f4c 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -341,7 +341,8 @@ private[spark] object Utils extends Logging { /** * Copy all data from an InputStream to an OutputStream upto maxSize and * close the input stream if all data is read. - * @return A combined stream of read data and any remaining data + * @return A tuple of boolean, which is whether the stream was fully copied, and an InputStream, + * which is a combined stream of read data and any remaining data */ def copyStreamUpTo(in: InputStream, maxSize: Long): (Boolean, InputStream) = { var count = 0L diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index aa40aba94eab..25907e5d0a23 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -25,7 +25,6 @@ import java.util.concurrent.Semaphore import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future -import org.apache.commons.io.IOUtils import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.{mock, times, verify, when} import org.mockito.invocation.InvocationOnMock @@ -454,31 +453,31 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT test("big blocks are also checked for corruption") { val streamLength = 10000L val blockManager = mock(classOf[BlockManager]) + val localBlockManagerId = BlockManagerId("local-client", "local-client", 1) + doReturn(localBlockManagerId).when(blockManager).blockManagerId // This stream will throw IOException when the first byte is read - val localBuffer = mockCorruptBuffer(streamLength, 0) - val localBmId = BlockManagerId("test-client", "test-client", 1) - doReturn(localBmId).when(blockManager).blockManagerId - doReturn(localBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) - val localShuffleBlockId = ShuffleBlockId(0, 0, 0) - val localBlockLengths = Seq[Tuple2[BlockId, Long]]( - localShuffleBlockId -> localBuffer.size() + val corruptBuffer1 = mockCorruptBuffer(streamLength, 0) + val blockManagerId1 = BlockManagerId("remote-client-1", "remote-client-1", 1) + val shuffleBlockId1 = ShuffleBlockId(0, 1, 0) + val blockLengths1 = Seq[Tuple2[BlockId, Long]]( + shuffleBlockId1 -> corruptBuffer1.size() ) val streamNotCorruptTill = 8 * 1024 // This stream will throw exception after streamNotCorruptTill bytes are read - val remoteBuffer = mockCorruptBuffer(streamLength, streamNotCorruptTill) - val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) - val remoteShuffleBlockId = ShuffleBlockId(0, 1, 0) - val remoteBlockLengths = Seq[Tuple2[BlockId, Long]]( - remoteShuffleBlockId -> remoteBuffer.size() + val corruptBuffer2 = mockCorruptBuffer(streamLength, streamNotCorruptTill) + val blockManagerId2 = BlockManagerId("remote-client-2", "remote-client-2", 2) + val shuffleBlockId2 = ShuffleBlockId(0, 2, 0) + val blockLengths2 = Seq[Tuple2[BlockId, Long]]( + shuffleBlockId2 -> corruptBuffer2.size() ) val transfer = createMockTransfer( - Map(localShuffleBlockId -> localBuffer, remoteShuffleBlockId -> remoteBuffer)) + Map(shuffleBlockId1 -> corruptBuffer1, shuffleBlockId2 -> corruptBuffer2)) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (localBmId, localBlockLengths), - (remoteBmId, remoteBlockLengths) + (blockManagerId1, blockLengths1), + (blockManagerId2, blockLengths2) ).toIterator val taskContext = TaskContext.empty() val maxBytesInFlight = 3 * 1024 @@ -497,9 +496,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT taskContext.taskMetrics.createTempShuffleReadMetrics()) // Only one block should be returned which has corruption after maxBytesInFlight/3 because the - // other block will be re-fetched + // other block will detect corruption on first fetch, and then get added to the queue again for + // a retry val (id, st) = iterator.next() - assert(id === remoteShuffleBlockId) + assert(id === shuffleBlockId2) // The other block will throw a FetchFailedException intercept[FetchFailedException] { @@ -513,14 +513,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Following will fail as it reads the remaining part of the stream which is corrupt intercept[FetchFailedException] { st.read() } - intercept[FetchFailedException] { st.read(new Array[Byte](1024)) } - intercept[FetchFailedException] { st.read(new Array[Byte](1024), 0, 1024) } - intercept[FetchFailedException] { st.skip(1024) } - IOUtils.closeQuietly(st) // Buffers are mocked and they return the original input corrupt streams - assert(localBuffer.createInputStream().asInstanceOf[CorruptStream].closed) - assert(remoteBuffer.createInputStream().asInstanceOf[CorruptStream].closed) + assert(corruptBuffer1.createInputStream().asInstanceOf[CorruptStream].closed) + assert(corruptBuffer2.createInputStream().asInstanceOf[CorruptStream].closed) } test("ensure big blocks available as a concatenated stream can be read") { From bd1a813216f1843f8c4d4c52749a27fccec0da5f Mon Sep 17 00:00:00 2001 From: ankurgupta Date: Fri, 8 Mar 2019 09:47:39 -0800 Subject: [PATCH 11/12] Review Comments: Part 7 Minor changes to variable names and comments --- .../storage/ShuffleBlockFetcherIterator.scala | 14 ++++++++------ .../main/scala/org/apache/spark/util/Utils.scala | 8 ++++---- .../storage/ShuffleBlockFetcherIteratorSuite.scala | 7 ++++--- .../scala/org/apache/spark/util/UtilsSuite.scala | 2 +- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 4021360caaab..66eb1ddebe37 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -584,7 +584,9 @@ final class ShuffleBlockFetcherIterator( } private[storage] def throwFetchFailedException( - blockId: BlockId, address: BlockManagerId, e: Throwable) = { + blockId: BlockId, + address: BlockManagerId, + e: Throwable) = { blockId match { case ShuffleBlockId(shufId, mapId, reduceId) => throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) @@ -605,7 +607,7 @@ private class BufferReleasingInputStream( private val iterator: ShuffleBlockFetcherIterator, private val blockId: BlockId, private val address: BlockManagerId, - private val streamCompressedOrEncrypted: Boolean) + private val detectCorruption: Boolean) extends InputStream { private[this] var closed = false @@ -613,7 +615,7 @@ private class BufferReleasingInputStream( try { delegate.read() } catch { - case e: IOException if streamCompressedOrEncrypted => + case e: IOException if detectCorruption => IOUtils.closeQuietly(this) iterator.throwFetchFailedException(blockId, address, e) } @@ -635,7 +637,7 @@ private class BufferReleasingInputStream( try { delegate.skip(n) } catch { - case e: IOException if streamCompressedOrEncrypted => + case e: IOException if detectCorruption => IOUtils.closeQuietly(this) iterator.throwFetchFailedException(blockId, address, e) } @@ -647,7 +649,7 @@ private class BufferReleasingInputStream( try { delegate.read(b) } catch { - case e: IOException if streamCompressedOrEncrypted => + case e: IOException if detectCorruption => IOUtils.closeQuietly(this) iterator.throwFetchFailedException(blockId, address, e) } @@ -657,7 +659,7 @@ private class BufferReleasingInputStream( try { delegate.read(b, off, len) } catch { - case e: IOException if streamCompressedOrEncrypted => + case e: IOException if detectCorruption => IOUtils.closeQuietly(this) iterator.throwFetchFailedException(blockId, address, e) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 6af671dc7f4c..ec485d0db8ee 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -347,7 +347,7 @@ private[spark] object Utils extends Logging { def copyStreamUpTo(in: InputStream, maxSize: Long): (Boolean, InputStream) = { var count = 0L val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) - val streamCopied = tryWithSafeFinally { + val fullyCopied = tryWithSafeFinally { val bufSize = Math.min(8192L, maxSize) val buf = new Array[Byte](bufSize.toInt) var n = 0 @@ -368,10 +368,10 @@ private[spark] object Utils extends Logging { out.close() } } - if (streamCopied) { - (streamCopied, out.toChunkedByteBuffer.toInputStream(dispose = true)) + if (fullyCopied) { + (fullyCopied, out.toChunkedByteBuffer.toInputStream(dispose = true)) } else { - (streamCopied, new SequenceInputStream( + (fullyCopied, new SequenceInputStream( out.toChunkedByteBuffer.toInputStream(dispose = true), in)) } } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 25907e5d0a23..a1c298ae9446 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -495,8 +495,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true, taskContext.taskMetrics.createTempShuffleReadMetrics()) - // Only one block should be returned which has corruption after maxBytesInFlight/3 because the - // other block will detect corruption on first fetch, and then get added to the queue again for + // We'll get back the block which has corruption after maxBytesInFlight/3 because the other + // block will detect corruption on first fetch, and then get added to the queue again for // a retry val (id, st) = iterator.next() assert(id === shuffleBlockId2) @@ -507,7 +507,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } // Following will succeed as it reads part of the stream which is not corrupt. This will read - // maxBytesInFlight/3 bytes from first stream and remaining from the second stream + // maxBytesInFlight/3 bytes from the portion copied into memory, and remaining from the + // underlying stream new DataInputStream(st).readFully( new Array[Byte](streamNotCorruptTill), 0, streamNotCorruptTill) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 9fd7b9cfcb9b..e88512957363 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -220,7 +220,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { val limit = 1000 // testing for inputLength less than, equal to and greater than limit - List(998, 999, 1000, 1001, 1002).foreach { inputLength => + (limit - 2 to limit + 2).foreach { inputLength => val in = new ByteArrayInputStream(bytes.take(inputLength)) val (fullyCopied: Boolean, mergedStream: InputStream) = Utils.copyStreamUpTo(in, limit) try { From d36c86251946e88a36e2e977d8ae1a0ae815a8d5 Mon Sep 17 00:00:00 2001 From: ankurgupta Date: Mon, 11 Mar 2019 09:25:36 -0700 Subject: [PATCH 12/12] Review Comments: Part 8 1. Ensured input stream is closed on exception 2. Minor comments changes --- .../storage/ShuffleBlockFetcherIterator.scala | 22 ++++++++----------- .../scala/org/apache/spark/util/Utils.scala | 22 ++++++++++++------- .../org/apache/spark/util/UtilsSuite.scala | 5 +++-- 3 files changed, 26 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 66eb1ddebe37..c89d5cc971d2 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -466,28 +466,22 @@ final class ShuffleBlockFetcherIterator( buf.release() throwFetchFailedException(blockId, address, e) } - var isStreamCopied: Boolean = false try { input = streamWrapper(blockId, in) - // Only copy the stream if it's wrapped by compression or encryption upto a size of - // maxBytesInFlight/3. If stream is longer, then corruption will be caught while reading - // the stream. + // If the stream is compressed or wrapped, then we optionally decompress/unwrap the + // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion + // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if + // the corruption is later, we'll still detect the corruption later in the stream. streamCompressedOrEncrypted = !input.eq(in) if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { - isStreamCopied = true - // Decompress the block upto maxBytesInFlight/3 at once to detect any corruption which - // could increase the memory usage and potentially increase the chance of OOM. // TODO: manage the memory used here, and spill it into disk in case of OOM. - val (fullyCopied: Boolean, mergedStream: InputStream) = Utils.copyStreamUpTo( - input, maxBytesInFlight / 3) - isStreamCopied = fullyCopied - input = mergedStream + input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3) } } catch { case e: IOException => buf.release() if (buf.isInstanceOf[FileSegmentManagedBuffer] - || corruptedBlocks.contains(blockId)) { + || corruptedBlocks.contains(blockId)) { throwFetchFailedException(blockId, address, e) } else { logWarning(s"got an corrupted block $blockId from $address, fetch again", e) @@ -497,7 +491,9 @@ final class ShuffleBlockFetcherIterator( } } finally { // TODO: release the buf here to free memory earlier - if (isStreamCopied) { + if (input == null) { + // Close the underlying stream if there was an issue in wrapping the stream using + // streamWrapper in.close() } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index ec485d0db8ee..bc5731163afd 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -339,12 +339,19 @@ private[spark] object Utils extends Logging { } /** - * Copy all data from an InputStream to an OutputStream upto maxSize and - * close the input stream if all data is read. - * @return A tuple of boolean, which is whether the stream was fully copied, and an InputStream, - * which is a combined stream of read data and any remaining data + * Copy the first `maxSize` bytes of data from the InputStream to an in-memory + * buffer, primarily to check for corruption. + * + * This returns a new InputStream which contains the same data as the original input stream. + * It may be entirely on in-memory buffer, or it may be a combination of in-memory data, and then + * continue to read from the original stream. The only real use of this is if the original input + * stream will potentially detect corruption while the data is being read (eg. from compression). + * This allows for an eager check of corruption in the first maxSize bytes of data. + * + * @return An InputStream which includes all data from the original stream (combining buffered + * data and remaining data in the original stream) */ - def copyStreamUpTo(in: InputStream, maxSize: Long): (Boolean, InputStream) = { + def copyStreamUpTo(in: InputStream, maxSize: Long): InputStream = { var count = 0L val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate) val fullyCopied = tryWithSafeFinally { @@ -369,10 +376,9 @@ private[spark] object Utils extends Logging { } } if (fullyCopied) { - (fullyCopied, out.toChunkedByteBuffer.toInputStream(dispose = true)) + out.toChunkedByteBuffer.toInputStream(dispose = true) } else { - (fullyCopied, new SequenceInputStream( - out.toChunkedByteBuffer.toInputStream(dispose = true), in)) + new SequenceInputStream( out.toChunkedByteBuffer.toInputStream(dispose = true), in) } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index e88512957363..d2d9eb06339c 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -222,19 +222,20 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { // testing for inputLength less than, equal to and greater than limit (limit - 2 to limit + 2).foreach { inputLength => val in = new ByteArrayInputStream(bytes.take(inputLength)) - val (fullyCopied: Boolean, mergedStream: InputStream) = Utils.copyStreamUpTo(in, limit) + val mergedStream = Utils.copyStreamUpTo(in, limit) try { // Get a handle on the buffered data, to make sure memory gets freed once we read past the // end of it. Need to use reflection to get handle on inner structures for this check val byteBufferInputStream = if (mergedStream.isInstanceOf[ChunkedByteBufferInputStream]) { + assert(inputLength < limit) mergedStream.asInstanceOf[ChunkedByteBufferInputStream] } else { + assert(inputLength >= limit) val sequenceStream = mergedStream.asInstanceOf[SequenceInputStream] val fieldValue = getFieldValue(sequenceStream, "in") assert(fieldValue.isInstanceOf[ChunkedByteBufferInputStream]) fieldValue.asInstanceOf[ChunkedByteBufferInputStream] } - assert(fullyCopied === (inputLength < limit)) (0 until inputLength).foreach { idx => assert(bytes(idx) === mergedStream.read().asInstanceOf[Byte]) if (idx == limit) {