@@ -36,6 +36,7 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
3636import org .apache .spark .network .shuffle .BlockFetchingListener
3737import org .apache .spark .network .util .LimitedInputStream
3838import org .apache .spark .shuffle .FetchFailedException
39+ import org .apache .spark .util .Utils
3940
4041
4142class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester {
@@ -420,9 +421,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
420421 doReturn(localBmId).when(blockManager).blockManagerId
421422
422423 val diskBlockManager = mock(classOf [DiskBlockManager ])
424+ val tmpDir = Utils .createTempDir()
423425 doReturn{
424- var blockId = new TempLocalBlockId (UUID .randomUUID())
425- (blockId, new File (blockId.name))
426+ val blockId = TempLocalBlockId (UUID .randomUUID())
427+ (blockId, new File (tmpDir, blockId.name))
426428 }.when(diskBlockManager).createTempLocalBlock()
427429 doReturn(diskBlockManager).when(blockManager).diskBlockManager
428430
@@ -443,34 +445,34 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
443445 }
444446 })
445447
448+ def fetchShuffleBlock (blocksByAddress : Seq [(BlockManagerId , Seq [(BlockId , Long )])]): Unit = {
449+ // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the
450+ // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks
451+ // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here.
452+ new ShuffleBlockFetcherIterator (
453+ TaskContext .empty(),
454+ transfer,
455+ blockManager,
456+ blocksByAddress,
457+ (_, in) => in,
458+ maxBytesInFlight = Int .MaxValue ,
459+ maxReqsInFlight = Int .MaxValue ,
460+ maxReqSizeShuffleToMem = 200 ,
461+ detectCorrupt = true )
462+ }
463+
446464 val blocksByAddress1 = Seq [(BlockManagerId , Seq [(BlockId , Long )])](
447465 (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L )).toSeq))
448- // Set maxReqSizeShuffleToMem to be 200.
449- val iterator1 = new ShuffleBlockFetcherIterator (
450- TaskContext .empty(),
451- transfer,
452- blockManager,
453- blocksByAddress1,
454- (_, in) => in,
455- Int .MaxValue ,
456- Int .MaxValue ,
457- 200 ,
458- true )
466+ fetchShuffleBlock(blocksByAddress1)
467+ // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch
468+ // shuffle block to disk.
459469 assert(shuffleFiles === null )
460470
461471 val blocksByAddress2 = Seq [(BlockManagerId , Seq [(BlockId , Long )])](
462472 (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L )).toSeq))
463- // Set maxReqSizeShuffleToMem to be 200.
464- val iterator2 = new ShuffleBlockFetcherIterator (
465- TaskContext .empty(),
466- transfer,
467- blockManager,
468- blocksByAddress2,
469- (_, in) => in,
470- Int .MaxValue ,
471- Int .MaxValue ,
472- 200 ,
473- true )
473+ fetchShuffleBlock(blocksByAddress2)
474+ // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch
475+ // shuffle block to disk.
474476 assert(shuffleFiles != null )
475477 }
476478}
0 commit comments