diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index 11793ea92adb1..e3403c867e936 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -18,6 +18,7 @@ package org.apache.spark.network.nio import java.nio.ByteBuffer +import java.io.IOException import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} @@ -25,6 +26,7 @@ import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} +import scala.collection.mutable.HashMap import scala.concurrent.Future @@ -39,6 +41,10 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa private var blockDataManager: BlockDataManager = _ + private val blockFailedCounts = new HashMap[Seq[String], Int] + + val maxRetryNum = conf.getInt("spark.shuffle.fetch.maxRetryNumber", 3) + /** * Port number the service is listening on, available only after [[init]] is invoked. */ @@ -96,6 +102,9 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa future.onSuccess { case message => val bufferMessage = message.asInstanceOf[BufferMessage] val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) + blockFailedCounts.synchronized { + blockFailedCounts -= blockIds + } // SPARK-4064: In some cases(eg. Remote block was removed) blockMessageArray may be empty. if (blockMessageArray.isEmpty) { @@ -121,8 +130,28 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa }(cm.futureExecContext) future.onFailure { case exception => - blockIds.foreach { blockId => - listener.onBlockFetchFailure(blockId, exception) + exception match { + case connectExcpt: IOException => + logWarning("Failed to connect to " + hostName + ":" + port) + val failedCount = blockFailedCounts.synchronized { + val newFailedCount = blockFailedCounts(blockIds).getOrElse(0) + 1 + blockFailedCounts(blockIds) = newFailedCount + newFailedCount + } + if (failedCount >= maxRetryNum) { + blockFailedCounts.synchronized { + blockFailedCounts -= blockIds + } + blockIds.foreach { blockId => + listener.onBlockFetchFailure(blockId, connectExcpt) + } + } else { + fetchBlocks(hostName, port, blockIds, listener) + } + case t: Throwable => + blockIds.foreach { blockId => + listener.onBlockFetchFailure(blockId, t) + } } }(cm.futureExecContext) }