@@ -103,6 +103,17 @@ abstract class RDD[T: ClassTag](
103103 _sc
104104 }
105105
106+ /**
107+ * Lock for all mutable state of this RDD (persistence, partitions, dependencies, etc.). We do
108+ * not use `this` because RDDs are user-visible, so users might have added their own locking on
109+ * RDDs; sharing that could lead to a deadlock.
110+ *
111+ * One thread might hold the lock on many of these, for a chain of RDD dependencies; but
112+ * because DAGs are acyclic, and we only ever hold locks for one path in that DAG, there is no
113+ * chance of deadlock.
114+ */
115+ private val stateLock = new Object ()
116+
106117 /** Construct an RDD with just a one-to-one dependency on one parent */
107118 def this (@ transient oneParent : RDD [_]) =
108119 this (oneParent.context, List (new OneToOneDependency (oneParent)))
@@ -167,7 +178,9 @@ abstract class RDD[T: ClassTag](
167178 * @param newLevel the target storage level
168179 * @param allowOverride whether to override any existing level with the new one
169180 */
170- private def persist (newLevel : StorageLevel , allowOverride : Boolean ): this .type = {
181+ private def persist (
182+ newLevel : StorageLevel ,
183+ allowOverride : Boolean ): this .type = stateLock.synchronized {
171184 // TODO: Handle changes of StorageLevel
172185 if (storageLevel != StorageLevel .NONE && newLevel != storageLevel && ! allowOverride) {
173186 throw new UnsupportedOperationException (
@@ -223,12 +236,12 @@ abstract class RDD[T: ClassTag](
223236 }
224237
225238 /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
226- def getStorageLevel : StorageLevel = storageLevel
239+ def getStorageLevel : StorageLevel = stateLock. synchronized { storageLevel }
227240
228241 // Our dependencies and partitions will be gotten by calling subclass's methods below, and will
229242 // be overwritten when we're checkpointed
230- private var dependencies_ : Seq [Dependency [_]] = _
231- @ transient private var partitions_ : Array [Partition ] = _
243+ @ volatile private var dependencies_ : Seq [Dependency [_]] = _
244+ @ volatile @ transient private var partitions_ : Array [Partition ] = _
232245
233246 /** An Option holding our checkpoint RDD, if we are checkpointed */
234247 private def checkpointRDD : Option [CheckpointRDD [T ]] = checkpointData.flatMap(_.checkpointRDD)
@@ -240,7 +253,11 @@ abstract class RDD[T: ClassTag](
240253 final def dependencies : Seq [Dependency [_]] = {
241254 checkpointRDD.map(r => List (new OneToOneDependency (r))).getOrElse {
242255 if (dependencies_ == null ) {
243- dependencies_ = getDependencies
256+ stateLock.synchronized {
257+ if (dependencies_ == null ) {
258+ dependencies_ = getDependencies
259+ }
260+ }
244261 }
245262 dependencies_
246263 }
@@ -253,10 +270,14 @@ abstract class RDD[T: ClassTag](
253270 final def partitions : Array [Partition ] = {
254271 checkpointRDD.map(_.partitions).getOrElse {
255272 if (partitions_ == null ) {
256- partitions_ = getPartitions
257- partitions_.zipWithIndex.foreach { case (partition, index) =>
258- require(partition.index == index,
259- s " partitions( $index).partition == ${partition.index}, but it should equal $index" )
273+ stateLock.synchronized {
274+ if (partitions_ == null ) {
275+ partitions_ = getPartitions
276+ partitions_.zipWithIndex.foreach { case (partition, index) =>
277+ require(partition.index == index,
278+ s " partitions( $index).partition == ${partition.index}, but it should equal $index" )
279+ }
280+ }
260281 }
261282 }
262283 partitions_
@@ -285,7 +306,7 @@ abstract class RDD[T: ClassTag](
285306 * subclasses of RDD.
286307 */
287308 final def iterator (split : Partition , context : TaskContext ): Iterator [T ] = {
288- if (storageLevel != StorageLevel .NONE ) {
309+ if (getStorageLevel != StorageLevel .NONE ) {
289310 getOrCompute(split, context)
290311 } else {
291312 computeOrReadCheckpoint(split, context)
@@ -335,7 +356,7 @@ abstract class RDD[T: ClassTag](
335356 val blockId = RDDBlockId (id, partition.index)
336357 var readCachedBlock = true
337358 // This method is called on executors, so we need call SparkEnv.get instead of sc.env.
338- SparkEnv .get.blockManager.getOrElseUpdate(blockId, storageLevel , elementClassTag, () => {
359+ SparkEnv .get.blockManager.getOrElseUpdate(blockId, getStorageLevel , elementClassTag, () => {
339360 readCachedBlock = false
340361 computeOrReadCheckpoint(partition, context)
341362 }) match {
@@ -1606,10 +1627,12 @@ abstract class RDD[T: ClassTag](
16061627 // the storage level he/she specified to one that is appropriate for local checkpointing
16071628 // (i.e. uses disk) to guarantee correctness.
16081629
1609- if (storageLevel == StorageLevel .NONE ) {
1610- persist(LocalRDDCheckpointData .DEFAULT_STORAGE_LEVEL )
1611- } else {
1612- persist(LocalRDDCheckpointData .transformStorageLevel(storageLevel), allowOverride = true )
1630+ stateLock.synchronized {
1631+ if (storageLevel == StorageLevel .NONE ) {
1632+ persist(LocalRDDCheckpointData .DEFAULT_STORAGE_LEVEL )
1633+ } else {
1634+ persist(LocalRDDCheckpointData .transformStorageLevel(storageLevel), allowOverride = true )
1635+ }
16131636 }
16141637
16151638 // If this RDD is already checkpointed and materialized, its lineage is already truncated.
@@ -1807,7 +1830,7 @@ abstract class RDD[T: ClassTag](
18071830 /** A description of this RDD and its recursive dependencies for debugging. */
18081831 def toDebugString : String = {
18091832 // Get a debug description of an rdd without its children
1810- def debugSelf (rdd : RDD [_]): Seq [String ] = {
1833+ def debugSelf (rdd : RDD [_]): Seq [String ] = stateLock. synchronized {
18111834 import Utils .bytesToString
18121835
18131836 val persistence = if (storageLevel != StorageLevel .NONE ) storageLevel.description else " "
0 commit comments