@@ -20,9 +20,9 @@ package org.apache.spark.storage
2020import java .nio .ByteBuffer
2121import java .util .LinkedHashMap
2222
23+ import scala .collection .mutable
2324import scala .collection .mutable .ArrayBuffer
2425
25- import org .apache .spark .SparkEnv
2626import org .apache .spark .util .{SizeEstimator , Utils }
2727import org .apache .spark .util .collection .SizeTrackingVector
2828
@@ -44,20 +44,15 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
4444 // blocks from the memory store.
4545 private val putLock = new Object ()
4646
47- /**
48- * Mapping from thread ID to memory used for unrolling blocks.
49- *
50- * To avoid potential deadlocks, all accesses of this map in MemoryStore are assumed to
51- * first synchronize on `putLock` and then on `unrollMemoryMap`, in that particular order.
52- * This is lazy because SparkEnv does not exist when we mock this class in tests.
53- */
54- private lazy val unrollMemoryMap = SparkEnv .get.unrollMemoryMap
47+ // A mapping from thread ID to amount of memory used for unrolling a block (in bytes)
48+ // All accesses of this map are assumed to have manually synchronized on `putLock`
49+ private val unrollMemoryMap = mutable.HashMap [Long , Long ]()
5550
5651 /**
5752 * The amount of space ensured for unrolling values in memory, shared across all cores.
5853 * This space is not reserved in advance, but allocated dynamically by dropping existing blocks.
5954 */
60- private val globalUnrollMemory = {
55+ private val maxUnrollMemory : Long = {
6156 val unrollFraction = conf.getDouble(" spark.storage.unrollFraction" , 0.2 )
6257 (maxMemory * unrollFraction).toLong
6358 }
@@ -227,22 +222,18 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
227222 var memoryThreshold = initialMemoryThreshold
228223 // Memory to request as a multiple of current vector size
229224 val memoryGrowthFactor = 1.5
230-
231- val threadId = Thread .currentThread().getId
225+ // Underlying vector for unrolling the block
232226 var vector = new SizeTrackingVector [Any ]
233227
234- // Request memory for our vector and return whether the request is granted. This involves
235- // synchronizing on putLock and unrollMemoryMap (in that order) , which could be expensive.
228+ // Request memory for our vector and return whether the request is granted
229+ // This involves synchronizing across all threads , which is expensive if called frequently
236230 def requestMemory (memoryToRequest : Long ): Boolean = {
237231 putLock.synchronized {
238- unrollMemoryMap.synchronized {
239- val previouslyOccupiedMemory = unrollMemoryMap.get(threadId).getOrElse(0L )
240- val otherThreadsMemory = unrollMemoryMap.values.sum - previouslyOccupiedMemory
241- val availableMemory = freeMemory - otherThreadsMemory
242- val granted = availableMemory > memoryToRequest
243- if (granted) { unrollMemoryMap(threadId) = memoryToRequest }
244- granted
245- }
232+ val otherThreadsMemory = currentUnrollMemory - threadCurrentUnrollMemory
233+ val availableMemory = freeMemory - otherThreadsMemory
234+ val granted = availableMemory > memoryToRequest
235+ if (granted) { reserveUnrollMemory(memoryToRequest) }
236+ granted
246237 }
247238 }
248239
@@ -261,17 +252,15 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
261252 // Hold the put lock, in case another thread concurrently puts a block that takes
262253 // up the unrolling space we just ensured here
263254 putLock.synchronized {
264- unrollMemoryMap.synchronized {
265- if (! requestMemory(amountToRequest)) {
266- // If the first request is not granted, try again after ensuring free space
267- // If there is still not enough space, give up and drop the partition
268- val extraSpaceNeeded = globalUnrollMemory - unrollMemoryMap.values.sum
269- val result = ensureFreeSpace(blockId, extraSpaceNeeded)
270- droppedBlocks ++= result.droppedBlocks
271- keepUnrolling = requestMemory(amountToRequest)
272- }
273- memoryThreshold = amountToRequest
255+ if (! requestMemory(amountToRequest)) {
256+ // If the first request is not granted, try again after ensuring free space
257+ // If there is still not enough space, give up and drop the partition
258+ val extraSpaceNeeded = maxUnrollMemory - currentUnrollMemory
259+ val result = ensureFreeSpace(blockId, extraSpaceNeeded)
260+ droppedBlocks ++= result.droppedBlocks
261+ keepUnrolling = requestMemory(amountToRequest)
274262 }
263+ memoryThreshold = amountToRequest
275264 }
276265 }
277266 }
@@ -292,9 +281,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
292281 // we release the memory claimed by this thread later on when the task finishes.
293282 if (keepUnrolling) {
294283 vector = null
295- unrollMemoryMap.synchronized {
296- unrollMemoryMap(threadId) = 0
297- }
284+ releaseUnrollMemory()
298285 }
299286 }
300287 }
@@ -387,7 +374,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
387374 }
388375
389376 // Take into account the amount of memory currently occupied by unrolling blocks
390- val freeSpace = unrollMemoryMap. synchronized { freeMemory - unrollMemoryMap.values.sum }
377+ val freeSpace = freeMemory - currentUnrollMemory
391378
392379 if (freeSpace < space) {
393380 val rddToAdd = getRddId(blockIdToAdd)
@@ -439,6 +426,34 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
439426 override def contains (blockId : BlockId ): Boolean = {
440427 entries.synchronized { entries.containsKey(blockId) }
441428 }
429+
430+ /**
431+ * Reserve memory for unrolling blocks used by this thread.
432+ */
433+ private def reserveUnrollMemory (memory : Long ): Unit = putLock.synchronized {
434+ unrollMemoryMap(Thread .currentThread().getId) = memory
435+ }
436+
437+ /**
438+ * Release memory used by this thread for unrolling blocks.
439+ */
440+ private [spark] def releaseUnrollMemory (): Unit = putLock.synchronized {
441+ unrollMemoryMap.remove(Thread .currentThread().getId)
442+ }
443+
444+ /**
445+ * Return the amount of memory currently occupied for unrolling blocks across all threads.
446+ */
447+ private def currentUnrollMemory : Long = putLock.synchronized {
448+ unrollMemoryMap.values.sum
449+ }
450+
451+ /**
452+ * Return the amount of memory currently occupied for unrolling blocks by this thread.
453+ */
454+ private def threadCurrentUnrollMemory : Long = putLock.synchronized {
455+ unrollMemoryMap.getOrElse(Thread .currentThread().getId, 0L )
456+ }
442457}
443458
444459private [spark] case class ResultWithDroppedBlocks (
0 commit comments