diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 482691c94f87e..c03e3e0bbaf59 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1571,13 +1571,13 @@ private[spark] class DAGScheduler( // guaranteed to be determinate, so the input data of the reducers will not change // even if the map tasks are re-tried. if (mapStage.rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE) { - // It's a little tricky to find all the succeeding stages of `failedStage`, because + // It's a little tricky to find all the succeeding stages of `mapStage`, because // each stage only know its parents not children. Here we traverse the stages from // the leaf nodes (the result stages of active jobs), and rollback all the stages - // in the stage chains that connect to the `failedStage`. To speed up the stage + // in the stage chains that connect to the `mapStage`. To speed up the stage // traversing, we collect the stages to rollback first. If a stage needs to // rollback, all its succeeding stages need to rollback to. - val stagesToRollback = HashSet(failedStage) + val stagesToRollback = HashSet[Stage](mapStage) def collectStagesToRollback(stageChain: List[Stage]): Unit = { if (stagesToRollback.contains(stageChain.head)) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index cff3ebf2fb7e0..2b3423f9a4d40 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2741,27 +2741,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi FetchFailed(makeBlockManagerId("hostC"), shuffleId2, 0, 0, "ignored"), null)) - val failedStages = scheduler.failedStages.toSeq - assert(failedStages.length == 2) - // Shuffle blocks of "hostC" is lost, so first task of the `shuffleMapRdd2` needs to retry. - assert(failedStages.collect { - case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId2 => stage - }.head.findMissingPartitions() == Seq(0)) - // The result stage is still waiting for its 2 tasks to complete - assert(failedStages.collect { - case stage: ResultStage => stage - }.head.findMissingPartitions() == Seq(0, 1)) - - scheduler.resubmitFailedStages() - - // The first task of the `shuffleMapRdd2` failed with fetch failure - runEvent(makeCompletionEvent( - taskSets(3).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0, 0, "ignored"), - null)) - - // The job should fail because Spark can't rollback the shuffle map stage. - assert(failure != null && failure.getMessage.contains("Spark cannot rollback")) + // The second shuffle map stage need to rerun, the job will abort for the indeterminate + // stage rerun. + // TODO: After we support re-generate shuffle file(SPARK-25341), this test will be extended. + assert(failure != null && failure.getMessage + .contains("Spark cannot rollback the ShuffleMapStage 1")) } private def assertResultStageFailToRollback(mapRdd: MyRDD): Unit = { @@ -2872,6 +2856,33 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(latch.await(10, TimeUnit.SECONDS)) } + test("SPARK-28699: abort stage if parent stage is indeterminate stage") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil, indeterminate = true) + + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1)) + + // Finish the first shuffle map stage. + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + + runEvent(makeCompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + null)) + + // Shuffle blocks of "hostA" is lost, so first task of the `shuffleMapRdd` needs to retry. + // The result stage is still waiting for its 2 tasks to complete. + // Because of shuffleMapRdd is indeterminate, this job will be abort. + assert(failure != null && failure.getMessage + .contains("Spark cannot rollback the ShuffleMapStage 0")) + } + test("Completions in zombie tasksets update status of non-zombie taskset") { val parts = 4 val shuffleMapRdd = new MyRDD(sc, parts, Nil)