Skip to content

Commit 86ca915

Browse files
jx158167cloud-fan
authored andcommitted
[SPARK-23524] Big local shuffle blocks should not be checked for corruption.
## What changes were proposed in this pull request? In current code, all local blocks will be checked for corruption no matter it's big or not. The reasons are as below: Size in FetchResult for local block is set to be 0 (https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala#L327) SPARK-4105 meant to only check the small blocks(size<maxBytesInFlight/3), but for reason 1, below check will be invalid. https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala#L420 We can fix this and avoid the OOM. ## How was this patch tested? UT added Author: jx158167 <[email protected]> Closes apache#20685 from jinxing64/SPARK-23524. (cherry picked from commit 77c91cc) Signed-off-by: Wenchen Fan <[email protected]>
1 parent ee6e797 commit 86ca915

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ final class ShuffleBlockFetcherIterator(
9090
private[this] val startTime = System.currentTimeMillis
9191

9292
/** Local blocks to fetch, excluding zero-sized blocks. */
93-
private[this] val localBlocks = new ArrayBuffer[BlockId]()
93+
private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[BlockId]()
9494

9595
/** Remote blocks to fetch, excluding zero-sized blocks. */
9696
private[this] val remoteBlocks = new HashSet[BlockId]()
@@ -316,6 +316,7 @@ final class ShuffleBlockFetcherIterator(
316316
* track in-memory are the ManagedBuffer references themselves.
317317
*/
318318
private[this] def fetchLocalBlocks() {
319+
logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}")
319320
val iter = localBlocks.iterator
320321
while (iter.hasNext) {
321322
val blockId = iter.next()
@@ -324,7 +325,8 @@ final class ShuffleBlockFetcherIterator(
324325
shuffleMetrics.incLocalBlocksFetched(1)
325326
shuffleMetrics.incLocalBytesRead(buf.size)
326327
buf.retain()
327-
results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false))
328+
results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId,
329+
buf.size(), buf, false))
328330
} catch {
329331
case e: Exception =>
330332
// If we see an exception, stop immediately.
@@ -397,7 +399,9 @@ final class ShuffleBlockFetcherIterator(
397399
}
398400
shuffleMetrics.incRemoteBlocksFetched(1)
399401
}
400-
bytesInFlight -= size
402+
if (!localBlocks.contains(blockId)) {
403+
bytesInFlight -= size
404+
}
401405
if (isNetworkReqDone) {
402406
reqsInFlight -= 1
403407
logDebug("Number of requests in flight " + reqsInFlight)
@@ -583,8 +587,8 @@ object ShuffleBlockFetcherIterator {
583587
* Result of a fetch from a remote block successfully.
584588
* @param blockId block id
585589
* @param address BlockManager that the block was fetched from.
586-
* @param size estimated size of the block, used to calculate bytesInFlight.
587-
* Note that this is NOT the exact bytes.
590+
* @param size estimated size of the block. Note that this is NOT the exact bytes.
591+
* Size of remote block is used to calculate bytesInFlight.
588592
* @param buf `ManagedBuffer` for the content.
589593
* @param isNetworkReqDone Is this the last network request for this host in this fetch request.
590594
*/

core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,51 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
352352
intercept[FetchFailedException] { iterator.next() }
353353
}
354354

355+
test("big blocks are not checked for corruption") {
356+
val corruptStream = mock(classOf[InputStream])
357+
when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
358+
val corruptBuffer = mock(classOf[ManagedBuffer])
359+
when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
360+
doReturn(10000L).when(corruptBuffer).size()
361+
362+
val blockManager = mock(classOf[BlockManager])
363+
val localBmId = BlockManagerId("test-client", "test-client", 1)
364+
doReturn(localBmId).when(blockManager).blockManagerId
365+
doReturn(corruptBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0))
366+
val localBlockLengths = Seq[Tuple2[BlockId, Long]](
367+
ShuffleBlockId(0, 0, 0) -> corruptBuffer.size()
368+
)
369+
370+
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
371+
val remoteBlockLengths = Seq[Tuple2[BlockId, Long]](
372+
ShuffleBlockId(0, 1, 0) -> corruptBuffer.size()
373+
)
374+
375+
val transfer = createMockTransfer(
376+
Map(ShuffleBlockId(0, 0, 0) -> corruptBuffer, ShuffleBlockId(0, 1, 0) -> corruptBuffer))
377+
378+
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
379+
(localBmId, localBlockLengths),
380+
(remoteBmId, remoteBlockLengths)
381+
)
382+
383+
val taskContext = TaskContext.empty()
384+
val iterator = new ShuffleBlockFetcherIterator(
385+
taskContext,
386+
transfer,
387+
blockManager,
388+
blocksByAddress,
389+
(_, in) => new LimitedInputStream(in, 10000),
390+
2048,
391+
Int.MaxValue,
392+
Int.MaxValue,
393+
Int.MaxValue,
394+
true)
395+
// Blocks should be returned without exceptions.
396+
assert(Set(iterator.next()._1, iterator.next()._1) ===
397+
Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0)))
398+
}
399+
355400
test("retry corrupt blocks (disabled)") {
356401
val blockManager = mock(classOf[BlockManager])
357402
val localBmId = BlockManagerId("test-client", "test-client", 1)

0 commit comments

Comments
 (0)