@@ -96,6 +96,30 @@ class DAGScheduler(
9696 // Stages that must be resubmitted due to fetch failures
9797 private [scheduler] val failedStages = new HashSet [Stage ]
9898
99+ // The maximum number of times to retry a stage before aborting
100+ val maxStageFailures = 5
101+
102+ // To avoid cyclical stage failures (see SPARK-5945) we limit the number of times that a stage
103+ // may be retried. However, it only makes sense to limit the number of times that a stage fails
104+ // if it's failing for the same reason every time. Therefore, track why a stage fails as well as
105+ // how many times it has failed.
106+ case class StageFailure (failureReason : String ) {
107+ var count = 1
108+ def fail () = { count += 1 }
109+ def shouldAbort (): Boolean = { count >= maxStageFailures }
110+
111+ override def equals (other : Any ): Boolean =
112+ other match {
113+ case that : StageFailure => that.failureReason.equals(this .failureReason)
114+ case _ => false
115+ }
116+
117+ override def hashCode : Int = failureReason.hashCode()
118+ }
119+
120+ // Map to track failure reasons for a given stage (indexed by stage ID)
121+ private [scheduler] val stageFailureReasons = new HashMap [Stage , HashSet [StageFailure ]]
122+
99123 private [scheduler] val activeJobs = new HashSet [ActiveJob ]
100124
101125 /**
@@ -460,6 +484,10 @@ class DAGScheduler(
460484 logDebug(" Removing stage %d from failed set." .format(stageId))
461485 failedStages -= stage
462486 }
487+ if (stageFailureReasons.contains(stage)) {
488+ logDebug(" Removing stage %d from failure reasons set." .format(stageId))
489+ stageFailureReasons -= stage
490+ }
463491 }
464492 // data structures based on StageId
465493 stageIdToStage -= stageId
@@ -940,6 +968,29 @@ class DAGScheduler(
940968 }
941969 }
942970
971+ /**
972+ * Check whether we should abort the failedStage due to multiple failures for the same reason.
973+ * This method updates the running count of failures for a particular stage and returns
974+ * true if the number of failures for any single reason exceeds the allowable number
975+ * of failures.
976+ * @return An Option that contains the failure reason that caused the abort
977+ */
978+ def shouldAbortStage (failedStage : Stage , failureReason : String ): Option [String ] = {
979+ if (! stageFailureReasons.contains(failedStage))
980+ stageFailureReasons.put(failedStage, new HashSet [StageFailure ]())
981+
982+ val failures = stageFailureReasons.get(failedStage).get
983+ val failure = StageFailure (failureReason)
984+ failures.find(s => s.equals(failure)) match {
985+ case Some (f) => f.fail()
986+ case None => failures.add(failure)
987+ }
988+ failures.find(_.shouldAbort()) match {
989+ case Some (f) => Some (f.failureReason)
990+ case None => None
991+ }
992+ }
993+
943994 /**
944995 * Responds to a task finishing. This is called inside the event loop so it assumes that it can
945996 * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
@@ -1083,8 +1134,13 @@ class DAGScheduler(
10831134 markStageAsFinished(failedStage, Some (failureMessage))
10841135 }
10851136
1137+ val shouldAbort = shouldAbortStage(failedStage, failureMessage)
10861138 if (disallowStageRetryForTest) {
10871139 abortStage(failedStage, " Fetch failure will not retry stage due to testing config" )
1140+ } else if (shouldAbort.isDefined) {
1141+ abortStage(failedStage, s " Fetch failure - aborting stage. Stage ${failedStage.name} " +
1142+ s " has failed the maximum allowable number of times: ${maxStageFailures}. " +
1143+ s " Failure reason: ${shouldAbort.get}" )
10881144 } else if (failedStages.isEmpty) {
10891145 // Don't schedule an event to resubmit failed stages if failed isn't empty, because
10901146 // in that case the event will already have been scheduled.
0 commit comments