Skip to content

Commit b9a6eee

Browse files
committed
Simplify locking behavior on unrollMemoryMap
This gets rid of the need to synchronize on unrollMemoryMap in addition to putLock. Now we require all accesses on unrollMemoryMap to synchronize on putLock, including when the executor releases the unroll memory after a task ends.
1 parent ed6cda4 commit b9a6eee

File tree

3 files changed

+57
-44
lines changed

3 files changed

+57
-44
lines changed

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,6 @@ class SparkEnv (
7171
// All accesses should be manually synchronized
7272
val shuffleMemoryMap = mutable.HashMap[Long, Long]()
7373

74-
// A mapping of thread ID to amount of memory, in bytes, used for unrolling a block
75-
// All accesses should be manually synchronized
76-
val unrollMemoryMap = mutable.HashMap[Long, Long]()
77-
7874
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
7975

8076
// A general, soft-reference map for metadata needed during HadoopRDD split computation

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,11 +266,13 @@ private[spark] class Executor(
266266
}
267267
}
268268
} finally {
269-
val threadId = Thread.currentThread().getId
269+
// Release memory used by this thread for shuffles
270270
val shuffleMemoryMap = env.shuffleMemoryMap
271-
val unrollMemoryMap = env.unrollMemoryMap
272-
shuffleMemoryMap.synchronized { shuffleMemoryMap.remove(threadId) }
273-
unrollMemoryMap.synchronized { unrollMemoryMap.remove(threadId) }
271+
shuffleMemoryMap.synchronized {
272+
shuffleMemoryMap.remove(Thread.currentThread().getId)
273+
}
274+
// Release memory used by this thread for unrolling blocks
275+
env.blockManager.memoryStore.releaseUnrollMemory()
274276
runningTasks.remove(taskId)
275277
}
276278
}

core/src/main/scala/org/apache/spark/storage/MemoryStore.scala

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ package org.apache.spark.storage
2020
import java.nio.ByteBuffer
2121
import java.util.LinkedHashMap
2222

23+
import scala.collection.mutable
2324
import scala.collection.mutable.ArrayBuffer
2425

25-
import org.apache.spark.SparkEnv
2626
import org.apache.spark.util.{SizeEstimator, Utils}
2727
import org.apache.spark.util.collection.SizeTrackingVector
2828

@@ -44,20 +44,15 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
4444
// blocks from the memory store.
4545
private val putLock = new Object()
4646

47-
/**
48-
* Mapping from thread ID to memory used for unrolling blocks.
49-
*
50-
* To avoid potential deadlocks, all accesses of this map in MemoryStore are assumed to
51-
* first synchronize on `putLock` and then on `unrollMemoryMap`, in that particular order.
52-
* This is lazy because SparkEnv does not exist when we mock this class in tests.
53-
*/
54-
private lazy val unrollMemoryMap = SparkEnv.get.unrollMemoryMap
47+
// A mapping from thread ID to amount of memory used for unrolling a block (in bytes)
48+
// All accesses of this map are assumed to have manually synchronized on `putLock`
49+
private val unrollMemoryMap = mutable.HashMap[Long, Long]()
5550

5651
/**
5752
* The amount of space ensured for unrolling values in memory, shared across all cores.
5853
* This space is not reserved in advance, but allocated dynamically by dropping existing blocks.
5954
*/
60-
private val globalUnrollMemory = {
55+
private val maxUnrollMemory: Long = {
6156
val unrollFraction = conf.getDouble("spark.storage.unrollFraction", 0.2)
6257
(maxMemory * unrollFraction).toLong
6358
}
@@ -227,22 +222,18 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
227222
var memoryThreshold = initialMemoryThreshold
228223
// Memory to request as a multiple of current vector size
229224
val memoryGrowthFactor = 1.5
230-
231-
val threadId = Thread.currentThread().getId
225+
// Underlying vector for unrolling the block
232226
var vector = new SizeTrackingVector[Any]
233227

234-
// Request memory for our vector and return whether the request is granted. This involves
235-
// synchronizing on putLock and unrollMemoryMap (in that order), which could be expensive.
228+
// Request memory for our vector and return whether the request is granted
229+
// This involves synchronizing across all threads, which is expensive if called frequently
236230
def requestMemory(memoryToRequest: Long): Boolean = {
237231
putLock.synchronized {
238-
unrollMemoryMap.synchronized {
239-
val previouslyOccupiedMemory = unrollMemoryMap.get(threadId).getOrElse(0L)
240-
val otherThreadsMemory = unrollMemoryMap.values.sum - previouslyOccupiedMemory
241-
val availableMemory = freeMemory - otherThreadsMemory
242-
val granted = availableMemory > memoryToRequest
243-
if (granted) { unrollMemoryMap(threadId) = memoryToRequest }
244-
granted
245-
}
232+
val otherThreadsMemory = currentUnrollMemory - threadCurrentUnrollMemory
233+
val availableMemory = freeMemory - otherThreadsMemory
234+
val granted = availableMemory > memoryToRequest
235+
if (granted) { reserveUnrollMemory(memoryToRequest) }
236+
granted
246237
}
247238
}
248239

@@ -261,17 +252,15 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
261252
// Hold the put lock, in case another thread concurrently puts a block that takes
262253
// up the unrolling space we just ensured here
263254
putLock.synchronized {
264-
unrollMemoryMap.synchronized {
265-
if (!requestMemory(amountToRequest)) {
266-
// If the first request is not granted, try again after ensuring free space
267-
// If there is still not enough space, give up and drop the partition
268-
val extraSpaceNeeded = globalUnrollMemory - unrollMemoryMap.values.sum
269-
val result = ensureFreeSpace(blockId, extraSpaceNeeded)
270-
droppedBlocks ++= result.droppedBlocks
271-
keepUnrolling = requestMemory(amountToRequest)
272-
}
273-
memoryThreshold = amountToRequest
255+
if (!requestMemory(amountToRequest)) {
256+
// If the first request is not granted, try again after ensuring free space
257+
// If there is still not enough space, give up and drop the partition
258+
val extraSpaceNeeded = maxUnrollMemory - currentUnrollMemory
259+
val result = ensureFreeSpace(blockId, extraSpaceNeeded)
260+
droppedBlocks ++= result.droppedBlocks
261+
keepUnrolling = requestMemory(amountToRequest)
274262
}
263+
memoryThreshold = amountToRequest
275264
}
276265
}
277266
}
@@ -292,9 +281,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
292281
// we release the memory claimed by this thread later on when the task finishes.
293282
if (keepUnrolling) {
294283
vector = null
295-
unrollMemoryMap.synchronized {
296-
unrollMemoryMap(threadId) = 0
297-
}
284+
releaseUnrollMemory()
298285
}
299286
}
300287
}
@@ -387,7 +374,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
387374
}
388375

389376
// Take into account the amount of memory currently occupied by unrolling blocks
390-
val freeSpace = unrollMemoryMap.synchronized { freeMemory - unrollMemoryMap.values.sum }
377+
val freeSpace = freeMemory - currentUnrollMemory
391378

392379
if (freeSpace < space) {
393380
val rddToAdd = getRddId(blockIdToAdd)
@@ -439,6 +426,34 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
439426
override def contains(blockId: BlockId): Boolean = {
440427
entries.synchronized { entries.containsKey(blockId) }
441428
}
429+
430+
/**
431+
* Reserve memory for unrolling blocks used by this thread.
432+
*/
433+
private def reserveUnrollMemory(memory: Long): Unit = putLock.synchronized {
434+
unrollMemoryMap(Thread.currentThread().getId) = memory
435+
}
436+
437+
/**
438+
* Release memory used by this thread for unrolling blocks.
439+
*/
440+
private[spark] def releaseUnrollMemory(): Unit = putLock.synchronized {
441+
unrollMemoryMap.remove(Thread.currentThread().getId)
442+
}
443+
444+
/**
445+
* Return the amount of memory currently occupied for unrolling blocks across all threads.
446+
*/
447+
private def currentUnrollMemory: Long = putLock.synchronized {
448+
unrollMemoryMap.values.sum
449+
}
450+
451+
/**
452+
* Return the amount of memory currently occupied for unrolling blocks by this thread.
453+
*/
454+
private def threadCurrentUnrollMemory: Long = putLock.synchronized {
455+
unrollMemoryMap.getOrElse(Thread.currentThread().getId, 0L)
456+
}
442457
}
443458

444459
private[spark] case class ResultWithDroppedBlocks(

0 commit comments

Comments
 (0)