Skip to content

Commit 0da667d

Browse files
squitoMarcelo Vanzin
authored andcommitted
[SPARK-28917][CORE] Synchronize access to RDD mutable state
RDD dependencies and partitions can be simultaneously accessed and mutated by user threads and spark's scheduler threads, so access must be thread-safe. In particular, as partitions and dependencies are lazily-initialized, before this change they could get initialized multiple times, which would lead to the scheduler having an inconsistent view of the pendings stages and get stuck. Tested with existing unit tests. Closes #25951 from squito/SPARK-28917. Authored-by: Imran Rashid <[email protected]> Signed-off-by: Marcelo Vanzin <[email protected]>
1 parent de360e9 commit 0da667d

File tree

3 files changed

+51
-11
lines changed

3 files changed

+51
-11
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,24 @@ abstract class RDD[T: ClassTag](
225225
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
226226
def getStorageLevel: StorageLevel = storageLevel
227227

228+
/**
229+
* Lock for all mutable state of this RDD (persistence, partitions, dependencies, etc.). We do
230+
* not use `this` because RDDs are user-visible, so users might have added their own locking on
231+
* RDDs; sharing that could lead to a deadlock.
232+
*
233+
* One thread might hold the lock on many of these, for a chain of RDD dependencies; but
234+
* because DAGs are acyclic, and we only ever hold locks for one path in that DAG, there is no
235+
* chance of deadlock.
236+
*
237+
* The use of Integer is simply so this is serializable -- executors may reference the shared
238+
* fields (though they should never mutate them, that only happens on the driver).
239+
*/
240+
private val stateLock = new Integer(0)
241+
228242
// Our dependencies and partitions will be gotten by calling subclass's methods below, and will
229243
// be overwritten when we're checkpointed
230-
private var dependencies_ : Seq[Dependency[_]] = _
231-
@transient private var partitions_ : Array[Partition] = _
244+
@volatile private var dependencies_ : Seq[Dependency[_]] = _
245+
@volatile @transient private var partitions_ : Array[Partition] = _
232246

233247
/** An Option holding our checkpoint RDD, if we are checkpointed */
234248
private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD)
@@ -240,7 +254,11 @@ abstract class RDD[T: ClassTag](
240254
final def dependencies: Seq[Dependency[_]] = {
241255
checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse {
242256
if (dependencies_ == null) {
243-
dependencies_ = getDependencies
257+
stateLock.synchronized {
258+
if (dependencies_ == null) {
259+
dependencies_ = getDependencies
260+
}
261+
}
244262
}
245263
dependencies_
246264
}
@@ -253,10 +271,14 @@ abstract class RDD[T: ClassTag](
253271
final def partitions: Array[Partition] = {
254272
checkpointRDD.map(_.partitions).getOrElse {
255273
if (partitions_ == null) {
256-
partitions_ = getPartitions
257-
partitions_.zipWithIndex.foreach { case (partition, index) =>
258-
require(partition.index == index,
259-
s"partitions($index).partition == ${partition.index}, but it should equal $index")
274+
stateLock.synchronized {
275+
if (partitions_ == null) {
276+
partitions_ = getPartitions
277+
partitions_.zipWithIndex.foreach { case (partition, index) =>
278+
require(partition.index == index,
279+
s"partitions($index).partition == ${partition.index}, but it should equal $index")
280+
}
281+
}
260282
}
261283
}
262284
partitions_
@@ -1788,7 +1810,7 @@ abstract class RDD[T: ClassTag](
17881810
* Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`)
17891811
* created from the checkpoint file, and forget its old dependencies and partitions.
17901812
*/
1791-
private[spark] def markCheckpointed(): Unit = {
1813+
private[spark] def markCheckpointed(): Unit = stateLock.synchronized {
17921814
clearDependencies()
17931815
partitions_ = null
17941816
deps = null // Forget the constructor argument for dependencies too
@@ -1800,7 +1822,7 @@ abstract class RDD[T: ClassTag](
18001822
* collected. Subclasses of RDD may override this method for implementing their own cleaning
18011823
* logic. See [[org.apache.spark.rdd.UnionRDD]] for an example.
18021824
*/
1803-
protected def clearDependencies(): Unit = {
1825+
protected def clearDependencies(): Unit = stateLock.synchronized {
18041826
dependencies_ = null
18051827
}
18061828

@@ -1959,6 +1981,7 @@ abstract class RDD[T: ClassTag](
19591981
deterministicLevelCandidates.maxBy(_.id)
19601982
}
19611983
}
1984+
19621985
}
19631986

19641987

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,8 @@ private[spark] class DAGScheduler(
400400
if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
401401
// Kind of ugly: need to register RDDs with the cache and map output tracker here
402402
// since we can't do it in the RDD constructor because # of partitions is unknown
403-
logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")")
403+
logInfo(s"Registering RDD ${rdd.id} (${rdd.getCreationSite}) as input to " +
404+
s"shuffle ${shuffleDep.shuffleId}")
404405
mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length)
405406
}
406407
stage
@@ -1080,7 +1081,8 @@ private[spark] class DAGScheduler(
10801081
private def submitStage(stage: Stage): Unit = {
10811082
val jobId = activeJobForStage(stage)
10821083
if (jobId.isDefined) {
1083-
logDebug("submitStage(" + stage + ")")
1084+
logDebug(s"submitStage($stage (name=${stage.name};" +
1085+
s"jobs=${stage.jobIds.toSeq.sorted.mkString(",")}))")
10841086
if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
10851087
val missing = getMissingParentStages(stage).sortBy(_.id)
10861088
logDebug("missing: " + missing)

core/src/test/scala/org/apache/spark/DistributedSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,21 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
339339
}
340340
}
341341

342+
test("reference partitions inside a task") {
343+
// Run a simple job which just makes sure there is no failure if we touch rdd.partitions
344+
// inside a task. This requires the stateLock to be serializable. This is very convoluted
345+
// use case, it's just a check for backwards-compatibility after the fix for SPARK-28917.
346+
sc = new SparkContext("local-cluster[1,1,1024]", "test")
347+
val rdd1 = sc.parallelize(1 to 10, 1)
348+
val rdd2 = rdd1.map { x => x + 1}
349+
// ensure we can force computation of rdd2.dependencies inside a task. Just touching
350+
// it will force computation and touching the stateLock. The check for null is to just
351+
// to make sure that we've setup our test correctly, and haven't precomputed dependencies
352+
// in the driver
353+
val dependencyComputeCount = rdd1.map { x => if (rdd2.dependencies == null) 1 else 0}.sum()
354+
assert(dependencyComputeCount > 0)
355+
}
356+
342357
}
343358

344359
object DistributedSuite {

0 commit comments

Comments
 (0)