Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 0 additions & 179 deletions core/src/main/scala/org/apache/spark/CacheManager.scala

This file was deleted.

4 changes: 0 additions & 4 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class SparkEnv (
private[spark] val rpcEnv: RpcEnv,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheManager: CacheManager,
val mapOutputTracker: MapOutputTracker,
val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
Expand Down Expand Up @@ -333,8 +332,6 @@ object SparkEnv extends Logging {

val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)

val cacheManager = new CacheManager(blockManager)

val metricsSystem = if (isDriver) {
// Don't start metrics system right now for Driver.
// We need to wait for the task scheduler to give us an app ID.
Expand Down Expand Up @@ -371,7 +368,6 @@ object SparkEnv extends Logging {
rpcEnv,
serializer,
closureSerializer,
cacheManager,
mapOutputTracker,
shuffleManager,
broadcastManager,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,14 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
// Store a copy of the broadcast variable in the driver so that tasks run on the driver
// do not create a duplicate copy of the broadcast variable's value.
val blockManager = SparkEnv.get.blockManager
if (blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
blockManager.releaseLock(broadcastId)
} else {
if (!blockManager.putSingle(broadcastId, value, MEMORY_AND_DISK, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
val blocks =
TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
blocks.zipWithIndex.foreach { case (block, i) =>
val pieceId = BroadcastBlockId(id, "piece" + i)
if (blockManager.putBytes(pieceId, block, MEMORY_AND_DISK_SER, tellMaster = true)) {
blockManager.releaseLock(pieceId)
} else {
if (!blockManager.putBytes(pieceId, block, MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
}
Expand All @@ -130,22 +126,24 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
// First try getLocalBytes because there is a chance that previous attempts to fetch the
// broadcast blocks have already fetched some of the blocks. In that case, some blocks
// would be available locally (on this executor).
def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId)
def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block =>
// If we found the block from remote executors/driver's BlockManager, put the block
// in this executor's BlockManager.
if (!bm.putBytes(pieceId, block, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(
s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
block
bm.getLocalBytes(pieceId) match {
case Some(block) =>
blocks(pid) = block
releaseLock(pieceId)
case None =>
bm.getRemoteBytes(pieceId) match {
case Some(b) =>
// We found the block from remote executors/driver's BlockManager, so put the block
// in this executor's BlockManager.
if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(
s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
blocks(pid) = b
case None =>
throw new SparkException(s"Failed to get $pieceId of $broadcastId")
}
}
val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse(
throw new SparkException(s"Failed to get $pieceId of $broadcastId"))
// At this point we are guaranteed to hold a read lock, since we either got the block locally
// or stored the remotely-fetched block and automatically downgraded the write lock.
blocks(pid) = block
releaseLock(pieceId)
}
blocks
}
Expand Down Expand Up @@ -191,9 +189,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
// Store the merged copy in BlockManager so other tasks on this executor don't
// need to re-fetch it.
val storageLevel = StorageLevel.MEMORY_AND_DISK
if (blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
releaseLock(broadcastId)
} else {
if (!blockManager.putSingle(broadcastId, obj, storageLevel, tellMaster = false)) {
throw new SparkException(s"Failed to store $broadcastId in BlockManager")
}
obj
Expand Down
5 changes: 1 addition & 4 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -279,11 +279,8 @@ private[spark] class Executor(
ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
} else if (resultSize >= maxRpcMessageSize) {
val blockId = TaskResultBlockId(taskId)
val putSucceeded = env.blockManager.putBytes(
env.blockManager.putBytes(
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
if (putSucceeded) {
env.blockManager.releaseLock(blockId)
}
logInfo(
s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,7 @@ class NettyBlockRpcServer(
serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata))
val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
val blockId = BlockId(uploadBlock.blockId)
val putSucceeded = blockManager.putBlockData(blockId, data, level)
if (putSucceeded) {
blockManager.releaseLock(blockId)
}
blockManager.putBlockData(blockId, data, level)
responseContext.onSuccess(ByteBuffer.allocate(0))
}
}
Expand Down
33 changes: 31 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.partial.BoundedDouble
import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler,
Expand Down Expand Up @@ -272,7 +272,7 @@ abstract class RDD[T: ClassTag](
*/
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
getOrCompute(split, context)
} else {
computeOrReadCheckpoint(split, context)
}
Expand Down Expand Up @@ -314,6 +314,35 @@ abstract class RDD[T: ClassTag](
}
}

/**
* Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is cached.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add that this is only called on the executors? (otherwise people might wonder why we don't just use sc.env instead of SparkEnv.get)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

*/
private[spark] def getOrCompute(partition: Partition, context: TaskContext): Iterator[T] = {
val blockId = RDDBlockId(id, partition.index)
var readCachedBlock = true
// This method is called on executors, so we need call SparkEnv.get instead of sc.env.
SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, () => {
readCachedBlock = false
computeOrReadCheckpoint(partition, context)
}) match {
case Left(blockResult) =>
if (readCachedBlock) {
val existingMetrics = context.taskMetrics().registerInputMetrics(blockResult.readMethod)
existingMetrics.incBytesReadInternal(blockResult.bytes)
new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) {
override def next(): T = {
existingMetrics.incRecordsReadInternal(1)
delegate.next()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

crazy idea: what if we push all of this metrics code into getOrElseUpdate? Then we don't need the complicated left or right signature anymore and we can contain all the complexity within BlockManager itself. We might need to pass in TaskContext but I think it's pretty doable to make getOrElseUpdate return an Iterator[T] instead. Or just do TaskContext.get there; doPut already does that anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered this but had some reservations due to the fact that we might not want to consider every read of a BlockManager iterator to be a read of an input record (since we also read from iterators in the implementation of getSingle() and in a couple of other places, and those methods are used for things like TorrentBroadcast blocks which I don't think we want to treat as input records in for the purposes of these metrics).

On the other hand, this getOrElseUpdate method really only makes sense in the context of caching code anyways, so maybe it's fine to couple its semantics a little more tightly to its only current use.

}
}
} else {
new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
}
case Right(iter) =>
new InterruptibleIterator(context, iter.asInstanceOf[Iterator[T]])
}
}

/**
* Execute a block of code in a scope such that all new RDDs created in this body will
* be part of the same scope. For more detail, see {{org.apache.spark.rdd.RDDOperationScope}}.
Expand Down
Loading