Skip to content

Commit d6969ff

Browse files
JoshRosenAndrew Or
authored andcommitted
[SPARK-12817] Add BlockManager.getOrElseUpdate and remove CacheManager
CacheManager directly calls MemoryStore.unrollSafely() and has its own logic for handling graceful fallback to disk when cached data does not fit in memory. However, this logic also exists inside of the MemoryStore itself, so this appears to be unnecessary duplication. Thanks to the addition of block-level read/write locks in apache#10705, we can refactor the code to remove the CacheManager and replace it with an atomic `BlockManager.getOrElseUpdate()` method. This pull request replaces / subsumes apache#10748. /cc andrewor14 and nongli for review. Note that this changes the locking semantics of a couple of internal BlockManager methods (`doPut()` and `lockNewBlockForWriting`), so please pay attention to the Scaladoc changes and new test cases for those methods. Author: Josh Rosen <[email protected]> Closes apache#11436 from JoshRosen/remove-cachemanager.
1 parent 8f8d8a2 commit d6969ff

File tree

16 files changed

+365
-597
lines changed

16 files changed

+365
-597
lines changed

core/src/main/scala/org/apache/spark/CacheManager.scala

Lines changed: 0 additions & 179 deletions
This file was deleted.

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ class SparkEnv (
5656
private[spark] val rpcEnv: RpcEnv,
5757
val serializer: Serializer,
5858
val closureSerializer: Serializer,
59-
val cacheManager: CacheManager,
6059
val mapOutputTracker: MapOutputTracker,
6160
val shuffleManager: ShuffleManager,
6261
val broadcastManager: BroadcastManager,
@@ -333,8 +332,6 @@ object SparkEnv extends Logging {
333332

334333
val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
335334

336-
val cacheManager = new CacheManager(blockManager)
337-
338335
val metricsSystem = if (isDriver) {
339336
// Don't start metrics system right now for Driver.
340337
// We need to wait for the task scheduler to give us an app ID.
@@ -371,7 +368,6 @@ object SparkEnv extends Logging {
371368
rpcEnv,
372369
serializer,
373370
closureSerializer,
374-
cacheManager,
375371
mapOutputTracker,
376372
shuffleManager,
377373
broadcastManager,

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -99,18 +99,14 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
9999
// Store a copy of the broadcast variable in the driver so that tasks run on the driver
100100
// do not create a duplicate copy of the broadcast variable's value.
101101
val blockManager = SparkEnv.get.blockManager
102-
if (blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
103-
blockManager.releaseLock(broadcastId)
104-
} else {
102+
if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
105103
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
106104
}
107105
val blocks =
108106
TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
109107
blocks.zipWithIndex.foreach { case (block, i) =>
110108
val pieceId = BroadcastBlockId(id, "piece" + i)
111-
if (blockManager.putBytes(pieceId, block, MEMORY_AND_DISK_SER, tellMaster = true)) {
112-
blockManager.releaseLock(pieceId)
113-
} else {
109+
if (!blockManager.putBytes(pieceId, block, MEMORY_AND_DISK_SER, tellMaster = true)) {
114110
throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
115111
}
116112
}
@@ -130,22 +126,24 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
130126
// First try getLocalBytes because there is a chance that previous attempts to fetch the
131127
// broadcast blocks have already fetched some of the blocks. In that case, some blocks
132128
// would be available locally (on this executor).
133-
def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId)
134-
def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block =>
135-
// If we found the block from remote executors/driver's BlockManager, put the block
136-
// in this executor's BlockManager.
137-
if (!bm.putBytes(pieceId, block, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
138-
throw new SparkException(
139-
s"Failed to store $pieceId of $broadcastId in local BlockManager")
140-
}
141-
block
129+
bm.getLocalBytes(pieceId) match {
130+
case Some(block) =>
131+
blocks(pid) = block
132+
releaseLock(pieceId)
133+
case None =>
134+
bm.getRemoteBytes(pieceId) match {
135+
case Some(b) =>
136+
// We found the block from remote executors/driver's BlockManager, so put the block
137+
// in this executor's BlockManager.
138+
if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
139+
throw new SparkException(
140+
s"Failed to store $pieceId of $broadcastId in local BlockManager")
141+
}
142+
blocks(pid) = b
143+
case None =>
144+
throw new SparkException(s"Failed to get $pieceId of $broadcastId")
145+
}
142146
}
143-
val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse(
144-
throw new SparkException(s"Failed to get $pieceId of $broadcastId"))
145-
// At this point we are guaranteed to hold a read lock, since we either got the block locally
146-
// or stored the remotely-fetched block and automatically downgraded the write lock.
147-
blocks(pid) = block
148-
releaseLock(pieceId)
149147
}
150148
blocks
151149
}
@@ -191,9 +189,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
191189
// Store the merged copy in BlockManager so other tasks on this executor don't
192190
// need to re-fetch it.
193191
val storageLevel = StorageLevel.MEMORY_AND_DISK
194-
if (blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
195-
releaseLock(broadcastId)
196-
} else {
192+
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
197193
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
198194
}
199195
obj

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,8 @@ private[spark] class Executor(
292292
ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
293293
} else if (resultSize >= maxRpcMessageSize) {
294294
val blockId = TaskResultBlockId(taskId)
295-
val putSucceeded = env.blockManager.putBytes(
295+
env.blockManager.putBytes(
296296
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
297-
if (putSucceeded) {
298-
env.blockManager.releaseLock(blockId)
299-
}
300297
logInfo(
301298
s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
302299
ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))

core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,7 @@ class NettyBlockRpcServer(
6666
serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata))
6767
val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
6868
val blockId = BlockId(uploadBlock.blockId)
69-
val putSucceeded = blockManager.putBlockData(blockId, data, level)
70-
if (putSucceeded) {
71-
blockManager.releaseLock(blockId)
72-
}
69+
blockManager.putBlockData(blockId, data, level)
7370
responseContext.onSuccess(ByteBuffer.allocate(0))
7471
}
7572
}

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import org.apache.spark.partial.BoundedDouble
3737
import org.apache.spark.partial.CountEvaluator
3838
import org.apache.spark.partial.GroupedCountEvaluator
3939
import org.apache.spark.partial.PartialResult
40-
import org.apache.spark.storage.StorageLevel
40+
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
4141
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
4242
import org.apache.spark.util.collection.OpenHashMap
4343
import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler,
@@ -272,7 +272,7 @@ abstract class RDD[T: ClassTag](
272272
*/
273273
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
274274
if (storageLevel != StorageLevel.NONE) {
275-
SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
275+
getOrCompute(split, context)
276276
} else {
277277
computeOrReadCheckpoint(split, context)
278278
}
@@ -314,6 +314,35 @@ abstract class RDD[T: ClassTag](
314314
}
315315
}
316316

317+
/**
318+
* Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is cached.
319+
*/
320+
private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
321+
val blockId = RDDBlockId(id, partition.index)
322+
var readCachedBlock = true
323+
// This method is called on executors, so we need call SparkEnv.get instead of sc.env.
324+
SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, () => {
325+
readCachedBlock = false
326+
computeOrReadCheckpoint(partition, context)
327+
}) match {
328+
case Left(blockResult) =>
329+
if (readCachedBlock) {
330+
val existingMetrics = context.taskMetrics().registerInputMetrics(blockResult.readMethod)
331+
existingMetrics.incBytesReadInternal(blockResult.bytes)
332+
new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) {
333+
override def next(): T = {
334+
existingMetrics.incRecordsReadInternal(1)
335+
delegate.next()
336+
}
337+
}
338+
} else {
339+
new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
340+
}
341+
case Right(iter) =>
342+
new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]])
343+
}
344+
}
345+
317346
/**
318347
* Execute a block of code in a scope such that all new RDDs created in this body will
319348
* be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}.

0 commit comments

Comments
 (0)