@@ -217,14 +217,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
217217 val initialMemoryThreshold = conf.getLong(" spark.storage.unrollMemoryThreshold" , 1024 * 1024 )
218218 // How often to check whether we need to request more memory
219219 val memoryCheckPeriod = 16
220- // Memory currently reserved by this thread (bytes)
220+ // Memory currently reserved by this thread for this particular unrolling operation
221221 var memoryThreshold = initialMemoryThreshold
222222 // Memory to request as a multiple of current vector size
223223 val memoryGrowthFactor = 1.5
224+ // Previous unroll memory held by this thread, for releasing later
225+ val previousMemoryReserved = currentUnrollMemoryForThisThread
224226 // Underlying vector for unrolling the block
225227 var vector = new SizeTrackingVector [Any ]
226228
227- // Request memory for our vector and return whether the request is granted
229+ // Request additional memory for our vector and return whether the request is granted
228230 // This involves synchronizing across all threads, which is expensive if called frequently
229231 def requestMemory (memoryToRequest : Long ): Boolean = {
230232 accountingLock.synchronized {
@@ -237,7 +239,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
237239 }
238240
239241 // Request enough memory to begin unrolling
240- keepUnrolling = requestMemory(memoryThreshold )
242+ keepUnrolling = requestMemory(initialMemoryThreshold )
241243
242244 // Unroll this block safely, checking whether we have exceeded our threshold periodically
243245 try {
@@ -247,7 +249,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
247249 // If our vector's size has exceeded the threshold, request more memory
248250 val currentSize = vector.estimateSize()
249251 if (currentSize >= memoryThreshold) {
250- val amountToRequest = (currentSize * memoryGrowthFactor).toLong
252+ val amountToRequest = (currentSize * ( memoryGrowthFactor - 1 ) ).toLong
251253 // Hold the put lock, in case another thread concurrently puts a block that takes
252254 // up the unrolling space we just ensured here
253255 accountingLock.synchronized {
@@ -259,8 +261,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
259261 droppedBlocks ++= result.droppedBlocks
260262 keepUnrolling = requestMemory(amountToRequest)
261263 }
262- memoryThreshold = amountToRequest
263264 }
265+ // New threshold is currentSize * memoryGrowthFactor
266+ memoryThreshold = currentSize + amountToRequest
264267 }
265268 }
266269 elementsUnrolled += 1
@@ -280,7 +283,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
280283 // we release the memory claimed by this thread later on when the task finishes.
281284 if (keepUnrolling) {
282285 vector = null
283- releaseUnrollMemoryForThisThread()
286+ val amountToRelease = currentUnrollMemoryForThisThread - previousMemoryReserved
287+ releaseUnrollMemoryForThisThread(amountToRelease)
284288 }
285289 }
286290 }
@@ -355,8 +359,8 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
355359 * from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that
356360 * don't fit into memory that we want to avoid).
357361 *
358- * Assume that `accountingLock` is held by the caller to ensure only one thread is dropping blocks.
359- * Otherwise, the freed space may fill up before the caller puts in their new value.
362+ * Assume that `accountingLock` is held by the caller to ensure only one thread is dropping
363+ * blocks. Otherwise, the freed space may fill up before the caller puts in their new value.
360364 *
361365 * Return whether there is enough free space, along with the blocks dropped in the process.
362366 */
@@ -427,17 +431,32 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
427431 }
428432
429433 /**
430- * Reserve memory for unrolling blocks used by this thread.
434+ * Reserve additional memory for unrolling blocks used by this thread.
431435 */
432- private def reserveUnrollMemoryForThisThread (memory : Long ): Unit = accountingLock.synchronized {
433- unrollMemoryMap(Thread .currentThread().getId) = memory
436+ private def reserveUnrollMemoryForThisThread (memory : Long ): Unit = {
437+ val threadId = Thread .currentThread().getId
438+ accountingLock.synchronized {
439+ unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, 0L ) + memory
440+ }
434441 }
435442
436443 /**
437444 * Release memory used by this thread for unrolling blocks.
445+ * If the amount is not specified, remove the current thread's allocation altogether.
438446 */
439- private [spark] def releaseUnrollMemoryForThisThread (): Unit = accountingLock.synchronized {
440- unrollMemoryMap.remove(Thread .currentThread().getId)
447+ private [spark] def releaseUnrollMemoryForThisThread (memory : Long = - 1L ): Unit = {
448+ val threadId = Thread .currentThread().getId
449+ accountingLock.synchronized {
450+ if (memory < 0 ) {
451+ unrollMemoryMap.remove(threadId)
452+ } else {
453+ unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, memory) - memory
454+ // If this thread claims no more unroll memory, release it completely
455+ if (unrollMemoryMap(threadId) <= 0 ) {
456+ unrollMemoryMap.remove(threadId)
457+ }
458+ }
459+ }
441460 }
442461
443462 /**
0 commit comments