-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-27991][CORE] Defer the fetch request on Netty OOM #32287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3871877
6b37944
2cc07e2
416bdaa
15426b5
0863419
04dff9f
88d6d81
95aae3b
5ecb6dc
857f305
e3d552c
92d4ee5
254fa7c
dc01c85
466966c
7bc1b93
1460299
b64482e
7fc9a0a
bff1ce3
b5dc374
77e43ba
5538c33
959dd8d
90aa9ea
22a441f
f78b906
7ac87d8
4aaf04f
298810d
f334557
48deaed
3611378
8b74473
353e681
2b4bfd0
c9cc13f
28201b6
9e877d5
7affa68
08e5f61
4af6ee7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1193,6 +1193,15 @@ package object config { | |
| .intConf | ||
| .createWithDefault(3) | ||
|
|
||
| private[spark] val SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM = | ||
| ConfigBuilder("spark.shuffle.maxAttemptsOnNettyOOM") | ||
| .doc("The max attempts of a shuffle block would retry on Netty OOM issue before throwing " + | ||
| "the shuffle fetch failure.") | ||
| .version("3.2.0") | ||
| .internal() | ||
| .intConf | ||
| .createWithDefault(10) | ||
|
||
|
|
||
| private[spark] val REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS = | ||
| ConfigBuilder("spark.reducer.maxBlocksInFlightPerAddress") | ||
| .doc("This configuration limits the number of remote blocks being fetched per reduce task " + | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -20,19 +20,21 @@ package org.apache.spark.storage | |||
| import java.io.{InputStream, IOException} | ||||
| import java.nio.channels.ClosedByInterruptException | ||||
| import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} | ||||
| import java.util.concurrent.atomic.AtomicBoolean | ||||
| import javax.annotation.concurrent.GuardedBy | ||||
|
|
||||
| import scala.collection.mutable | ||||
| import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, LinkedHashMap, Queue} | ||||
| import scala.util.{Failure, Success} | ||||
|
|
||||
| import io.netty.util.internal.OutOfDirectMemoryError | ||||
| import org.apache.commons.io.IOUtils | ||||
|
|
||||
| import org.apache.spark.{SparkException, TaskContext} | ||||
| import org.apache.spark.internal.Logging | ||||
| import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} | ||||
| import org.apache.spark.network.shuffle._ | ||||
| import org.apache.spark.network.util.TransportConf | ||||
| import org.apache.spark.network.util.{NettyUtils, TransportConf} | ||||
| import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} | ||||
| import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} | ||||
|
|
||||
|
|
@@ -61,6 +63,8 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} | |||
| * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point | ||||
| * for a given remote host:port. | ||||
| * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. | ||||
| * @param maxAttemptsOnNettyOOM The max number of a block could retry due to Netty OOM before | ||||
| * throwing the fetch failure. | ||||
| * @param detectCorrupt whether to detect any corruption in fetched blocks. | ||||
| * @param shuffleMetrics used to report shuffle metrics. | ||||
| * @param doBatchFetch fetch continuous shuffle blocks from same executor in batch if the server | ||||
|
|
@@ -76,7 +80,8 @@ final class ShuffleBlockFetcherIterator( | |||
| maxBytesInFlight: Long, | ||||
| maxReqsInFlight: Int, | ||||
| maxBlocksInFlightPerAddress: Int, | ||||
| maxReqSizeShuffleToMem: Long, | ||||
| val maxReqSizeShuffleToMem: Long, | ||||
| maxAttemptsOnNettyOOM: Int, | ||||
| detectCorrupt: Boolean, | ||||
| detectCorruptUseExtraMemory: Boolean, | ||||
| shuffleMetrics: ShuffleReadMetricsReporter, | ||||
|
|
@@ -146,6 +151,12 @@ final class ShuffleBlockFetcherIterator( | |||
| /** Current number of blocks in flight per host:port */ | ||||
| private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]() | ||||
|
|
||||
| /** | ||||
| * Count the retry times for the blocks due to Netty OOM. The block will stop retry if | ||||
| * retry times has exceeded the [[maxAttemptsOnNettyOOM]]. | ||||
| */ | ||||
| private[this] val blockOOMRetryCounts = new HashMap[String, Int] | ||||
|
|
||||
| /** | ||||
| * The blocks that can't be decompressed successfully, it is used to guarantee that we retry | ||||
| * at most once for those corrupted blocks. | ||||
|
|
@@ -245,9 +256,21 @@ final class ShuffleBlockFetcherIterator( | |||
| case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex)) | ||||
| }.toMap | ||||
| val remainingBlocks = new HashSet[String]() ++= infoMap.keys | ||||
| val deferredBlocks = new ArrayBuffer[String]() | ||||
| val blockIds = req.blocks.map(_.blockId.toString) | ||||
| val address = req.address | ||||
|
|
||||
| @inline def enqueueDeferredFetchRequestIfNecessary(): Unit = { | ||||
Ngone51 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
| if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) { | ||||
| val blocks = deferredBlocks.map { blockId => | ||||
| val (size, mapIndex) = infoMap(blockId) | ||||
| FetchBlockInfo(BlockId(blockId), size, mapIndex) | ||||
| } | ||||
| results.put(DeferFetchRequestResult(FetchRequest(address, blocks.toSeq))) | ||||
| deferredBlocks.clear() | ||||
| } | ||||
| } | ||||
|
||||
|
|
||||
| val blockFetchingListener = new BlockFetchingListener { | ||||
| override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { | ||||
| // Only add the buffer to results queue if the iterator is not zombie, | ||||
|
|
@@ -258,17 +281,57 @@ final class ShuffleBlockFetcherIterator( | |||
| // This needs to be released after use. | ||||
| buf.retain() | ||||
| remainingBlocks -= blockId | ||||
| blockOOMRetryCounts.remove(blockId) | ||||
| results.put(new SuccessFetchResult(BlockId(blockId), infoMap(blockId)._2, | ||||
| address, infoMap(blockId)._1, buf, remainingBlocks.isEmpty)) | ||||
| logDebug("remainingBlocks: " + remainingBlocks) | ||||
| enqueueDeferredFetchRequestIfNecessary() | ||||
| } | ||||
| } | ||||
| logTrace(s"Got remote block $blockId after ${Utils.getUsedTimeNs(startTimeNs)}") | ||||
| } | ||||
|
|
||||
| override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { | ||||
| logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) | ||||
| results.put(new FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e)) | ||||
| ShuffleBlockFetcherIterator.this.synchronized { | ||||
| logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) | ||||
| e match { | ||||
| // SPARK-27991: Catch the Netty OOM and set the flag `isNettyOOMOnShuffle` (shared among | ||||
| // tasks) to true as early as possible. The pending fetch requests won't be sent | ||||
| // afterwards until the flag is set to false on: | ||||
| // 1) the Netty free memory >= maxReqSizeShuffleToMem | ||||
| // - we'll check this whenever there's a fetch request succeeds. | ||||
| // 2) the number of in-flight requests becomes 0 | ||||
| // - we'll check this in `fetchUpToMaxBytes` whenever it's invoked. | ||||
| // Although Netty memory is shared across multiple modules, e.g., shuffle, rpc, the flag | ||||
| // only takes effect for the shuffle due to the implementation simplicity concern. | ||||
| // And we'll buffer the consecutive block failures caused by the OOM error until there's | ||||
| // no remaining blocks in the current request. Then, we'll package these blocks into | ||||
| // a same fetch request for the retry later. In this way, instead of creating the fetch | ||||
| // request per block, it would help reduce the concurrent connections and data loads | ||||
| // pressure at remote server. | ||||
| // Note that catching OOM and do something based on it is only a workaround for | ||||
| // handling the Netty OOM issue, which is not the best way towards memory management. | ||||
| // We can get rid of it when we find a way to manage Netty's memory precisely. | ||||
| case _: OutOfDirectMemoryError | ||||
| if blockOOMRetryCounts.getOrElseUpdate(blockId, 0) < maxAttemptsOnNettyOOM => | ||||
| if (!isZombie) { | ||||
| val failureTimes = blockOOMRetryCounts(blockId) | ||||
| blockOOMRetryCounts(blockId) += 1 | ||||
| if (isNettyOOMOnShuffle.compareAndSet(false, true)) { | ||||
| // The fetcher can fail remaining blocks in batch for the same error. So we only | ||||
| // log the warning once to avoid flooding the logs. | ||||
| logInfo(s"Block $blockId has failed $failureTimes times " + | ||||
| s"due to Netty OOM, will retry") | ||||
| } | ||||
| remainingBlocks -= blockId | ||||
| deferredBlocks += blockId | ||||
| enqueueDeferredFetchRequestIfNecessary() | ||||
| } | ||||
|
||||
|
|
||||
| case _ => | ||||
| results.put(FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e)) | ||||
| } | ||||
| } | ||||
| } | ||||
| } | ||||
|
|
||||
|
|
@@ -613,6 +676,7 @@ final class ShuffleBlockFetcherIterator( | |||
| } | ||||
| if (isNetworkReqDone) { | ||||
| reqsInFlight -= 1 | ||||
| resetNettyOOMFlagIfPossible(maxReqSizeShuffleToMem) | ||||
| logDebug("Number of requests in flight " + reqsInFlight) | ||||
| } | ||||
|
|
||||
|
|
@@ -684,7 +748,25 @@ final class ShuffleBlockFetcherIterator( | |||
| } | ||||
|
|
||||
| case FailureFetchResult(blockId, mapIndex, address, e) => | ||||
| throwFetchFailedException(blockId, mapIndex, address, e) | ||||
| var errorMsg: String = null | ||||
| if (e.isInstanceOf[OutOfDirectMemoryError]) { | ||||
| errorMsg = s"Block $blockId fetch failed after $maxAttemptsOnNettyOOM " + | ||||
| s"retries due to Netty OOM" | ||||
| logError(errorMsg) | ||||
| } | ||||
| throwFetchFailedException(blockId, mapIndex, address, e, Some(errorMsg)) | ||||
|
|
||||
| case DeferFetchRequestResult(request) => | ||||
| val address = request.address | ||||
| numBlocksInFlightPerAddress(address) = | ||||
| numBlocksInFlightPerAddress(address) - request.blocks.size | ||||
| bytesInFlight -= request.size | ||||
| reqsInFlight -= 1 | ||||
| logDebug("Number of requests in flight " + reqsInFlight) | ||||
| val defReqQueue = | ||||
| deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]()) | ||||
| defReqQueue.enqueue(request) | ||||
| result = null | ||||
| } | ||||
|
|
||||
| // Send fetch requests up to maxBytesInFlight | ||||
|
|
@@ -699,7 +781,8 @@ final class ShuffleBlockFetcherIterator( | |||
| currentResult.blockId, | ||||
| currentResult.mapIndex, | ||||
| currentResult.address, | ||||
| detectCorrupt && streamCompressedOrEncrypted)) | ||||
| detectCorrupt && streamCompressedOrEncrypted, | ||||
| currentResult.isNetworkReqDone)) | ||||
| } | ||||
|
|
||||
| def toCompletionIterator: Iterator[(BlockId, InputStream)] = { | ||||
|
|
@@ -708,6 +791,15 @@ final class ShuffleBlockFetcherIterator( | |||
| } | ||||
|
|
||||
| private def fetchUpToMaxBytes(): Unit = { | ||||
| if (isNettyOOMOnShuffle.get()) { | ||||
| if (reqsInFlight > 0) { | ||||
| // Return immediately if Netty is still OOMed and there're ongoing fetch requests | ||||
| return | ||||
| } else { | ||||
| resetNettyOOMFlagIfPossible(0) | ||||
| } | ||||
| } | ||||
|
|
||||
|
||||
| result = results.take() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see. It just waits indefinitely here. Can we not change this to poll(time) ? If there is nothing, result will be null and it will call fetchUpToMaxBytes again.
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
poll(time) may work but I think it breaks the existing design. After the change, we would introduce the overhead due to call fetchUpToMaxBytes many times while it's not necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to write a unit test for this change but realized it's not easy while I was trying. So I gave it up. But I have done the manual test and looks fine. cc @mridulm