Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ void closeAndWriteOutput() throws IOException {
final SpillInfo[] spills = sorter.closeAndGetSpills();
sorter = null;
final long[] partitionLengths;
final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
final long initialFileLength = outputFile.length();
try {
partitionLengths = mergeSpills(spills);
} finally {
Expand All @@ -184,7 +186,7 @@ void closeAndWriteOutput() throws IOException {
}
}
}
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths, initialFileLength);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}

Expand Down
7 changes: 4 additions & 3 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,10 @@ private[spark] object MapOutputTracker extends Logging {
statuses.map {
status =>
if (status == null) {
logError("Missing an output location for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId)
val missing = statuses.iterator.zipWithIndex.filter{_._1 == null}.map{_._2}.mkString(",")
val msg = "Missing an output location for shuffle =" + shuffleId + "; mapIds =" + missing
logError(msg)
throw new MetadataFetchFailedException(shuffleId, reduceId, msg)
} else {
(status.location, status.getSizeForBlock(reduceId))
}
Expand Down
10 changes: 10 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkException.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,13 @@ class SparkException(message: String, cause: Throwable)
*/
private[spark] class SparkDriverExecutionException(cause: Throwable)
extends SparkException("Execution error", cause)

/**
* Exception indicating an error internal to Spark -- it is in an inconsistent state, not due
* to any error by the user
*/
class SparkIllegalStateException(message: String, cause: Throwable)
extends SparkException(message, cause) {

def this(message: String) = this(message, null)
}
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ private[spark] class TaskContextImpl(
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
val runningLocally: Boolean = false,
val stageAttemptId: Int = 0, // for testing
val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
with Logging {
Expand Down
50 changes: 43 additions & 7 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,10 @@ class DAGScheduler(
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
stage.pendingTasks.clear()
stage match {
case smt: ShuffleMapStage => smt.clearPartitionComputeCount()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this (and everything related to partitionComputeCount) is for issues (3) & (4)

case _ =>
}


// First figure out the indexes of partition ids to compute.
Expand Down Expand Up @@ -891,7 +895,7 @@ class DAGScheduler(
partitionsToCompute.map { id =>
val locs = getPreferredLocs(stage.rdd, id)
val part = stage.rdd.partitions(id)
new ShuffleMapTask(stage.id, taskBinary, part, locs)
new ShuffleMapTask(stage.id, stage.attemptId, taskBinary, part, locs)
}

case stage: ResultStage =>
Expand All @@ -900,7 +904,7 @@ class DAGScheduler(
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = getPreferredLocs(stage.rdd, p)
new ResultTask(stage.id, taskBinary, part, locs, id)
new ResultTask(stage.id, stage.attemptId, taskBinary, part, locs, id)
}
}

Expand Down Expand Up @@ -974,6 +978,7 @@ class DAGScheduler(
val stageId = task.stageId
val taskType = Utils.getFormattedClassName(task)

// REVIEWERS: does this need special handling for multiple completions of the same task?
outputCommitCoordinator.taskCompleted(stageId, task.partitionId,
event.taskInfo.attempt, event.reason)

Expand Down Expand Up @@ -1031,15 +1036,31 @@ class DAGScheduler(

case smt: ShuffleMapTask =>
val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
val computeCount = shuffleStage.incComputeCount(smt.partitionId)
updateAccumulators(event)
val status = event.result.asInstanceOf[MapStatus]
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
if (computeCount > 1) {
// REVIEWERS: do I need to worry about speculation here, when multiple completion
// events are normal?

// REVIEWERS: is this really only a problem on a ShuffleMapTask?? does it also cause
// problems for ResultTask?

// This can happen when a retry runs a task, but there was a lingering task from an
// earlier attempt which also finished. The results might be OK, or they might not.
// To be safe, we'll retry the task, and do it in yet another attempt, to avoid more
// task output clobbering.
logInfo(s"Multiple completion events for task $task. Results may be corrupt," +
s" assuming task needs to be rerun.")
shuffleStage.removeOutputLoc(task.partitionId)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for (3) & (4)

} else if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
} else {
shuffleStage.addOutputLoc(smt.partitionId, status)
}

if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) {
markStageAsFinished(shuffleStage)
logInfo("looking for newly runnable stages")
Expand Down Expand Up @@ -1103,9 +1124,14 @@ class DAGScheduler(
// multiple tasks running concurrently on different executors). In that case, it is possible
// the fetch failure has already been handled by the scheduler.
if (runningStages.contains(failedStage)) {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
markStageAsFinished(failedStage, Some(failureMessage))
if (failedStage.attemptId - 1 > task.stageAttemptId) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This addresses issue (1)-- if we get a fetch failure, but we've already failed the attempt of the stage that caused the fetch failure, then do not resubmit the stage again. (Lots of other small changes to add stageAttemptId to the task)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block is a little awkward. Being a little more explicit would be good here.

Something like this:

//it is possible failure has already been handled by the scheduler.
val failureRequiresHandling = runningStages.contains(failedStage);
if (failureRequireHandling) {
  val stageHasFailed = failedStage.attemptId - 1 > task.stageAttemptId;
  if (stageHasFailed) {
    ...
  } 
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused about how this works. Doesn't the stage still get added to failed stages on line 1149, so it will still be resubmitted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, good point. I think it works in my existing test case, because submitStage already checks if the stage is running before submitting it. So now this makes the stage simultaneously running and failed :/. Most likely this would result in issues if my test case had an even longer pipeline of stages in one job, so at some point a later attempt for this stage would succeed, so it would no longer be running and only be failed, and then it would get resubmitted for no reason. this is just from the top of my head though ... I'll need to look more carefully and try some more cases to see what is going on here.

(btw, thanks for looking at it in this state, I do still plan on splitting this apart some, just keep getting sidetracked ...)

logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
s" ${task.stageAttemptId}, which has already failed")
} else {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
markStageAsFinished(failedStage, Some(failureMessage))
}
}

if (disallowStageRetryForTest) {
Expand All @@ -1128,6 +1154,16 @@ class DAGScheduler(
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
}

// We also have to mark this map output as unavailable. Its possible that a *later* attempt
// has finished this task in the meantime, but when this task fails, it might end up
// deleting the mapOutput from the earlier successful attempt.
failedStage match {
case smt: ShuffleMapStage =>
smt.incComputeCount(reduceId)
smt.removeOutputLoc(reduceId)
case _ =>
}

// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD
*/
private[spark] class ResultTask[T, U](
stageId: Int,
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient locs: Seq[TaskLocation],
val outputId: Int)
extends Task[U](stageId, partition.index) with Serializable {
extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {

@transient private[this] val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.scheduler

import scala.collection.mutable.HashMap

import org.apache.spark.ShuffleDependency
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.BlockManagerId
Expand All @@ -43,6 +45,17 @@ private[spark] class ShuffleMapStage(

val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)

private val partitionComputeCount = HashMap[Int, Int]()

def incComputeCount(partition: Int): Int = {
partitionComputeCount(partition) = partitionComputeCount.getOrElse(partition, 0) + 1
partitionComputeCount(partition)
}

def clearPartitionComputeCount(): Unit = {
partitionComputeCount.clear()
}

def addOutputLoc(partition: Int, status: MapStatus): Unit = {
val prevList = outputLocs(partition)
outputLocs(partition) = status :: prevList
Expand All @@ -51,6 +64,13 @@ private[spark] class ShuffleMapStage(
}
}

def removeOutputLoc(partition: Int): Unit = {
if (outputLocs(partition) != Nil) {
outputLocs(partition) = Nil
numAvailableOutputs -= 1
}
}

def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = {
val prevList = outputLocs(partition)
val newList = prevList.filterNot(_.location == bmAddress)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ import org.apache.spark.shuffle.ShuffleWriter
*/
private[spark] class ShuffleMapTask(
stageId: Int,
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient private var locs: Seq[TaskLocation])
extends Task[MapStatus](stageId, partition.index) with Logging {
extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {

/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
this(0, null, new Partition { override def index: Int = 0 }, null)
this(0, 0, null, new Partition { override def index: Int = 0 }, null)
}

@transient private val preferredLocs: Seq[TaskLocation] = {
Expand Down
6 changes: 5 additions & 1 deletion core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ import org.apache.spark.util.Utils
* @param stageId id of the stage this task belongs to
* @param partitionId index of the number in the RDD
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
private[spark] abstract class Task[T](
val stageId: Int,
val stageAttemptId: Int,
var partitionId: Int) extends Serializable {

/**
* Called by [[Executor]] to run this task.
Expand All @@ -55,6 +58,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
final def run(taskAttemptId: Long, attemptNumber: Int): T = {
context = new TaskContextImpl(
stageId = stageId,
stageAttemptId = stageAttemptId,
partitionId = partitionId,
taskAttemptId = taskAttemptId,
attemptNumber = attemptNumber,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ private[spark] class TaskSchedulerImpl(
this.synchronized {
val manager = createTaskSetManager(taskSet, maxTaskFailures)
activeTaskSets(taskSet.id) = manager
val taskSetsPerStage = activeTaskSets.values.filterNot(_.isZombie).groupBy(_.stageId)
taskSetsPerStage.foreach { case (stage, taskSets) =>
if (taskSets.size > 1) {
throw new SparkIllegalStateException("more than one active taskSet for stage " + stage)
}
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is defensive. Hopefully this PR will eliminate multiple concurrent attempts for the same stage, but I'd like to add this check in any case.

schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

if (!isLocal && !hasReceivedTask) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -620,18 +620,21 @@ private[spark] class TaskSetManager(
val index = info.index
info.markSuccessful()
removeRunningTask(tid)
val task = tasks(index)
// This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the
// "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not
// "deserialize" the value when holding a lock to avoid blocking other threads. So we call
// "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here.
// Note: "result.value()" only deserializes the value when it's called at the first time, so
// here "result.value()" just returns the value and won't block other threads.
sched.dagScheduler.taskEnded(
tasks(index), Success, result.value(), result.accumUpdates, info, result.metrics)
task, Success, result.value(), result.accumUpdates, info, result.metrics)
if (!successful(index)) {
tasksSuccessful += 1
logInfo("Finished task %s in stage %s (TID %d) in %d ms on %s (%d/%d)".format(
info.id, taskSet.id, info.taskId, info.duration, info.host, tasksSuccessful, numTasks))
// include the partition here b/c on a retry, the partition is *not* the same as info.id
logInfo(("Finished task %s in stage %s (TID %d, partition %d) in %d ms on executor %s (%s) " +
"(%d/%d)").format(info.id, taskSet.id, task.partitionId, info.taskId, info.duration,
info.executorId, info.host, tasksSuccessful, numTasks))
// Mark successful and stop if all the tasks have succeeded.
successful(index) = true
if (tasksSuccessful == numTasks) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,16 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
* end of the output file. This will be used by getBlockLocation to figure out where each block
* begins and ends.
* */
def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = {
def writeIndexFile(
shuffleId: Int,
mapId: Int,
lengths: Array[Long],
initialFileLength: Long): Unit = {
val indexFile = getIndexFile(shuffleId, mapId)
val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
Utils.tryWithSafeFinally {
// We take in lengths of each block, need to convert it to offsets.
var offset = 0L
var offset = initialFileLength
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to fix issue (2) -- data files are appended, but before this, index files always pointed to the beginning of the data file.

out.writeLong(offset)
for (length <- lengths) {
offset += length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.shuffle.sort

import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext}
import org.apache.spark.{SparkEnv, Logging, TaskContext}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
Expand Down Expand Up @@ -66,9 +66,12 @@ private[spark] class SortShuffleWriter[K, V, C](
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
// Because we append to the data file, we need the index file to know the current size of the
// data file as a starting point
val initialFileLength = outputFile.length()
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths, initialFileLength)

mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ public void persist() {
@Test
public void iterator() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics());
TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, 0, new TaskMetrics());
Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
return null;
}
}).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class));
}).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class), eq(0L));

when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
new Answer<Tuple2<TempShuffleBlockId, File>>() {
Expand Down
Loading