1717
1818package org .apache .spark .storage
1919
20- import java .lang
2120import javax .annotation .concurrent .GuardedBy
2221
2322import scala .collection .JavaConverters ._
2423import scala .collection .mutable
2524
26- import com .google .common .cache .{CacheBuilder , CacheLoader , LoadingCache }
2725import com .google .common .collect .ConcurrentHashMultiset
2826
2927import org .apache .spark .{Logging , SparkException , TaskContext }
@@ -141,18 +139,29 @@ private[storage] class BlockInfoManager extends Logging {
141139
142140 /**
143141 * Tracks the set of blocks that each task has locked for reading, along with the number of times
144- * that a block has been locked (since our read locks are re-entrant). This is thread-safe.
142+ * that a block has been locked (since our read locks are re-entrant).
145143 */
146- private [this ] val readLocksByTask : LoadingCache [lang.Long , ConcurrentHashMultiset [BlockId ]] = {
147- // We need to explicitly box as java.lang.Long to avoid a type mismatch error:
148- val loader = new CacheLoader [java.lang.Long , ConcurrentHashMultiset [BlockId ]] {
149- override def load (t : java.lang.Long ) = ConcurrentHashMultiset .create[BlockId ]()
150- }
151- CacheBuilder .newBuilder().build(loader)
152- }
144+ @ GuardedBy (" this" )
145+ private [this ] val readLocksByTask =
146+ new mutable.HashMap [TaskAttemptId , ConcurrentHashMultiset [BlockId ]]
147+
148+ // ----------------------------------------------------------------------------------------------
149+
150+ // Initialization for special task attempt ids:
151+ registerTask(BlockInfo .NON_TASK_WRITER )
153152
154153 // ----------------------------------------------------------------------------------------------
155154
155+ /**
156+ * Called at the start of a task in order to register that task with this [[BlockInfoManager ]].
157+ * This must be called prior to calling any other BlockInfoManager methods from that task.
158+ */
159+ def registerTask (taskAttemptId : TaskAttemptId ): Unit = {
160+ require(! readLocksByTask.contains(taskAttemptId),
161+ s " Task attempt $taskAttemptId is already registered " )
162+ readLocksByTask(taskAttemptId) = ConcurrentHashMultiset .create()
163+ }
164+
156165 /**
157166 * Returns the current task's task attempt id (which uniquely identifies the task), or
158167 * [[BlockInfo.NON_TASK_WRITER ]] if called by a non-task thread.
@@ -284,7 +293,7 @@ private[storage] class BlockInfoManager extends Logging {
284293 } else {
285294 assert(info.readerCount > 0 , s " Block $blockId is not locked for reading " )
286295 info.readerCount -= 1
287- val countsForTask = readLocksByTask.get (currentTaskAttemptId)
296+ val countsForTask = readLocksByTask(currentTaskAttemptId)
288297 val newPinCountForTask : Int = countsForTask.remove(blockId, 1 ) - 1
289298 assert(newPinCountForTask >= 0 ,
290299 s " Task $currentTaskAttemptId release lock on block $blockId more times than it acquired it " )
@@ -325,20 +334,21 @@ private[storage] class BlockInfoManager extends Logging {
325334 */
326335 def releaseAllLocksForTask (taskAttemptId : TaskAttemptId ): Seq [BlockId ] = {
327336 val blocksWithReleasedLocks = mutable.ArrayBuffer [BlockId ]()
328- synchronized {
329- writeLocksByTask.remove(taskAttemptId).foreach { locks =>
330- for (blockId <- locks) {
331- infos.get(blockId).foreach { info =>
332- assert(info.writerTask == taskAttemptId)
333- info.writerTask = BlockInfo .NO_WRITER
334- }
335- blocksWithReleasedLocks += blockId
336- }
337+
338+ val readLocks = synchronized {
339+ readLocksByTask.remove(taskAttemptId).get
340+ }
341+ val writeLocks = synchronized {
342+ writeLocksByTask.remove(taskAttemptId).getOrElse(Seq .empty)
343+ }
344+
345+ for (blockId <- writeLocks) {
346+ infos.get(blockId).foreach { info =>
347+ assert(info.writerTask == taskAttemptId)
348+ info.writerTask = BlockInfo .NO_WRITER
337349 }
338- notifyAll()
350+ blocksWithReleasedLocks += blockId
339351 }
340- val readLocks = readLocksByTask.get(taskAttemptId)
341- readLocksByTask.invalidate(taskAttemptId)
342352 readLocks.entrySet().iterator().asScala.foreach { entry =>
343353 val blockId = entry.getElement
344354 val lockCount = entry.getCount
@@ -350,6 +360,7 @@ private[storage] class BlockInfoManager extends Logging {
350360 }
351361 }
352362 }
363+
353364 synchronized {
354365 notifyAll()
355366 }
@@ -369,8 +380,8 @@ private[storage] class BlockInfoManager extends Logging {
369380 */
370381 private [storage] def getNumberOfMapEntries : Long = synchronized {
371382 size +
372- readLocksByTask.size() +
373- readLocksByTask.asMap().asScala. map(_._2.size()).sum +
383+ readLocksByTask.size +
384+ readLocksByTask.map(_._2.size()).sum +
374385 writeLocksByTask.size +
375386 writeLocksByTask.map(_._2.size).sum
376387 }
@@ -419,7 +430,7 @@ private[storage] class BlockInfoManager extends Logging {
419430 blockInfo.removed = true
420431 }
421432 infos.clear()
422- readLocksByTask.invalidateAll ()
433+ readLocksByTask.clear ()
423434 writeLocksByTask.clear()
424435 notifyAll()
425436 }
0 commit comments