Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 35 additions & 12 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ private[spark] class BlockManager(
private val compressRdds = conf.getBoolean("spark.rdd.compress", false)
// Whether to compress shuffle output temporarily spilled to disk
private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true)
// Max number of failures before this block manager refreshes the block locations from the driver
private val maxFailuresBeforeLocationRefresh =
conf.getInt("spark.block.failures.beforeLocationRefresh", 5)

private val slaveEndpoint = rpcEnv.setupEndpoint(
"BlockManagerEndpoint" + BlockManager.ID_GENERATOR.next,
Expand Down Expand Up @@ -568,26 +571,46 @@ private[spark] class BlockManager(
def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = {
logDebug(s"Getting remote block $blockId")
require(blockId != null, "BlockId is null")
var runningFailureCount = 0
var totalFailureCount = 0
val locations = getLocations(blockId)
var numFetchFailures = 0
for (loc <- locations) {
val maxFetchFailures = locations.size
var locationIterator = locations.iterator
while (locationIterator.hasNext) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a thought: would this be easier to reason about as a tail-recursive method? would eliminate the need for the vars and returns.

val loc = locationIterator.next()
logDebug(s"Getting remote block $blockId from $loc")
val data = try {
blockTransferService.fetchBlockSync(
loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer()
} catch {
case NonFatal(e) =>
numFetchFailures += 1
if (numFetchFailures == locations.size) {
// An exception is thrown while fetching this block from all locations
throw new BlockFetchException(s"Failed to fetch block from" +
s" ${locations.size} locations. Most recent failure cause:", e)
} else {
// This location failed, so we retry fetch from a different one by returning null here
logWarning(s"Failed to fetch remote block $blockId " +
s"from $loc (failed attempt $numFetchFailures)", e)
null
runningFailureCount += 1
totalFailureCount += 1

if (totalFailureCount >= maxFetchFailures) {
// Give up trying anymore locations. Either we've tried all of the original locations,
// or we've refreshed the list of locations from the master, and have still
// hit failures after trying locations from the refreshed list.
throw new BlockFetchException(s"Failed to fetch block after" +
s" ${totalFailureCount} fetch failures. Most recent failure cause:", e)
}

logWarning(s"Failed to fetch remote block $blockId " +
s"from $loc (failed attempt $runningFailureCount)", e)

// If there is a large number of executors then locations list can contain a
// large number of stale entries causing a large number of retries that may
// take a significant amount of time. To get rid of these stale entries
// we refresh the block locations after a certain number of fetch failures
if (runningFailureCount >= maxFailuresBeforeLocationRefresh) {
locationIterator = getLocations(blockId).iterator
logDebug(s"Refreshed locations from the driver " +
s"after ${runningFailureCount} fetch failures.")
runningFailureCount = 0
}

// This location failed, so we retry fetch from a different one by returning null here
null
}

if (data != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,23 @@ import java.nio.ByteBuffer

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._
import scala.concurrent.Future
import scala.language.implicitConversions
import scala.language.postfixOps

import org.mockito.{Matchers => mc}
import org.mockito.Mockito.{mock, when}
import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest._
import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts._

import org.apache.spark._
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.memory.StaticMemoryManager
import org.apache.spark.network.{BlockDataManager, BlockTransferService}
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
Expand Down Expand Up @@ -66,9 +70,11 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
private def makeBlockManager(
maxMem: Long,
name: String = SparkContext.DRIVER_IDENTIFIER,
master: BlockManagerMaster = this.master): BlockManager = {
master: BlockManagerMaster = this.master,
transferService: Option[BlockTransferService] = Option.empty): BlockManager = {
val serializer = new KryoSerializer(conf)
val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1)
val transfer = transferService
.getOrElse(new NettyBlockTransferService(conf, securityMgr, numCores = 1))
val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1)
val blockManager = new BlockManager(name, rpcEnv, master, serializer, conf,
memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0)
Expand Down Expand Up @@ -1287,6 +1293,78 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
assert(store.getSingle("a1").isDefined, "a1 was not in store")
assert(store.getSingle("a3").isDefined, "a3 was not in store")
}

test("SPARK-13328: refresh block locations (fetch should fail after hitting a threshold)") {
val mockBlockTransferService =
new MockBlockTransferService(conf.getInt("spark.block.failures.beforeLocationRefresh", 5))
store = makeBlockManager(8000, "executor1", transferService = Option(mockBlockTransferService))
store.putSingle("item", 999L, StorageLevel.MEMORY_ONLY, tellMaster = true)
intercept[BlockFetchException] {
store.getRemoteBytes("item")
}
}

test("SPARK-13328: refresh block locations (fetch should succeed after location refresh)") {
val maxFailuresBeforeLocationRefresh =
conf.getInt("spark.block.failures.beforeLocationRefresh", 5)
val mockBlockManagerMaster = mock(classOf[BlockManagerMaster])
val mockBlockTransferService =
new MockBlockTransferService(maxFailuresBeforeLocationRefresh)
// make sure we have more than maxFailuresBeforeLocationRefresh locations
// so that we have a chance to do location refresh
val blockManagerIds = (0 to maxFailuresBeforeLocationRefresh)
.map { i => BlockManagerId(s"id-$i", s"host-$i", i + 1) }
when(mockBlockManagerMaster.getLocations(mc.any[BlockId])).thenReturn(blockManagerIds)
store = makeBlockManager(8000, "executor1", mockBlockManagerMaster,
transferService = Option(mockBlockTransferService))
val block = store.getRemoteBytes("item")
.asInstanceOf[Option[ByteBuffer]]
assert(block.isDefined)
verify(mockBlockManagerMaster, times(2)).getLocations("item")
}

class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService {
var numCalls = 0

override def init(blockDataManager: BlockDataManager): Unit = {}

override def fetchBlocks(
host: String,
port: Int,
execId: String,
blockIds: Array[String],
listener: BlockFetchingListener): Unit = {
listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1)))
}

override def close(): Unit = {}

override def hostName: String = { "MockBlockTransferServiceHost" }

override def port: Int = { 63332 }

override def uploadBlock(
hostname: String,
port: Int, execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit] = {
import scala.concurrent.ExecutionContext.Implicits.global
Future {}
}

override def fetchBlockSync(
host: String,
port: Int,
execId: String,
blockId: String): ManagedBuffer = {
numCalls += 1
if (numCalls <= maxFailures) {
throw new RuntimeException("Failing block fetch in the mock block transfer service")
}
super.fetchBlockSync(host, port, execId, blockId)
}
}
}

private object BlockManagerSuite {
Expand Down