From 22d824fb60892bc02a39e26692b4b7e624ebbc2c Mon Sep 17 00:00:00 2001 From: Erik Krogen Date: Wed, 28 Apr 2021 12:43:46 -0700 Subject: [PATCH 01/11] Refactor ShuffleBlockFetcherIteratorSuite -- break out common logic, especially instantiating the ShuffleBlockFetcherIterator --- .../ShuffleBlockFetcherIteratorSuite.scala | 382 ++++++------------ 1 file changed, 126 insertions(+), 256 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 99c43b12d655..27dd4f2ce085 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -35,14 +35,13 @@ import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient} import org.apache.spark.network.util.LimitedInputStream -import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo import org.apache.spark.util.Utils class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { - - private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) + import ShuffleBlockFetcherIteratorSuite._ // Some of the tests are quite tricky because we are testing the cleanup behavior // in the presence of faults. @@ -66,16 +65,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer } - private def createMockBlockManager(): BlockManager = { - val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-local-host", 1) - doReturn(localBmId).when(blockManager).blockManagerId - // By default, the mock BlockManager returns None for hostLocalDirManager. One could - // still use initHostLocalDirManager() to specify a custom hostLocalDirManager. - doReturn(None).when(blockManager).hostLocalDirManager - blockManager - } - private def initHostLocalDirManager( blockManager: BlockManager, hostLocalDirs: Map[String, Array[String]]): Unit = { @@ -167,22 +156,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq) ).toIterator - val taskContext = TaskContext.empty() - val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, + val iterator = createShuffleBlockIteratorWithDefaults( transfer, - blockManager, blocksByAddress, - (_, in) => in, - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - false, - metrics, - false) + blockManager = Some(blockManager) + ) // 3 local blocks fetched in initialization verify(blockManager, times(3)).getLocalBlockData(any()) @@ -238,28 +216,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq) ).toIterator - val transfer = createMockTransfer(Map()) - val taskContext = TaskContext.empty() - val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, - transfer, - blockManager, - blocksByAddress, - (_, in) => in, - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - false, - metrics, - false) + val iterator = createShuffleBlockIteratorWithDefaults( + createMockTransfer(Map()), + blocksByAddress + ) intercept[FetchFailedException] { iterator.next() } } test("Hit maxBytesInFlight limitation before maxBlocksInFlightPerAddress") { - val blockManager = createMockBlockManager() val remoteBmId1 = BlockManagerId("test-remote-client-1", "test-remote-host1", 1) val remoteBmId2 = BlockManagerId("test-remote-client-2", "test-remote-host2", 2) val blockId1 = ShuffleBlockId(0, 1, 0) @@ -270,22 +234,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val transfer = createMockTransfer(Map( blockId1 -> createMockManagedBuffer(1000), blockId2 -> createMockManagedBuffer(1000))) - val taskContext = TaskContext.empty() - val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, + val iterator = createShuffleBlockIteratorWithDefaults( transfer, - blockManager, blocksByAddress, - (_, in) => in, - 1000L, // allow 1 FetchRequests at most at the same time - Int.MaxValue, - Int.MaxValue, // set maxBlocksInFlightPerAddress to Int.MaxValue - Int.MaxValue, - true, - false, - metrics, - false) + maxBytesInFlight = 1000L + ) // After initialize() we'll have 2 FetchRequests and each is 1000 bytes. So only the // first FetchRequests can be sent, and the second one will hit maxBytesInFlight so // it won't be sent. @@ -301,7 +254,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("Hit maxBlocksInFlightPerAddress limitation before maxBytesInFlight") { - val blockManager = createMockBlockManager() val remoteBmId = BlockManagerId("test-remote-client-1", "test-remote-host", 2) val blockId1 = ShuffleBlockId(0, 1, 0) val blockId2 = ShuffleBlockId(0, 2, 0) @@ -312,22 +264,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT blockId1 -> createMockManagedBuffer(), blockId2 -> createMockManagedBuffer(), blockId3 -> createMockManagedBuffer())) - val taskContext = TaskContext.empty() - val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, + val iterator = createShuffleBlockIteratorWithDefaults( transfer, - blockManager, blocksByAddress, - (_, in) => in, - Int.MaxValue, // set maxBytesInFlight to Int.MaxValue - Int.MaxValue, - 2, // set maxBlocksInFlightPerAddress to 2 - Int.MaxValue, - true, - false, - metrics, - false) + maxBlocksInFlightPerAddress = 2 + ) // After initialize(), we'll have 2 FetchRequests that one has 2 blocks inside and another one // has only one block. So only the first FetchRequest can be sent. The second FetchRequest will // hit maxBlocksInFlightPerAddress so it won't be sent. @@ -392,22 +333,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq) ).toIterator - val taskContext = TaskContext.empty() - val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, + val iterator = createShuffleBlockIteratorWithDefaults( transfer, - blockManager, blocksByAddress, - (_, in) => in, - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - false, - metrics, - true) + blockManager = Some(blockManager), + doBatchFetch = true + ) // 3 local blocks batch fetched in initialization verify(blockManager, times(1)).getLocalBlockData(any()) @@ -430,7 +361,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("fetch continuous blocks in batch should respect maxBytesInFlight") { - val blockManager = createMockBlockManager() // Make sure remote blocks would return the merged block val remoteBmId1 = BlockManagerId("test-client-1", "test-client-1", 1) val remoteBmId2 = BlockManagerId("test-client-2", "test-client-2", 2) @@ -449,22 +379,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (remoteBmId1, remoteBlocks1.map(blockId => (blockId, 100L, 1))), (remoteBmId2, remoteBlocks2.map(blockId => (blockId, 100L, 1)))).toIterator - val taskContext = TaskContext.empty() - val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, + val iterator = createShuffleBlockIteratorWithDefaults( transfer, - blockManager, blocksByAddress, - (_, in) => in, - 1500, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - false, - metrics, - true) + maxBytesInFlight = 1500, + doBatchFetch = true + ) var numResults = 0 // After initialize(), there will be 6 FetchRequests. And each of the first 5 requests @@ -486,7 +406,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("fetch continuous blocks in batch should respect maxBlocksInFlightPerAddress") { - val blockManager = createMockBlockManager() // Make sure remote blocks would return the merged block val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 1) val remoteBlocks = Seq( @@ -501,24 +420,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockBatchId(0, 5, 0, 1) -> createMockManagedBuffer()) val transfer = createMockTransfer(mergedRemoteBlocks) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (remoteBmId, remoteBlocks.map(blockId => (blockId, 100L, 1)))).toIterator - val taskContext = TaskContext.empty() - val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, + val iterator = createShuffleBlockIteratorWithDefaults( transfer, - blockManager, - blocksByAddress, - (_, in) => in, - Int.MaxValue, - Int.MaxValue, - 2, - Int.MaxValue, - true, - false, - metrics, - true) + getBlocksByAddressForSingleBM(remoteBmId, remoteBlocks, 100L, 1), + maxBlocksInFlightPerAddress = 2, + doBatchFetch = true + ) var numResults = 0 // After initialize(), there will be 2 FetchRequests. First one has 2 merged blocks and each // of them is merged from 2 shuffle blocks, second one has 1 merged block which is merged from @@ -538,7 +445,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("release current unexhausted buffer in case the task completes early") { - val blockManager = createMockBlockManager() // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( @@ -565,24 +471,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq)).toIterator - val taskContext = TaskContext.empty() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, + val iterator = createShuffleBlockIteratorWithDefaults( transfer, - blockManager, - blocksByAddress, - (_, in) => in, - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - false, - taskContext.taskMetrics.createTempShuffleReadMetrics(), - false) + getBlocksByAddressForSingleBM(remoteBmId, blocks.keys, 1L, 0), + taskContext = Some(taskContext) + ) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() iterator.next()._2.close() // close() first block's input stream @@ -603,7 +497,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("fail all blocks if any of the remote request fails") { - val blockManager = createMockBlockManager() // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( @@ -631,25 +524,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq)) - .toIterator - - val taskContext = TaskContext.empty() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, + val iterator = createShuffleBlockIteratorWithDefaults( transfer, - blockManager, - blocksByAddress, - (_, in) => in, - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - false, - taskContext.taskMetrics.createTempShuffleReadMetrics(), - false) + getBlocksByAddressForSingleBM(remoteBmId, blocks.keys, 1L, 0) + ) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -719,24 +597,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq)).toIterator - - val taskContext = TaskContext.empty() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, + val iterator = createShuffleBlockIteratorWithDefaults( transfer, - blockManager, - blocksByAddress, - (_, in) => new LimitedInputStream(in, 100), - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - true, - taskContext.taskMetrics.createTempShuffleReadMetrics(), - false) + getBlocksByAddressForSingleBM(remoteBmId, blocks.keys, 1L, 0), + streamWrapperLimitSize = Some(100) + ) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -765,7 +630,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT test("big blocks are also checked for corruption") { val streamLength = 10000L - val blockManager = createMockBlockManager() // This stream will throw IOException when the first byte is read val corruptBuffer1 = mockCorruptBuffer(streamLength, 0) @@ -790,22 +654,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (blockManagerId1, blockLengths1), (blockManagerId2, blockLengths2) ).toIterator - val taskContext = TaskContext.empty() - val maxBytesInFlight = 3 * 1024 - val iterator = new ShuffleBlockFetcherIterator( - taskContext, + val iterator = createShuffleBlockIteratorWithDefaults( transfer, - blockManager, blocksByAddress, - (_, in) => new LimitedInputStream(in, streamLength), - maxBytesInFlight, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - true, - taskContext.taskMetrics.createTempShuffleReadMetrics(), - false) + streamWrapperLimitSize = Some(streamLength), + maxBytesInFlight = 3 * 1024 + ) // We'll get back the block which has corruption after maxBytesInFlight/3 because the other // block will detect corruption on first fetch, and then get added to the queue again for @@ -856,21 +710,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (localBmId, localBlockLengths) ).toIterator - val taskContext = TaskContext.empty() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, + val iterator = createShuffleBlockIteratorWithDefaults( transfer, - blockManager, blocksByAddress, - (_, in) => new LimitedInputStream(in, 10000), - 2048, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - true, - taskContext.taskMetrics.createTempShuffleReadMetrics(), - false) + blockManager = Some(blockManager), + streamWrapperLimitSize = Some(10000), + maxBytesInFlight = 2048 + ) val (id, st) = iterator.next() // Check that the test setup is correct -- make sure we have a concatenated stream. assert (st.asInstanceOf[BufferReleasingInputStream].delegate.isInstanceOf[SequenceInputStream]) @@ -884,7 +730,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("retry corrupt blocks (disabled)") { - val blockManager = createMockBlockManager() // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( @@ -912,25 +757,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq)) - .toIterator - - val taskContext = TaskContext.empty() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, + val iterator = createShuffleBlockIteratorWithDefaults( transfer, - blockManager, - blocksByAddress, - (_, in) => new LimitedInputStream(in, 100), - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - false, - taskContext.taskMetrics.createTempShuffleReadMetrics(), - false) + getBlocksByAddressForSingleBM(remoteBmId, blocks.keys, 1L, 0), + streamWrapperLimitSize = Some(100), + maxBytesInFlight = 48 * 1024 * 1024, + detectCorruptUseExtraMemory = false + ) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -975,32 +808,20 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. - val taskContext = TaskContext.empty() - new ShuffleBlockFetcherIterator( - taskContext, + createShuffleBlockIteratorWithDefaults( transfer, - blockManager, blocksByAddress, - (_, in) => in, - maxBytesInFlight = Int.MaxValue, - maxReqsInFlight = Int.MaxValue, - maxBlocksInFlightPerAddress = Int.MaxValue, - maxReqSizeShuffleToMem = 200, - detectCorrupt = true, - false, - taskContext.taskMetrics.createTempShuffleReadMetrics(), - false) + blockManager = Some(blockManager), + maxReqSizeShuffleToMem = 200) } - val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L, 0)).toSeq)).toIterator + val blocksByAddress1 = getBlocksByAddressForSingleBM(remoteBmId, remoteBlocks.keys, 100L, 0) fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. assert(tempFileManager == null) - val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L, 0)).toSeq)).toIterator + val blocksByAddress2 = getBlocksByAddressForSingleBM(remoteBmId, remoteBlocks.keys, 300L, 0) fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. @@ -1008,7 +829,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("fail zero-size blocks") { - val blockManager = createMockBlockManager() // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( @@ -1018,24 +838,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val transfer = createMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0)).toMap) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq)) - - val taskContext = TaskContext.empty() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, + val iterator = createShuffleBlockIteratorWithDefaults( transfer, - blockManager, - blocksByAddress.toIterator, - (_, in) => in, - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - false, - taskContext.taskMetrics.createTempShuffleReadMetrics(), - false) + getBlocksByAddressForSingleBM(remoteBmId, blocks.keys, 1L, 0) + ) // All blocks fetched return zero length and should trigger a receive-side error: val e = intercept[FetchFailedException] { iterator.next() } @@ -1061,3 +867,67 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(mergedBlock.size === inputBlocks.map(_.size).sum) } } + +object ShuffleBlockFetcherIteratorSuite { + + private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) + + private def createMockBlockManager(): BlockManager = { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-local-host", 1) + doReturn(localBmId).when(blockManager).blockManagerId + // By default, the mock BlockManager returns None for hostLocalDirManager. One could + // still use initHostLocalDirManager() to specify a custom hostLocalDirManager. + doReturn(None).when(blockManager).hostLocalDirManager + blockManager + } + + /** + * Get a blockByAddress iterator for a single BlockManagerId assuming all blocks have the same + * size and `blockMapId`. + */ + private def getBlocksByAddressForSingleBM( + blockManagerId: BlockManagerId, + blocks: Traversable[BlockId], + blockSize: Long, + blockMapId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + Seq( + (blockManagerId, blocks.map(blockId => (blockId, blockSize, blockMapId)).toSeq) + ).toIterator + } + + // scalastyle:off argcount + private def createShuffleBlockIteratorWithDefaults( + shuffleClient: BlockTransferService, + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + taskContext: Option[TaskContext] = None, + streamWrapperLimitSize: Option[Long] = None, + blockManager: Option[BlockManager] = None, + maxBytesInFlight: Long = Long.MaxValue, + maxReqsInFlight: Int = Int.MaxValue, + maxBlocksInFlightPerAddress: Int = Int.MaxValue, + maxReqSizeShuffleToMem: Int = Int.MaxValue, + detectCorrupt: Boolean = true, + detectCorruptUseExtraMemory: Boolean = true, + shuffleMetrics: Option[ShuffleReadMetricsReporter] = None, + doBatchFetch: Boolean = false): ShuffleBlockFetcherIterator = { + val tContext = taskContext.getOrElse(TaskContext.empty()) + new ShuffleBlockFetcherIterator( + tContext, + shuffleClient, + blockManager.getOrElse(createMockBlockManager()), + blocksByAddress, + streamWrapperLimitSize + .map(limit => (_: BlockId, in: InputStream) => new LimitedInputStream(in, limit)) + .getOrElse((_: BlockId, in: InputStream) => in), + maxBytesInFlight, + maxReqsInFlight, + maxBlocksInFlightPerAddress, + maxReqSizeShuffleToMem, + detectCorrupt, + detectCorruptUseExtraMemory, + shuffleMetrics.getOrElse(tContext.taskMetrics().createTempShuffleReadMetrics()), + doBatchFetch) + } + // scalastyle:on argcount +} From d37adb80b268cff12a14962ea1574e9390113cc1 Mon Sep 17 00:00:00 2001 From: Erik Krogen Date: Wed, 28 Apr 2021 13:01:20 -0700 Subject: [PATCH 02/11] Refactor ShuffleBlockFetcherIteratorSuite -- use more Scala-like Mockito syntax and pull out common calls to when(transfer.fetchBlocks) and verify(transfer, ...).fetchBlocks --- .../ShuffleBlockFetcherIteratorSuite.scala | 198 +++++++++--------- 1 file changed, 97 insertions(+), 101 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 27dd4f2ce085..40710d8fa0d3 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -27,7 +27,7 @@ import scala.concurrent.Future import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.{mock, times, verify, when} -import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkFunSuite, TaskContext} @@ -49,19 +49,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())).thenAnswer( - (invocation: InvocationOnMock) => { - val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - - for (blockId <- blocks) { - if (data.contains(BlockId(blockId))) { - listener.onBlockFetchSuccess(blockId, data(BlockId(blockId))) - } else { - listener.onBlockFetchFailure(blockId, new BlockNotFoundException(blockId)) - } + answerFetchBlocks(transfer) { invocation => + val blocks = invocation.getArgument[Array[String]](3) + val listener = invocation.getArgument[BlockFetchingListener](4) + + for (blockId <- blocks) { + if (data.contains(BlockId(blockId))) { + listener.onBlockFetchSuccess(blockId, data(BlockId(blockId))) + } else { + listener.onBlockFetchFailure(blockId, new BlockNotFoundException(blockId)) } - }) + } + } transfer } @@ -77,10 +76,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(blockManager.hostLocalDirManager).thenReturn(Some(hostLocalDirManager)) when(mockExternalBlockStoreClient.getHostLocalDirs(any(), any(), any(), any())) .thenAnswer { invocation => - val completableFuture = invocation.getArguments()(3) - .asInstanceOf[CompletableFuture[java.util.Map[String, Array[String]]]] import scala.collection.JavaConverters._ - completableFuture.complete(hostLocalDirs.asJava) + invocation.getArgument[CompletableFuture[java.util.Map[String, Array[String]]]](3) + .complete(hostLocalDirs.asJava) } blockManager.hostLocalDirManager = Some(hostLocalDirManager) @@ -181,7 +179,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT .getHostLocalShuffleData(any(), meq(Array("local-dir"))) // 2 remote blocks are read from the same block manager - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksCount(transfer, 1) assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size === 1) } @@ -206,9 +204,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(blockManager.hostLocalDirManager).thenReturn(Some(hostLocalDirManager)) when(mockExternalBlockStoreClient.getHostLocalDirs(any(), any(), any(), any())) .thenAnswer { invocation => - val completableFuture = invocation.getArguments()(3) - .asInstanceOf[CompletableFuture[java.util.Map[String, Array[String]]]] - completableFuture.completeExceptionally(new Throwable("failed fetch")) + invocation.getArgument[CompletableFuture[java.util.Map[String, Array[String]]]](3) + .completeExceptionally(new Throwable("failed fetch")) } blockManager.hostLocalDirManager = Some(hostLocalDirManager) @@ -242,12 +239,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // After initialize() we'll have 2 FetchRequests and each is 1000 bytes. So only the // first FetchRequests can be sent, and the second one will hit maxBytesInFlight so // it won't be sent. - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksCount(transfer, 1) assert(iterator.hasNext) // next() will trigger off sending deferred request iterator.next() // the second FetchRequest should be sent at this time - verify(transfer, times(2)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksCount(transfer, 2) assert(iterator.hasNext) iterator.next() assert(!iterator.hasNext) @@ -272,14 +269,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // After initialize(), we'll have 2 FetchRequests that one has 2 blocks inside and another one // has only one block. So only the first FetchRequest can be sent. The second FetchRequest will // hit maxBlocksInFlightPerAddress so it won't be sent. - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksCount(transfer, 1) // the first request packaged 2 blocks, so we also need to // call next() for 2 times to exhaust the iterator. assert(iterator.hasNext) iterator.next() assert(iterator.hasNext) iterator.next() - verify(transfer, times(2)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksCount(transfer, 2) assert(iterator.hasNext) iterator.next() assert(!iterator.hasNext) @@ -347,7 +344,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT for (i <- 0 until 3) { assert(iterator.hasNext, s"iterator should have 3 elements but actually has $i elements") val (blockId, inputStream) = iterator.next() - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksCount(transfer, 1) // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = allBlocks(blockId) verifyBufferRelease(mockBuf, inputStream) @@ -392,7 +389,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // block which merged from 2 shuffle blocks. So, only the first 5 requests(5 * 3 * 100 >= 1500) // can be sent. The 6th FetchRequest will hit maxBlocksInFlightPerAddress so it won't // be sent. - verify(transfer, times(5)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksCount(transfer, 5) while (iterator.hasNext) { val (blockId, inputStream) = iterator.next() // Make sure we release buffers when a wrapped input stream is closed. @@ -401,7 +398,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT numResults += 1 } // The 6th request will be sent after next() is called. - verify(transfer, times(6)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksCount(transfer, 6) assert(numResults == 6) } @@ -431,7 +428,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // of them is merged from 2 shuffle blocks, second one has 1 merged block which is merged from // 1 shuffle block. So only the first FetchRequest can be sent. The second FetchRequest will // hit maxBlocksInFlightPerAddress so it won't be sent. - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksCount(transfer, 1) while (iterator.hasNext) { val (blockId, inputStream) = iterator.next() // Make sure we release buffers when a wrapped input stream is closed. @@ -440,7 +437,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT numResults += 1 } // The second request will be sent after next() is called. - verify(transfer, times(2)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksCount(transfer, 2) assert(numResults == 3) } @@ -456,20 +453,19 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) - .thenAnswer((invocation: InvocationOnMock) => { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - Future { - // Return the first two blocks, and wait till task completion before returning the 3rd one - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0))) - sem.acquire() - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0))) - } - }) + answerFetchBlocks(transfer) { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + Future { + // Return the first two blocks, and wait till task completion before returning the 3rd one + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0))) + sem.acquire() + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0))) + } + } val taskContext = TaskContext.empty() val iterator = createShuffleBlockIteratorWithDefaults( @@ -509,20 +505,19 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) - .thenAnswer((invocation: InvocationOnMock) => { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - Future { - // Return the first block, and then fail. - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) - listener.onBlockFetchFailure( - ShuffleBlockId(0, 1, 0).toString, new BlockNotFoundException("blah")) + answerFetchBlocks(transfer) { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + Future { + // Return the first block, and then fail. + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchFailure( + ShuffleBlockId(0, 1, 0).toString, new BlockNotFoundException("blah")) listener.onBlockFetchFailure( ShuffleBlockId(0, 2, 0).toString, new BlockNotFoundException("blah")) - sem.release() - } - }) + sem.release() + } + } val iterator = createShuffleBlockIteratorWithDefaults( transfer, @@ -582,20 +577,19 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) - .thenAnswer((invocation: InvocationOnMock) => { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - Future { - // Return the first block, and then fail. - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer) - sem.release() - } - }) + answerFetchBlocks(transfer) { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + Future { + // Return the first block, and then fail. + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer) + sem.release() + } + } val iterator = createShuffleBlockIteratorWithDefaults( transfer, @@ -610,16 +604,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) - .thenAnswer((invocation: InvocationOnMock) => { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - Future { - // Return the first block, and then fail. - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) - sem.release() - } - }) + answerFetchBlocks(transfer) { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + Future { + // Return the first block, and then fail. + listener.onBlockFetchSuccess(ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) + sem.release() + } + } // The next block is corrupt local block (the second one is corrupt and retried) intercept[FetchFailedException] { iterator.next() } @@ -742,20 +734,19 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) - .thenAnswer((invocation: InvocationOnMock) => { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - Future { - // Return the first block, and then fail. - listener.onBlockFetchSuccess( + answerFetchBlocks(transfer) { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + Future { + // Return the first block, and then fail. + listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer()) - sem.release() - } - }) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer()) + sem.release() + } + } val iterator = createShuffleBlockIteratorWithDefaults( transfer, @@ -793,15 +784,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) var tempFileManager: DownloadFileManager = null - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) - .thenAnswer((invocation: InvocationOnMock) => { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - tempFileManager = invocation.getArguments()(5).asInstanceOf[DownloadFileManager] - Future { - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) - } - }) + answerFetchBlocks(transfer) { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + tempFileManager = invocation.getArgument[DownloadFileManager](5) + Future { + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) + } + } def fetchShuffleBlock( blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = { @@ -872,6 +862,12 @@ object ShuffleBlockFetcherIteratorSuite { private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) + private def answerFetchBlocks(transfer: BlockTransferService)(answer: Answer[Unit]): Unit = + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())).thenAnswer(answer) + + private def verifyFetchBlocksCount(transfer: BlockTransferService, expectedCount: Int): Unit = + verify(transfer, times(expectedCount)).fetchBlocks(any(), any(), any(), any(), any(), any()) + private def createMockBlockManager(): BlockManager = { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-local-host", 1) From 6aa6364c88c9afb181ee3781c05b969fbb3277a3 Mon Sep 17 00:00:00 2001 From: Erik Krogen Date: Thu, 29 Apr 2021 09:10:36 -0700 Subject: [PATCH 03/11] Get rid of object, move helper methods into class. Rename verifyFetchBlocksCount to verifyFetchBlocksInvocationCount. Restore helpful comments. Remove one redundant parameter override. Fix upa few minor issues such as unnecessary specification of very long types. --- .../ShuffleBlockFetcherIteratorSuite.scala | 186 ++++++++---------- 1 file changed, 87 insertions(+), 99 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 40710d8fa0d3..5c339d8807a2 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -41,7 +41,16 @@ import org.apache.spark.util.Utils class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { - import ShuffleBlockFetcherIteratorSuite._ + + private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) + + private def answerFetchBlocks(transfer: BlockTransferService)(answer: Answer[Unit]): Unit = + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())).thenAnswer(answer) + + private def verifyFetchBlocksInvocationCount( + transfer: BlockTransferService, + expectedCount: Int): Unit = + verify(transfer, times(expectedCount)).fetchBlocks(any(), any(), any(), any(), any(), any()) // Some of the tests are quite tricky because we are testing the cleanup behavior // in the presence of faults. @@ -64,6 +73,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer } + private def createMockBlockManager(): BlockManager = { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-local-host", 1) + doReturn(localBmId).when(blockManager).blockManagerId + // By default, the mock BlockManager returns None for hostLocalDirManager. One could + // still use initHostLocalDirManager() to specify a custom hostLocalDirManager. + doReturn(None).when(blockManager).hostLocalDirManager + blockManager + } + private def initHostLocalDirManager( blockManager: BlockManager, hostLocalDirs: Map[String, Array[String]]): Unit = { @@ -110,6 +129,55 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() } + /** + * Get a blockByAddress iterator for a single BlockManagerId assuming all blocks have the same + * size and `blockMapId`. + */ + private def getBlocksByAddressForSingleBM( + blockManagerId: BlockManagerId, + blocks: Traversable[BlockId], + blockSize: Long, + blockMapId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + Seq( + (blockManagerId, blocks.map(blockId => (blockId, blockSize, blockMapId)).toSeq) + ).toIterator + } + + // scalastyle:off argcount + private def createShuffleBlockIteratorWithDefaults( + shuffleClient: BlockTransferService, + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + taskContext: Option[TaskContext] = None, + streamWrapperLimitSize: Option[Long] = None, + blockManager: Option[BlockManager] = None, + maxBytesInFlight: Long = Long.MaxValue, + maxReqsInFlight: Int = Int.MaxValue, + maxBlocksInFlightPerAddress: Int = Int.MaxValue, + maxReqSizeShuffleToMem: Int = Int.MaxValue, + detectCorrupt: Boolean = true, + detectCorruptUseExtraMemory: Boolean = true, + shuffleMetrics: Option[ShuffleReadMetricsReporter] = None, + doBatchFetch: Boolean = false): ShuffleBlockFetcherIterator = { + val tContext = taskContext.getOrElse(TaskContext.empty()) + new ShuffleBlockFetcherIterator( + tContext, + shuffleClient, + blockManager.getOrElse(createMockBlockManager()), + blocksByAddress, + streamWrapperLimitSize + .map(limit => (_: BlockId, in: InputStream) => new LimitedInputStream(in, limit)) + .getOrElse((_: BlockId, in: InputStream) => in), + maxBytesInFlight, + maxReqsInFlight, + maxBlocksInFlightPerAddress, + maxReqSizeShuffleToMem, + detectCorrupt, + detectCorruptUseExtraMemory, + shuffleMetrics.getOrElse(tContext.taskMetrics().createTempShuffleReadMetrics()), + doBatchFetch) + } + // scalastyle:on argcount + test("successful 3 local + 4 host local + 2 remote reads") { val blockManager = createMockBlockManager() val localBmId = blockManager.blockManagerId @@ -179,7 +247,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT .getHostLocalShuffleData(any(), meq(Array("local-dir"))) // 2 remote blocks are read from the same block manager - verifyFetchBlocksCount(transfer, 1) + verifyFetchBlocksInvocationCount(transfer, 1) assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size === 1) } @@ -234,17 +302,17 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val iterator = createShuffleBlockIteratorWithDefaults( transfer, blocksByAddress, - maxBytesInFlight = 1000L + maxBytesInFlight = 1000L // allow 1 FetchRequests at most at the same time ) // After initialize() we'll have 2 FetchRequests and each is 1000 bytes. So only the // first FetchRequests can be sent, and the second one will hit maxBytesInFlight so // it won't be sent. - verifyFetchBlocksCount(transfer, 1) + verifyFetchBlocksInvocationCount(transfer, 1) assert(iterator.hasNext) // next() will trigger off sending deferred request iterator.next() // the second FetchRequest should be sent at this time - verifyFetchBlocksCount(transfer, 2) + verifyFetchBlocksInvocationCount(transfer, 2) assert(iterator.hasNext) iterator.next() assert(!iterator.hasNext) @@ -269,14 +337,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // After initialize(), we'll have 2 FetchRequests that one has 2 blocks inside and another one // has only one block. So only the first FetchRequest can be sent. The second FetchRequest will // hit maxBlocksInFlightPerAddress so it won't be sent. - verifyFetchBlocksCount(transfer, 1) + verifyFetchBlocksInvocationCount(transfer, 1) // the first request packaged 2 blocks, so we also need to // call next() for 2 times to exhaust the iterator. assert(iterator.hasNext) iterator.next() assert(iterator.hasNext) iterator.next() - verifyFetchBlocksCount(transfer, 2) + verifyFetchBlocksInvocationCount(transfer, 2) assert(iterator.hasNext) iterator.next() assert(!iterator.hasNext) @@ -344,7 +412,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT for (i <- 0 until 3) { assert(iterator.hasNext, s"iterator should have 3 elements but actually has $i elements") val (blockId, inputStream) = iterator.next() - verifyFetchBlocksCount(transfer, 1) + verifyFetchBlocksInvocationCount(transfer, 1) // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = allBlocks(blockId) verifyBufferRelease(mockBuf, inputStream) @@ -389,7 +457,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // block which merged from 2 shuffle blocks. So, only the first 5 requests(5 * 3 * 100 >= 1500) // can be sent. The 6th FetchRequest will hit maxBlocksInFlightPerAddress so it won't // be sent. - verifyFetchBlocksCount(transfer, 5) + verifyFetchBlocksInvocationCount(transfer, 5) while (iterator.hasNext) { val (blockId, inputStream) = iterator.next() // Make sure we release buffers when a wrapped input stream is closed. @@ -398,7 +466,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT numResults += 1 } // The 6th request will be sent after next() is called. - verifyFetchBlocksCount(transfer, 6) + verifyFetchBlocksInvocationCount(transfer, 6) assert(numResults == 6) } @@ -428,7 +496,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // of them is merged from 2 shuffle blocks, second one has 1 merged block which is merged from // 1 shuffle block. So only the first FetchRequest can be sent. The second FetchRequest will // hit maxBlocksInFlightPerAddress so it won't be sent. - verifyFetchBlocksCount(transfer, 1) + verifyFetchBlocksInvocationCount(transfer, 1) while (iterator.hasNext) { val (blockId, inputStream) = iterator.next() // Make sure we release buffers when a wrapped input stream is closed. @@ -437,7 +505,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT numResults += 1 } // The second request will be sent after next() is called. - verifyFetchBlocksCount(transfer, 2) + verifyFetchBlocksInvocationCount(transfer, 2) assert(numResults == 3) } @@ -563,7 +631,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("retry corrupt blocks") { - val blockManager = createMockBlockManager() // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( @@ -627,22 +694,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptBuffer1 = mockCorruptBuffer(streamLength, 0) val blockManagerId1 = BlockManagerId("remote-client-1", "remote-client-1", 1) val shuffleBlockId1 = ShuffleBlockId(0, 1, 0) - val blockLengths1 = Seq[Tuple3[BlockId, Long, Int]]( - (shuffleBlockId1, corruptBuffer1.size(), 1) - ) + val blockLengths1 = Seq((shuffleBlockId1, corruptBuffer1.size(), 1)) val streamNotCorruptTill = 8 * 1024 // This stream will throw exception after streamNotCorruptTill bytes are read val corruptBuffer2 = mockCorruptBuffer(streamLength, streamNotCorruptTill) val blockManagerId2 = BlockManagerId("remote-client-2", "remote-client-2", 2) val shuffleBlockId2 = ShuffleBlockId(0, 2, 0) - val blockLengths2 = Seq[Tuple3[BlockId, Long, Int]]( - (shuffleBlockId2, corruptBuffer2.size(), 2) - ) + val blockLengths2 = Seq((shuffleBlockId2, corruptBuffer2.size(), 2)) val transfer = createMockTransfer( Map(shuffleBlockId1 -> corruptBuffer1, shuffleBlockId2 -> corruptBuffer2)) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + val blocksByAddress = Seq( (blockManagerId1, blockLengths1), (blockManagerId2, blockLengths2) ).toIterator @@ -694,22 +757,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val localBmId = BlockManagerId("test-client", "test-client", 1) doReturn(localBmId).when(blockManager).blockManagerId doReturn(managedBuffer).when(blockManager).getLocalBlockData(meq(ShuffleBlockId(0, 0, 0))) - val localBlockLengths = Seq[Tuple3[BlockId, Long, Int]]( - (ShuffleBlockId(0, 0, 0), 10000, 0) - ) val transfer = createMockTransfer(Map(ShuffleBlockId(0, 0, 0) -> managedBuffer)) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (localBmId, localBlockLengths) - ).toIterator + val blocksByAddress = + getBlocksByAddressForSingleBM(localBmId, Seq(ShuffleBlockId(0, 0, 0)), 10000L, 0) val iterator = createShuffleBlockIteratorWithDefaults( transfer, blocksByAddress, blockManager = Some(blockManager), streamWrapperLimitSize = Some(10000), - maxBytesInFlight = 2048 + maxBytesInFlight = 2048 // force concatenation of stream by limiting bytes in flight ) - val (id, st) = iterator.next() + val (_, st) = iterator.next() // Check that the test setup is correct -- make sure we have a concatenated stream. assert (st.asInstanceOf[BufferReleasingInputStream].delegate.isInstanceOf[SequenceInputStream]) @@ -752,7 +811,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, getBlocksByAddressForSingleBM(remoteBmId, blocks.keys, 1L, 0), streamWrapperLimitSize = Some(100), - maxBytesInFlight = 48 * 1024 * 1024, detectCorruptUseExtraMemory = false ) @@ -857,73 +915,3 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(mergedBlock.size === inputBlocks.map(_.size).sum) } } - -object ShuffleBlockFetcherIteratorSuite { - - private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) - - private def answerFetchBlocks(transfer: BlockTransferService)(answer: Answer[Unit]): Unit = - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())).thenAnswer(answer) - - private def verifyFetchBlocksCount(transfer: BlockTransferService, expectedCount: Int): Unit = - verify(transfer, times(expectedCount)).fetchBlocks(any(), any(), any(), any(), any(), any()) - - private def createMockBlockManager(): BlockManager = { - val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-local-host", 1) - doReturn(localBmId).when(blockManager).blockManagerId - // By default, the mock BlockManager returns None for hostLocalDirManager. One could - // still use initHostLocalDirManager() to specify a custom hostLocalDirManager. - doReturn(None).when(blockManager).hostLocalDirManager - blockManager - } - - /** - * Get a blockByAddress iterator for a single BlockManagerId assuming all blocks have the same - * size and `blockMapId`. - */ - private def getBlocksByAddressForSingleBM( - blockManagerId: BlockManagerId, - blocks: Traversable[BlockId], - blockSize: Long, - blockMapId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { - Seq( - (blockManagerId, blocks.map(blockId => (blockId, blockSize, blockMapId)).toSeq) - ).toIterator - } - - // scalastyle:off argcount - private def createShuffleBlockIteratorWithDefaults( - shuffleClient: BlockTransferService, - blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], - taskContext: Option[TaskContext] = None, - streamWrapperLimitSize: Option[Long] = None, - blockManager: Option[BlockManager] = None, - maxBytesInFlight: Long = Long.MaxValue, - maxReqsInFlight: Int = Int.MaxValue, - maxBlocksInFlightPerAddress: Int = Int.MaxValue, - maxReqSizeShuffleToMem: Int = Int.MaxValue, - detectCorrupt: Boolean = true, - detectCorruptUseExtraMemory: Boolean = true, - shuffleMetrics: Option[ShuffleReadMetricsReporter] = None, - doBatchFetch: Boolean = false): ShuffleBlockFetcherIterator = { - val tContext = taskContext.getOrElse(TaskContext.empty()) - new ShuffleBlockFetcherIterator( - tContext, - shuffleClient, - blockManager.getOrElse(createMockBlockManager()), - blocksByAddress, - streamWrapperLimitSize - .map(limit => (_: BlockId, in: InputStream) => new LimitedInputStream(in, limit)) - .getOrElse((_: BlockId, in: InputStream) => in), - maxBytesInFlight, - maxReqsInFlight, - maxBlocksInFlightPerAddress, - maxReqSizeShuffleToMem, - detectCorrupt, - detectCorruptUseExtraMemory, - shuffleMetrics.getOrElse(tContext.taskMetrics().createTempShuffleReadMetrics()), - doBatchFetch) - } - // scalastyle:on argcount -} From 9329214f71ddc53c6e4f24605ede8bed0636764f Mon Sep 17 00:00:00 2001 From: Erik Krogen Date: Thu, 29 Apr 2021 09:18:08 -0700 Subject: [PATCH 04/11] Move 'transfer' to be a field to reduce duplicated mock setup and passing it around in method signatures. --- .../ShuffleBlockFetcherIteratorSuite.scala | 103 +++++++----------- 1 file changed, 40 insertions(+), 63 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 5c339d8807a2..009f56361c97 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -42,23 +42,26 @@ import org.apache.spark.util.Utils class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { + private var transfer: BlockTransferService = _ + + override def beforeEach(): Unit = { + transfer = mock(classOf[BlockTransferService]) + } + private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) - private def answerFetchBlocks(transfer: BlockTransferService)(answer: Answer[Unit]): Unit = + private def answerFetchBlocks(answer: Answer[Unit]): Unit = when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())).thenAnswer(answer) - private def verifyFetchBlocksInvocationCount( - transfer: BlockTransferService, - expectedCount: Int): Unit = + private def verifyFetchBlocksInvocationCount(expectedCount: Int): Unit = verify(transfer, times(expectedCount)).fetchBlocks(any(), any(), any(), any(), any(), any()) // Some of the tests are quite tricky because we are testing the cleanup behavior // in the presence of faults. - /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ - private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { - val transfer = mock(classOf[BlockTransferService]) - answerFetchBlocks(transfer) { invocation => + /** Configures `transfer` (mock [[BlockTransferService]]) to return data from the given map. */ + private def configureMockTransfer(data: Map[BlockId, ManagedBuffer]): Unit = { + answerFetchBlocks { invocation => val blocks = invocation.getArgument[Array[String]](3) val listener = invocation.getArgument[BlockFetchingListener](4) @@ -70,7 +73,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } } } - transfer } private def createMockBlockManager(): BlockManager = { @@ -145,7 +147,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // scalastyle:off argcount private def createShuffleBlockIteratorWithDefaults( - shuffleClient: BlockTransferService, blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], taskContext: Option[TaskContext] = None, streamWrapperLimitSize: Option[Long] = None, @@ -161,7 +162,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val tContext = taskContext.getOrElse(TaskContext.empty()) new ShuffleBlockFetcherIterator( tContext, - shuffleClient, + transfer, blockManager.getOrElse(createMockBlockManager()), blocksByAddress, streamWrapperLimitSize @@ -197,7 +198,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer(), ShuffleBlockId(0, 4, 0) -> createMockManagedBuffer()) - val transfer = createMockTransfer(remoteBlocks) + configureMockTransfer(remoteBlocks) // Create a block manager running on the same host (host-local) val hostLocalBmId = BlockManagerId("test-host-local-client-1", "test-local-host", 3) @@ -223,7 +224,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ).toIterator val iterator = createShuffleBlockIteratorWithDefaults( - transfer, blocksByAddress, blockManager = Some(blockManager) ) @@ -247,7 +247,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT .getHostLocalShuffleData(any(), meq(Array("local-dir"))) // 2 remote blocks are read from the same block manager - verifyFetchBlocksInvocationCount(transfer, 1) + verifyFetchBlocksInvocationCount(1) assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size === 1) } @@ -281,10 +281,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq) ).toIterator - val iterator = createShuffleBlockIteratorWithDefaults( - createMockTransfer(Map()), - blocksByAddress - ) + configureMockTransfer(Map()) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress) intercept[FetchFailedException] { iterator.next() } } @@ -296,23 +294,19 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq( (remoteBmId1, Seq((blockId1, 1000L, 0))), (remoteBmId2, Seq((blockId2, 1000L, 0)))).toIterator - val transfer = createMockTransfer(Map( + configureMockTransfer(Map( blockId1 -> createMockManagedBuffer(1000), blockId2 -> createMockManagedBuffer(1000))) - val iterator = createShuffleBlockIteratorWithDefaults( - transfer, - blocksByAddress, - maxBytesInFlight = 1000L // allow 1 FetchRequests at most at the same time - ) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, maxBytesInFlight = 1000L) // After initialize() we'll have 2 FetchRequests and each is 1000 bytes. So only the // first FetchRequests can be sent, and the second one will hit maxBytesInFlight so // it won't be sent. - verifyFetchBlocksInvocationCount(transfer, 1) + verifyFetchBlocksInvocationCount(1) assert(iterator.hasNext) // next() will trigger off sending deferred request iterator.next() // the second FetchRequest should be sent at this time - verifyFetchBlocksInvocationCount(transfer, 2) + verifyFetchBlocksInvocationCount(2) assert(iterator.hasNext) iterator.next() assert(!iterator.hasNext) @@ -325,26 +319,25 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blockId3 = ShuffleBlockId(0, 3, 0) val blocksByAddress = Seq((remoteBmId, Seq((blockId1, 1000L, 0), (blockId2, 1000L, 0), (blockId3, 1000L, 0)))).toIterator - val transfer = createMockTransfer(Map( + configureMockTransfer(Map( blockId1 -> createMockManagedBuffer(), blockId2 -> createMockManagedBuffer(), blockId3 -> createMockManagedBuffer())) val iterator = createShuffleBlockIteratorWithDefaults( - transfer, blocksByAddress, maxBlocksInFlightPerAddress = 2 ) // After initialize(), we'll have 2 FetchRequests that one has 2 blocks inside and another one // has only one block. So only the first FetchRequest can be sent. The second FetchRequest will // hit maxBlocksInFlightPerAddress so it won't be sent. - verifyFetchBlocksInvocationCount(transfer, 1) + verifyFetchBlocksInvocationCount(1) // the first request packaged 2 blocks, so we also need to // call next() for 2 times to exhaust the iterator. assert(iterator.hasNext) iterator.next() assert(iterator.hasNext) iterator.next() - verifyFetchBlocksInvocationCount(transfer, 2) + verifyFetchBlocksInvocationCount(2) assert(iterator.hasNext) iterator.next() assert(!iterator.hasNext) @@ -371,7 +364,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 3, 1)) val mergedRemoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockBatchId(0, 3, 0, 2) -> createMockManagedBuffer()) - val transfer = createMockTransfer(mergedRemoteBlocks) + configureMockTransfer(mergedRemoteBlocks) // Create a block manager running on the same host (host-local) val hostLocalBmId = BlockManagerId("test-host-local-client-1", "test-local-host", 3) @@ -399,7 +392,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ).toIterator val iterator = createShuffleBlockIteratorWithDefaults( - transfer, blocksByAddress, blockManager = Some(blockManager), doBatchFetch = true @@ -412,7 +404,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT for (i <- 0 until 3) { assert(iterator.hasNext, s"iterator should have 3 elements but actually has $i elements") val (blockId, inputStream) = iterator.next() - verifyFetchBlocksInvocationCount(transfer, 1) + verifyFetchBlocksInvocationCount(1) // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = allBlocks(blockId) verifyBufferRelease(mockBuf, inputStream) @@ -438,14 +430,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockBatchId(0, 3, 9, 12) -> createMockManagedBuffer(), ShuffleBlockBatchId(0, 3, 12, 15) -> createMockManagedBuffer(), ShuffleBlockBatchId(0, 4, 0, 2) -> createMockManagedBuffer()) - val transfer = createMockTransfer(mergedRemoteBlocks) + configureMockTransfer(mergedRemoteBlocks) val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( (remoteBmId1, remoteBlocks1.map(blockId => (blockId, 100L, 1))), (remoteBmId2, remoteBlocks2.map(blockId => (blockId, 100L, 1)))).toIterator val iterator = createShuffleBlockIteratorWithDefaults( - transfer, blocksByAddress, maxBytesInFlight = 1500, doBatchFetch = true @@ -457,7 +448,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // block which merged from 2 shuffle blocks. So, only the first 5 requests(5 * 3 * 100 >= 1500) // can be sent. The 6th FetchRequest will hit maxBlocksInFlightPerAddress so it won't // be sent. - verifyFetchBlocksInvocationCount(transfer, 5) + verifyFetchBlocksInvocationCount(5) while (iterator.hasNext) { val (blockId, inputStream) = iterator.next() // Make sure we release buffers when a wrapped input stream is closed. @@ -466,7 +457,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT numResults += 1 } // The 6th request will be sent after next() is called. - verifyFetchBlocksInvocationCount(transfer, 6) + verifyFetchBlocksInvocationCount(6) assert(numResults == 6) } @@ -484,9 +475,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockBatchId(0, 4, 0, 2) -> createMockManagedBuffer(), ShuffleBlockBatchId(0, 5, 0, 1) -> createMockManagedBuffer()) - val transfer = createMockTransfer(mergedRemoteBlocks) + configureMockTransfer(mergedRemoteBlocks) val iterator = createShuffleBlockIteratorWithDefaults( - transfer, getBlocksByAddressForSingleBM(remoteBmId, remoteBlocks, 100L, 1), maxBlocksInFlightPerAddress = 2, doBatchFetch = true @@ -496,7 +486,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // of them is merged from 2 shuffle blocks, second one has 1 merged block which is merged from // 1 shuffle block. So only the first FetchRequest can be sent. The second FetchRequest will // hit maxBlocksInFlightPerAddress so it won't be sent. - verifyFetchBlocksInvocationCount(transfer, 1) + verifyFetchBlocksInvocationCount(1) while (iterator.hasNext) { val (blockId, inputStream) = iterator.next() // Make sure we release buffers when a wrapped input stream is closed. @@ -505,7 +495,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT numResults += 1 } // The second request will be sent after next() is called. - verifyFetchBlocksInvocationCount(transfer, 2) + verifyFetchBlocksInvocationCount(2) assert(numResults == 3) } @@ -520,8 +510,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - val transfer = mock(classOf[BlockTransferService]) - answerFetchBlocks(transfer) { invocation => + answerFetchBlocks { invocation => val listener = invocation.getArgument[BlockFetchingListener](4) Future { // Return the first two blocks, and wait till task completion before returning the 3rd one @@ -537,7 +526,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val taskContext = TaskContext.empty() val iterator = createShuffleBlockIteratorWithDefaults( - transfer, getBlocksByAddressForSingleBM(remoteBmId, blocks.keys, 1L, 0), taskContext = Some(taskContext) ) @@ -572,8 +560,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - val transfer = mock(classOf[BlockTransferService]) - answerFetchBlocks(transfer) { invocation => + answerFetchBlocks { invocation => val listener = invocation.getArgument[BlockFetchingListener](4) Future { // Return the first block, and then fail. @@ -588,7 +575,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } val iterator = createShuffleBlockIteratorWithDefaults( - transfer, getBlocksByAddressForSingleBM(remoteBmId, blocks.keys, 1L, 0) ) @@ -643,8 +629,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) - val transfer = mock(classOf[BlockTransferService]) - answerFetchBlocks(transfer) { invocation => + answerFetchBlocks { invocation => val listener = invocation.getArgument[BlockFetchingListener](4) Future { // Return the first block, and then fail. @@ -659,7 +644,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } val iterator = createShuffleBlockIteratorWithDefaults( - transfer, getBlocksByAddressForSingleBM(remoteBmId, blocks.keys, 1L, 0), streamWrapperLimitSize = Some(100) ) @@ -671,7 +655,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - answerFetchBlocks(transfer) { invocation => + answerFetchBlocks { invocation => val listener = invocation.getArgument[BlockFetchingListener](4) Future { // Return the first block, and then fail. @@ -703,14 +687,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val shuffleBlockId2 = ShuffleBlockId(0, 2, 0) val blockLengths2 = Seq((shuffleBlockId2, corruptBuffer2.size(), 2)) - val transfer = createMockTransfer( + configureMockTransfer( Map(shuffleBlockId1 -> corruptBuffer1, shuffleBlockId2 -> corruptBuffer2)) val blocksByAddress = Seq( (blockManagerId1, blockLengths1), (blockManagerId2, blockLengths2) ).toIterator val iterator = createShuffleBlockIteratorWithDefaults( - transfer, blocksByAddress, streamWrapperLimitSize = Some(streamLength), maxBytesInFlight = 3 * 1024 @@ -757,12 +740,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val localBmId = BlockManagerId("test-client", "test-client", 1) doReturn(localBmId).when(blockManager).blockManagerId doReturn(managedBuffer).when(blockManager).getLocalBlockData(meq(ShuffleBlockId(0, 0, 0))) - val transfer = createMockTransfer(Map(ShuffleBlockId(0, 0, 0) -> managedBuffer)) + configureMockTransfer(Map(ShuffleBlockId(0, 0, 0) -> managedBuffer)) val blocksByAddress = getBlocksByAddressForSingleBM(localBmId, Seq(ShuffleBlockId(0, 0, 0)), 10000L, 0) val iterator = createShuffleBlockIteratorWithDefaults( - transfer, blocksByAddress, blockManager = Some(blockManager), streamWrapperLimitSize = Some(10000), @@ -792,8 +774,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - val transfer = mock(classOf[BlockTransferService]) - answerFetchBlocks(transfer) { invocation => + answerFetchBlocks { invocation => val listener = invocation.getArgument[BlockFetchingListener](4) Future { // Return the first block, and then fail. @@ -808,7 +789,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } val iterator = createShuffleBlockIteratorWithDefaults( - transfer, getBlocksByAddressForSingleBM(remoteBmId, blocks.keys, 1L, 0), streamWrapperLimitSize = Some(100), detectCorruptUseExtraMemory = false @@ -840,9 +820,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val remoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) - val transfer = mock(classOf[BlockTransferService]) var tempFileManager: DownloadFileManager = null - answerFetchBlocks(transfer) { invocation => + answerFetchBlocks { invocation => val listener = invocation.getArgument[BlockFetchingListener](4) tempFileManager = invocation.getArgument[DownloadFileManager](5) Future { @@ -857,7 +836,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. createShuffleBlockIteratorWithDefaults( - transfer, blocksByAddress, blockManager = Some(blockManager), maxReqSizeShuffleToMem = 200) @@ -884,10 +862,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer() ) - val transfer = createMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0)).toMap) + configureMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0)).toMap) val iterator = createShuffleBlockIteratorWithDefaults( - transfer, getBlocksByAddressForSingleBM(remoteBmId, blocks.keys, 1L, 0) ) From 4dfe47dd43630799a277096d637fe48be8b85fd0 Mon Sep 17 00:00:00 2001 From: Erik Krogen Date: Mon, 10 May 2021 09:11:55 -0700 Subject: [PATCH 05/11] Address ngone51 comments -- minor updates and add a new getBlocksByAddress which accepts a Map --- .../ShuffleBlockFetcherIteratorSuite.scala | 107 +++++++++--------- 1 file changed, 53 insertions(+), 54 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 009f56361c97..211dc9d4cf53 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -131,18 +131,30 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() } + /** + * Get a blockByAddress iterator for a set of block managers, assuming that for each block + * manager, all associated blocks have the same size and `blockMapId`. Each value of the + * map contain `(blocks, blockSize, blockMapIndex)`. + */ + private def getBlocksByAddress( + blockManagerIdToBlocks: Map[BlockManagerId, (Traversable[BlockId], Long, Int)]) + : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + blockManagerIdToBlocks.map { case (blockManagerId, (blocks, blockSize, blockMapIndex)) => + (blockManagerId, blocks.map(blockId => (blockId, blockSize, blockMapIndex)).toSeq) + }.toIterator + } + /** * Get a blockByAddress iterator for a single BlockManagerId assuming all blocks have the same - * size and `blockMapId`. + * size and `blockMapIndex`. Convenience method for calling `getBlocksByAddress` with a map + * containing a single entry. */ private def getBlocksByAddressForSingleBM( blockManagerId: BlockManagerId, blocks: Traversable[BlockId], blockSize: Long, - blockMapId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { - Seq( - (blockManagerId, blocks.map(blockId => (blockId, blockSize, blockMapId)).toSeq) - ).toIterator + blockMapIndex: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + getBlocksByAddress(Map(blockManagerId -> (blocks, blockSize, blockMapIndex))) } // scalastyle:off argcount @@ -166,8 +178,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT blockManager.getOrElse(createMockBlockManager()), blocksByAddress, streamWrapperLimitSize - .map(limit => (_: BlockId, in: InputStream) => new LimitedInputStream(in, limit)) - .getOrElse((_: BlockId, in: InputStream) => in), + .map(limit => (_: BlockId, in: InputStream) => new LimitedInputStream(in, limit)) + .getOrElse((_: BlockId, in: InputStream) => in), maxBytesInFlight, maxReqsInFlight, maxBlocksInFlightPerAddress, @@ -217,11 +229,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // returning local dir for hostLocalBmId initHostLocalDirManager(blockManager, hostLocalDirs) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (localBmId, localBlocks.keys.map(blockId => (blockId, 1L, 0)).toSeq), - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq), - (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq) - ).toIterator + val blocksByAddress = getBlocksByAddress(Map( + localBmId -> (localBlocks.keys, 1L, 0), + remoteBmId -> (remoteBlocks.keys, 1L, 1), + hostLocalBlocks -> (hostLocalBlocks.keys, 1L, 1) + )) val iterator = createShuffleBlockIteratorWithDefaults( blocksByAddress, @@ -277,12 +289,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } blockManager.hostLocalDirManager = Some(hostLocalDirManager) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq) - ).toIterator configureMockTransfer(Map()) - val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress) + val iterator = createShuffleBlockIteratorWithDefaults( + getBlocksByAddressForSingleBM(hostLocalBmId, hostLocalBlocks.keys, 1L, 1) + ) intercept[FetchFailedException] { iterator.next() } } @@ -291,9 +302,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBmId2 = BlockManagerId("test-remote-client-2", "test-remote-host2", 2) val blockId1 = ShuffleBlockId(0, 1, 0) val blockId2 = ShuffleBlockId(1, 1, 0) - val blocksByAddress = Seq( - (remoteBmId1, Seq((blockId1, 1000L, 0))), - (remoteBmId2, Seq((blockId2, 1000L, 0)))).toIterator + val blocksByAddress = getBlocksByAddress(Map( + remoteBmId1 -> (Seq(blockId1), 1000L, 0), + remoteBmId2 -> (Seq(blockId2), 1000L, 0) + )) configureMockTransfer(Map( blockId1 -> createMockManagedBuffer(1000), blockId2 -> createMockManagedBuffer(1000))) @@ -314,17 +326,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT test("Hit maxBlocksInFlightPerAddress limitation before maxBytesInFlight") { val remoteBmId = BlockManagerId("test-remote-client-1", "test-remote-host", 2) - val blockId1 = ShuffleBlockId(0, 1, 0) - val blockId2 = ShuffleBlockId(0, 2, 0) - val blockId3 = ShuffleBlockId(0, 3, 0) - val blocksByAddress = Seq((remoteBmId, - Seq((blockId1, 1000L, 0), (blockId2, 1000L, 0), (blockId3, 1000L, 0)))).toIterator - configureMockTransfer(Map( - blockId1 -> createMockManagedBuffer(), - blockId2 -> createMockManagedBuffer(), - blockId3 -> createMockManagedBuffer())) + val blocks = 1.to(3).map(ShuffleBlockId(0, _, 0)) + configureMockTransfer(blocks.map(_ -> createMockManagedBuffer()).toMap) val iterator = createShuffleBlockIteratorWithDefaults( - blocksByAddress, + getBlocksByAddressForSingleBM(remoteBmId, blocks, 1000L, 0), maxBlocksInFlightPerAddress = 2 ) // After initialize(), we'll have 2 FetchRequests that one has 2 blocks inside and another one @@ -385,11 +390,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // returning local dir for hostLocalBmId initHostLocalDirManager(blockManager, hostLocalDirs) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (localBmId, localBlocks.map(blockId => (blockId, 1L, 0))), - (remoteBmId, remoteBlocks.map(blockId => (blockId, 1L, 1))), - (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq) - ).toIterator + val blocksByAddress = getBlocksByAddress(Map( + localBmId -> (localBlocks, 1L, 0), + remoteBmId -> (remoteBlocks, 1L, 1), + hostLocalBmId -> (hostLocalBlocks.keys, 1L, 1) + )) val iterator = createShuffleBlockIteratorWithDefaults( blocksByAddress, @@ -432,9 +437,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockBatchId(0, 4, 0, 2) -> createMockManagedBuffer()) configureMockTransfer(mergedRemoteBlocks) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (remoteBmId1, remoteBlocks1.map(blockId => (blockId, 100L, 1))), - (remoteBmId2, remoteBlocks2.map(blockId => (blockId, 100L, 1)))).toIterator + val blocksByAddress = getBlocksByAddress(Map( + remoteBmId1 -> (remoteBlocks1, 100L, 1), + remoteBmId2 -> (remoteBlocks2, 100L, 1) + )) val iterator = createShuffleBlockIteratorWithDefaults( blocksByAddress, @@ -678,21 +684,19 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptBuffer1 = mockCorruptBuffer(streamLength, 0) val blockManagerId1 = BlockManagerId("remote-client-1", "remote-client-1", 1) val shuffleBlockId1 = ShuffleBlockId(0, 1, 0) - val blockLengths1 = Seq((shuffleBlockId1, corruptBuffer1.size(), 1)) val streamNotCorruptTill = 8 * 1024 // This stream will throw exception after streamNotCorruptTill bytes are read val corruptBuffer2 = mockCorruptBuffer(streamLength, streamNotCorruptTill) val blockManagerId2 = BlockManagerId("remote-client-2", "remote-client-2", 2) val shuffleBlockId2 = ShuffleBlockId(0, 2, 0) - val blockLengths2 = Seq((shuffleBlockId2, corruptBuffer2.size(), 2)) configureMockTransfer( Map(shuffleBlockId1 -> corruptBuffer1, shuffleBlockId2 -> corruptBuffer2)) - val blocksByAddress = Seq( - (blockManagerId1, blockLengths1), - (blockManagerId2, blockLengths2) - ).toIterator + val blocksByAddress = getBlocksByAddress(Map( + blockManagerId1 -> (Seq(shuffleBlockId1), corruptBuffer1.size(), 1), + blockManagerId2 -> (Seq(shuffleBlockId2), corruptBuffer2.size(), 2) + )) val iterator = createShuffleBlockIteratorWithDefaults( blocksByAddress, streamWrapperLimitSize = Some(streamLength), @@ -741,11 +745,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT doReturn(localBmId).when(blockManager).blockManagerId doReturn(managedBuffer).when(blockManager).getLocalBlockData(meq(ShuffleBlockId(0, 0, 0))) configureMockTransfer(Map(ShuffleBlockId(0, 0, 0) -> managedBuffer)) - val blocksByAddress = - getBlocksByAddressForSingleBM(localBmId, Seq(ShuffleBlockId(0, 0, 0)), 10000L, 0) val iterator = createShuffleBlockIteratorWithDefaults( - blocksByAddress, + getBlocksByAddressForSingleBM(localBmId, Seq(ShuffleBlockId(0, 0, 0)), 10000L, 0), blockManager = Some(blockManager), streamWrapperLimitSize = Some(10000), maxBytesInFlight = 2048 // force concatenation of stream by limiting bytes in flight @@ -830,25 +832,22 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } } - def fetchShuffleBlock( - blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = { - // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the + def fetchShuffleBlock(blockSize: Long): Unit = { + // Use default `maxBytesInFlight` and `maxReqsInFlight` (`Int.MaxValue`) so that during the // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. createShuffleBlockIteratorWithDefaults( - blocksByAddress, + getBlocksByAddressForSingleBM(remoteBmId, remoteBlocks.keys, blockSize, 0), blockManager = Some(blockManager), maxReqSizeShuffleToMem = 200) } - val blocksByAddress1 = getBlocksByAddressForSingleBM(remoteBmId, remoteBlocks.keys, 100L, 0) - fetchShuffleBlock(blocksByAddress1) + fetchShuffleBlock(100L) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. assert(tempFileManager == null) - val blocksByAddress2 = getBlocksByAddressForSingleBM(remoteBmId, remoteBlocks.keys, 300L, 0) - fetchShuffleBlock(blocksByAddress2) + fetchShuffleBlock(300L) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. assert(tempFileManager != null) From e5d4e348b50c4098f0e1c76acb4ca5ff879193e1 Mon Sep 17 00:00:00 2001 From: Erik Krogen Date: Mon, 10 May 2021 09:16:08 -0700 Subject: [PATCH 06/11] Additional refactor to just make getShuffleIteratorWithDefaults directly accept a map of block info --- .../ShuffleBlockFetcherIteratorSuite.scala | 110 ++++++------------ 1 file changed, 37 insertions(+), 73 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 211dc9d4cf53..6b06eec97c17 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -131,35 +131,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() } - /** - * Get a blockByAddress iterator for a set of block managers, assuming that for each block - * manager, all associated blocks have the same size and `blockMapId`. Each value of the - * map contain `(blocks, blockSize, blockMapIndex)`. - */ - private def getBlocksByAddress( - blockManagerIdToBlocks: Map[BlockManagerId, (Traversable[BlockId], Long, Int)]) - : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { - blockManagerIdToBlocks.map { case (blockManagerId, (blocks, blockSize, blockMapIndex)) => - (blockManagerId, blocks.map(blockId => (blockId, blockSize, blockMapIndex)).toSeq) - }.toIterator - } - - /** - * Get a blockByAddress iterator for a single BlockManagerId assuming all blocks have the same - * size and `blockMapIndex`. Convenience method for calling `getBlocksByAddress` with a map - * containing a single entry. - */ - private def getBlocksByAddressForSingleBM( - blockManagerId: BlockManagerId, - blocks: Traversable[BlockId], - blockSize: Long, - blockMapIndex: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { - getBlocksByAddress(Map(blockManagerId -> (blocks, blockSize, blockMapIndex))) - } - // scalastyle:off argcount private def createShuffleBlockIteratorWithDefaults( - blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + blocksByAddress: Map[BlockManagerId, (Traversable[BlockId], Long, Int)], taskContext: Option[TaskContext] = None, streamWrapperLimitSize: Option[Long] = None, blockManager: Option[BlockManager] = None, @@ -176,7 +150,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT tContext, transfer, blockManager.getOrElse(createMockBlockManager()), - blocksByAddress, + blocksByAddress.map { case (blockManagerId, (blocks, blockSize, blockMapIndex)) => + (blockManagerId, blocks.map(blockId => (blockId, blockSize, blockMapIndex)).toSeq) + }.toIterator, streamWrapperLimitSize .map(limit => (_: BlockId, in: InputStream) => new LimitedInputStream(in, limit)) .getOrElse((_: BlockId, in: InputStream) => in), @@ -214,11 +190,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Create a block manager running on the same host (host-local) val hostLocalBmId = BlockManagerId("test-host-local-client-1", "test-local-host", 3) - val hostLocalBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 5, 0) -> createMockManagedBuffer(), - ShuffleBlockId(0, 6, 0) -> createMockManagedBuffer(), - ShuffleBlockId(0, 7, 0) -> createMockManagedBuffer(), - ShuffleBlockId(0, 8, 0) -> createMockManagedBuffer()) + val hostLocalBlocks = 5.to(8).map(ShuffleBlockId(0, _, 0) -> createMockManagedBuffer()).toMap hostLocalBlocks.foreach { case (blockId, buf) => doReturn(buf) @@ -229,14 +201,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // returning local dir for hostLocalBmId initHostLocalDirManager(blockManager, hostLocalDirs) - val blocksByAddress = getBlocksByAddress(Map( - localBmId -> (localBlocks.keys, 1L, 0), - remoteBmId -> (remoteBlocks.keys, 1L, 1), - hostLocalBlocks -> (hostLocalBlocks.keys, 1L, 1) - )) - val iterator = createShuffleBlockIteratorWithDefaults( - blocksByAddress, + Map( + localBmId -> (localBlocks.keys, 1L, 0), + remoteBmId -> (remoteBlocks.keys, 1L, 1), + hostLocalBmId -> (hostLocalBlocks.keys, 1L, 1) + ), blockManager = Some(blockManager) ) @@ -292,7 +262,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT configureMockTransfer(Map()) val iterator = createShuffleBlockIteratorWithDefaults( - getBlocksByAddressForSingleBM(hostLocalBmId, hostLocalBlocks.keys, 1L, 1) + Map(hostLocalBmId -> (hostLocalBlocks.keys, 1L, 1)) ) intercept[FetchFailedException] { iterator.next() } } @@ -302,14 +272,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBmId2 = BlockManagerId("test-remote-client-2", "test-remote-host2", 2) val blockId1 = ShuffleBlockId(0, 1, 0) val blockId2 = ShuffleBlockId(1, 1, 0) - val blocksByAddress = getBlocksByAddress(Map( - remoteBmId1 -> (Seq(blockId1), 1000L, 0), - remoteBmId2 -> (Seq(blockId2), 1000L, 0) - )) configureMockTransfer(Map( blockId1 -> createMockManagedBuffer(1000), blockId2 -> createMockManagedBuffer(1000))) - val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, maxBytesInFlight = 1000L) + val iterator = createShuffleBlockIteratorWithDefaults(Map( + remoteBmId1 -> (Seq(blockId1), 1000L, 0), + remoteBmId2 -> (Seq(blockId2), 1000L, 0) + ), maxBytesInFlight = 1000L) // After initialize() we'll have 2 FetchRequests and each is 1000 bytes. So only the // first FetchRequests can be sent, and the second one will hit maxBytesInFlight so // it won't be sent. @@ -329,7 +298,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocks = 1.to(3).map(ShuffleBlockId(0, _, 0)) configureMockTransfer(blocks.map(_ -> createMockManagedBuffer()).toMap) val iterator = createShuffleBlockIteratorWithDefaults( - getBlocksByAddressForSingleBM(remoteBmId, blocks, 1000L, 0), + Map(remoteBmId -> (blocks, 1000L, 0)), maxBlocksInFlightPerAddress = 2 ) // After initialize(), we'll have 2 FetchRequests that one has 2 blocks inside and another one @@ -390,14 +359,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // returning local dir for hostLocalBmId initHostLocalDirManager(blockManager, hostLocalDirs) - val blocksByAddress = getBlocksByAddress(Map( - localBmId -> (localBlocks, 1L, 0), - remoteBmId -> (remoteBlocks, 1L, 1), - hostLocalBmId -> (hostLocalBlocks.keys, 1L, 1) - )) - val iterator = createShuffleBlockIteratorWithDefaults( - blocksByAddress, + Map( + localBmId -> (localBlocks, 1L, 0), + remoteBmId -> (remoteBlocks, 1L, 1), + hostLocalBmId -> (hostLocalBlocks.keys, 1L, 1) + ), blockManager = Some(blockManager), doBatchFetch = true ) @@ -437,13 +404,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockBatchId(0, 4, 0, 2) -> createMockManagedBuffer()) configureMockTransfer(mergedRemoteBlocks) - val blocksByAddress = getBlocksByAddress(Map( - remoteBmId1 -> (remoteBlocks1, 100L, 1), - remoteBmId2 -> (remoteBlocks2, 100L, 1) - )) - val iterator = createShuffleBlockIteratorWithDefaults( - blocksByAddress, + Map( + remoteBmId1 -> (remoteBlocks1, 100L, 1), + remoteBmId2 -> (remoteBlocks2, 100L, 1) + ), maxBytesInFlight = 1500, doBatchFetch = true ) @@ -483,7 +448,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT configureMockTransfer(mergedRemoteBlocks) val iterator = createShuffleBlockIteratorWithDefaults( - getBlocksByAddressForSingleBM(remoteBmId, remoteBlocks, 100L, 1), + Map(remoteBmId -> (remoteBlocks, 100L, 1)), maxBlocksInFlightPerAddress = 2, doBatchFetch = true ) @@ -532,7 +497,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val taskContext = TaskContext.empty() val iterator = createShuffleBlockIteratorWithDefaults( - getBlocksByAddressForSingleBM(remoteBmId, blocks.keys, 1L, 0), + Map(remoteBmId -> (blocks.keys, 1L, 0)), taskContext = Some(taskContext) ) @@ -581,7 +546,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } val iterator = createShuffleBlockIteratorWithDefaults( - getBlocksByAddressForSingleBM(remoteBmId, blocks.keys, 1L, 0) + Map(remoteBmId -> (blocks.keys, 1L, 0)) ) // Continue only after the mock calls onBlockFetchFailure @@ -650,7 +615,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } val iterator = createShuffleBlockIteratorWithDefaults( - getBlocksByAddressForSingleBM(remoteBmId, blocks.keys, 1L, 0), + Map(remoteBmId ->(blocks.keys, 1L, 0)), streamWrapperLimitSize = Some(100) ) @@ -693,12 +658,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT configureMockTransfer( Map(shuffleBlockId1 -> corruptBuffer1, shuffleBlockId2 -> corruptBuffer2)) - val blocksByAddress = getBlocksByAddress(Map( - blockManagerId1 -> (Seq(shuffleBlockId1), corruptBuffer1.size(), 1), - blockManagerId2 -> (Seq(shuffleBlockId2), corruptBuffer2.size(), 2) - )) val iterator = createShuffleBlockIteratorWithDefaults( - blocksByAddress, + Map( + blockManagerId1 -> (Seq(shuffleBlockId1), corruptBuffer1.size(), 1), + blockManagerId2 -> (Seq(shuffleBlockId2), corruptBuffer2.size(), 2) + ), streamWrapperLimitSize = Some(streamLength), maxBytesInFlight = 3 * 1024 ) @@ -747,7 +711,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT configureMockTransfer(Map(ShuffleBlockId(0, 0, 0) -> managedBuffer)) val iterator = createShuffleBlockIteratorWithDefaults( - getBlocksByAddressForSingleBM(localBmId, Seq(ShuffleBlockId(0, 0, 0)), 10000L, 0), + Map(localBmId -> (Seq(ShuffleBlockId(0, 0, 0)), 10000L, 0)), blockManager = Some(blockManager), streamWrapperLimitSize = Some(10000), maxBytesInFlight = 2048 // force concatenation of stream by limiting bytes in flight @@ -791,7 +755,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } val iterator = createShuffleBlockIteratorWithDefaults( - getBlocksByAddressForSingleBM(remoteBmId, blocks.keys, 1L, 0), + Map(remoteBmId -> (blocks.keys, 1L, 0)), streamWrapperLimitSize = Some(100), detectCorruptUseExtraMemory = false ) @@ -837,7 +801,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. createShuffleBlockIteratorWithDefaults( - getBlocksByAddressForSingleBM(remoteBmId, remoteBlocks.keys, blockSize, 0), + Map(remoteBmId -> (remoteBlocks.keys, blockSize, 0)), blockManager = Some(blockManager), maxReqSizeShuffleToMem = 200) } @@ -864,7 +828,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT configureMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0)).toMap) val iterator = createShuffleBlockIteratorWithDefaults( - getBlocksByAddressForSingleBM(remoteBmId, blocks.keys, 1L, 0) + Map(remoteBmId ->(blocks.keys, 1L, 0)) ) // All blocks fetched return zero length and should trigger a receive-side error: From cc4f3b47636a303b7a50e9a8d6673ad1dbe6f377 Mon Sep 17 00:00:00 2001 From: Erik Krogen Date: Mon, 17 May 2021 10:57:21 -0700 Subject: [PATCH 07/11] Address Mridul's comments --- .../spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 6b06eec97c17..d0252a95504d 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -153,9 +153,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT blocksByAddress.map { case (blockManagerId, (blocks, blockSize, blockMapIndex)) => (blockManagerId, blocks.map(blockId => (blockId, blockSize, blockMapIndex)).toSeq) }.toIterator, - streamWrapperLimitSize - .map(limit => (_: BlockId, in: InputStream) => new LimitedInputStream(in, limit)) - .getOrElse((_: BlockId, in: InputStream) => in), + (_, in) => new LimitedInputStream(in, streamWrapperLimitSize.getOrElse(Long.MaxValue)), maxBytesInFlight, maxReqsInFlight, maxBlocksInFlightPerAddress, @@ -615,7 +613,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } val iterator = createShuffleBlockIteratorWithDefaults( - Map(remoteBmId ->(blocks.keys, 1L, 0)), + Map(remoteBmId -> (blocks.keys, 1L, 0)), streamWrapperLimitSize = Some(100) ) @@ -828,7 +826,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT configureMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0)).toMap) val iterator = createShuffleBlockIteratorWithDefaults( - Map(remoteBmId ->(blocks.keys, 1L, 0)) + Map(remoteBmId -> (blocks.keys, 1L, 0)) ) // All blocks fetched return zero length and should trigger a receive-side error: From b29dc14839bea3a80980e842ece80473ffa8b547 Mon Sep 17 00:00:00 2001 From: Erik Krogen Date: Mon, 17 May 2021 11:04:15 -0700 Subject: [PATCH 08/11] Address Mridul's comments pt 2: Make block list passing more general --- .../ShuffleBlockFetcherIteratorSuite.scala | 59 +++++++++++-------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index d0252a95504d..15ad9daf32e2 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -133,7 +133,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // scalastyle:off argcount private def createShuffleBlockIteratorWithDefaults( - blocksByAddress: Map[BlockManagerId, (Traversable[BlockId], Long, Int)], + blocksByAddress: Map[BlockManagerId, Seq[(BlockId, Long, Int)]], taskContext: Option[TaskContext] = None, streamWrapperLimitSize: Option[Long] = None, blockManager: Option[BlockManager] = None, @@ -150,9 +150,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT tContext, transfer, blockManager.getOrElse(createMockBlockManager()), - blocksByAddress.map { case (blockManagerId, (blocks, blockSize, blockMapIndex)) => - (blockManagerId, blocks.map(blockId => (blockId, blockSize, blockMapIndex)).toSeq) - }.toIterator, + blocksByAddress.toIterator, (_, in) => new LimitedInputStream(in, streamWrapperLimitSize.getOrElse(Long.MaxValue)), maxBytesInFlight, maxReqsInFlight, @@ -165,6 +163,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } // scalastyle:on argcount + /** + * Convert a list of block IDs into a list of blocks with metadata, assuming all blocks have the + * same size and index. + */ + private def toBlockList(blockIds: Traversable[BlockId], blockSize: Long, blockMapIndex: Int) + : Seq[(BlockId, Long, Int)] = { + blockIds.map(blockId => (blockId, blockSize, blockMapIndex)).toSeq + } + test("successful 3 local + 4 host local + 2 remote reads") { val blockManager = createMockBlockManager() val localBmId = blockManager.blockManagerId @@ -201,9 +208,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val iterator = createShuffleBlockIteratorWithDefaults( Map( - localBmId -> (localBlocks.keys, 1L, 0), - remoteBmId -> (remoteBlocks.keys, 1L, 1), - hostLocalBmId -> (hostLocalBlocks.keys, 1L, 1) + localBmId -> toBlockList(localBlocks.keys, 1L, 0), + remoteBmId -> toBlockList(remoteBlocks.keys, 1L, 1), + hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L, 1) ), blockManager = Some(blockManager) ) @@ -260,7 +267,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT configureMockTransfer(Map()) val iterator = createShuffleBlockIteratorWithDefaults( - Map(hostLocalBmId -> (hostLocalBlocks.keys, 1L, 1)) + Map(hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L, 1)) ) intercept[FetchFailedException] { iterator.next() } } @@ -274,8 +281,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT blockId1 -> createMockManagedBuffer(1000), blockId2 -> createMockManagedBuffer(1000))) val iterator = createShuffleBlockIteratorWithDefaults(Map( - remoteBmId1 -> (Seq(blockId1), 1000L, 0), - remoteBmId2 -> (Seq(blockId2), 1000L, 0) + remoteBmId1 -> toBlockList(Seq(blockId1), 1000L, 0), + remoteBmId2 -> toBlockList(Seq(blockId2), 1000L, 0) ), maxBytesInFlight = 1000L) // After initialize() we'll have 2 FetchRequests and each is 1000 bytes. So only the // first FetchRequests can be sent, and the second one will hit maxBytesInFlight so @@ -296,7 +303,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocks = 1.to(3).map(ShuffleBlockId(0, _, 0)) configureMockTransfer(blocks.map(_ -> createMockManagedBuffer()).toMap) val iterator = createShuffleBlockIteratorWithDefaults( - Map(remoteBmId -> (blocks, 1000L, 0)), + Map(remoteBmId -> toBlockList(blocks, 1000L, 0)), maxBlocksInFlightPerAddress = 2 ) // After initialize(), we'll have 2 FetchRequests that one has 2 blocks inside and another one @@ -359,9 +366,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val iterator = createShuffleBlockIteratorWithDefaults( Map( - localBmId -> (localBlocks, 1L, 0), - remoteBmId -> (remoteBlocks, 1L, 1), - hostLocalBmId -> (hostLocalBlocks.keys, 1L, 1) + localBmId -> toBlockList(localBlocks, 1L, 0), + remoteBmId -> toBlockList(remoteBlocks, 1L, 1), + hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L, 1) ), blockManager = Some(blockManager), doBatchFetch = true @@ -404,8 +411,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val iterator = createShuffleBlockIteratorWithDefaults( Map( - remoteBmId1 -> (remoteBlocks1, 100L, 1), - remoteBmId2 -> (remoteBlocks2, 100L, 1) + remoteBmId1 -> toBlockList(remoteBlocks1, 100L, 1), + remoteBmId2 -> toBlockList(remoteBlocks2, 100L, 1) ), maxBytesInFlight = 1500, doBatchFetch = true @@ -446,7 +453,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT configureMockTransfer(mergedRemoteBlocks) val iterator = createShuffleBlockIteratorWithDefaults( - Map(remoteBmId -> (remoteBlocks, 100L, 1)), + Map(remoteBmId -> toBlockList(remoteBlocks, 100L, 1)), maxBlocksInFlightPerAddress = 2, doBatchFetch = true ) @@ -495,7 +502,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val taskContext = TaskContext.empty() val iterator = createShuffleBlockIteratorWithDefaults( - Map(remoteBmId -> (blocks.keys, 1L, 0)), + Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)), taskContext = Some(taskContext) ) @@ -544,7 +551,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } val iterator = createShuffleBlockIteratorWithDefaults( - Map(remoteBmId -> (blocks.keys, 1L, 0)) + Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)) ) // Continue only after the mock calls onBlockFetchFailure @@ -613,7 +620,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } val iterator = createShuffleBlockIteratorWithDefaults( - Map(remoteBmId -> (blocks.keys, 1L, 0)), + Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)), streamWrapperLimitSize = Some(100) ) @@ -658,8 +665,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Map(shuffleBlockId1 -> corruptBuffer1, shuffleBlockId2 -> corruptBuffer2)) val iterator = createShuffleBlockIteratorWithDefaults( Map( - blockManagerId1 -> (Seq(shuffleBlockId1), corruptBuffer1.size(), 1), - blockManagerId2 -> (Seq(shuffleBlockId2), corruptBuffer2.size(), 2) + blockManagerId1 -> toBlockList(Seq(shuffleBlockId1), corruptBuffer1.size(), 1), + blockManagerId2 -> toBlockList(Seq(shuffleBlockId2), corruptBuffer2.size(), 2) ), streamWrapperLimitSize = Some(streamLength), maxBytesInFlight = 3 * 1024 @@ -709,7 +716,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT configureMockTransfer(Map(ShuffleBlockId(0, 0, 0) -> managedBuffer)) val iterator = createShuffleBlockIteratorWithDefaults( - Map(localBmId -> (Seq(ShuffleBlockId(0, 0, 0)), 10000L, 0)), + Map(localBmId -> toBlockList(Seq(ShuffleBlockId(0, 0, 0)), 10000L, 0)), blockManager = Some(blockManager), streamWrapperLimitSize = Some(10000), maxBytesInFlight = 2048 // force concatenation of stream by limiting bytes in flight @@ -753,7 +760,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } val iterator = createShuffleBlockIteratorWithDefaults( - Map(remoteBmId -> (blocks.keys, 1L, 0)), + Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)), streamWrapperLimitSize = Some(100), detectCorruptUseExtraMemory = false ) @@ -799,7 +806,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. createShuffleBlockIteratorWithDefaults( - Map(remoteBmId -> (remoteBlocks.keys, blockSize, 0)), + Map(remoteBmId -> toBlockList(remoteBlocks.keys, blockSize, 0)), blockManager = Some(blockManager), maxReqSizeShuffleToMem = 200) } @@ -826,7 +833,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT configureMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0)).toMap) val iterator = createShuffleBlockIteratorWithDefaults( - Map(remoteBmId -> (blocks.keys, 1L, 0)) + Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)) ) // All blocks fetched return zero length and should trigger a receive-side error: From 9a60603881d678318436b672c1bf9ea887956526 Mon Sep 17 00:00:00 2001 From: Erik Krogen Date: Mon, 17 May 2021 11:23:51 -0700 Subject: [PATCH 09/11] Revert part of the changes to address test failure --- .../spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 15ad9daf32e2..774e5ab91659 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -151,7 +151,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager.getOrElse(createMockBlockManager()), blocksByAddress.toIterator, - (_, in) => new LimitedInputStream(in, streamWrapperLimitSize.getOrElse(Long.MaxValue)), + streamWrapperLimitSize + .map(limit => (_: BlockId, in: InputStream) => new LimitedInputStream(in, limit)) + .getOrElse((_: BlockId, in: InputStream) => in), maxBytesInFlight, maxReqsInFlight, maxBlocksInFlightPerAddress, From e3ccd5167f6ad63c9e96831d06baa6371b6d0217 Mon Sep 17 00:00:00 2001 From: Erik Krogen Date: Mon, 17 May 2021 11:26:25 -0700 Subject: [PATCH 10/11] new simplification for the input stream wrapper that doesn't break any tests --- .../spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 774e5ab91659..ffbd7aba5271 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -151,9 +151,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer, blockManager.getOrElse(createMockBlockManager()), blocksByAddress.toIterator, - streamWrapperLimitSize - .map(limit => (_: BlockId, in: InputStream) => new LimitedInputStream(in, limit)) - .getOrElse((_: BlockId, in: InputStream) => in), + (_, in) => streamWrapperLimitSize.map(new LimitedInputStream(in, _)).getOrElse(in), maxBytesInFlight, maxReqsInFlight, maxBlocksInFlightPerAddress, From eea80f5c8c487a7117036b7cfb963137a4b7eeb5 Mon Sep 17 00:00:00 2001 From: Erik Krogen Date: Tue, 18 May 2021 10:14:24 -0700 Subject: [PATCH 11/11] minor refactor --- .../spark/storage/ShuffleBlockFetcherIteratorSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index ffbd7aba5271..4be5faea4b25 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -167,8 +167,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT * Convert a list of block IDs into a list of blocks with metadata, assuming all blocks have the * same size and index. */ - private def toBlockList(blockIds: Traversable[BlockId], blockSize: Long, blockMapIndex: Int) - : Seq[(BlockId, Long, Int)] = { + private def toBlockList( + blockIds: Traversable[BlockId], + blockSize: Long, + blockMapIndex: Int): Seq[(BlockId, Long, Int)] = { blockIds.map(blockId => (blockId, blockSize, blockMapIndex)).toSeq }