Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 41 additions & 13 deletions core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ private[spark] object AccumulatorSuite {
val listener = new SaveInfoListener
sc.addSparkListener(listener)
testBody
// wait until all events have been processed before proceeding to assert things
sc.listenerBus.waitUntilEmpty(10 * 1000)
val accums = listener.getCompletedStageInfos.flatMap(_.accumulables.values)
val isSet = accums.exists { a =>
a.name == Some(PEAK_EXECUTION_MEMORY) && a.value.exists(_.asInstanceOf[Long] > 0L)
Expand All @@ -321,35 +323,60 @@ private[spark] object AccumulatorSuite {
* A simple listener that keeps track of the TaskInfos and StageInfos of all completed jobs.
*/
private class SaveInfoListener extends SparkListener {
private val completedStageInfos: ArrayBuffer[StageInfo] = new ArrayBuffer[StageInfo]
private val completedTaskInfos: ArrayBuffer[TaskInfo] = new ArrayBuffer[TaskInfo]
private var jobCompletionCallback: (Int => Unit) = null // parameter is job ID
type StageId = Int
type StageAttemptId = Int

// Accesses must be synchronized to ensure failures in `jobCompletionCallback` are propagated
private val completedStageInfos = new ArrayBuffer[StageInfo]
private val completedTaskInfos =
new mutable.HashMap[(StageId, StageAttemptId), ArrayBuffer[TaskInfo]]

// Callback to call when a job completes. Parameter is job ID.
@GuardedBy("this")
private var jobCompletionCallback: () => Unit = null
private var calledJobCompletionCallback: Boolean = false
private var exception: Throwable = null

def getCompletedStageInfos: Seq[StageInfo] = completedStageInfos.toArray.toSeq
def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.toArray.toSeq
def getCompletedTaskInfos: Seq[TaskInfo] = completedTaskInfos.values.flatten.toSeq
def getCompletedTaskInfos(stageId: StageId, stageAttemptId: StageAttemptId): Seq[TaskInfo] =
completedTaskInfos.get((stageId, stageAttemptId)).getOrElse(Seq.empty[TaskInfo])

/** Register a callback to be called on job end. */
def registerJobCompletionCallback(callback: (Int => Unit)): Unit = {
jobCompletionCallback = callback
/**
* If `jobCompletionCallback` is set, block until the next call has finished.
* If the callback failed with an exception, throw it.
*/
def awaitNextJobCompletion(): Unit = synchronized {
if (jobCompletionCallback != null) {
while (!calledJobCompletionCallback) {
wait()
}
calledJobCompletionCallback = false
if (exception != null) {
exception = null
throw exception
}
}
}

/** Throw a stored exception, if any. */
def maybeThrowException(): Unit = synchronized {
if (exception != null) { throw exception }
/**
* Register a callback to be called on job end.
* A call to this should be followed by [[awaitNextJobCompletion]].
*/
def registerJobCompletionCallback(callback: () => Unit): Unit = {
jobCompletionCallback = callback
}

override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized {
if (jobCompletionCallback != null) {
try {
jobCompletionCallback(jobEnd.jobId)
jobCompletionCallback()
} catch {
// Store any exception thrown here so we can throw them later in the main thread.
// Otherwise, if `jobCompletionCallback` threw something it wouldn't fail the test.
case NonFatal(e) => exception = e
} finally {
calledJobCompletionCallback = true
notify()
}
}
}
Expand All @@ -359,7 +386,8 @@ private class SaveInfoListener extends SparkListener {
}

override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
completedTaskInfos += taskEnd.taskInfo
completedTaskInfos.getOrElseUpdate(
(taskEnd.stageId, taskEnd.stageAttemptId), new ArrayBuffer[TaskInfo]) += taskEnd.taskInfo
}
}

Expand Down
128 changes: 63 additions & 65 deletions core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockId, BlockStatus}


Expand Down Expand Up @@ -160,7 +161,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
iter
}
// Register asserts in job completion callback to avoid flakiness
listener.registerJobCompletionCallback { _ =>
listener.registerJobCompletionCallback { () =>
val stageInfos = listener.getCompletedStageInfos
val taskInfos = listener.getCompletedTaskInfos
assert(stageInfos.size === 1)
Expand All @@ -179,6 +180,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
}
rdd.count()
listener.awaitNextJobCompletion()
}

test("internal accumulators in multiple stages") {
Expand All @@ -205,7 +207,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
iter
}
// Register asserts in job completion callback to avoid flakiness
listener.registerJobCompletionCallback { _ =>
listener.registerJobCompletionCallback { () =>
// We ran 3 stages, and the accumulator values should be distinct
val stageInfos = listener.getCompletedStageInfos
assert(stageInfos.size === 3)
Expand All @@ -220,13 +222,66 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
rdd.count()
}

// TODO: these two tests are incorrect; they don't actually trigger stage retries.
ignore("internal accumulators in fully resubmitted stages") {
testInternalAccumulatorsWithFailedTasks((i: Int) => true) // fail all tasks
}
test("internal accumulators in resubmitted stages") {
val listener = new SaveInfoListener
val numPartitions = 10
sc = new SparkContext("local", "test")
sc.addSparkListener(listener)

// Simulate fetch failures in order to trigger a stage retry. Here we run 1 job with
// 2 stages. On the second stage, we trigger a fetch failure on the first stage attempt.
// This should retry both stages in the scheduler. Note that we only want to fail the
// first stage attempt because we want the stage to eventually succeed.
val x = sc.parallelize(1 to 100, numPartitions)
.mapPartitions { iter => TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1; iter }
.groupBy(identity)
val sid = x.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle.shuffleId
val rdd = x.mapPartitionsWithIndex { case (i, iter) =>
// Fail the first stage attempt. Here we use the task attempt ID to determine this.
// This job runs 2 stages, and we're in the second stage. Therefore, any task attempt
// ID that's < 2 * numPartitions belongs to the first attempt of this stage.
val taskContext = TaskContext.get()
val isFirstStageAttempt = taskContext.taskAttemptId() < numPartitions * 2
if (isFirstStageAttempt) {
throw new FetchFailedException(
SparkEnv.get.blockManager.blockManagerId,
sid,
taskContext.partitionId(),
taskContext.partitionId(),
"simulated fetch failure")
} else {
iter
}
}

ignore("internal accumulators in partially resubmitted stages") {
testInternalAccumulatorsWithFailedTasks((i: Int) => i % 2 == 0) // fail a subset
// Register asserts in job completion callback to avoid flakiness
listener.registerJobCompletionCallback { () =>
val stageInfos = listener.getCompletedStageInfos
assert(stageInfos.size === 4) // 1 shuffle map stage + 1 result stage, both are retried
val mapStageId = stageInfos.head.stageId
val mapStageInfo1stAttempt = stageInfos.head
val mapStageInfo2ndAttempt = {
stageInfos.tail.find(_.stageId == mapStageId).getOrElse {
fail("expected two attempts of the same shuffle map stage.")
}
}
val stageAccum1stAttempt = findTestAccum(mapStageInfo1stAttempt.accumulables.values)
val stageAccum2ndAttempt = findTestAccum(mapStageInfo2ndAttempt.accumulables.values)
// Both map stages should have succeeded, since the fetch failure happened in the
// result stage, not the map stage. This means we should get the accumulator updates
// from all partitions.
assert(stageAccum1stAttempt.value.get.asInstanceOf[Long] === numPartitions)
assert(stageAccum2ndAttempt.value.get.asInstanceOf[Long] === numPartitions)
// Because this test resubmitted the map stage with all missing partitions, we should have
// created a fresh set of internal accumulators in the 2nd stage attempt. Assert this is
// the case by comparing the accumulator IDs between the two attempts.
// Note: it would be good to also test the case where the map stage is resubmitted where
// only a subset of the original partitions are missing. However, this scenario is very
// difficult to construct without potentially introducing flakiness.
assert(stageAccum1stAttempt.id != stageAccum2ndAttempt.id)
}
rdd.count()
listener.awaitNextJobCompletion()
}

test("internal accumulators are registered for cleanups") {
Expand Down Expand Up @@ -257,63 +312,6 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext {
}
}

/**
* Test whether internal accumulators are merged properly if some tasks fail.
* TODO: make this actually retry the stage.
*/
private def testInternalAccumulatorsWithFailedTasks(failCondition: (Int => Boolean)): Unit = {
val listener = new SaveInfoListener
val numPartitions = 10
val numFailedPartitions = (0 until numPartitions).count(failCondition)
// This says use 1 core and retry tasks up to 2 times
sc = new SparkContext("local[1, 2]", "test")
sc.addSparkListener(listener)
val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitionsWithIndex { case (i, iter) =>
val taskContext = TaskContext.get()
taskContext.taskMetrics().getAccum(TEST_ACCUM) += 1
// Fail the first attempts of a subset of the tasks
if (failCondition(i) && taskContext.attemptNumber() == 0) {
throw new Exception("Failing a task intentionally.")
}
iter
}
// Register asserts in job completion callback to avoid flakiness
listener.registerJobCompletionCallback { _ =>
val stageInfos = listener.getCompletedStageInfos
val taskInfos = listener.getCompletedTaskInfos
assert(stageInfos.size === 1)
assert(taskInfos.size === numPartitions + numFailedPartitions)
val stageAccum = findTestAccum(stageInfos.head.accumulables.values)
// If all partitions failed, then we would resubmit the whole stage again and create a
// fresh set of internal accumulators. Otherwise, these internal accumulators do count
// failed values, so we must include the failed values.
val expectedAccumValue =
if (numPartitions == numFailedPartitions) {
numPartitions
} else {
numPartitions + numFailedPartitions
}
assert(stageAccum.value.get.asInstanceOf[Long] === expectedAccumValue)
val taskAccumValues = taskInfos.flatMap { taskInfo =>
if (!taskInfo.failed) {
// If a task succeeded, its update value should always be 1
val taskAccum = findTestAccum(taskInfo.accumulables)
assert(taskAccum.update.isDefined)
assert(taskAccum.update.get.asInstanceOf[Long] === 1L)
assert(taskAccum.value.isDefined)
Some(taskAccum.value.get.asInstanceOf[Long])
} else {
// If a task failed, we should not get its accumulator values
assert(taskInfo.accumulables.isEmpty)
None
}
}
assert(taskAccumValues.sorted === (1L to numPartitions).toSeq)
}
rdd.count()
listener.maybeThrowException()
}

/**
* A special [[ContextCleaner]] that saves the IDs of the accumulators registered for cleanup.
*/
Expand Down