@@ -19,12 +19,14 @@ package org.apache.spark.storage
1919
2020import java .io ._
2121import java .nio .{ByteBuffer , MappedByteBuffer }
22+ import java .util .concurrent .ConcurrentHashMap
2223
2324import scala .collection .mutable .{ArrayBuffer , HashMap }
2425import scala .concurrent .duration ._
2526import scala .concurrent .{Await , ExecutionContext , Future }
2627import scala .util .Random
2728import scala .util .control .NonFatal
29+ import scala .collection .JavaConverters ._
2830
2931import sun .nio .ch .DirectBuffer
3032
@@ -65,7 +67,7 @@ private[spark] class BlockManager(
6567 val master : BlockManagerMaster ,
6668 defaultSerializer : Serializer ,
6769 val conf : SparkConf ,
68- memoryManager : MemoryManager ,
70+ val memoryManager : MemoryManager ,
6971 mapOutputTracker : MapOutputTracker ,
7072 shuffleManager : ShuffleManager ,
7173 blockTransferService : BlockTransferService ,
@@ -163,6 +165,11 @@ private[spark] class BlockManager(
163165 * Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been
164166 * loaded yet. */
165167 private lazy val compressionCodec : CompressionCodec = CompressionCodec .createCodec(conf)
168+
169+ // Blocks are removing by another thread
170+ private val pendingToRemove = new ConcurrentHashMap [BlockId , Long ]()
171+
172+ private val NON_TASK_WRITER = - 1024L
166173
167174 /**
168175 * Initializes the BlockManager with the given appId. This is not performed in the constructor as
@@ -1025,7 +1032,7 @@ private[spark] class BlockManager(
10251032 val info = blockInfo.get(blockId).orNull
10261033
10271034 // If the block has not already been dropped
1028- if (info != null ) {
1035+ if (info != null && ! pendingToRemove.containsKey(blockId) ) {
10291036 info.synchronized {
10301037 // required ? As of now, this will be invoked only for blocks which are ready
10311038 // But in case this changes in future, adding for consistency sake.
@@ -1051,11 +1058,13 @@ private[spark] class BlockManager(
10511058 }
10521059 blockIsUpdated = true
10531060 }
1061+ pendingToRemove.put(blockId, currentTaskAttemptId)
10541062
10551063 // Actually drop from memory store
10561064 val droppedMemorySize =
10571065 if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L
10581066 val blockIsRemoved = memoryStore.remove(blockId)
1067+ pendingToRemove.remove(blockId)
10591068 if (blockIsRemoved) {
10601069 blockIsUpdated = true
10611070 } else {
@@ -1080,6 +1089,7 @@ private[spark] class BlockManager(
10801089
10811090 /**
10821091 * Remove all blocks belonging to the given RDD.
1092+ *
10831093 * @return The number of blocks removed.
10841094 */
10851095 def removeRdd (rddId : Int ): Int = {
@@ -1108,11 +1118,14 @@ private[spark] class BlockManager(
11081118 def removeBlock (blockId : BlockId , tellMaster : Boolean = true ): Unit = {
11091119 logDebug(s " Removing block $blockId" )
11101120 val info = blockInfo.get(blockId).orNull
1111- if (info != null ) {
1121+ if (info != null && ! pendingToRemove.containsKey(blockId)) {
1122+ pendingToRemove.put(blockId, currentTaskAttemptId)
11121123 info.synchronized {
1124+ val level = info.level
11131125 // Removals are idempotent in disk store and memory store. At worst, we get a warning.
1114- val removedFromMemory = memoryStore.remove(blockId)
1115- val removedFromDisk = diskStore.remove(blockId)
1126+ val removedFromMemory = if (level.useMemory) memoryStore.remove(blockId) else false
1127+ pendingToRemove.remove(blockId)
1128+ val removedFromDisk = if (level.useDisk) diskStore.remove(blockId) else false
11161129 val removedFromExternalBlockStore =
11171130 if (externalBlockStoreInitialized) externalBlockStore.remove(blockId) else false
11181131 if (! removedFromMemory && ! removedFromDisk && ! removedFromExternalBlockStore) {
@@ -1147,9 +1160,11 @@ private[spark] class BlockManager(
11471160 val entry = iterator.next()
11481161 val (id, info, time) = (entry.getKey, entry.getValue.value, entry.getValue.timestamp)
11491162 if (time < cleanupTime && shouldDrop(id)) {
1163+ pendingToRemove.put(id, currentTaskAttemptId)
11501164 info.synchronized {
11511165 val level = info.level
11521166 if (level.useMemory) { memoryStore.remove(id) }
1167+ pendingToRemove.remove(id)
11531168 if (level.useDisk) { diskStore.remove(id) }
11541169 if (level.useOffHeap) { externalBlockStore.remove(id) }
11551170 iterator.remove()
@@ -1160,6 +1175,28 @@ private[spark] class BlockManager(
11601175 }
11611176 }
11621177 }
1178+
1179+ private def currentTaskAttemptId : Long = {
1180+ Option (TaskContext .get()).map(_.taskAttemptId()).getOrElse(NON_TASK_WRITER )
1181+ }
1182+
1183+ /**
1184+ * Release all lock held by the given task, clearing that task's pin bookkeeping
1185+ * structures and updating the global pin counts. This method should be called at the
1186+ * end of a task (either by a task completion handler or in `TaskRunner.run()`).
1187+ *
1188+ * @return the ids of blocks whose pins were released
1189+ */
1190+ def releaseAllLocksForTask (taskAttemptId : Long ): ArrayBuffer [BlockId ] = {
1191+ var selectLocks = ArrayBuffer [BlockId ]()
1192+ pendingToRemove.entrySet().asScala.foreach { entry =>
1193+ if (entry.getValue == taskAttemptId) {
1194+ pendingToRemove.remove(entry.getKey)
1195+ selectLocks += entry.getKey
1196+ }
1197+ }
1198+ selectLocks
1199+ }
11631200
11641201 private def shouldCompress (blockId : BlockId ): Boolean = {
11651202 blockId match {
@@ -1234,6 +1271,7 @@ private[spark] class BlockManager(
12341271 rpcEnv.stop(slaveEndpoint)
12351272 blockInfo.clear()
12361273 memoryStore.clear()
1274+ pendingToRemove.clear()
12371275 diskStore.clear()
12381276 if (externalBlockStoreInitialized) {
12391277 externalBlockStore.clear()
0 commit comments