Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,25 @@ final class ShuffleBlockFetcherIterator(
logDebug("Number of requests in flight " + reqsInFlight)
}

if (buf.size == 0) {
// We will never legitimately receive a zero-size block. All blocks with zero records
// have zero size and all zero-size blocks have no records (and hence should never
// have been requested in the first place). This statement relies on behaviors of the
// shuffle writers, which are guaranteed by the following test cases:
//
// - BypassMergeSortShuffleWriterSuite: "write with some empty partitions"
// - UnsafeShuffleWriterSuite: "writeEmptyIterator"
// - DiskBlockObjectWriterSuite: "commit() and close() without ever opening or writing"
//
// There is not an explicit test for SortShuffleWriter but the underlying APIs that
// uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter
// which returns a zero-size from commitAndGet() in case no records were written
// since the last call.
val msg = s"Received a zero-size buffer for block $blockId from $address " +
s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)"
throwFetchFailedException(blockId, address, new IOException(msg))
}

val in = try {
buf.createInputStream()
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
}

// Create a mock managed buffer for testing
def createMockManagedBuffer(): ManagedBuffer = {
def createMockManagedBuffer(size: Int = 1): ManagedBuffer = {
val mockManagedBuffer = mock(classOf[ManagedBuffer])
val in = mock(classOf[InputStream])
when(in.read(any())).thenReturn(1)
when(in.read(any(), any(), any())).thenReturn(1)
when(mockManagedBuffer.createInputStream()).thenReturn(in)
when(mockManagedBuffer.size()).thenReturn(size)
mockManagedBuffer
}

Expand Down Expand Up @@ -269,6 +270,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
intercept[FetchFailedException] { iterator.next() }
}

private def mockCorruptBuffer(size: Long = 1L): ManagedBuffer = {
val corruptStream = mock(classOf[InputStream])
when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
val corruptBuffer = mock(classOf[ManagedBuffer])
when(corruptBuffer.size()).thenReturn(size)
when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
corruptBuffer
}

test("retry corrupt blocks") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
Expand All @@ -284,11 +294,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT

// Semaphore to coordinate event sequence in two different threads.
val sem = new Semaphore(0)

val corruptStream = mock(classOf[InputStream])
when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
val corruptBuffer = mock(classOf[ManagedBuffer])
when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100)

val transfer = mock(classOf[BlockTransferService])
Expand All @@ -301,7 +306,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer)
sem.release()
Expand Down Expand Up @@ -339,7 +344,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
Future {
// Return the first block, and then fail.
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
sem.release()
}
}
Expand All @@ -353,11 +358,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
}

test("big blocks are not checked for corruption") {
val corruptStream = mock(classOf[InputStream])
when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
val corruptBuffer = mock(classOf[ManagedBuffer])
when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
doReturn(10000L).when(corruptBuffer).size()
val corruptBuffer = mockCorruptBuffer(10000L)

val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
Expand Down Expand Up @@ -413,11 +414,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
// Semaphore to coordinate event sequence in two different threads.
val sem = new Semaphore(0)

val corruptStream = mock(classOf[InputStream])
when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
val corruptBuffer = mock(classOf[ManagedBuffer])
when(corruptBuffer.createInputStream()).thenReturn(corruptStream)

val transfer = mock(classOf[BlockTransferService])
when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
.thenAnswer(new Answer[Unit] {
Expand All @@ -428,9 +424,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
listener.onBlockFetchSuccess(
ShuffleBlockId(0, 2, 0).toString, corruptBuffer)
ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer())
sem.release()
}
}
Expand Down Expand Up @@ -526,4 +522,39 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
// shuffle block to disk.
assert(tempFileManager != null)
}

test("fail zero-size blocks") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
doReturn(localBmId).when(blockManager).blockManagerId

// Make sure remote blocks would return
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
val blocks = Map[BlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer()
)

val transfer = createMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0)))

val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
(remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))

val taskContext = TaskContext.empty()
val iterator = new ShuffleBlockFetcherIterator(
taskContext,
transfer,
blockManager,
blocksByAddress,
(_, in) => in,
48 * 1024 * 1024,
Int.MaxValue,
Int.MaxValue,
Int.MaxValue,
true)

// All blocks fetched return zero length and should trigger a receive-side error:
val e = intercept[FetchFailedException] { iterator.next() }
assert(e.getMessage.contains("Received a zero-size buffer"))
}
}