Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ class DAGScheduler(
private def failJobAndIndependentStages(job: ActiveJob, failureReason: String,
resultStage: Option[Stage]) {
val error = new SparkException(failureReason)
job.listener.jobFailed(error)
var ableToCancelStages = true

val shouldInterruptThread =
if (job.properties == null) false
Expand All @@ -1062,18 +1062,26 @@ class DAGScheduler(
// This is the only job that uses this stage, so fail the stage if it is running.
val stage = stageIdToStage(stageId)
if (runningStages.contains(stage)) {
taskScheduler.cancelTasks(stageId, shouldInterruptThread)
val stageInfo = stageToInfos(stage)
stageInfo.stageFailed(failureReason)
listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage)))
try { // cancelTasks will fail if a SchedulerBackend does not implement killTask
taskScheduler.cancelTasks(stageId, shouldInterruptThread)
val stageInfo = stageToInfos(stage)
stageInfo.stageFailed(failureReason)
listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage)))
} catch {
case e: UnsupportedOperationException =>
logInfo(s"Could not cancel tasks for stage $stageId", e)
ableToCancelStages = false
}
}
}
}
}

cleanupStateForJobAndIndependentStages(job, resultStage)

listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
if (ableToCancelStages) {
job.listener.jobFailed(error)
cleanupStateForJobAndIndependentStages(job, resultStage)
listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
}
}

/**
Expand Down Expand Up @@ -1155,7 +1163,11 @@ private[scheduler] class DAGSchedulerActorSupervisor(dagScheduler: DAGScheduler)
case x: Exception =>
logError("eventProcesserActor failed due to the error %s; shutting down SparkContext"
.format(x.getMessage))
dagScheduler.doCancelAllJobs()
try {
dagScheduler.doCancelAllJobs()
} catch {
case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
}
dagScheduler.sc.stop()
Stop
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
sc = new SparkContext("local", "DAGSchedulerSuite")
sparkListener.successfulStages.clear()
sparkListener.failedStages.clear()
failure = null
sc.addSparkListener(sparkListener)
taskSets.clear()
cancelledStages.clear()
Expand Down Expand Up @@ -314,6 +315,53 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
assertDataStructuresEmpty
}

test("job cancellation no-kill backend") {
// make sure that the DAGScheduler doesn't crash when the TaskScheduler
// doesn't implement killTask()
val noKillTaskScheduler = new TaskScheduler() {
override def rootPool: Pool = null
override def schedulingMode: SchedulingMode = SchedulingMode.NONE
override def start() = {}
override def stop() = {}
override def submitTasks(taskSet: TaskSet) = {
taskSets += taskSet
}
override def cancelTasks(stageId: Int, interruptThread: Boolean) {
throw new UnsupportedOperationException
}
override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
override def defaultParallelism() = 2
}
val noKillScheduler = new DAGScheduler(
sc,
noKillTaskScheduler,
sc.listenerBus,
mapOutputTracker,
blockManagerMaster,
sc.env) {
override def runLocally(job: ActiveJob) {
// don't bother with the thread while unit testing
runLocallyWithinThread(job)
}
}
dagEventProcessTestActor = TestActorRef[DAGSchedulerEventProcessActor](
Props(classOf[DAGSchedulerEventProcessActor], noKillScheduler))(system)
val rdd = makeRdd(1, Nil)
val jobId = submit(rdd, Array(0))
cancel(jobId)
// Because the job wasn't actually cancelled, we shouldn't have received a failure message.
assert(failure === null)

// When the task set completes normally, state should be correctly updated.
complete(taskSets(0), Seq((Success, 42)))
assert(results === Map(0 -> 42))
assertDataStructuresEmpty

assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
assert(sparkListener.failedStages.isEmpty)
assert(sparkListener.successfulStages.contains(0))
}

test("run trivial shuffle") {
val shuffleMapRdd = makeRdd(2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
Expand Down