Skip to content

Commit b9d6e18

Browse files
committed
Require tasks to explicitly register themselves with the BlockManager.
1 parent a5ef11b commit b9d6e18

File tree

5 files changed

+56
-31
lines changed

5 files changed

+56
-31
lines changed

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ private[spark] class Executor(
207207
logDebug("Task " + taskId + "'s epoch is " + task.epoch)
208208
env.mapOutputTracker.updateEpoch(task.epoch)
209209

210+
env.blockManager.registerTask(taskId)
211+
210212
// Run the actual task and measure its runtime.
211213
taskStart = System.currentTimeMillis()
212214
var threwException = true

core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,11 @@
1717

1818
package org.apache.spark.storage
1919

20-
import java.lang
2120
import javax.annotation.concurrent.GuardedBy
2221

2322
import scala.collection.JavaConverters._
2423
import scala.collection.mutable
2524

26-
import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
2725
import com.google.common.collect.ConcurrentHashMultiset
2826

2927
import org.apache.spark.{Logging, SparkException, TaskContext}
@@ -141,18 +139,29 @@ private[storage] class BlockInfoManager extends Logging {
141139

142140
/**
143141
* Tracks the set of blocks that each task has locked for reading, along with the number of times
144-
* that a block has been locked (since our read locks are re-entrant). This is thread-safe.
142+
* that a block has been locked (since our read locks are re-entrant).
145143
*/
146-
private[this] val readLocksByTask: LoadingCache[lang.Long, ConcurrentHashMultiset[BlockId]] = {
147-
// We need to explicitly box as java.lang.Long to avoid a type mismatch error:
148-
val loader = new CacheLoader[java.lang.Long, ConcurrentHashMultiset[BlockId]] {
149-
override def load(t: java.lang.Long) = ConcurrentHashMultiset.create[BlockId]()
150-
}
151-
CacheBuilder.newBuilder().build(loader)
152-
}
144+
@GuardedBy("this")
145+
private[this] val readLocksByTask =
146+
new mutable.HashMap[TaskAttemptId, ConcurrentHashMultiset[BlockId]]
147+
148+
// ----------------------------------------------------------------------------------------------
149+
150+
// Initialization for special task attempt ids:
151+
registerTask(BlockInfo.NON_TASK_WRITER)
153152

154153
// ----------------------------------------------------------------------------------------------
155154

155+
/**
156+
* Called at the start of a task in order to register that task with this [[BlockInfoManager]].
157+
* This must be called prior to calling any other BlockInfoManager methods from that task.
158+
*/
159+
def registerTask(taskAttemptId: TaskAttemptId): Unit = {
160+
require(!readLocksByTask.contains(taskAttemptId),
161+
s"Task attempt $taskAttemptId is already registered")
162+
readLocksByTask(taskAttemptId) = ConcurrentHashMultiset.create()
163+
}
164+
156165
/**
157166
* Returns the current task's task attempt id (which uniquely identifies the task), or
158167
* [[BlockInfo.NON_TASK_WRITER]] if called by a non-task thread.
@@ -284,7 +293,7 @@ private[storage] class BlockInfoManager extends Logging {
284293
} else {
285294
assert(info.readerCount > 0, s"Block $blockId is not locked for reading")
286295
info.readerCount -= 1
287-
val countsForTask = readLocksByTask.get(currentTaskAttemptId)
296+
val countsForTask = readLocksByTask(currentTaskAttemptId)
288297
val newPinCountForTask: Int = countsForTask.remove(blockId, 1) - 1
289298
assert(newPinCountForTask >= 0,
290299
s"Task $currentTaskAttemptId release lock on block $blockId more times than it acquired it")
@@ -325,20 +334,21 @@ private[storage] class BlockInfoManager extends Logging {
325334
*/
326335
def releaseAllLocksForTask(taskAttemptId: TaskAttemptId): Seq[BlockId] = {
327336
val blocksWithReleasedLocks = mutable.ArrayBuffer[BlockId]()
328-
synchronized {
329-
writeLocksByTask.remove(taskAttemptId).foreach { locks =>
330-
for (blockId <- locks) {
331-
infos.get(blockId).foreach { info =>
332-
assert(info.writerTask == taskAttemptId)
333-
info.writerTask = BlockInfo.NO_WRITER
334-
}
335-
blocksWithReleasedLocks += blockId
336-
}
337+
338+
val readLocks = synchronized {
339+
readLocksByTask.remove(taskAttemptId).get
340+
}
341+
val writeLocks = synchronized {
342+
writeLocksByTask.remove(taskAttemptId).getOrElse(Seq.empty)
343+
}
344+
345+
for (blockId <- writeLocks) {
346+
infos.get(blockId).foreach { info =>
347+
assert(info.writerTask == taskAttemptId)
348+
info.writerTask = BlockInfo.NO_WRITER
337349
}
338-
notifyAll()
350+
blocksWithReleasedLocks += blockId
339351
}
340-
val readLocks = readLocksByTask.get(taskAttemptId)
341-
readLocksByTask.invalidate(taskAttemptId)
342352
readLocks.entrySet().iterator().asScala.foreach { entry =>
343353
val blockId = entry.getElement
344354
val lockCount = entry.getCount
@@ -350,6 +360,7 @@ private[storage] class BlockInfoManager extends Logging {
350360
}
351361
}
352362
}
363+
353364
synchronized {
354365
notifyAll()
355366
}
@@ -369,8 +380,8 @@ private[storage] class BlockInfoManager extends Logging {
369380
*/
370381
private[storage] def getNumberOfMapEntries: Long = synchronized {
371382
size +
372-
readLocksByTask.size() +
373-
readLocksByTask.asMap().asScala.map(_._2.size()).sum +
383+
readLocksByTask.size +
384+
readLocksByTask.map(_._2.size()).sum +
374385
writeLocksByTask.size +
375386
writeLocksByTask.map(_._2.size).sum
376387
}
@@ -419,7 +430,7 @@ private[storage] class BlockInfoManager extends Logging {
419430
blockInfo.removed = true
420431
}
421432
infos.clear()
422-
readLocksByTask.invalidateAll()
433+
readLocksByTask.clear()
423434
writeLocksByTask.clear()
424435
notifyAll()
425436
}

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,13 @@ private[spark] class BlockManager(
631631
blockInfoManager.unlock(blockId)
632632
}
633633

634+
/**
635+
* Registers a task with the BlockManager in order to initialize per-task bookkeeping structures.
636+
*/
637+
def registerTask(taskAttemptId: Long): Unit = {
638+
blockInfoManager.registerTask(taskAttemptId)
639+
}
640+
634641
/**
635642
* Release all locks for the given task.
636643
*

core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach {
3434
override protected def beforeEach(): Unit = {
3535
super.beforeEach()
3636
blockInfoManager = new BlockInfoManager()
37+
for (t <- 0 to 4) {
38+
blockInfoManager.registerTask(t)
39+
}
3740
}
3841

3942
override protected def afterEach(): Unit = {
@@ -62,7 +65,6 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach {
6265
}
6366

6467
test("initial memory usage") {
65-
assert(blockInfoManager.getNumberOfMapEntries === 0)
6668
assert(blockInfoManager.size === 0)
6769
}
6870

@@ -72,7 +74,8 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach {
7274
assert(blockInfoManager.lockForWriting("non-existent-block").isEmpty)
7375
}
7476

75-
test("basic putAndLockForWritingIfAbsent") {
77+
test("basic lockNewBlockForWriting") {
78+
val initialNumMapEntries = blockInfoManager.getNumberOfMapEntries
7679
val blockInfo = newBlockInfo()
7780
withTaskId(1) {
7881
assert(blockInfoManager.lockNewBlockForWriting("block", blockInfo))
@@ -86,7 +89,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach {
8689
assert(blockInfo.writerTask === BlockInfo.NO_WRITER)
8790
}
8891
assert(blockInfoManager.size === 1)
89-
assert(blockInfoManager.getNumberOfMapEntries === 1)
92+
assert(blockInfoManager.getNumberOfMapEntries === initialNumMapEntries + 1)
9093
}
9194

9295
test("read locks are reentrant") {
@@ -273,11 +276,12 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach {
273276
}
274277

275278
test("releaseAllLocksForTask releases write locks") {
279+
val initialNumMapEntries = blockInfoManager.getNumberOfMapEntries
276280
withTaskId(0) {
277281
assert(blockInfoManager.lockNewBlockForWriting("block", newBlockInfo()))
278282
}
279-
assert(blockInfoManager.getNumberOfMapEntries === 3)
283+
assert(blockInfoManager.getNumberOfMapEntries === initialNumMapEntries + 3)
280284
blockInfoManager.releaseAllLocksForTask(0)
281-
assert(blockInfoManager.getNumberOfMapEntries === 1)
285+
assert(blockInfoManager.getNumberOfMapEntries === initialNumMapEntries)
282286
}
283287
}

core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
865865

866866
test("updated block statuses") {
867867
store = makeBlockManager(12000)
868+
store.registerTask(0)
868869
val list = List.fill(2)(new Array[Byte](2000))
869870
val bigList = List.fill(8)(new Array[Byte](2000))
870871

0 commit comments

Comments
 (0)