Skip to content

Commit 195abd7

Browse files
committed
Refactor: move unfold logic to MemoryStore
This logic is also needed in other parts of the MemoryStore, e.g. when we try to store deserialized bytes in memory. The unfolding logic is specific to the memory case, so it makes sense for it to reside in MemoryStore, as opposed to the higher level CacheManager.
1 parent 1e82d00 commit 195abd7

File tree

2 files changed

+119
-80
lines changed

2 files changed

+119
-80
lines changed

core/src/main/scala/org/apache/spark/CacheManager.scala

Lines changed: 25 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,6 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
3434
/** Keys of RDD partitions that are being computed/loaded. */
3535
private val loading = new mutable.HashSet[RDDBlockId]
3636

37-
/**
38-
* The amount of space ensured for unrolling partitions, shared across all cores.
39-
* This space is not reserved in advance, but allocated dynamically by dropping existing blocks.
40-
* It must be a lazy val in order to access a mocked BlockManager's conf in tests properly.
41-
*/
42-
private lazy val globalBufferMemory = BlockManager.getBufferMemory(blockManager.conf)
43-
4437
/** Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is cached. */
4538
def getOrCompute[T](
4639
rdd: RDD[T],
@@ -137,10 +130,12 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
137130
updatedBlocks: ArrayBuffer[(BlockId, BlockStatus)]): Iterator[T] = {
138131

139132
if (!storageLevel.useMemory) {
140-
/* This RDD is not to be cached in memory, so we can just pass the computed values
141-
* as an iterator directly to the BlockManager, rather than first fully unrolling
133+
/*
134+
* This RDD is not to be cached in memory, so we can just pass the computed values
135+
* as an iterator directly to the BlockManager, rather than first fully unfolding
142136
* it in memory. The latter option potentially uses much more memory and risks OOM
143-
* exceptions that can be avoided. */
137+
* exceptions that can be avoided.
138+
*/
144139
updatedBlocks ++= blockManager.put(key, values, storageLevel, tellMaster = true)
145140
blockManager.get(key) match {
146141
case Some(v) => v.data.asInstanceOf[Iterator[T]]
@@ -149,86 +144,38 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
149144
throw new BlockException(key, s"Block manager failed to return cached value for $key!")
150145
}
151146
} else {
152-
/* This RDD is to be cached in memory. In this case we cannot pass the computed values
147+
/*
148+
* This RDD is to be cached in memory. In this case we cannot pass the computed values
153149
* to the BlockManager as an iterator and expect to read it back later. This is because
154150
* we may end up dropping a partition from memory store before getting it back, e.g.
155-
* when the entirety of the RDD does not fit in memory. */
156-
157-
var count = 0 // The number of elements unrolled so far
158-
var dropPartition = false // Whether to drop the new partition from memory
159-
var previousSize = 0L // Previous estimate of the size of our buffer
160-
val memoryRequestPeriod = 1000 // How frequently we request for more memory for our buffer
161-
162-
val threadId = Thread.currentThread().getId
163-
val cacheMemoryMap = SparkEnv.get.cacheMemoryMap
164-
var buffer = new SizeTrackingAppendOnlyBuffer[Any]
165-
166-
try {
167-
/* While adding values to the in-memory buffer, periodically check whether the memory
168-
* restrictions for unrolling partitions are still satisfied. If not, stop immediately,
169-
* and persist the partition to disk if specified by the storage level. This check is
170-
* a safeguard against the scenario when a single partition does not fit in memory. */
171-
while (values.hasNext && !dropPartition) {
172-
buffer += values.next()
173-
count += 1
174-
if (count % memoryRequestPeriod == 1) {
175-
// Calculate the amount of memory to request from the global memory pool
176-
val currentSize = buffer.estimateSize()
177-
val delta = math.max(currentSize - previousSize, 0)
178-
val memoryToRequest = currentSize + delta
179-
previousSize = currentSize
180-
181-
// Atomically check whether there is sufficient memory in the global pool to continue
182-
cacheMemoryMap.synchronized {
183-
val previouslyOccupiedMemory = cacheMemoryMap.get(threadId).getOrElse(0L)
184-
val otherThreadsMemory = cacheMemoryMap.values.sum - previouslyOccupiedMemory
185-
186-
// Request for memory for the local buffer, and return whether request is granted
187-
def requestForMemory(): Boolean = {
188-
val availableMemory = blockManager.memoryStore.freeMemory - otherThreadsMemory
189-
val granted = availableMemory > memoryToRequest
190-
if (granted) { cacheMemoryMap(threadId) = memoryToRequest }
191-
granted
192-
}
193-
194-
// If the first request is not granted, try again after ensuring free space
195-
// If there is still not enough space, give up and drop the partition
196-
if (!requestForMemory()) {
197-
val result = blockManager.memoryStore.ensureFreeSpace(key, globalBufferMemory)
198-
updatedBlocks ++= result.droppedBlocks
199-
dropPartition = !requestForMemory()
200-
}
201-
}
202-
}
203-
}
204-
205-
if (!dropPartition) {
206-
// We have successfully unrolled the entire partition, so cache it in memory
207-
updatedBlocks ++= blockManager.put(key, buffer.array, storageLevel, tellMaster = true)
208-
buffer.iterator.asInstanceOf[Iterator[T]]
209-
} else {
210-
// We have exceeded our collective quota. This partition will not be cached in memory.
151+
* when the entirety of the RDD does not fit in memory.
152+
*
153+
* In addition, we must be careful to not unfold the entire partition in memory at once.
154+
* Otherwise, we may cause an OOM exception if the JVM does not have enough space for this
155+
* single partition. Instead, we unfold the values cautiously, potentially aborting and
156+
* dropping the partition to disk if applicable.
157+
*/
158+
blockManager.memoryStore.unfoldSafely(key, values, storageLevel, updatedBlocks) match {
159+
case Left(arrayValues) =>
160+
// We have successfully unfolded the entire partition, so cache it in memory
161+
updatedBlocks ++= blockManager.put(key, arrayValues, storageLevel, tellMaster = true)
162+
arrayValues.iterator.asInstanceOf[Iterator[T]]
163+
case Right(iteratorValues) =>
164+
// There is not enough space to cache this partition in memory
165+
var returnValues = iteratorValues.asInstanceOf[Iterator[T]]
211166
val persistToDisk = storageLevel.useDisk
212-
logWarning(s"Failed to cache $key in memory! There is not enough space to unroll the " +
167+
logWarning(s"Failed to cache $key in memory! There is not enough space to unfold the " +
213168
s"entire partition. " + (if (persistToDisk) "Persisting to disk instead." else ""))
214-
var newValues = (buffer.iterator ++ values).asInstanceOf[Iterator[T]]
215169
if (persistToDisk) {
216170
val newLevel = StorageLevel(
217171
storageLevel.useDisk,
218172
useMemory = false,
219173
storageLevel.useOffHeap,
220174
deserialized = false,
221175
storageLevel.replication)
222-
newValues = putInBlockManager[T](key, newValues, newLevel, updatedBlocks)
176+
returnValues = putInBlockManager[T](key, returnValues, newLevel, updatedBlocks)
223177
}
224-
newValues
225-
}
226-
} finally {
227-
// Free up buffer for other threads
228-
buffer = null
229-
cacheMemoryMap.synchronized {
230-
cacheMemoryMap(threadId) = 0
231-
}
178+
returnValues
232179
}
233180
}
234181
}

core/src/main/scala/org/apache/spark/storage/MemoryStore.scala

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ import java.util.LinkedHashMap
2222

2323
import scala.collection.mutable.ArrayBuffer
2424

25+
import org.apache.spark.SparkEnv
2526
import org.apache.spark.util.{SizeEstimator, Utils}
27+
import org.apache.spark.util.collection.SizeTrackingAppendOnlyBuffer
2628

2729
private case class MemoryEntry(value: Any, size: Long, deserialized: Boolean)
2830

@@ -34,11 +36,20 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
3436
extends BlockStore(blockManager) {
3537

3638
private val entries = new LinkedHashMap[BlockId, MemoryEntry](32, 0.75f, true)
39+
3740
@volatile private var currentMemory = 0L
38-
// Object used to ensure that only one thread is putting blocks and if necessary, dropping
39-
// blocks from the memory store.
41+
42+
// Object used to ensure that only one thread is putting blocks and if necessary,
43+
// dropping blocks from the memory store.
4044
private val putLock = new Object()
4145

46+
/**
47+
* The amount of space ensured for unfolding values in memory, shared across all cores.
48+
* This space is not reserved in advance, but allocated dynamically by dropping existing blocks.
49+
* It must be a lazy val in order to access a mocked BlockManager's conf in tests properly.
50+
*/
51+
private lazy val globalBufferMemory = BlockManager.getBufferMemory(blockManager.conf)
52+
4253
logInfo("MemoryStore started with capacity %s".format(Utils.bytesToString(maxMemory)))
4354

4455
def freeMemory: Long = maxMemory - currentMemory
@@ -137,6 +148,87 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
137148
logInfo("MemoryStore cleared")
138149
}
139150

151+
/**
152+
* Unfold the given block in memory safely.
153+
*
154+
* The safety of this operation refers to avoiding potential OOM exceptions caused by
155+
* unfolding the entirety of the block in memory at once. This is achieved by periodically
156+
* checking whether the memory restrictions for unfolding blocks are still satisfied,
157+
* stopping immediately if not. This check is a safeguard against the scenario in which
158+
* there is not enough free memory to accommodate the entirety of a single block.
159+
*
160+
* This method returns either a fully unfolded array or a partially unfolded iterator.
161+
*/
162+
def unfoldSafely(
163+
blockId: BlockId,
164+
values: Iterator[Any],
165+
storageLevel: StorageLevel,
166+
droppedBlocks: ArrayBuffer[(BlockId, BlockStatus)])
167+
: Either[Array[Any], Iterator[Any]] = {
168+
169+
var count = 0 // The number of elements unfolded so far
170+
var enoughMemory = true // Whether there is enough memory to unfold this block
171+
var previousSize = 0L // Previous estimate of the size of our buffer
172+
val memoryRequestPeriod = 1000 // How frequently we request for more memory for our buffer
173+
174+
val threadId = Thread.currentThread().getId
175+
val cacheMemoryMap = SparkEnv.get.cacheMemoryMap
176+
var buffer = new SizeTrackingAppendOnlyBuffer[Any]
177+
178+
try {
179+
while (values.hasNext && enoughMemory) {
180+
buffer += values.next()
181+
count += 1
182+
if (count % memoryRequestPeriod == 1) {
183+
// Calculate the amount of memory to request from the global memory pool
184+
val currentSize = buffer.estimateSize()
185+
val delta = math.max(currentSize - previousSize, 0)
186+
val memoryToRequest = currentSize + delta
187+
previousSize = currentSize
188+
189+
// Atomically check whether there is sufficient memory in the global pool to continue
190+
cacheMemoryMap.synchronized {
191+
val previouslyOccupiedMemory = cacheMemoryMap.get(threadId).getOrElse(0L)
192+
val otherThreadsMemory = cacheMemoryMap.values.sum - previouslyOccupiedMemory
193+
194+
// Request for memory for the local buffer, and return whether request is granted
195+
def requestForMemory(): Boolean = {
196+
val availableMemory = freeMemory - otherThreadsMemory
197+
val granted = availableMemory > memoryToRequest
198+
if (granted) { cacheMemoryMap(threadId) = memoryToRequest }
199+
granted
200+
}
201+
202+
// If the first request is not granted, try again after ensuring free space
203+
// If there is still not enough space, give up and drop the partition
204+
if (!requestForMemory()) {
205+
val result = ensureFreeSpace(blockId, globalBufferMemory)
206+
droppedBlocks ++= result.droppedBlocks
207+
enoughMemory = requestForMemory()
208+
}
209+
}
210+
}
211+
}
212+
213+
if (enoughMemory) {
214+
// We successfully unfolded the entirety of this block
215+
Left(buffer.array)
216+
} else {
217+
// We ran out of space while unfolding the values for this block
218+
Right(buffer.iterator ++ values)
219+
}
220+
221+
} finally {
222+
// Unless we return an iterator that depends on the buffer, free up space for other threads
223+
if (enoughMemory) {
224+
buffer = null
225+
cacheMemoryMap.synchronized {
226+
cacheMemoryMap(threadId) = 0
227+
}
228+
}
229+
}
230+
}
231+
140232
/**
141233
* Return the RDD ID that a given block ID is from, or None if it is not an RDD block.
142234
*/

0 commit comments

Comments
 (0)