Skip to content

Commit e6482fa

Browse files
committed
Thread-safety fixes.
1 parent 5ffe30f commit e6482fa

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
291291

292292
// HashMaps for storing mapStatuses and cached serialized statuses in the driver.
293293
// Statuses are dropped only by explicit de-registering.
294-
protected val mapStatuses = new HashMap[Int, Array[MapStatus]]()
295-
private val cachedSerializedStatuses = new HashMap[Int, Array[Byte]]()
294+
protected val mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala
295+
private val cachedSerializedStatuses = new ConcurrentHashMap[Int, Array[Byte]]().asScala
296296

297297
def registerShuffle(shuffleId: Int, numMaps: Int) {
298298
if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) {

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.io._
2323
import java.lang.reflect.Constructor
2424
import java.net.URI
2525
import java.util.{Arrays, Properties, UUID}
26+
import java.util.concurrent.ConcurrentMap
2627
import java.util.concurrent.atomic.{AtomicReference, AtomicBoolean, AtomicInteger}
2728
import java.util.UUID.randomUUID
2829

@@ -295,7 +296,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
295296
private[spark] val addedJars = HashMap[String, Long]()
296297

297298
// Keeps track of all persisted RDDs
298-
private[spark] val persistentRdds = new MapMaker().weakValues().makeMap[Int, RDD[_]]().asScala
299+
private[spark] val persistentRdds = {
300+
val map : ConcurrentMap[Int, RDD[_]] = new MapMaker().weakValues().makeMap[Int, RDD[_]]()
301+
map.asScala
302+
}
299303
private[spark] def jobProgressListener: JobProgressListener = _jobProgressListener
300304

301305
def statusTracker: SparkStatusTracker = _statusTracker

core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.shuffle
1919

20-
import java.util.concurrent.ConcurrentLinkedQueue
20+
import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
2121

2222
import scala.collection.JavaConverters._
2323

@@ -63,7 +63,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
6363
val completedMapTasks = new ConcurrentLinkedQueue[Int]()
6464
}
6565

66-
private val shuffleStates = new scala.collection.mutable.HashMap[ShuffleId, ShuffleState]
66+
private val shuffleStates = new ConcurrentHashMap[ShuffleId, ShuffleState]
6767

6868
/**
6969
* Get a ShuffleWriterGroup for the given map task, which will register it as complete
@@ -72,8 +72,12 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
7272
def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer,
7373
writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = {
7474
new ShuffleWriterGroup {
75-
private val shuffleState =
76-
shuffleStates.getOrElseUpdate(shuffleId, new ShuffleState(numReducers))
75+
private val shuffleState: ShuffleState = {
76+
// Note: we do _not_ want to just wrap this java ConcurrentHashMap into a Scala map and use
77+
// .getOrElseUpdate() because that's actually NOT atomic.
78+
shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers))
79+
shuffleStates.get(shuffleId)
80+
}
7781
val openStartTime = System.nanoTime
7882
val serializerInstance = serializer.newInstance()
7983
val writers: Array[DiskBlockObjectWriter] = {
@@ -110,7 +114,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
110114

111115
/** Remove all the blocks / files related to a particular shuffle. */
112116
private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = {
113-
shuffleStates.get(shuffleId) match {
117+
Option(shuffleStates.get(shuffleId)) match {
114118
case Some(state) =>
115119
for (mapId <- state.completedMapTasks.asScala; reduceId <- 0 until state.numReducers) {
116120
val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId)

0 commit comments

Comments
 (0)