Skip to content

Commit f83d1a7

Browse files
committed
[SPARK-28917][CORE] Synchronize access to RDD mutable state.
RDD dependencies, partitions, and storageLevel can be simultaneously accessed and mutated by user threads and spark's scheduler threads, so access must be thread-safe. In particular, as partitions and dependencies are lazily-initiliazed, before this change they could get initialized multiple times, which would lead to the scheduler having an inconsistent view of the pendings stages and get stuck.
1 parent 8beb736 commit f83d1a7

File tree

2 files changed

+41
-17
lines changed

2 files changed

+41
-17
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,17 @@ abstract class RDD[T: ClassTag](
103103
_sc
104104
}
105105

106+
/**
107+
* Lock for all mutable state of this RDD (persistence, partitions, dependencies, etc.). We do
108+
* not use `this` because RDDs are user-visible, so users might have added their own locking on
109+
* RDDs; sharing that could lead to a deadlock.
110+
*
111+
* One thread might hold the lock on many of these, for a chain of RDD dependencies; but
112+
* because DAGs are acyclic, and we only ever hold locks for one path in that DAG, there is no
113+
* chance of deadlock.
114+
*/
115+
private val stateLock = new Object()
116+
106117
/** Construct an RDD with just a one-to-one dependency on one parent */
107118
def this(@transient oneParent: RDD[_]) =
108119
this(oneParent.context, List(new OneToOneDependency(oneParent)))
@@ -167,7 +178,9 @@ abstract class RDD[T: ClassTag](
167178
* @param newLevel the target storage level
168179
* @param allowOverride whether to override any existing level with the new one
169180
*/
170-
private def persist(newLevel: StorageLevel, allowOverride: Boolean): this.type = {
181+
private def persist(
182+
newLevel: StorageLevel,
183+
allowOverride: Boolean): this.type = stateLock.synchronized {
171184
// TODO: Handle changes of StorageLevel
172185
if (storageLevel != StorageLevel.NONE && newLevel != storageLevel && !allowOverride) {
173186
throw new UnsupportedOperationException(
@@ -223,12 +236,12 @@ abstract class RDD[T: ClassTag](
223236
}
224237

225238
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
226-
def getStorageLevel: StorageLevel = storageLevel
239+
def getStorageLevel: StorageLevel = stateLock.synchronized { storageLevel }
227240

228241
// Our dependencies and partitions will be gotten by calling subclass's methods below, and will
229242
// be overwritten when we're checkpointed
230-
private var dependencies_ : Seq[Dependency[_]] = _
231-
@transient private var partitions_ : Array[Partition] = _
243+
@volatile private var dependencies_ : Seq[Dependency[_]] = _
244+
@volatile @transient private var partitions_ : Array[Partition] = _
232245

233246
/** An Option holding our checkpoint RDD, if we are checkpointed */
234247
private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD)
@@ -240,7 +253,11 @@ abstract class RDD[T: ClassTag](
240253
final def dependencies: Seq[Dependency[_]] = {
241254
checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse {
242255
if (dependencies_ == null) {
243-
dependencies_ = getDependencies
256+
stateLock.synchronized {
257+
if (dependencies_ == null) {
258+
dependencies_ = getDependencies
259+
}
260+
}
244261
}
245262
dependencies_
246263
}
@@ -253,10 +270,14 @@ abstract class RDD[T: ClassTag](
253270
final def partitions: Array[Partition] = {
254271
checkpointRDD.map(_.partitions).getOrElse {
255272
if (partitions_ == null) {
256-
partitions_ = getPartitions
257-
partitions_.zipWithIndex.foreach { case (partition, index) =>
258-
require(partition.index == index,
259-
s"partitions($index).partition == ${partition.index}, but it should equal $index")
273+
stateLock.synchronized {
274+
if (partitions_ == null) {
275+
partitions_ = getPartitions
276+
partitions_.zipWithIndex.foreach { case (partition, index) =>
277+
require(partition.index == index,
278+
s"partitions($index).partition == ${partition.index}, but it should equal $index")
279+
}
280+
}
260281
}
261282
}
262283
partitions_
@@ -285,7 +306,7 @@ abstract class RDD[T: ClassTag](
285306
* subclasses of RDD.
286307
*/
287308
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
288-
if (storageLevel != StorageLevel.NONE) {
309+
if (getStorageLevel != StorageLevel.NONE) {
289310
getOrCompute(split, context)
290311
} else {
291312
computeOrReadCheckpoint(split, context)
@@ -335,7 +356,7 @@ abstract class RDD[T: ClassTag](
335356
val blockId = RDDBlockId(id, partition.index)
336357
var readCachedBlock = true
337358
// This method is called on executors, so we need call SparkEnv.get instead of sc.env.
338-
SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
359+
SparkEnv.get.blockManager.getOrElseUpdate(blockId, getStorageLevel, elementClassTag, () => {
339360
readCachedBlock = false
340361
computeOrReadCheckpoint(partition, context)
341362
}) match {
@@ -1606,10 +1627,12 @@ abstract class RDD[T: ClassTag](
16061627
// the storage level he/she specified to one that is appropriate for local checkpointing
16071628
// (i.e. uses disk) to guarantee correctness.
16081629

1609-
if (storageLevel == StorageLevel.NONE) {
1610-
persist(LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL)
1611-
} else {
1612-
persist(LocalRDDCheckpointData.transformStorageLevel(storageLevel), allowOverride = true)
1630+
stateLock.synchronized {
1631+
if (storageLevel == StorageLevel.NONE) {
1632+
persist(LocalRDDCheckpointData.DEFAULT_STORAGE_LEVEL)
1633+
} else {
1634+
persist(LocalRDDCheckpointData.transformStorageLevel(storageLevel), allowOverride = true)
1635+
}
16131636
}
16141637

16151638
// If this RDD is already checkpointed and materialized, its lineage is already truncated.
@@ -1807,7 +1830,7 @@ abstract class RDD[T: ClassTag](
18071830
/** A description of this RDD and its recursive dependencies for debugging. */
18081831
def toDebugString: String = {
18091832
// Get a debug description of an rdd without its children
1810-
def debugSelf(rdd: RDD[_]): Seq[String] = {
1833+
def debugSelf(rdd: RDD[_]): Seq[String] = stateLock.synchronized {
18111834
import Utils.bytesToString
18121835

18131836
val persistence = if (storageLevel != StorageLevel.NONE) storageLevel.description else ""

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,8 @@ private[spark] class DAGScheduler(
400400
if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
401401
// Kind of ugly: need to register RDDs with the cache and map output tracker here
402402
// since we can't do it in the RDD constructor because # of partitions is unknown
403-
logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")")
403+
logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + s") as input to " +
404+
s"shuffle ${shuffleDep.shuffleId}")
404405
mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length)
405406
}
406407
stage

0 commit comments

Comments
 (0)