1717
1818package org .apache .spark .shuffle
1919
20- import java .util .concurrent .ConcurrentLinkedQueue
20+ import java .util .concurrent .{ ConcurrentHashMap , ConcurrentLinkedQueue }
2121
2222import scala .collection .JavaConverters ._
2323
@@ -63,7 +63,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
6363 val completedMapTasks = new ConcurrentLinkedQueue [Int ]()
6464 }
6565
66- private val shuffleStates = new scala.collection.mutable. HashMap [ShuffleId , ShuffleState ]
66+ private val shuffleStates = new ConcurrentHashMap [ShuffleId , ShuffleState ]
6767
6868 /**
6969 * Get a ShuffleWriterGroup for the given map task, which will register it as complete
@@ -72,8 +72,12 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
7272 def forMapTask (shuffleId : Int , mapId : Int , numReducers : Int , serializer : Serializer ,
7373 writeMetrics : ShuffleWriteMetrics ): ShuffleWriterGroup = {
7474 new ShuffleWriterGroup {
75- private val shuffleState =
76- shuffleStates.getOrElseUpdate(shuffleId, new ShuffleState (numReducers))
75+ private val shuffleState : ShuffleState = {
76+ // Note: we do _not_ want to just wrap this java ConcurrentHashMap into a Scala map and use
77+ // .getOrElseUpdate() because that's actually NOT atomic.
78+ shuffleStates.putIfAbsent(shuffleId, new ShuffleState (numReducers))
79+ shuffleStates.get(shuffleId)
80+ }
7781 val openStartTime = System .nanoTime
7882 val serializerInstance = serializer.newInstance()
7983 val writers : Array [DiskBlockObjectWriter ] = {
@@ -110,7 +114,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
110114
111115 /** Remove all the blocks / files related to a particular shuffle. */
112116 private def removeShuffleBlocks (shuffleId : ShuffleId ): Boolean = {
113- shuffleStates.get(shuffleId) match {
117+ Option ( shuffleStates.get(shuffleId) ) match {
114118 case Some (state) =>
115119 for (mapId <- state.completedMapTasks.asScala; reduceId <- 0 until state.numReducers) {
116120 val blockId = new ShuffleBlockId (shuffleId, mapId, reduceId)
0 commit comments