From 110c8510dcc6c2abaf4ca416b95854daf129b0a5 Mon Sep 17 00:00:00 2001 From: jx158167 Date: Tue, 27 Feb 2018 17:56:38 +0800 Subject: [PATCH 1/3] [SPARK-23524] Big local shuffle blocks should not be checked for corruption. --- .../storage/ShuffleBlockFetcherIterator.scala | 10 +++- .../ShuffleBlockFetcherIteratorSuite.scala | 57 +++++++++++++++++++ 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 98b5a735a4529..ff822f352833b 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -90,7 +90,7 @@ final class ShuffleBlockFetcherIterator( private[this] val startTime = System.currentTimeMillis /** Local blocks to fetch, excluding zero-sized blocks. */ - private[this] val localBlocks = new ArrayBuffer[BlockId]() + private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[BlockId]() /** Remote blocks to fetch, excluding zero-sized blocks. */ private[this] val remoteBlocks = new HashSet[BlockId]() @@ -316,6 +316,7 @@ final class ShuffleBlockFetcherIterator( * track in-memory are the ManagedBuffer references themselves. */ private[this] def fetchLocalBlocks() { + logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") val iter = localBlocks.iterator while (iter.hasNext) { val blockId = iter.next() @@ -324,7 +325,8 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() - results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false)) + results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, + buf.size(), buf, false)) } catch { case e: Exception => // If we see an exception, stop immediately. @@ -397,7 +399,9 @@ final class ShuffleBlockFetcherIterator( } shuffleMetrics.incRemoteBlocksFetched(1) } - bytesInFlight -= size + if (!localBlocks.contains(blockId)) { + bytesInFlight -= size + } if (isNetworkReqDone) { reqsInFlight -= 1 logDebug("Number of requests in flight " + reqsInFlight) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 5bfe9905ff17b..85cc38addf892 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -352,6 +352,63 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } + test("big corrupt blocks will not be retiried") { + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + val corruptBuffer = mock(classOf[ManagedBuffer]) + when(corruptBuffer.createInputStream()).thenReturn(corruptStream) + doReturn(10000L).when(corruptBuffer).size() + + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + doReturn(corruptBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) + val localBlockLengths = Seq[Tuple2[BlockId, Long]]( + ShuffleBlockId(0, 0, 0) -> corruptBuffer.size() + ) + + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlockLengths = Seq[Tuple2[BlockId, Long]]( + ShuffleBlockId(0, 1, 0) -> corruptBuffer.size() + ) + + val transfer = mock(classOf[BlockTransferService]) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] + Future { + blocks.foreach (listener.onBlockFetchSuccess(_, corruptBuffer)) + } + } + }) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (localBmId, localBlockLengths), + (remoteBmId, remoteBlockLengths) + ) + + val taskContext = TaskContext.empty() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => new LimitedInputStream(in, 10000), + 2048, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true) + // Blocks should be returned without exceptions. + val blockSet = collection.mutable.HashSet[BlockId]() + blockSet.add(iterator.next()._1) + blockSet.add(iterator.next()._1) + assert(blockSet == collection.immutable.HashSet( + ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) + } + test("retry corrupt blocks (disabled)") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) From 1cd8653d8dca73ef60f65fa344d30ef8dc60367e Mon Sep 17 00:00:00 2001 From: jx158167 Date: Tue, 6 Mar 2018 19:52:36 +0800 Subject: [PATCH 2/3] resolve doc and test's name --- .../apache/spark/storage/ShuffleBlockFetcherIterator.scala | 4 ++-- .../spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index ff822f352833b..8eaa1e956a37b 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -587,8 +587,8 @@ object ShuffleBlockFetcherIterator { * Result of a fetch from a remote block successfully. * @param blockId block id * @param address BlockManager that the block was fetched from. - * @param size estimated size of the block, used to calculate bytesInFlight. - * Note that this is NOT the exact bytes. + * @param size estimated size of the block. Note that this is NOT the exact bytes. + * Size of remote block is used to calculate bytesInFlight. * @param buf `ManagedBuffer` for the content. * @param isNetworkReqDone Is this the last network request for this host in this fetch request. */ diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 85cc38addf892..56625738a16d9 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -352,7 +352,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } - test("big corrupt blocks will not be retiried") { + test("big blocks are not checked for corruption") { val corruptStream = mock(classOf[InputStream]) when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) val corruptBuffer = mock(classOf[ManagedBuffer]) From 4e4f07544d17ea0493b4c5887d8215550eedc424 Mon Sep 17 00:00:00 2001 From: jx158167 Date: Wed, 7 Mar 2018 11:16:01 +0800 Subject: [PATCH 3/3] refine test --- .../storage/ShuffleBlockFetcherIterator.scala | 2 +- .../ShuffleBlockFetcherIteratorSuite.scala | 20 ++++--------------- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 8eaa1e956a37b..dd9df74689a13 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -588,7 +588,7 @@ object ShuffleBlockFetcherIterator { * @param blockId block id * @param address BlockManager that the block was fetched from. * @param size estimated size of the block. Note that this is NOT the exact bytes. - * Size of remote block is used to calculate bytesInFlight. + * Size of remote block is used to calculate bytesInFlight. * @param buf `ManagedBuffer` for the content. * @param isNetworkReqDone Is this the last network request for this host in this fetch request. */ diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 56625738a16d9..692ae3bf597e0 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -372,17 +372,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 1, 0) -> corruptBuffer.size() ) - val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) - .thenAnswer(new Answer[Unit] { - override def answer(invocation: InvocationOnMock): Unit = { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] - Future { - blocks.foreach (listener.onBlockFetchSuccess(_, corruptBuffer)) - } - } - }) + val transfer = createMockTransfer( + Map(ShuffleBlockId(0, 0, 0) -> corruptBuffer, ShuffleBlockId(0, 1, 0) -> corruptBuffer)) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (localBmId, localBlockLengths), @@ -402,11 +393,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, true) // Blocks should be returned without exceptions. - val blockSet = collection.mutable.HashSet[BlockId]() - blockSet.add(iterator.next()._1) - blockSet.add(iterator.next()._1) - assert(blockSet == collection.immutable.HashSet( - ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) + assert(Set(iterator.next()._1, iterator.next()._1) === + Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) } test("retry corrupt blocks (disabled)") {