Skip to content

Commit aaacf30

Browse files
committed
changes to the test and the logic: ignore fetch failures to abort on decom hosts and forcefully handle fetch failure
1 parent 0c850c7 commit aaacf30

File tree

3 files changed

+75
-18
lines changed

3 files changed

+75
-18
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,6 +1667,11 @@ private[spark] class DAGScheduler(
16671667
case FetchFailed(bmAddress, shuffleId, _, mapIndex, _, failureMessage) =>
16681668
val failedStage = stageIdToStage(task.stageId)
16691669
val mapStage = shuffleIdToMapStage(shuffleId)
1670+
val sourceDecommissioned = if (bmAddress != null && bmAddress.executorId != null) {
1671+
taskScheduler.getExecutorDecommissionInfo(bmAddress.executorId)
1672+
} else {
1673+
None
1674+
}
16701675

16711676
if (failedStage.latestInfo.attemptNumber != task.stageAttemptId) {
16721677
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
@@ -1675,7 +1680,8 @@ private[spark] class DAGScheduler(
16751680
} else {
16761681
failedStage.failedAttemptIds.add(task.stageAttemptId)
16771682
val shouldAbortStage =
1678-
failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts ||
1683+
sourceDecommissioned.isEmpty &&
1684+
failedStage.failedAttemptIds.size >= maxConsecutiveStageAttempts ||
16791685
disallowStageRetryForTest
16801686

16811687
// It is likely that we receive multiple FetchFailed for a single stage (because we have
@@ -1824,16 +1830,14 @@ private[spark] class DAGScheduler(
18241830
// TODO: mark the executor as failed only if there were lots of fetch failures on it
18251831
if (bmAddress != null) {
18261832
val externalShuffleServiceEnabled = env.blockManager.externalShuffleServiceEnabled
1827-
val isHostDecommissioned = taskScheduler
1828-
.getExecutorDecommissionInfo(bmAddress.executorId)
1829-
.exists(_.isHostDecommissioned)
1833+
val sourceHostDecommissioned = sourceDecommissioned.exists(_.isHostDecommissioned)
18301834

18311835
// Shuffle output of all executors on host `bmAddress.host` may be lost if:
18321836
// - External shuffle service is enabled, so we assume that all shuffle data on node is
18331837
// bad.
18341838
// - Host is decommissioned, thus all executors on that host will die.
18351839
val shuffleOutputOfEntireHostLost = externalShuffleServiceEnabled ||
1836-
isHostDecommissioned
1840+
sourceHostDecommissioned
18371841
val hostToUnregisterOutputs = if (shuffleOutputOfEntireHostLost
18381842
&& unRegisterOutputOnHostOnFetchFailure) {
18391843
Some(bmAddress.host)
@@ -1842,11 +1846,18 @@ private[spark] class DAGScheduler(
18421846
// reason to believe shuffle data has been lost for the entire host).
18431847
None
18441848
}
1849+
val maybeEpoch = if (sourceHostDecommissioned) {
1850+
// If we know that the host has been decommissioned, remove its map outputs
1851+
// unconditionally
1852+
None
1853+
} else {
1854+
Some(task.epoch)
1855+
}
18451856
removeExecutorAndUnregisterOutputs(
18461857
execId = bmAddress.executorId,
18471858
fileLost = true,
18481859
hostToUnregisterOutputs = hostToUnregisterOutputs,
1849-
maybeEpoch = Some(task.epoch))
1860+
maybeEpoch)
18501861
}
18511862
}
18521863

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.scheduler
1919

2020
import java.nio.ByteBuffer
21+
import java.util
2122
import java.util.{Timer, TimerTask}
2223
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
2324
import java.util.concurrent.atomic.AtomicLong
@@ -137,6 +138,8 @@ private[spark] class TaskSchedulerImpl(
137138
private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]]
138139

139140
private val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo]
141+
// map of second to list of executors to clear form the above map
142+
private val decommissioningExecutorsToGc = new util.TreeMap[Long, mutable.ArrayBuffer[String]]()
140143

141144
def runningTasksByExecutors: Map[String, Int] = synchronized {
142145
executorIdToRunningTaskIds.toMap.mapValues(_.size).toMap
@@ -921,7 +924,13 @@ private[spark] class TaskSchedulerImpl(
921924

922925
override def getExecutorDecommissionInfo(executorId: String)
923926
: Option[ExecutorDecommissionInfo] = synchronized {
924-
executorsPendingDecommission.get(executorId)
927+
import scala.collection.JavaConverters._
928+
// Garbage collect old decommissioning entries
929+
val secondToGcUptil = math.floor(clock.getTimeMillis() / 1000.0).toLong
930+
val headMap = decommissioningExecutorsToGc.headMap(secondToGcUptil)
931+
headMap.values().asScala.flatten.foreach(executorsPendingDecommission -= _)
932+
headMap.clear()
933+
executorsPendingDecommission.get(executorId)
925934
}
926935

927936
override def executorLost(executorId: String, givenReason: ExecutorLossReason): Unit = {
@@ -1027,7 +1036,13 @@ private[spark] class TaskSchedulerImpl(
10271036
}
10281037
}
10291038

1030-
executorsPendingDecommission -= executorId
1039+
1040+
val decomInfo = executorsPendingDecommission.get(executorId)
1041+
if (decomInfo.isDefined) {
1042+
// TODO(dagrawal): make this timestamp configurable
1043+
val gcSecond = math.ceil(clock.getTimeMillis() / 1000.0).toLong + 60
1044+
decommissioningExecutorsToGc.getOrDefault(gcSecond, mutable.ArrayBuffer.empty) += executorId
1045+
}
10311046

10321047
if (reason != LossReasonPending) {
10331048
executorIdToHost -= executorId

core/src/test/scala/org/apache/spark/deploy/DecommissionWorkerSuite.scala

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@ class DecommissionWorkerSuite
8484
}
8585
}
8686

87+
// Unlike TestUtils.withListener, it also waits for the job to be done
88+
def withListener(sc: SparkContext, listener: RootStageAwareListener)
89+
(body: SparkListener => Unit): Unit = {
90+
sc.addSparkListener(listener)
91+
try {
92+
body(listener)
93+
sc.listenerBus.waitUntilEmpty()
94+
listener.waitForJobDone()
95+
} finally {
96+
sc.listenerBus.removeListener(listener)
97+
}
98+
}
99+
87100
test("decommission workers should not result in job failure") {
88101
val maxTaskFailures = 2
89102
val numTimesToKillWorkers = maxTaskFailures + 1
@@ -109,7 +122,7 @@ class DecommissionWorkerSuite
109122
}
110123
}
111124
}
112-
TestUtils.withListener(sc, listener) { _ =>
125+
withListener(sc, listener) { _ =>
113126
val jobResult = sc.parallelize(1 to 1, 1).map { _ =>
114127
Thread.sleep(5 * 1000L); 1
115128
}.count()
@@ -164,7 +177,7 @@ class DecommissionWorkerSuite
164177
}
165178
}
166179
}
167-
TestUtils.withListener(sc, listener) { _ =>
180+
withListener(sc, listener) { _ =>
168181
val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((pid, _) => {
169182
val sleepTimeSeconds = if (pid == 0) 1 else 10
170183
Thread.sleep(sleepTimeSeconds * 1000L)
@@ -212,22 +225,27 @@ class DecommissionWorkerSuite
212225
override def handleRootTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
213226
val taskInfo = taskEnd.taskInfo
214227
if (taskInfo.executorId == executorToDecom && taskInfo.attemptNumber == 0 &&
215-
taskEnd.stageAttemptId == 0) {
228+
taskEnd.stageAttemptId == 0 && taskEnd.stageId == 0) {
216229
decommissionWorkerOnMaster(workerToDecom,
217230
"decommission worker after task on it is done")
218231
}
219232
}
220233
}
221-
TestUtils.withListener(sc, listener) { _ =>
234+
withListener(sc, listener) { _ =>
222235
val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((_, _) => {
223236
val executorId = SparkEnv.get.executorId
224-
val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
225-
Thread.sleep(sleepTimeSeconds * 1000L)
237+
val context = TaskContext.get()
238+
if (context.attemptNumber() == 0 && context.stageAttemptNumber() == 0) {
239+
val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
240+
Thread.sleep(sleepTimeSeconds * 1000L)
241+
}
226242
List(1).iterator
227243
}, preservesPartitioning = true)
228244
.repartition(1).mapPartitions(iter => {
229245
val context = TaskContext.get()
230246
if (context.attemptNumber == 0 && context.stageAttemptNumber() == 0) {
247+
// Wait a bit for the decommissioning to be triggered in the listener
248+
Thread.sleep(5000)
231249
// MapIndex is explicitly -1 to force the entire host to be decommissioned
232250
// However, this will cause both the tasks in the preceding stage since the host here is
233251
// "localhost" (shortcoming of this single-machine unit test in that all the workers
@@ -265,23 +283,31 @@ class DecommissionWorkerSuite
265283
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
266284
jobEnd.jobResult match {
267285
case JobSucceeded => jobDone.set(true)
286+
case JobFailed(exception) => logError(s"Job failed", exception)
268287
}
269288
}
270289

271290
protected def handleRootTaskEnd(end: SparkListenerTaskEnd) = {}
272291

273292
protected def handleRootTaskStart(start: SparkListenerTaskStart) = {}
274293

294+
private def getSignature(taskInfo: TaskInfo, stageId: Int, stageAttemptId: Int):
295+
String = {
296+
s"${stageId}:${stageAttemptId}:" +
297+
s"${taskInfo.index}:${taskInfo.attemptNumber}-${taskInfo.status}"
298+
}
299+
275300
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
301+
val signature = getSignature(taskStart.taskInfo, taskStart.stageId, taskStart.stageAttemptId)
302+
logInfo(s"Task started: $signature")
276303
if (isRootStageId(taskStart.stageId)) {
277304
rootTasksStarted.add(taskStart.taskInfo)
278305
handleRootTaskStart(taskStart)
279306
}
280307
}
281308

282309
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
283-
val taskSignature = s"${taskEnd.stageId}:${taskEnd.stageAttemptId}:" +
284-
s"${taskEnd.taskInfo.index}:${taskEnd.taskInfo.attemptNumber}"
310+
val taskSignature = getSignature(taskEnd.taskInfo, taskEnd.stageId, taskEnd.stageAttemptId)
285311
logInfo(s"Task End $taskSignature")
286312
tasksFinished.add(taskSignature)
287313
if (isRootStageId(taskEnd.stageId)) {
@@ -291,8 +317,13 @@ class DecommissionWorkerSuite
291317
}
292318

293319
def getTasksFinished(): Seq[String] = {
294-
assert(jobDone.get(), "Job isn't successfully done yet")
295-
tasksFinished.asScala.toSeq
320+
tasksFinished.asScala.toList
321+
}
322+
323+
def waitForJobDone(): Unit = {
324+
eventually(timeout(10.seconds), interval(100.milliseconds)) {
325+
assert(jobDone.get(), "Job isn't successfully done yet")
326+
}
296327
}
297328
}
298329

0 commit comments

Comments
 (0)