Skip to content

Commit 07a401a

Browse files
JoshRosenpwendell
authored andcommitted
[SPARK-4454] Properly synchronize accesses to DAGScheduler cacheLocs map
This patch addresses a race condition in DAGScheduler by properly synchronizing accesses to its `cacheLocs` map. This map is accessed by the `getCacheLocs` and `clearCacheLocs()` methods, which can be called by separate threads, since DAGScheduler's `getPreferredLocs()` method is called by SparkContext and indirectly calls `getCacheLocs()`. If this map is cleared by the DAGScheduler event processing thread while a user thread is submitting a job and computing preferred locations, then this can cause the user thread to throw "NoSuchElementException: key not found" errors. Most accesses to DAGScheduler's internal state do not need synchronization because that state is only accessed from the event processing loop's thread. An alternative approach to fixing this bug would be to refactor this code so that SparkContext sends the DAGScheduler a message in order to get the list of preferred locations. However, this would involve more extensive changes to this code and would be significantly harder to backport to maintenance branches since some of the related code has undergone significant refactoring (e.g. the introduction of EventLoop). Since `cacheLocs` is the only state that's accessed in this way, adding simple synchronization seems like a better short-term fix. See #3345 for additional context. Author: Josh Rosen <[email protected]> Closes #4660 from JoshRosen/SPARK-4454 and squashes the following commits: 12d64ba [Josh Rosen] Properly synchronize accesses to DAGScheduler cacheLocs map. (cherry picked from commit d46d624) Signed-off-by: Patrick Wendell <[email protected]>
1 parent cb90584 commit 07a401a

File tree

1 file changed

+24
-10
lines changed

1 file changed

+24
-10
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,13 @@ class DAGScheduler(
9898

9999
private[scheduler] val activeJobs = new HashSet[ActiveJob]
100100

101-
// Contains the locations that each RDD's partitions are cached on
101+
/**
102+
* Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids
103+
* and its values are arrays indexed by partition numbers. Each array value is the set of
104+
* locations where that RDD partition is cached.
105+
*
106+
* All accesses to this map should be guarded by synchronizing on it (see SPARK-4454).
107+
*/
102108
private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]]
103109

104110
// For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with
@@ -183,18 +189,17 @@ class DAGScheduler(
183189
eventProcessLoop.post(TaskSetFailed(taskSet, reason))
184190
}
185191

186-
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
187-
if (!cacheLocs.contains(rdd.id)) {
192+
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = cacheLocs.synchronized {
193+
cacheLocs.getOrElseUpdate(rdd.id, {
188194
val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId]
189195
val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster)
190-
cacheLocs(rdd.id) = blockIds.map { id =>
196+
blockIds.map { id =>
191197
locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId))
192198
}
193-
}
194-
cacheLocs(rdd.id)
199+
})
195200
}
196201

197-
private def clearCacheLocs() {
202+
private def clearCacheLocs(): Unit = cacheLocs.synchronized {
198203
cacheLocs.clear()
199204
}
200205

@@ -1276,17 +1281,26 @@ class DAGScheduler(
12761281
}
12771282

12781283
/**
1279-
* Synchronized method that might be called from other threads.
1284+
* Gets the locality information associated with a partition of a particular RDD.
1285+
*
1286+
* This method is thread-safe and is called from both DAGScheduler and SparkContext.
1287+
*
12801288
* @param rdd whose partitions are to be looked at
12811289
* @param partition to lookup locality information for
12821290
* @return list of machines that are preferred by the partition
12831291
*/
12841292
private[spark]
1285-
def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = synchronized {
1293+
def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = {
12861294
getPreferredLocsInternal(rdd, partition, new HashSet)
12871295
}
12881296

1289-
/** Recursive implementation for getPreferredLocs. */
1297+
/**
1298+
* Recursive implementation for getPreferredLocs.
1299+
*
1300+
* This method is thread-safe because it only accesses DAGScheduler state through thread-safe
1301+
* methods (getCacheLocs()); please be careful when modifying this method, because any new
1302+
* DAGScheduler state accessed by it may require additional synchronization.
1303+
*/
12901304
private def getPreferredLocsInternal(
12911305
rdd: RDD[_],
12921306
partition: Int,

0 commit comments

Comments
 (0)