Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,24 @@ abstract class RDD[T: ClassTag](
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel: StorageLevel = storageLevel

/**
* Lock for all mutable state of this RDD (persistence, partitions, dependencies, etc.). We do
* not use `this` because RDDs are user-visible, so users might have added their own locking on
* RDDs; sharing that could lead to a deadlock.
*
* One thread might hold the lock on many of these, for a chain of RDD dependencies; but
* because DAGs are acyclic, and we only ever hold locks for one path in that DAG, there is no
* chance of deadlock.
*
* The use of Integer is simply so this is serializable -- executors may reference the shared
* fields (though they should never mutate them, that only happens on the driver).
*/
private val stateLock = new Integer(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Integer constructor has been deprecated already, see https://docs.oracle.com/javase/9/docs/api/java/lang/Integer.html . This yields the warning:

RDD.scala:240: constructor Integer in class Integer is deprecated: see corresponding Javadoc for more information.

Is it possible to replace it by something else?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to eliminate the warning in #27399


// Our dependencies and partitions will be gotten by calling subclass's methods below, and will
// be overwritten when we're checkpointed
private var dependencies_ : Seq[Dependency[_]] = _
@transient private var partitions_ : Array[Partition] = _
@volatile private var dependencies_ : Seq[Dependency[_]] = _
@volatile @transient private var partitions_ : Array[Partition] = _

/** An Option holding our checkpoint RDD, if we are checkpointed */
private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD)
Expand All @@ -240,7 +254,11 @@ abstract class RDD[T: ClassTag](
final def dependencies: Seq[Dependency[_]] = {
checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse {
if (dependencies_ == null) {
dependencies_ = getDependencies
stateLock.synchronized {
if (dependencies_ == null) {
dependencies_ = getDependencies
}
}
}
dependencies_
}
Expand All @@ -253,10 +271,14 @@ abstract class RDD[T: ClassTag](
final def partitions: Array[Partition] = {
checkpointRDD.map(_.partitions).getOrElse {
if (partitions_ == null) {
partitions_ = getPartitions
partitions_.zipWithIndex.foreach { case (partition, index) =>
require(partition.index == index,
s"partitions($index).partition == ${partition.index}, but it should equal $index")
stateLock.synchronized {
if (partitions_ == null) {
partitions_ = getPartitions
partitions_.zipWithIndex.foreach { case (partition, index) =>
require(partition.index == index,
s"partitions($index).partition == ${partition.index}, but it should equal $index")
}
}
}
}
partitions_
Expand Down Expand Up @@ -1788,7 +1810,7 @@ abstract class RDD[T: ClassTag](
* Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`)
* created from the checkpoint file, and forget its old dependencies and partitions.
*/
private[spark] def markCheckpointed(): Unit = {
private[spark] def markCheckpointed(): Unit = stateLock.synchronized {
clearDependencies()
partitions_ = null
deps = null // Forget the constructor argument for dependencies too
Expand All @@ -1800,7 +1822,7 @@ abstract class RDD[T: ClassTag](
* collected. Subclasses of RDD may override this method for implementing their own cleaning
* logic. See [[org.apache.spark.rdd.UnionRDD]] for an example.
*/
protected def clearDependencies(): Unit = {
protected def clearDependencies(): Unit = stateLock.synchronized {
dependencies_ = null
}

Expand Down Expand Up @@ -1959,6 +1981,7 @@ abstract class RDD[T: ClassTag](
deterministicLevelCandidates.maxBy(_.id)
}
}

}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,8 @@ private[spark] class DAGScheduler(
if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
// Kind of ugly: need to register RDDs with the cache and map output tracker here
// since we can't do it in the RDD constructor because # of partitions is unknown
logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")")
logInfo(s"Registering RDD ${rdd.id} (${rdd.getCreationSite}) as input to " +
s"shuffle ${shuffleDep.shuffleId}")
mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length)
}
stage
Expand Down Expand Up @@ -1080,7 +1081,8 @@ private[spark] class DAGScheduler(
private def submitStage(stage: Stage): Unit = {
val jobId = activeJobForStage(stage)
if (jobId.isDefined) {
logDebug("submitStage(" + stage + ")")
logDebug(s"submitStage($stage (name=${stage.name};" +
s"jobs=${stage.jobIds.toSeq.sorted.mkString(",")}))")
if (!waitingStages(stage) && !runningStages(stage) && !failedStages(stage)) {
val missing = getMissingParentStages(stage).sortBy(_.id)
logDebug("missing: " + missing)
Expand Down
15 changes: 15 additions & 0 deletions core/src/test/scala/org/apache/spark/DistributedSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,21 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
}
}

test("reference partitions inside a task") {
// Run a simple job which just makes sure there is no failure if we touch rdd.partitions
// inside a task. This requires the stateLock to be serializable. This is very convoluted
// use case, it's just a check for backwards-compatibility after the fix for SPARK-28917.
sc = new SparkContext("local-cluster[1,1,1024]", "test")
val rdd1 = sc.parallelize(1 to 10, 1)
val rdd2 = rdd1.map { x => x + 1}
// ensure we can force computation of rdd2.dependencies inside a task. Just touching
// it will force computation and touching the stateLock. The check for null is to just
// to make sure that we've setup our test correctly, and haven't precomputed dependencies
// in the driver
val dependencyComputeCount = rdd1.map { x => if (rdd2.dependencies == null) 1 else 0}.sum()
assert(dependencyComputeCount > 0)
}

}

object DistributedSuite {
Expand Down