Skip to content

Commit c4e975e

Browse files
committed
[SPARK-23626][CORE] Eagerly compute RDD.partitions on entire DAG when submitting job to DAGScheduler
### What changes were proposed in this pull request? This PR fixes a longstanding issue where the `DAGScheduler'`s single-threaded event processing loop could become blocked by slow `RDD.getPartitions()` calls, preventing other events (like task completions and concurrent job submissions) from being processed in a timely manner. With this patch's change, Spark will now call `.partitions` on every RDD in the DAG before submitting a job to the scheduler, ensuring that the expensive `getPartitions()` calls occur outside of the scheduler event loop. #### Background The `RDD.partitions` method lazily computes an RDD's partitions by calling `RDD.getPartitions()`. The `getPartitions()` method is invoked only once per RDD and its result is cached in the `RDD.partitions_` private field. Sometimes the `getPartitions()` call can be expensive: for example, `HadoopRDD.getPartitions()` performs file listing operations. The `.partitions` method is invoked at many different places in Spark's code, including many existing call sites that are outside of the scheduler event loop. As a result, it's _often_ the case that an RDD's partitions will have been computed before the RDD is submitted to the DAGScheduler. For example, [`submitJob` calls `rdd.partitions.length`](https://github.com/apache/spark/blob/3ba57f5edc5594ee676249cd309b8f0d8248462e/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L837), so the DAG root's partitions will be computed outside of the scheduler event loop. However, there's still some cases where `partitions` gets evaluated for the first time inside of the `DAGScheduler` internals. For example, [`ShuffledRDD.getPartitions`](https://github.com/apache/spark/blob/3ba57f5edc5594ee676249cd309b8f0d8248462e/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala#L92-L94) doesn't call `.partitions` on the RDD being shuffled, so a plan with a ShuffledRDD at the root won't necessarily result in `.partitions` having been called on all RDDs prior to scheduler job submission. #### Correctness: proving that we make no excess `.partitions` calls This PR adds code to traverse the DAG prior to job submission and call `.partitions` on every RDD encountered. I'd like to argue that this results in no _excess_ `.partitions` calls: in every case where the new code calls `.partitions` there is existing code which would have called `.partitions` at some point during a successful job execution: - Assume that this is the first time we are computing every RDD in the DAG. - Every RDD appears in some stage. - [`submitStage` will call `submitMissingTasks`](https://github.com/databricks/runtime/blob/1e83dfe4f685bad7f260621e77282b1b4cf9bca4/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L1438) on every stage root RDD. - [`submitStage` calls `getPreferredLocsInternal`](https://github.com/databricks/runtime/blob/1e83dfe4f685bad7f260621e77282b1b4cf9bca4/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L1687-L1696) on every stage root RDD. - [`getPreferredLocsInternal`](https://github.com/databricks/runtime/blob/1e83dfe4f685bad7f260621e77282b1b4cf9bca4/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L2995-L3043) visits the RDD and all of its parents RDDs that are computed in the same stage (via narrow dependencies) and calls `.partitions` on each RDD visited. - Therefore `.partitions` is invoked on every RDD in the DAG by the time the job has successfully completed. - Therefore this patch's change does not introduce any new calls to `.partitions` which would not have otherwise occurred (assuming the job succeeded). #### Ordering of `.partitions` calls I don't think the order in which `.partitions` calls occur matters for correctness: the DAGScheduler happens to invoke `.partitions` in a particular order today (defined by the DAG traversal order in internal scheduler methods), but there's many lots of out-of-order `.partition` calls occurring elsewhere in the codebase. #### Handling of exceptions in `.partitions` I've chosen **not** to add special error-handling for the new `.partitions` calls: if exceptions occur then they'll bubble up, unwrapped, to the user code submitting the Spark job. It's sometimes important to preserve exception wrapping behavior, but I don't think that concern is warranted in this particular case: whether `getPartitions` occurred inside or outside of the scheduler (impacting whether exceptions manifest in wrapped or unwrapped form, and impacting whether failed jobs appear in the Spark UI) was not crisply defined (and in some rare cases could even be [influenced by Spark settings in non-obvious ways](https://github.com/apache/spark/blob/10d5303174bf4a47508f6227bbdb1eaf4c92fcdb/core/src/main/scala/org/apache/spark/Partitioner.scala#L75-L79)), so I think it's both unlikely that users were relying on the old behavior and very difficult to preserve it. #### Should this have a configuration flag? Per discussion from a previous PR trying to solve this problem (#24438 (review)), I've decided to skip adding a configuration flag for this. ### Why are the changes needed? This fixes a longstanding scheduler performance problem which has been reported by multiple users. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? I added a regression test in `BasicSchedulerIntegrationSuite` to cover the regular job submission codepath (`DAGScheduler.submitJob`)This test uses CountDownLatches to simulate the submission of a job containing an RDD with a slow `getPartitions()` call and checks that a concurrently-submitted job is not blocked. I have **not** added separate integration tests for the `runApproximateJob` and `submitMapStage` codepaths (both of which also received the same fix). Closes #34265 from JoshRosen/SPARK-23626. Authored-by: Josh Rosen <[email protected]> Signed-off-by: Josh Rosen <[email protected]>
1 parent 2267d7f commit c4e975e

File tree

3 files changed

+155
-14
lines changed

3 files changed

+155
-14
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,35 @@ private[spark] class DAGScheduler(
732732
missing.toList
733733
}
734734

735+
/** Invoke `.partitions` on the given RDD and all of its ancestors */
736+
private def eagerlyComputePartitionsForRddAndAncestors(rdd: RDD[_]): Unit = {
737+
val startTime = System.nanoTime
738+
val visitedRdds = new HashSet[RDD[_]]
739+
// We are manually maintaining a stack here to prevent StackOverflowError
740+
// caused by recursively visiting
741+
val waitingForVisit = new ListBuffer[RDD[_]]
742+
waitingForVisit += rdd
743+
744+
def visit(rdd: RDD[_]): Unit = {
745+
if (!visitedRdds(rdd)) {
746+
visitedRdds += rdd
747+
748+
// Eagerly compute:
749+
rdd.partitions
750+
751+
for (dep <- rdd.dependencies) {
752+
waitingForVisit.prepend(dep.rdd)
753+
}
754+
}
755+
}
756+
757+
while (waitingForVisit.nonEmpty) {
758+
visit(waitingForVisit.remove(0))
759+
}
760+
logDebug("eagerlyComputePartitionsForRddAndAncestors for RDD %d took %f seconds"
761+
.format(rdd.id, (System.nanoTime - startTime) / 1e9))
762+
}
763+
735764
/**
736765
* Registers the given jobId among the jobs that need the given stage and
737766
* all of that stage's ancestors.
@@ -841,6 +870,11 @@ private[spark] class DAGScheduler(
841870
"Total number of partitions: " + maxPartitions)
842871
}
843872

873+
// SPARK-23626: `RDD.getPartitions()` can be slow, so we eagerly compute
874+
// `.partitions` on every RDD in the DAG to ensure that `getPartitions()`
875+
// is evaluated outside of the DAGScheduler's single-threaded event loop:
876+
eagerlyComputePartitionsForRddAndAncestors(rdd)
877+
844878
val jobId = nextJobId.getAndIncrement()
845879
if (partitions.isEmpty) {
846880
val clonedProperties = Utils.cloneProperties(properties)
@@ -930,6 +964,12 @@ private[spark] class DAGScheduler(
930964
listenerBus.post(SparkListenerJobEnd(jobId, time, JobSucceeded))
931965
return new PartialResult(evaluator.currentResult(), true)
932966
}
967+
968+
// SPARK-23626: `RDD.getPartitions()` can be slow, so we eagerly compute
969+
// `.partitions` on every RDD in the DAG to ensure that `getPartitions()`
970+
// is evaluated outside of the DAGScheduler's single-threaded event loop:
971+
eagerlyComputePartitionsForRddAndAncestors(rdd)
972+
933973
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
934974
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
935975
eventProcessLoop.post(JobSubmitted(
@@ -962,6 +1002,11 @@ private[spark] class DAGScheduler(
9621002
throw SparkCoreErrors.cannotRunSubmitMapStageOnZeroPartitionRDDError()
9631003
}
9641004

1005+
// SPARK-23626: `RDD.getPartitions()` can be slow, so we eagerly compute
1006+
// `.partitions` on every RDD in the DAG to ensure that `getPartitions()`
1007+
// is evaluated outside of the DAGScheduler's single-threaded event loop:
1008+
eagerlyComputePartitionsForRddAndAncestors(rdd)
1009+
9651010
// We create a JobWaiter with only one "task", which will be marked as complete when the whole
9661011
// map stage has completed, and will be passed the MapOutputStatistics for that stage.
9671012
// This makes it easier to avoid race conditions between the user code and the map output

core/src/test/scala/org/apache/spark/scheduler/HealthTrackerIntegrationSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class HealthTrackerIntegrationSuite extends SchedulerIntegrationSuite[MultiExecu
112112
backend.taskFailed(taskDescription, new RuntimeException("test task failure"))
113113
}
114114
withBackend(runBackend _) {
115-
val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray)
115+
val jobFuture = submit(new MockRDD(sc, 10, Nil, Nil), (0 until 10).toArray)
116116
awaitJobTermination(jobFuture, duration)
117117
val pattern = (
118118
s"""|Aborting TaskSet 0.0 because task .*
@@ -150,7 +150,7 @@ class MockRDDWithLocalityPrefs(
150150
sc: SparkContext,
151151
numPartitions: Int,
152152
shuffleDeps: Seq[ShuffleDependency[Int, Int, Nothing]],
153-
val preferredLoc: String) extends MockRDD(sc, numPartitions, shuffleDeps) {
153+
val preferredLoc: String) extends MockRDD(sc, numPartitions, shuffleDeps, Nil) {
154154
override def getPreferredLocations(split: Partition): Seq[String] = {
155155
Seq(preferredLoc)
156156
}

core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala

Lines changed: 108 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
package org.apache.spark.scheduler
1818

1919
import java.util.Properties
20-
import java.util.concurrent.{TimeoutException, TimeUnit}
20+
import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit}
2121
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}
2222

2323
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
@@ -205,7 +205,13 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa
205205
def shuffle(nParts: Int, input: MockRDD): MockRDD = {
206206
val partitioner = new HashPartitioner(nParts)
207207
val shuffleDep = new ShuffleDependency[Int, Int, Nothing](input, partitioner)
208-
new MockRDD(sc, nParts, List(shuffleDep))
208+
new MockRDD(sc, nParts, List(shuffleDep), Nil)
209+
}
210+
211+
/** models a one-to-one dependency within a stage, like a map or filter */
212+
def oneToOne(input: MockRDD): MockRDD = {
213+
val dep = new OneToOneDependency[(Int, Int)](input)
214+
new MockRDD(sc, input.numPartitions, Nil, Seq(dep))
209215
}
210216

211217
/** models a stage boundary with multiple dependencies, like a join */
@@ -214,7 +220,7 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa
214220
val shuffleDeps = inputs.map { inputRDD =>
215221
new ShuffleDependency[Int, Int, Nothing](inputRDD, partitioner)
216222
}
217-
new MockRDD(sc, nParts, shuffleDeps)
223+
new MockRDD(sc, nParts, shuffleDeps, Nil)
218224
}
219225

220226
val backendException = new AtomicReference[Exception](null)
@@ -449,10 +455,11 @@ case class ExecutorTaskStatus(host: String, executorId: String, var freeCores: I
449455
class MockRDD(
450456
sc: SparkContext,
451457
val numPartitions: Int,
452-
val shuffleDeps: Seq[ShuffleDependency[Int, Int, Nothing]]
453-
) extends RDD[(Int, Int)](sc, shuffleDeps) with Serializable {
458+
val shuffleDeps: Seq[ShuffleDependency[Int, Int, Nothing]],
459+
val oneToOneDeps: Seq[OneToOneDependency[(Int, Int)]]
460+
) extends RDD[(Int, Int)](sc, deps = shuffleDeps ++ oneToOneDeps) with Serializable {
454461

455-
MockRDD.validate(numPartitions, shuffleDeps)
462+
MockRDD.validate(numPartitions, shuffleDeps, oneToOneDeps)
456463

457464
override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
458465
throw new RuntimeException("should not be reached")
@@ -468,14 +475,25 @@ class MockRDD(
468475
object MockRDD extends AssertionsHelper with TripleEquals with Assertions {
469476
/**
470477
* make sure all the shuffle dependencies have a consistent number of output partitions
478+
* and that one-to-one dependencies have the same partition counts as their parents
471479
* (mostly to make sure the test setup makes sense, not that Spark itself would get this wrong)
472480
*/
473-
def validate(numPartitions: Int, dependencies: Seq[ShuffleDependency[_, _, _]]): Unit = {
474-
dependencies.foreach { dependency =>
481+
def validate(
482+
numPartitions: Int,
483+
shuffleDependencies: Seq[ShuffleDependency[_, _, _]],
484+
oneToOneDependencies: Seq[OneToOneDependency[_]]): Unit = {
485+
shuffleDependencies.foreach { dependency =>
475486
val partitioner = dependency.partitioner
476487
assert(partitioner != null)
477488
assert(partitioner.numPartitions === numPartitions)
478489
}
490+
oneToOneDependencies.foreach { dependency =>
491+
// In order to support the SPARK-23626 testcase, we cast to MockRDD
492+
// and access `numPartitions` instead of just calling `getNumPartitions`:
493+
// `getNumPartitions` would call `getPartitions`, undermining the intention
494+
// of the SPARK-23626 testcase.
495+
assert(dependency.rdd.asInstanceOf[MockRDD].numPartitions === numPartitions)
496+
}
479497
}
480498
}
481499

@@ -539,7 +557,7 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor
539557
backend.taskSuccess(taskDescription, 42)
540558
}
541559
withBackend(runBackend _) {
542-
val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray)
560+
val jobFuture = submit(new MockRDD(sc, 10, Nil, Nil), (0 until 10).toArray)
543561
awaitJobTermination(jobFuture, duration)
544562
}
545563
assert(results === (0 until 10).map { _ -> 42 }.toMap)
@@ -564,7 +582,7 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor
564582
}
565583
}
566584

567-
val a = new MockRDD(sc, 2, Nil)
585+
val a = new MockRDD(sc, 2, Nil, Nil)
568586
val b = shuffle(10, a)
569587
val c = shuffle(20, a)
570588
val d = join(30, b, c)
@@ -604,7 +622,7 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor
604622
* (b) we get a second attempt for stage 0 & stage 1
605623
*/
606624
testScheduler("job with fetch failure") {
607-
val input = new MockRDD(sc, 2, Nil)
625+
val input = new MockRDD(sc, 2, Nil, Nil)
608626
val shuffledRdd = shuffle(10, input)
609627
val shuffleId = shuffledRdd.shuffleDeps.head.shuffleId
610628

@@ -646,10 +664,88 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor
646664
backend.taskFailed(taskDescription, new RuntimeException("test task failure"))
647665
}
648666
withBackend(runBackend _) {
649-
val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray)
667+
val jobFuture = submit(new MockRDD(sc, 10, Nil, Nil), (0 until 10).toArray)
650668
awaitJobTermination(jobFuture, duration)
651669
assert(failure.getMessage.contains("test task failure"))
652670
}
653671
assertDataStructuresEmpty(noFailure = false)
654672
}
673+
674+
testScheduler("SPARK-23626: RDD with expensive getPartitions() doesn't block scheduler loop") {
675+
// Before SPARK-23626, expensive `RDD.getPartitions()` calls might occur inside of the
676+
// DAGScheduler event loop, causing concurrently-submitted jobs to block. This test case
677+
// reproduces a scenario where that blocking could occur.
678+
679+
// We'll use latches to simulate an RDD with a slow getPartitions() call.
680+
import MockRDDWithSlowGetPartitions._
681+
682+
// DAGScheduler.submitJob calls `.partitions` on the RDD passed to it.
683+
// Therefore to write a proper regression test for SPARK-23626 we must
684+
// ensure that the slow getPartitions() call occurs deeper in the RDD DAG:
685+
val rddWithSlowGetPartitions = oneToOne(new MockRDDWithSlowGetPartitions(sc, 1))
686+
687+
// A RDD whose execution should not be blocked by the other RDD's slow getPartitions():
688+
val simpleRdd = new MockRDD(sc, 1, Nil, Nil)
689+
690+
getPartitionsShouldNotHaveBeenCalledYet.set(false)
691+
692+
def runBackend(): Unit = {
693+
val (taskDescription, _) = backend.beginTask()
694+
backend.taskSuccess(taskDescription, 42)
695+
}
696+
697+
withBackend(runBackend _) {
698+
// Submit a job containing an RDD which will hang in getPartitions() until we release
699+
// the countdown latch:
700+
import scala.concurrent.ExecutionContext.Implicits.global
701+
val slowJobFuture = Future { submit(rddWithSlowGetPartitions, Array(0)) }.flatten
702+
703+
// Block the current thread until the other thread has started the getPartitions() call:
704+
beginGetPartitionsLatch.await(duration.toSeconds, SECONDS)
705+
706+
// Submit a concurrent job. This job's execution should not be blocked by the other job:
707+
val fastJobFuture = submit(simpleRdd, Array(0))
708+
awaitJobTermination(fastJobFuture, duration)
709+
710+
// The slow job should still be blocked in the getPartitions() call:
711+
assert(!slowJobFuture.isCompleted)
712+
713+
// Allow it to complete:
714+
endGetPartitionsLatch.countDown()
715+
awaitJobTermination(slowJobFuture, duration)
716+
}
717+
718+
assertDataStructuresEmpty()
719+
}
720+
}
721+
722+
/** Helper class used in SPARK-23626 test case */
723+
private object MockRDDWithSlowGetPartitions {
724+
// Latch for blocking the test execution thread until getPartitions() has been called:
725+
val beginGetPartitionsLatch = new CountDownLatch(1)
726+
727+
// Latch for blocking the getPartitions() call from completing:
728+
val endGetPartitionsLatch = new CountDownLatch(1)
729+
730+
// Atomic boolean which is used to fail the test in case getPartitions() is called earlier
731+
// than expected. This guards against false-negatives (e.g. the test passing because
732+
// `.getPartitions()` was called in the test setup before we even submitted a job):
733+
val getPartitionsShouldNotHaveBeenCalledYet = new AtomicBoolean(true)
734+
}
735+
736+
/** Helper class used in SPARK-23626 test case */
737+
private class MockRDDWithSlowGetPartitions(
738+
sc: SparkContext,
739+
numPartitions: Int) extends MockRDD(sc, numPartitions, Nil, Nil) {
740+
import MockRDDWithSlowGetPartitions._
741+
742+
override def getPartitions: Array[Partition] = {
743+
if (getPartitionsShouldNotHaveBeenCalledYet.get()) {
744+
throw new Exception("getPartitions() should not have been called at this point")
745+
}
746+
beginGetPartitionsLatch.countDown()
747+
val partitions = super.getPartitions
748+
endGetPartitionsLatch.await()
749+
partitions
750+
}
655751
}

0 commit comments

Comments
 (0)