diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 08fc309d5238..5788b70e75a7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -225,10 +225,24 @@ abstract class RDD[T: ClassTag]( /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel: StorageLevel = storageLevel + /** + * Lock for all mutable state of this RDD (persistence, partitions, dependencies, etc.). We do + * not use `this` because RDDs are user-visible, so users might have added their own locking on + * RDDs; sharing that could lead to a deadlock. + * + * One thread might hold the lock on many of these, for a chain of RDD dependencies; but + * because DAGs are acyclic, and we only ever hold locks for one path in that DAG, there is no + * chance of deadlock. + * + * The use of Integer is simply so this is serializable -- executors may reference the shared + * fields (though they should never mutate them, that only happens on the driver). + */ + private val stateLock = new Integer(0) + // Our dependencies and partitions will be gotten by calling subclass's methods below, and will // be overwritten when we're checkpointed - private var dependencies_ : Seq[Dependency[_]] = _ - @transient private var partitions_ : Array[Partition] = _ + @volatile private var dependencies_ : Seq[Dependency[_]] = _ + @volatile @transient private var partitions_ : Array[Partition] = _ /** An Option holding our checkpoint RDD, if we are checkpointed */ private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD) @@ -240,7 +254,11 @@ abstract class RDD[T: ClassTag]( final def dependencies: Seq[Dependency[_]] = { checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse { if (dependencies_ == null) { - dependencies_ = getDependencies + stateLock.synchronized { + if (dependencies_ == null) { + dependencies_ = getDependencies + } + } } dependencies_ } @@ -253,10 +271,14 @@ abstract class RDD[T: ClassTag]( final def partitions: Array[Partition] = { checkpointRDD.map(_.partitions).getOrElse { if (partitions_ == null) { - partitions_ = getPartitions - partitions_.zipWithIndex.foreach { case (partition, index) => - require(partition.index == index, - s"partitions($index).partition == ${partition.index}, but it should equal $index") + stateLock.synchronized { + if (partitions_ == null) { + partitions_ = getPartitions + partitions_.zipWithIndex.foreach { case (partition, index) => + require(partition.index == index, + s"partitions($index).partition == ${partition.index}, but it should equal $index") + } + } } } partitions_ @@ -1788,7 +1810,7 @@ abstract class RDD[T: ClassTag]( * Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`) * created from the checkpoint file, and forget its old dependencies and partitions. */ - private[spark] def markCheckpointed(): Unit = { + private[spark] def markCheckpointed(): Unit = stateLock.synchronized { clearDependencies() partitions_ = null deps = null // Forget the constructor argument for dependencies too @@ -1800,7 +1822,7 @@ abstract class RDD[T: ClassTag]( * collected. Subclasses of RDD may override this method for implementing their own cleaning * logic. See [[org.apache.spark.rdd.UnionRDD]] for an example. */ - protected def clearDependencies(): Unit = { + protected def clearDependencies(): Unit = stateLock.synchronized { dependencies_ = null } @@ -1959,6 +1981,7 @@ abstract class RDD[T: ClassTag]( deterministicLevelCandidates.maxBy(_.id) } } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 81e0543ccefe..c3e1cd8b23f1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -400,7 +400,8 @@ private[spark] class DAGScheduler( if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { // Kind of ugly: need to register RDDs with the cache and map output tracker here // since we can't do it in the RDD constructor because # of partitions is unknown - logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")") + logInfo(s"Registering RDD ${rdd.id} (${rdd.getCreationSite}) as input to " + + s"shuffle ${shuffleDep.shuffleId}") mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length) } stage @@ -1080,7 +1081,8 @@ private[spark] class DAGScheduler( private def submitStage(stage: Stage): Unit = { val jobId = activeJobForStage(stage) if (jobId.isDefined) { - logDebug("submitStage(" + stage + ")") + logDebug(s"submitStage($stage (name=${stage.name};" + + s"jobs=${stage.jobIds.toSeq.sorted.mkString(",")}))") if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) { val missing = getMissingParentStages(stage).sortBy(_.id) logDebug("missing: " + missing) diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index aad20545bafb..f1c2bc0677c8 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -339,6 +339,21 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex } } + test("reference partitions inside a task") { + // Run a simple job which just makes sure there is no failure if we touch rdd.partitions + // inside a task. This requires the stateLock to be serializable. This is very convoluted + // use case, it's just a check for backwards-compatibility after the fix for SPARK-28917. + sc = new SparkContext("local-cluster[1,1,1024]", "test") + val rdd1 = sc.parallelize(1 to 10, 1) + val rdd2 = rdd1.map { x => x + 1} + // ensure we can force computation of rdd2.dependencies inside a task. Just touching + // it will force computation and touching the stateLock. The check for null is to just + // to make sure that we've setup our test correctly, and haven't precomputed dependencies + // in the driver + val dependencyComputeCount = rdd1.map { x => if (rdd2.dependencies == null) 1 else 0}.sum() + assert(dependencyComputeCount > 0) + } + } object DistributedSuite {