diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java index ad7eb04afcd8c..746676c7aa583 100644 --- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java @@ -175,6 +175,8 @@ void closeAndWriteOutput() throws IOException { final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; final long[] partitionLengths; + final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId); + final long initialFileLength = outputFile.length(); try { partitionLengths = mergeSpills(spills); } finally { @@ -184,7 +186,7 @@ void closeAndWriteOutput() throws IOException { } } } - shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths); + shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths, initialFileLength); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 018422827e1c8..3f77765d132c7 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -385,9 +385,10 @@ private[spark] object MapOutputTracker extends Logging { statuses.map { status => if (status == null) { - logError("Missing an output location for shuffle " + shuffleId) - throw new MetadataFetchFailedException( - shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId) + val missing = statuses.iterator.zipWithIndex.filter{_._1 == null}.map{_._2}.mkString(",") + val msg = "Missing an output location for shuffle =" + shuffleId + "; mapIds =" + missing + logError(msg) + throw new MetadataFetchFailedException(shuffleId, reduceId, msg) } else { (status.location, status.getSizeForBlock(reduceId)) } diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala index 2ebd7a7151a59..b7c2386fd7d87 100644 --- a/core/src/main/scala/org/apache/spark/SparkException.scala +++ b/core/src/main/scala/org/apache/spark/SparkException.scala @@ -30,3 +30,13 @@ class SparkException(message: String, cause: Throwable) */ private[spark] class SparkDriverExecutionException(cause: Throwable) extends SparkException("Execution error", cause) + +/** + * Exception indicating an error internal to Spark -- it is in an inconsistent state, not due + * to any error by the user + */ +class SparkIllegalStateException(message: String, cause: Throwable) + extends SparkException(message, cause) { + + def this(message: String) = this(message, null) +} diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index b4d572cb52313..077fa038bbe69 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -30,6 +30,7 @@ private[spark] class TaskContextImpl( override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, val runningLocally: Boolean = false, + val stageAttemptId: Int = 0, // for testing val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext with Logging { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5d812918a13d1..bb634343049b4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -830,6 +830,10 @@ class DAGScheduler( logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry stage.pendingTasks.clear() + stage match { + case smt: ShuffleMapStage => smt.clearPartitionComputeCount() + case _ => + } // First figure out the indexes of partition ids to compute. @@ -891,7 +895,7 @@ class DAGScheduler( partitionsToCompute.map { id => val locs = getPreferredLocs(stage.rdd, id) val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, taskBinary, part, locs) + new ShuffleMapTask(stage.id, stage.attemptId, taskBinary, part, locs) } case stage: ResultStage => @@ -900,7 +904,7 @@ class DAGScheduler( val p: Int = job.partitions(id) val part = stage.rdd.partitions(p) val locs = getPreferredLocs(stage.rdd, p) - new ResultTask(stage.id, taskBinary, part, locs, id) + new ResultTask(stage.id, stage.attemptId, taskBinary, part, locs, id) } } @@ -974,6 +978,7 @@ class DAGScheduler( val stageId = task.stageId val taskType = Utils.getFormattedClassName(task) + // REVIEWERS: does this need special handling for multiple completions of the same task? outputCommitCoordinator.taskCompleted(stageId, task.partitionId, event.taskInfo.attempt, event.reason) @@ -1031,15 +1036,31 @@ class DAGScheduler( case smt: ShuffleMapTask => val shuffleStage = stage.asInstanceOf[ShuffleMapStage] + val computeCount = shuffleStage.incComputeCount(smt.partitionId) updateAccumulators(event) val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) - if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { - logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId) + if (computeCount > 1) { + // REVIEWERS: do I need to worry about speculation here, when multiple completion + // events are normal? + + // REVIEWERS: is this really only a problem on a ShuffleMapTask?? does it also cause + // problems for ResultTask? + + // This can happen when a retry runs a task, but there was a lingering task from an + // earlier attempt which also finished. The results might be OK, or they might not. + // To be safe, we'll retry the task, and do it in yet another attempt, to avoid more + // task output clobbering. + logInfo(s"Multiple completion events for task $task. Results may be corrupt," + + s" assuming task needs to be rerun.") + shuffleStage.removeOutputLoc(task.partitionId) + } else if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) { + logInfo(s"Ignoring possibly bogus $smt completion from executor $execId") } else { shuffleStage.addOutputLoc(smt.partitionId, status) } + if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) { markStageAsFinished(shuffleStage) logInfo("looking for newly runnable stages") @@ -1103,9 +1124,14 @@ class DAGScheduler( // multiple tasks running concurrently on different executors). In that case, it is possible // the fetch failure has already been handled by the scheduler. if (runningStages.contains(failedStage)) { - logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + - s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some(failureMessage)) + if (failedStage.attemptId - 1 > task.stageAttemptId) { + logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" + + s" ${task.stageAttemptId}, which has already failed") + } else { + logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + + s"due to a fetch failure from $mapStage (${mapStage.name})") + markStageAsFinished(failedStage, Some(failureMessage)) + } } if (disallowStageRetryForTest) { @@ -1128,6 +1154,16 @@ class DAGScheduler( mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) } + // We also have to mark this map output as unavailable. Its possible that a *later* attempt + // has finished this task in the meantime, but when this task fails, it might end up + // deleting the mapOutput from the earlier successful attempt. + failedStage match { + case smt: ShuffleMapStage => + smt.incComputeCount(reduceId) + smt.removeOutputLoc(reduceId) + case _ => + } + // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index c9a124113961f..9c2606e278c54 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD */ private[spark] class ResultTask[T, U]( stageId: Int, + stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient locs: Seq[TaskLocation], val outputId: Int) - extends Task[U](stageId, partition.index) with Serializable { + extends Task[U](stageId, stageAttemptId, partition.index) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index d02210743484c..da3916430c0b2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import scala.collection.mutable.HashMap + import org.apache.spark.ShuffleDependency import org.apache.spark.rdd.RDD import org.apache.spark.storage.BlockManagerId @@ -43,6 +45,17 @@ private[spark] class ShuffleMapStage( val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil) + private val partitionComputeCount = HashMap[Int, Int]() + + def incComputeCount(partition: Int): Int = { + partitionComputeCount(partition) = partitionComputeCount.getOrElse(partition, 0) + 1 + partitionComputeCount(partition) + } + + def clearPartitionComputeCount(): Unit = { + partitionComputeCount.clear() + } + def addOutputLoc(partition: Int, status: MapStatus): Unit = { val prevList = outputLocs(partition) outputLocs(partition) = status :: prevList @@ -51,6 +64,13 @@ private[spark] class ShuffleMapStage( } } + def removeOutputLoc(partition: Int): Unit = { + if (outputLocs(partition) != Nil) { + outputLocs(partition) = Nil + numAvailableOutputs -= 1 + } + } + def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = { val prevList = outputLocs(partition) val newList = prevList.filterNot(_.location == bmAddress) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index bd3dd23dfe1ac..14c8c00961487 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -40,14 +40,15 @@ import org.apache.spark.shuffle.ShuffleWriter */ private[spark] class ShuffleMapTask( stageId: Int, + stageAttemptId: Int, taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId, partition.index) with Logging { + extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, null, new Partition { override def index: Int = 0 }, null) + this(0, 0, null, new Partition { override def index: Int = 0 }, null) } @transient private val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 586d1e06204c1..b8bd0a5c10370 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -43,7 +43,10 @@ import org.apache.spark.util.Utils * @param stageId id of the stage this task belongs to * @param partitionId index of the number in the RDD */ -private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { +private[spark] abstract class Task[T]( + val stageId: Int, + val stageAttemptId: Int, + var partitionId: Int) extends Serializable { /** * Called by [[Executor]] to run this task. @@ -55,6 +58,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex final def run(taskAttemptId: Long, attemptNumber: Int): T = { context = new TaskContextImpl( stageId = stageId, + stageAttemptId = stageAttemptId, partitionId = partitionId, taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index b4b8a630694bb..d37f02e51e3bf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -159,6 +159,12 @@ private[spark] class TaskSchedulerImpl( this.synchronized { val manager = createTaskSetManager(taskSet, maxTaskFailures) activeTaskSets(taskSet.id) = manager + val taskSetsPerStage = activeTaskSets.values.filterNot(_.isZombie).groupBy(_.stageId) + taskSetsPerStage.foreach { case (stage, taskSets) => + if (taskSets.size > 1) { + throw new SparkIllegalStateException("more than one active taskSet for stage " + stage) + } + } schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) if (!isLocal && !hasReceivedTask) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index c4487d5b37247..066cbe3a88a3a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -620,6 +620,7 @@ private[spark] class TaskSetManager( val index = info.index info.markSuccessful() removeRunningTask(tid) + val task = tasks(index) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not // "deserialize" the value when holding a lock to avoid blocking other threads. So we call @@ -627,11 +628,13 @@ private[spark] class TaskSetManager( // Note: "result.value()" only deserializes the value when it's called at the first time, so // here "result.value()" just returns the value and won't block other threads. sched.dagScheduler.taskEnded( - tasks(index), Success, result.value(), result.accumUpdates, info, result.metrics) + task, Success, result.value(), result.accumUpdates, info, result.metrics) if (!successful(index)) { tasksSuccessful += 1 - logInfo("Finished task %s in stage %s (TID %d) in %d ms on %s (%d/%d)".format( - info.id, taskSet.id, info.taskId, info.duration, info.host, tasksSuccessful, numTasks)) + // include the partition here b/c on a retry, the partition is *not* the same as info.id + logInfo(("Finished task %s in stage %s (TID %d, partition %d) in %d ms on executor %s (%s) " + + "(%d/%d)").format(info.id, taskSet.id, task.partitionId, info.taskId, info.duration, + info.executorId, info.host, tasksSuccessful, numTasks)) // Mark successful and stop if all the tasks have succeeded. successful(index) = true if (tasksSuccessful == numTasks) { diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d9c63b6e7bbb9..a6bd3020d147f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -74,12 +74,16 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB * end of the output file. This will be used by getBlockLocation to figure out where each block * begins and ends. * */ - def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = { + def writeIndexFile( + shuffleId: Int, + mapId: Int, + lengths: Array[Long], + initialFileLength: Long): Unit = { val indexFile = getIndexFile(shuffleId, mapId) val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile))) Utils.tryWithSafeFinally { // We take in lengths of each block, need to convert it to offsets. - var offset = 0L + var offset = initialFileLength out.writeLong(offset) for (length <- lengths) { offset += length diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index c9dd6bfc4c219..bbdd7a5c5cb99 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle.sort -import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext} +import org.apache.spark.{SparkEnv, Logging, TaskContext} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle} @@ -66,9 +66,12 @@ private[spark] class SortShuffleWriter[K, V, C]( // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) + // Because we append to the data file, we need the index file to know the current size of the + // data file as a starting point + val initialFileLength = outputFile.length() val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) - shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths) + shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths, initialFileLength) mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index c2089b0e56a1f..8718f276a9245 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1009,7 +1009,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics()); + TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, 0, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java index 83d109115aa5c..ffbe8d37d025b 100644 --- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java @@ -175,7 +175,7 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable { partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; return null; } - }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class)); + }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class), eq(0L)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer( new Answer>() { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 6a8ae29aae675..29056eff0e57e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} import org.apache.spark.util.CallSite import org.apache.spark.executor.TaskMetrics @@ -535,7 +536,7 @@ class DAGSchedulerSuite // should be ignored for being too old runEvent(CompletionEvent( taskSet.tasks(0), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) - // should work because it's a non-failed host + // its a non-failed host, but we can't be sure if the results were clobbered runEvent(CompletionEvent( taskSet.tasks(0), Success, makeMapStatus("hostB", 1), null, createFakeTaskInfo(), null)) // should be ignored for being too old @@ -545,9 +546,21 @@ class DAGSchedulerSuite taskSet.tasks(1).epoch = newEpoch runEvent(CompletionEvent( taskSet.tasks(1), Success, makeMapStatus("hostA", 1), null, createFakeTaskInfo(), null)) + + // now we should have a new taskSet for stage 0, which has us retry partition 0 + assert(taskSets.size === 2) + val newTaskSet = taskSets(1) + assert(newTaskSet.stageId === 0) + assert(newTaskSet.attempt === 1) + assert(newTaskSet.tasks.size === 1) + val newTask = newTaskSet.tasks(0) + assert(newTask.epoch === newEpoch + 1) + runEvent(CompletionEvent( + newTask, Success, makeMapStatus("hostB", 1), null, createFakeTaskInfo(), null)) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) - complete(taskSets(1), Seq((Success, 42), (Success, 43))) + complete(taskSets(2), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) assertDataStructuresEmpty() } @@ -773,6 +786,66 @@ class DAGSchedulerSuite assertDataStructuresEmpty() } + ignore("no concurrent retries for stage attempts (SPARK-7308)") { + // see SPARK-7308 for a detailed description of the conditions this is trying to recreate. + // note that this is somewhat convoluted for a test case, but isn't actually very unusual + // under a real workload. Note that we only fail the first attempt of stage 2, but that + // could be enough to cause havoc. + + val conf = new SparkConf().set("spark.executor.memory", "100m") + val clusterSc = new SparkContext("local-cluster[5,4,100]", "test-cluster", conf) + val bms = ArrayBuffer[BlockManagerId]() + val stageFailureCount = HashMap[Int, Int]() + clusterSc.addSparkListener(new SparkListener { + override def onBlockManagerAdded(bmAdded: SparkListenerBlockManagerAdded): Unit = { + bms += bmAdded.blockManagerId + } + + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + if (stageCompleted.stageInfo.failureReason.isDefined) { + val stage = stageCompleted.stageInfo.stageId + stageFailureCount(stage) = stageFailureCount.getOrElse(stage, 0) + 1 + println("stage " + stage + " failed: " + stageFailureCount(stage)) + } + } + }) + try { + val rawData = clusterSc.parallelize(1 to 1e6.toInt, 20).map { x => (x % 100) -> x }.cache() + rawData.count() + val aBm = bms(0) + val shuffled = rawData.groupByKey(100).mapPartitionsWithIndex { case (idx, itr) => + // we want one failure quickly, and more failures after stage 0 has finished its + // second attempt + if (TaskContext.get().asInstanceOf[TaskContextImpl].stageAttemptId == 0) { + if (idx == 0) { + throw new FetchFailedException(aBm, 0, 0, idx, + cause = new RuntimeException("simulated fetch failure")) + } else if (idx > 0 && math.random < 0.2) { + Thread.sleep(5000) + throw new FetchFailedException(aBm, 0, 0, idx, + cause = new RuntimeException("simulated fetch failure")) + } else { + // want to make sure plenty of these finish after task 0 fails, and some even finish + // after the previous stage is retried and this stage retry is started + Thread.sleep((500 + math.random * 5000).toLong) + } + } + itr.map { x => ((x._1 + 5) % 100) -> x._2 } + } + val shuffledAgain = shuffled.flatMap { case (k, vs) => vs.map(k -> _) }.groupByKey(100) + val data = shuffledAgain.mapPartitions { itr => itr.flatMap(_._2) }.cache().collect() + val count = data.size + assert(count === 1e6.toInt) + assert(data.toSet === (1 to 1e6.toInt).toSet) + // we should only get one failure from stage 2, everything else should be fine + assert(stageFailureCount(2) === 1) + assert(stageFailureCount.getOrElse(1, 0) === 0) + assert(stageFailureCount.getOrElse(3, 0) == 0) + } finally { + clusterSc.stop() + } + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 0a7cb69416a08..188dded7c02f7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import org.apache.spark.TaskContext -class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0) { +class FakeTask(stageId: Int, prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) { override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index 9b92f8de56759..383855caefa2f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -25,7 +25,7 @@ import org.apache.spark.TaskContext * A Task implementation that fails to serialize. */ private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) - extends Task[Array[Byte]](stageId, 0) { + extends Task[Array[Byte]](stageId, 0, 0) { override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 83ae8701243e5..34602a95e53ec 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -42,8 +42,8 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte } val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() - val task = new ResultTask[String, String]( - 0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) + val task = new ResultTask[String, String](0, 0, + sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) intercept[RuntimeException] { task.run(0, 0) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 6198cea46ddf8..65688124607a7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -137,7 +137,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex /** * A Task implementation that results in a large serialized task. */ -class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0) { +class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) { val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024) val random = new Random(0) random.nextBytes(randomBuffer)