Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
3871877
SPARK-2799
Ngone51 Apr 22, 2021
6b37944
Add comment
Ngone51 Apr 22, 2021
2cc07e2
add new line for NettyOutOfMemoryError
Ngone51 Apr 23, 2021
416bdaa
use AtomicBoolean
Ngone51 Apr 23, 2021
15426b5
unset in finally
Ngone51 Apr 23, 2021
0863419
revert FailureFetchResult
Ngone51 Apr 23, 2021
04dff9f
add DeferFetchResult
Ngone51 Apr 23, 2021
88d6d81
rename to TestNettyOutOfMemoryError
Ngone51 Apr 23, 2021
95aae3b
avoid endless retry
Ngone51 Apr 23, 2021
5ecb6dc
fix fmt
Ngone51 Apr 23, 2021
857f305
use reflection to create OutOfDirectMemoryError
Ngone51 Apr 27, 2021
e3d552c
remove random
Ngone51 Apr 27, 2021
92d4ee5
set isNettyOOMOnShuffle to flase when inflight requests = 0
Ngone51 Apr 27, 2021
254fa7c
move isNettyOOMOnShuffle from netty utils to shuffleblockfetchiterator
Ngone51 Apr 27, 2021
dc01c85
defer blocks in batch
Ngone51 Apr 27, 2021
466966c
fix thread safety
Ngone51 Apr 29, 2021
7bc1b93
fix build
Ngone51 Apr 29, 2021
1460299
fix build
Ngone51 Apr 29, 2021
b64482e
fix /zero
Ngone51 Apr 29, 2021
7fc9a0a
fix with request.size
Ngone51 Apr 29, 2021
bff1ce3
fix scala2.13 error
Ngone51 Apr 29, 2021
b5dc374
add isZombie
Ngone51 Apr 29, 2021
77e43ba
remove ;
Ngone51 Apr 29, 2021
5538c33
add resetOOMFlatIfPossible
Ngone51 Apr 29, 2021
959dd8d
add comment
Ngone51 Apr 29, 2021
90aa9ea
unset the flag when the free memory > 200m
Ngone51 Apr 30, 2021
22a441f
add test
Ngone51 Apr 30, 2021
f78b906
fail fast when netty memory less than max in-mem block
Ngone51 Apr 30, 2021
7ac87d8
use the configured
Ngone51 Apr 30, 2021
4aaf04f
update comment
Ngone51 Apr 30, 2021
298810d
revert testing code
Ngone51 Apr 30, 2021
f334557
add internal conf for max attempts
Ngone51 May 3, 2021
48deaed
fix typo and loginfo
Ngone51 May 3, 2021
3611378
ensure thread safe of blockOOMRetryTimes
Ngone51 May 3, 2021
8b74473
failfast when netty memory < maxRemoteBlockSizeFetchToMem
Ngone51 May 7, 2021
353e681
address comment
Ngone51 May 7, 2021
2b4bfd0
move maxAttemptsOnNettyOOM to constructor
Ngone51 May 8, 2021
c9cc13f
fix fmt
Ngone51 May 10, 2021
28201b6
rename conf
Ngone51 May 10, 2021
9e877d5
use Option[String]
Ngone51 May 10, 2021
7affa68
update comment
Ngone51 May 10, 2021
08e5f61
rename to blockOOMRetryCounts
Ngone51 May 10, 2021
4af6ee7
rebase
Ngone51 May 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ public class NettyUtils {
private static final PooledByteBufAllocator[] _sharedPooledByteBufAllocator =
new PooledByteBufAllocator[2];

public static long freeDirectMemory() {
return PlatformDependent.maxDirectMemory() - PlatformDependent.usedDirectMemory();
}

/** Creates a new ThreadFactory which prefixes each thread with the given name. */
public static ThreadFactory createThreadFactory(String threadPoolPrefix) {
return new DefaultThreadFactory(threadPoolPrefix, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.collection.mutable
import scala.util.{Failure, Success}
import scala.util.control.NonFatal

import io.netty.util.internal.PlatformDependent
import org.json4s.DefaultFormats

import org.apache.spark._
Expand Down Expand Up @@ -90,6 +91,14 @@ private[spark] class CoarseGrainedExecutorBackend(

logInfo("Connecting to driver: " + driverUrl)
try {
if (PlatformDependent.directBufferPreferred() &&
PlatformDependent.maxDirectMemory() < env.conf.get(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)) {
throw new SparkException(s"Netty direct memory should at least be bigger than " +
s"'${MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM.key}', but got " +
s"${PlatformDependent.maxDirectMemory()} bytes < " +
s"${env.conf.get(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)}")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to write a unit test for this change but realized it's not easy while I was trying. So I gave it up. But I have done the manual test and looks fine. cc @mridulm

}

_resources = parseOrFindResources(resourcesFileOpt)
} catch {
case NonFatal(e) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,15 @@ package object config {
.intConf
.createWithDefault(3)

private[spark] val SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM =
ConfigBuilder("spark.shuffle.maxAttemptsOnNettyOOM")
.doc("The max attempts of a shuffle block would retry on Netty OOM issue before throwing " +
"the shuffle fetch failure.")
.version("3.2.0")
.internal()
.intConf
.createWithDefault(10)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question. Is there a reason of 10 instead of 3?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the discussion (#32287 (comment)) there, Netty OOM could be raised more frequently in certain cases, e.g.,

For case b), the OOM threshold might be 20 requests. In this case, there're still 80 deferred requests, which would hit the OOM soon as you mentioned. That being said, I think the current fix would work around the issue in the end. Note that the application would fail before the fix.

Thus, I'd like to give more chances for the block in case we fall into the case like b).


private[spark] val REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS =
ConfigBuilder("spark.reducer.maxBlocksInFlightPerAddress")
.doc("This configuration limits the number of remote blocks being fetched per reduce task " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
readMetrics,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,21 @@ package org.apache.spark.storage
import java.io.{InputStream, IOException}
import java.nio.channels.ClosedByInterruptException
import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
import java.util.concurrent.atomic.AtomicBoolean
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, LinkedHashMap, Queue}
import scala.util.{Failure, Success}

import io.netty.util.internal.OutOfDirectMemoryError
import org.apache.commons.io.IOUtils

import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle._
import org.apache.spark.network.util.TransportConf
import org.apache.spark.network.util.{NettyUtils, TransportConf}
import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter}
import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils}

Expand Down Expand Up @@ -61,6 +63,8 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils}
* @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point
* for a given remote host:port.
* @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory.
* @param maxAttemptsOnNettyOOM The max number of a block could retry due to Netty OOM before
* throwing the fetch failure.
* @param detectCorrupt whether to detect any corruption in fetched blocks.
* @param shuffleMetrics used to report shuffle metrics.
* @param doBatchFetch fetch continuous shuffle blocks from same executor in batch if the server
Expand All @@ -76,7 +80,8 @@ final class ShuffleBlockFetcherIterator(
maxBytesInFlight: Long,
maxReqsInFlight: Int,
maxBlocksInFlightPerAddress: Int,
maxReqSizeShuffleToMem: Long,
val maxReqSizeShuffleToMem: Long,
maxAttemptsOnNettyOOM: Int,
detectCorrupt: Boolean,
detectCorruptUseExtraMemory: Boolean,
shuffleMetrics: ShuffleReadMetricsReporter,
Expand Down Expand Up @@ -146,6 +151,12 @@ final class ShuffleBlockFetcherIterator(
/** Current number of blocks in flight per host:port */
private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]()

/**
* Count the retry times for the blocks due to Netty OOM. The block will stop retry if
* retry times has exceeded the [[maxAttemptsOnNettyOOM]].
*/
private[this] val blockOOMRetryCounts = new HashMap[String, Int]

/**
* The blocks that can't be decompressed successfully, it is used to guarantee that we retry
* at most once for those corrupted blocks.
Expand Down Expand Up @@ -245,9 +256,21 @@ final class ShuffleBlockFetcherIterator(
case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex))
}.toMap
val remainingBlocks = new HashSet[String]() ++= infoMap.keys
val deferredBlocks = new ArrayBuffer[String]()
val blockIds = req.blocks.map(_.blockId.toString)
val address = req.address

@inline def enqueueDeferredFetchRequestIfNecessary(): Unit = {
if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) {
val blocks = deferredBlocks.map { blockId =>
val (size, mapIndex) = infoMap(blockId)
FetchBlockInfo(BlockId(blockId), size, mapIndex)
}
results.put(DeferFetchRequestResult(FetchRequest(address, blocks.toSeq)))
deferredBlocks.clear()
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will let @otterc elaborate more; we had discussed whether we can minimize direct memory load in fetchUpToMaxBytes due to all 'expensive' deferred blocks being sent out (mostly) together (modulo constraints on maxBlocksInFlightPerAddress/maxBytesInFlight).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about the "direct memory load" you mean here. I'll wait for @otterc's explanation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related to the conversation here:
#32287 (comment)
We were discussing simple ways to reduce the number of remote fetch requests after the OOM. One such thing could be that after the OOM, we just sent the out the requests that were deferred due to the OOM and not send any additional requests.
I am not sure how effective is this going to be though. Since the in-flight limit remains the same, the next call to fetchUpToMaxBytes when it sends out non-deferred requests can cause new blocks to OOM.

Another simple way could be to modify isRemoteBlockFetchable such that after this iterator has seen an OOM, it will also check bytesInFlight + fetchReqQueue.front.size < freeDirectorMemory?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So after #32287 (comment), I have changed the unset condition to freeDirectorMemory > maxReqSizeShuffleToMem (200M by default), which I think is already very strict. So, it should avoid the issue you mentioned in #32287 (comment).

Another simple way could be to modify isRemoteBlockFetchable such that after this iterator has seen an OOM, it will also check bytesInFlight + fetchReqQueue.front.size < freeDirectorMemory?

Do you mean check bytesInFlight + fetchReqQueue.front.size < freeDirectorMemory for all the cases or only when isNettyOOMOnShuffle=true? If you also want to check when isNettyOOMOnShuffle=false, I'd like to mention that block size is not equal to the consumed memory size of Netty. As you know, blocks that bigger than maxReqSizeShuffleToMem would be stored on disk.

Copy link
Contributor

@otterc otterc May 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed the unset condition to freeDirectorMemory > maxReqSizeShuffleToMem (200M by default), which I think is already very strict. So, it should avoid the issue you mentioned in #32287 (comment).

I don't think it will avoid the issue. This adds more time when the next set of requests are going to be set. However, the next set of requests (including the deferred ones) will still be sent at same frequency, so some of them would again see OOMs.

Do you mean check bytesInFlight + fetchReqQueue.front.size < freeDirectorMemory for all the cases or only when isNettyOOMOnShuffle=true?

I meant that if once this OOM is encountered, after that the iterator checks against freeDirectorMemory as well. If a request.size > maxReqSizeShuffleToMem then we can skip the check on it as that is stored on disk.

Copy link
Member Author

@Ngone51 Ngone51 May 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are two different cases here. For example, assuming there're 100 requests.

For case a), the OOM threshold might be 80 requests. In this case, after OOMed, the deferred 20 requests shouldn't hit the OOM again.

For case b), the OOM threshold might be 20 requests. In this case, there're still 80 deferred requests, which would hit the OOM soon as you mentioned. That being said, I think the current fix would work around the issue in the end. Note that the application would fail before the fix.

To improve the current fix further, I think we can do it in a separate PR as I think it's not an easy thing[1] to do (or do you have any other ideas?) and must require more discussion. WDTY?

  1. Even if we skip the case of request > maxReqSizeShuffleToMem, note that the in-memory request size is not strictly equal to the consumed memory size in Netty due to Netty's memory management mechanism (https://www.facebook.com/notes/facebook-engineering/scalable-memory-allocation-using-jemalloc/480222803919). For example, Netty may allocate 16MB for a 9MB block.
    And there could be multiple tasks fetching concurrently. So we may need to track the total bytesInFlight of all tasks rather than the single task itself. And it would require more synchronization among tasks and make the thing more complex.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It sounds like a good idea to dynamically tune the request frequency, but this doesn't seem like trivial work. I'd vote for doing it in a follow-up with more discussions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If nontrivial, we can definitely push it to a later effort - can you file a jira for it @Ngone51 so that we can track it (and perhaps someone can pick it up later ?).
This PR is already a marked improvement over what we currently have w.r.t OOM at executor - as long as ESS load does not go up :-)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


val blockFetchingListener = new BlockFetchingListener {
override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
// Only add the buffer to results queue if the iterator is not zombie,
Expand All @@ -258,17 +281,57 @@ final class ShuffleBlockFetcherIterator(
// This needs to be released after use.
buf.retain()
remainingBlocks -= blockId
blockOOMRetryCounts.remove(blockId)
results.put(new SuccessFetchResult(BlockId(blockId), infoMap(blockId)._2,
address, infoMap(blockId)._1, buf, remainingBlocks.isEmpty))
logDebug("remainingBlocks: " + remainingBlocks)
enqueueDeferredFetchRequestIfNecessary()
}
}
logTrace(s"Got remote block $blockId after ${Utils.getUsedTimeNs(startTimeNs)}")
}

override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
results.put(new FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e))
ShuffleBlockFetcherIterator.this.synchronized {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
e match {
// SPARK-27991: Catch the Netty OOM and set the flag `isNettyOOMOnShuffle` (shared among
// tasks) to true as early as possible. The pending fetch requests won't be sent
// afterwards until the flag is set to false on:
// 1) the Netty free memory >= maxReqSizeShuffleToMem
// - we'll check this whenever there's a fetch request succeeds.
// 2) the number of in-flight requests becomes 0
// - we'll check this in `fetchUpToMaxBytes` whenever it's invoked.
// Although Netty memory is shared across multiple modules, e.g., shuffle, rpc, the flag
// only takes effect for the shuffle due to the implementation simplicity concern.
// And we'll buffer the consecutive block failures caused by the OOM error until there's
// no remaining blocks in the current request. Then, we'll package these blocks into
// a same fetch request for the retry later. In this way, instead of creating the fetch
// request per block, it would help reduce the concurrent connections and data loads
// pressure at remote server.
// Note that catching OOM and do something based on it is only a workaround for
// handling the Netty OOM issue, which is not the best way towards memory management.
// We can get rid of it when we find a way to manage Netty's memory precisely.
case _: OutOfDirectMemoryError
if blockOOMRetryCounts.getOrElseUpdate(blockId, 0) < maxAttemptsOnNettyOOM =>
if (!isZombie) {
val failureTimes = blockOOMRetryCounts(blockId)
blockOOMRetryCounts(blockId) += 1
if (isNettyOOMOnShuffle.compareAndSet(false, true)) {
// The fetcher can fail remaining blocks in batch for the same error. So we only
// log the warning once to avoid flooding the logs.
logInfo(s"Block $blockId has failed $failureTimes times " +
s"due to Netty OOM, will retry")
}
remainingBlocks -= blockId
deferredBlocks += blockId
enqueueDeferredFetchRequestIfNecessary()
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add a debug message with else statement because this code path is changed after this PR? Previously, we do results.put always.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure which case the else statement you want to match here. Do you want:

if (!isZombie) {

} else {
 ...
}

?

(Note that in onBlockFetchSuccess we also don't have such else statement for zombie case.)

Or you actually want the else case of not putting results? If so, I think we already have the log info as

logInfo(s"Block $blockId has failed $failureTimes times " +
                    s"due to Netty OOM, will retry")


case _ =>
results.put(FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e))
}
}
}
}

Expand Down Expand Up @@ -613,6 +676,7 @@ final class ShuffleBlockFetcherIterator(
}
if (isNetworkReqDone) {
reqsInFlight -= 1
resetNettyOOMFlagIfPossible(maxReqSizeShuffleToMem)
logDebug("Number of requests in flight " + reqsInFlight)
}

Expand Down Expand Up @@ -684,7 +748,25 @@ final class ShuffleBlockFetcherIterator(
}

case FailureFetchResult(blockId, mapIndex, address, e) =>
throwFetchFailedException(blockId, mapIndex, address, e)
var errorMsg: String = null
if (e.isInstanceOf[OutOfDirectMemoryError]) {
errorMsg = s"Block $blockId fetch failed after $maxAttemptsOnNettyOOM " +
s"retries due to Netty OOM"
logError(errorMsg)
}
throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg))

case DeferFetchRequestResult(request) =>
val address = request.address
numBlocksInFlightPerAddress(address) =
numBlocksInFlightPerAddress(address) - request.blocks.size
bytesInFlight -= request.size
reqsInFlight -= 1
logDebug("Number of requests in flight " + reqsInFlight)
val defReqQueue =
deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
defReqQueue.enqueue(request)
result = null
}

// Send fetch requests up to maxBytesInFlight
Expand All @@ -699,7 +781,8 @@ final class ShuffleBlockFetcherIterator(
currentResult.blockId,
currentResult.mapIndex,
currentResult.address,
detectCorrupt && streamCompressedOrEncrypted))
detectCorrupt && streamCompressedOrEncrypted,
currentResult.isNetworkReqDone))
}

def toCompletionIterator: Iterator[(BlockId, InputStream)] = {
Expand All @@ -708,6 +791,15 @@ final class ShuffleBlockFetcherIterator(
}

private def fetchUpToMaxBytes(): Unit = {
if (isNettyOOMOnShuffle.get()) {
if (reqsInFlight > 0) {
// Return immediately if Netty is still OOMed and there're ongoing fetch requests
return
} else {
resetNettyOOMFlagIfPossible(0)
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was thinking more about this solution and there is a potential problem I see. Once a netty OOM is encountered for some responses, the corresponding requests are deferred and no more remote requests are sent. Now this helps to recover. But we don't change any in-flight remote requests limits. So after the isNettyOOMOnShuffle is reset with a successful remote response, the next burst of remote requests will be sent out at the same rate. This means again there are chances to see netty OOMs and again some of the blocks will be deferred. This introduces more delay and increases the load on shuffle server.

I think solving this maybe more complex and right now this is just a workaround. But maybe we can do something simpler to reduce the number of requests made after a netty OOM is encountered?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's true. I also considered another way previously, which is to adjust the threshold of in-flight requests dynamically. For example, when OOM throws, the threshold would be reduced to the max number of in-flight requests before OOM. And if OOM happens again, we continue to reduce the threshold. Then, it comes to a question: when do we increase the threshold? When the backlogged requests are too much and OOM has been disappeared for a while? If we go this way, we should be very careful about the adjustment algorithm as it's directly related to the performance.

Let me think more about it. Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I have limited the reset condition to whether the Netty free memory is larger than maxReqSizeShuffleToMem (default 200M), which is more strict than the averageBlockSize. I think this would mitigate the issue you mentioned here. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it again check at line 805 that freeDirectMemory > maxReqSizeShuffleToMem?
Otherwise it immediately unsets the flag and sends more requests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it again check at line 805 that freeDirectMemory > maxReqSizeShuffleToMem?

No. Otherwise, for the first invocation of fetchUpToMaxBytes, the fetching can hang if isNettyOOMOnShuffle=true. Because if no requests were sent in the first invocation, there would be no callback on fetchUpToMaxBytes later.

Copy link
Contributor

@otterc otterc May 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because if no requests were sent in the first invocation, there would be no callback on fetchUpToMaxBytes later.

For the first invocation fetchUpToMaxBytes if isNettyOOMOnShuffle=true then that means another iterator saw the NettyOOM as set it to true. This will just unset it.

Also how will it hang? let's just say if they are only remote blocks to be fetched and none of the requests are sent initially, iterator.next() will keep calling fetchUpToMaxBytes as I see it. Eventually when enough freeDirectMemory is available, it will send remote requests.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, it will block at

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. It just waits indefinitely here. Can we not change this to poll(time) ? If there is nothing, result will be null and it will call fetchUpToMaxBytes again.
WDYT?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

poll(time) may work but I think it breaks the existing design. After the change, we would introduce the overhead due to call fetchUpToMaxBytes many times while it's not necessary.

// Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host
// immediately, defer the request until the next time it can be processed.

Expand Down Expand Up @@ -766,12 +858,14 @@ final class ShuffleBlockFetcherIterator(
blockId: BlockId,
mapIndex: Int,
address: BlockManagerId,
e: Throwable) = {
e: Throwable,
message: Option[String] = None) = {
val msg = message.getOrElse(e.getMessage)
blockId match {
case ShuffleBlockId(shufId, mapId, reduceId) =>
throw new FetchFailedException(address, shufId, mapId, mapIndex, reduceId, e)
throw new FetchFailedException(address, shufId, mapId, mapIndex, reduceId, msg, e)
case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, _) =>
throw new FetchFailedException(address, shuffleId, mapId, mapIndex, startReduceId, e)
throw new FetchFailedException(address, shuffleId, mapId, mapIndex, startReduceId, msg, e)
case _ =>
throw new SparkException(
"Failed to get block " + blockId + ", which is not a shuffle block", e)
Expand All @@ -790,7 +884,8 @@ private class BufferReleasingInputStream(
private val blockId: BlockId,
private val mapIndex: Int,
private val address: BlockManagerId,
private val detectCorruption: Boolean)
private val detectCorruption: Boolean,
private val isNetworkReqDone: Boolean)
extends InputStream {
private[this] var closed = false

Expand All @@ -799,9 +894,16 @@ private class BufferReleasingInputStream(

override def close(): Unit = {
if (!closed) {
delegate.close()
iterator.releaseCurrentResultBuffer()
closed = true
try {
delegate.close()
iterator.releaseCurrentResultBuffer()
} finally {
// Unset the flag when a remote request finished and free memory is fairly enough.
if (isNetworkReqDone) {
ShuffleBlockFetcherIterator.resetNettyOOMFlagIfPossible(iterator.maxReqSizeShuffleToMem)
}
closed = true
}
}
}

Expand Down Expand Up @@ -862,6 +964,20 @@ private class ShuffleFetchCompletionListener(var data: ShuffleBlockFetcherIterat
private[storage]
object ShuffleBlockFetcherIterator {

/**
* A flag which indicates whether the Netty OOM error has raised during shuffle.
* If true, unless there's no in-flight fetch requests, all the pending shuffle
* fetch requests will be deferred until the flag is unset (whenever there's a
* complete fetch request).
*/
val isNettyOOMOnShuffle = new AtomicBoolean(false)

def resetNettyOOMFlagIfPossible(freeMemoryLowerBound: Long): Unit = {
if (isNettyOOMOnShuffle.get() && NettyUtils.freeDirectMemory() >= freeMemoryLowerBound) {
isNettyOOMOnShuffle.compareAndSet(true, false)
}
}

/**
* This function is used to merged blocks when doBatchFetch is true. Blocks which have the
* same `mapId` can be merged into one block batch. The block batch is specified by a range
Expand Down Expand Up @@ -966,10 +1082,7 @@ object ShuffleBlockFetcherIterator {
/**
* Result of a fetch from a remote block.
*/
private[storage] sealed trait FetchResult {
val blockId: BlockId
val address: BlockManagerId
}
private[storage] sealed trait FetchResult

/**
* Result of a fetch from a remote block successfully.
Expand Down Expand Up @@ -1005,4 +1118,10 @@ object ShuffleBlockFetcherIterator {
address: BlockManagerId,
e: Throwable)
extends FetchResult

/**
* Result of a fetch request that should be deferred for some reasons, e.g., Netty OOM
*/
private[storage]
case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
}
Loading