Skip to content

Commit a8f181d

Browse files
committed
Add special handling for StorageLevel.MEMORY_*_SER
We only unroll the serialized form of each partition for this case, because the deserialized form may be much larger and may not fit in memory. This commit also abstracts out part of the logic of getOrCompute to make it more readable.
1 parent b3736e3 commit a8f181d

File tree

1 file changed

+104
-71
lines changed

1 file changed

+104
-71
lines changed

core/src/main/scala/org/apache/spark/CacheManager.scala

Lines changed: 104 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -20,105 +20,53 @@ package org.apache.spark
2020
import scala.collection.mutable.{ArrayBuffer, HashSet}
2121

2222
import org.apache.spark.rdd.RDD
23-
import org.apache.spark.storage.{BlockId, BlockManager, BlockStatus, RDDBlockId, StorageLevel}
23+
import org.apache.spark.storage._
2424

2525
/**
26-
* Spark class responsible for passing RDDs split contents to the BlockManager and making
26+
* Spark class responsible for passing RDDs partition contents to the BlockManager and making
2727
* sure a node doesn't load two copies of an RDD at once.
2828
*/
2929
private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
3030

31-
/** Keys of RDD splits that are being computed/loaded. */
31+
/** Keys of RDD partitions that are being computed/loaded. */
3232
private val loading = new HashSet[RDDBlockId]()
3333

34-
/** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
34+
/** Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is cached. */
3535
def getOrCompute[T](
3636
rdd: RDD[T],
37-
split: Partition,
37+
partition: Partition,
3838
context: TaskContext,
3939
storageLevel: StorageLevel): Iterator[T] = {
4040

41-
val key = RDDBlockId(rdd.id, split.index)
41+
val key = RDDBlockId(rdd.id, partition.index)
4242
logDebug(s"Looking for partition $key")
4343
blockManager.get(key) match {
4444
case Some(values) =>
4545
// Partition is already materialized, so just return its values
4646
new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
4747

4848
case None =>
49-
// Mark the split as loading (unless someone else marks it first)
50-
loading.synchronized {
51-
if (loading.contains(key)) {
52-
logInfo(s"Another thread is loading $key, waiting for it to finish...")
53-
while (loading.contains(key)) {
54-
try {
55-
loading.wait()
56-
} catch {
57-
case e: Exception =>
58-
logWarning(s"Got an exception while waiting for another thread to load $key", e)
59-
}
60-
}
61-
logInfo(s"Finished waiting for $key")
62-
/* See whether someone else has successfully loaded it. The main way this would fail
63-
* is for the RDD-level cache eviction policy if someone else has loaded the same RDD
64-
* partition but we didn't want to make space for it. However, that case is unlikely
65-
* because it's unlikely that two threads would work on the same RDD partition. One
66-
* downside of the current code is that threads wait serially if this does happen. */
67-
blockManager.get(key) match {
68-
case Some(values) =>
69-
return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
70-
case None =>
71-
logInfo(s"Whoever was loading $key failed; we'll try it ourselves")
72-
loading.add(key)
73-
}
74-
} else {
75-
loading.add(key)
76-
}
49+
// Acquire a lock for loading this partition
50+
// If another thread already holds the lock, wait for it to finish return its results
51+
acquireLockForPartition(key).foreach { values =>
52+
return new InterruptibleIterator[T](context, values.asInstanceOf[Iterator[T]])
7753
}
54+
55+
// Otherwise, we have to load the partition ourselves
7856
try {
79-
// If we got here, we have to load the split
8057
logInfo(s"Partition $key not found, computing it")
81-
val computedValues = rdd.computeOrReadCheckpoint(split, context)
58+
val computedValues = rdd.computeOrReadCheckpoint(partition, context)
8259

83-
// Persist the result, so long as the task is not running locally
60+
// If the task is running locally, do not persist the result
8461
if (context.runningLocally) {
8562
return computedValues
8663
}
8764

88-
// Keep track of blocks with updated statuses
89-
var updatedBlocks = Seq[(BlockId, BlockStatus)]()
90-
val returnValue: Iterator[T] = {
91-
if (storageLevel.useDisk && !storageLevel.useMemory) {
92-
/* In the case that this RDD is to be persisted using DISK_ONLY
93-
* the iterator will be passed directly to the blockManager (rather then
94-
* caching it to an ArrayBuffer first), then the resulting block data iterator
95-
* will be passed back to the user. If the iterator generates a lot of data,
96-
* this means that it doesn't all have to be held in memory at one time.
97-
* This could also apply to MEMORY_ONLY_SER storage, but we need to make sure
98-
* blocks aren't dropped by the block store before enabling that. */
99-
updatedBlocks = blockManager.put(key, computedValues, storageLevel, tellMaster = true)
100-
blockManager.get(key) match {
101-
case Some(values) =>
102-
values.asInstanceOf[Iterator[T]]
103-
case None =>
104-
logInfo(s"Failure to store $key")
105-
throw new SparkException("Block manager failed to return persisted value")
106-
}
107-
} else {
108-
// In this case the RDD is cached to an array buffer. This will save the results
109-
// if we're dealing with a 'one-time' iterator
110-
val elements = new ArrayBuffer[Any]
111-
elements ++= computedValues
112-
updatedBlocks = blockManager.put(key, elements, storageLevel, tellMaster = true)
113-
elements.iterator.asInstanceOf[Iterator[T]]
114-
}
115-
}
116-
117-
// Update task metrics to include any blocks whose storage status is updated
118-
val metrics = context.taskMetrics
119-
metrics.updatedBlocks = Some(updatedBlocks)
120-
121-
new InterruptibleIterator(context, returnValue)
65+
// Otherwise, cache the values and keep track of any updates in block statuses
66+
val updatedBlocks = new ArrayBuffer[(BlockId, BlockStatus)]
67+
val cachedValues = cacheValues(key, computedValues, storageLevel, updatedBlocks)
68+
context.taskMetrics.updatedBlocks = Some(updatedBlocks)
69+
new InterruptibleIterator(context, cachedValues)
12270

12371
} finally {
12472
loading.synchronized {
@@ -128,4 +76,89 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
12876
}
12977
}
13078
}
79+
80+
/**
81+
* Acquire a loading lock for the partition identified by the given block ID.
82+
*
83+
* If the lock is free, just acquire it and return None. Otherwise, another thread is already
84+
* loading the partition, so we wait for it to finish and return the values loaded by the thread.
85+
*/
86+
private def acquireLockForPartition(id: RDDBlockId): Option[Iterator[Any]] = {
87+
loading.synchronized {
88+
if (!loading.contains(id)) {
89+
// If the partition is free, acquire its lock and begin computing its value
90+
loading.add(id)
91+
None
92+
} else {
93+
// Otherwise, wait for another thread to finish and return its result
94+
logInfo(s"Another thread is loading $id, waiting for it to finish...")
95+
while (loading.contains(id)) {
96+
try {
97+
loading.wait()
98+
} catch {
99+
case e: Exception =>
100+
logWarning(s"Exception while waiting for another thread to load $id", e)
101+
}
102+
}
103+
logInfo(s"Finished waiting for $id")
104+
/* See whether someone else has successfully loaded it. The main way this would fail
105+
* is for the RDD-level cache eviction policy if someone else has loaded the same RDD
106+
* partition but we didn't want to make space for it. However, that case is unlikely
107+
* because it's unlikely that two threads would work on the same RDD partition. One
108+
* downside of the current code is that threads wait serially if this does happen. */
109+
val values = blockManager.get(id)
110+
if (!values.isDefined) {
111+
logInfo(s"Whoever was loading $id failed; we'll try it ourselves")
112+
loading.add(id)
113+
}
114+
values
115+
}
116+
}
117+
}
118+
119+
/**
120+
* Cache the values of a partition, keeping track of any updates in the storage statuses
121+
* of other blocks along the way.
122+
*/
123+
private def cacheValues[T](
124+
key: BlockId,
125+
value: Iterator[T],
126+
storageLevel: StorageLevel,
127+
updatedBlocks: ArrayBuffer[(BlockId, BlockStatus)]): Iterator[T] = {
128+
129+
if (!storageLevel.useMemory) {
130+
/* This RDD is not to be cached in memory, so we can just pass the computed values
131+
* as an iterator directly to the BlockManager, rather than first fully unrolling
132+
* it in memory. The latter option potentially uses much more memory and risks OOM
133+
* exceptions that can be avoided. */
134+
assume(storageLevel.useDisk || storageLevel.useOffHeap, s"Empty storage level for $key!")
135+
updatedBlocks ++= blockManager.put(key, value, storageLevel, tellMaster = true)
136+
blockManager.get(key) match {
137+
case Some(values) =>
138+
values.asInstanceOf[Iterator[T]]
139+
case None =>
140+
logInfo(s"Failure to store $key")
141+
throw new BlockException(key, s"Block manager failed to return cached value for $key!")
142+
}
143+
} else {
144+
/* This RDD is to be cached in memory. In this case we cannot pass the computed values
145+
* to the BlockManager as an iterator and expect to read it back later. This is because
146+
* we may end up dropping a partition from memory store before getting it back, e.g.
147+
* when the entirety of the RDD does not fit in memory. */
148+
if (storageLevel.deserialized) {
149+
val elements = new ArrayBuffer[Any]
150+
elements ++= value
151+
updatedBlocks ++= blockManager.put(key, elements, storageLevel, tellMaster = true)
152+
elements.iterator.asInstanceOf[Iterator[T]]
153+
} else {
154+
/* This RDD is to be cached in memory in the form of serialized bytes. In this case,
155+
* we only unroll the serialized form of the data, because the deserialized form may
156+
* be much larger and may not fit in memory. */
157+
val bytes = blockManager.dataSerialize(key, value)
158+
updatedBlocks ++= blockManager.putBytes(key, bytes, storageLevel, tellMaster = true)
159+
blockManager.dataDeserialize(key, bytes).asInstanceOf[Iterator[T]]
160+
}
161+
}
162+
}
163+
131164
}

0 commit comments

Comments
 (0)