diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 8cd1d1c96aa0..01d8973e1bb0 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -110,10 +110,10 @@ private[spark] class TaskContextImpl( /** Marks the task as completed and triggers the completion listeners. */ @GuardedBy("this") - private[spark] def markTaskCompleted(): Unit = synchronized { + private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = synchronized { if (completed) return completed = true - invokeListeners(onCompleteCallbacks, "TaskCompletionListener", None) { + invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) { _.onTaskCompletion(this) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 5c337b992c84..7767ef1803a0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -115,26 +115,33 @@ private[spark] abstract class Task[T]( case t: Throwable => e.addSuppressed(t) } + context.markTaskCompleted(Some(e)) throw e } finally { - // Call the task completion callbacks. - context.markTaskCompleted() try { - Utils.tryLogNonFatalError { - // Release memory used by this thread for unrolling blocks - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP) - // Notify any tasks waiting for execution memory to be freed to wake up and try to - // acquire memory again. This makes impossible the scenario where a task sleeps forever - // because there are no other tasks left to notify it. Since this is safe to do but may - // not be strictly necessary, we should revisit whether we can remove this in the future. - val memoryManager = SparkEnv.get.memoryManager - memoryManager.synchronized { memoryManager.notifyAll() } - } + // Call the task completion callbacks. If "markTaskCompleted" is called twice, the second + // one is no-op. + context.markTaskCompleted(None) } finally { - // Though we unset the ThreadLocal here, the context member variable itself is still queried - // directly in the TaskRunner to check for FetchFailedExceptions. - TaskContext.unset() + try { + Utils.tryLogNonFatalError { + // Release memory used by this thread for unrolling blocks + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask( + MemoryMode.OFF_HEAP) + // Notify any tasks waiting for execution memory to be freed to wake up and try to + // acquire memory again. This makes impossible the scenario where a task sleeps forever + // because there are no other tasks left to notify it. Since this is safe to do but may + // not be strictly necessary, we should revisit whether we can remove this in the + // future. + val memoryManager = SparkEnv.get.memoryManager + memoryManager.synchronized { memoryManager.notifyAll() } + } + } finally { + // Though we unset the ThreadLocal here, the context member variable itself is still + // queried directly in the TaskRunner to check for FetchFailedExceptions. + TaskContext.unset() + } } } } diff --git a/core/src/main/scala/org/apache/spark/util/taskListeners.scala b/core/src/main/scala/org/apache/spark/util/taskListeners.scala index 1be31e88ab68..51feccfb8342 100644 --- a/core/src/main/scala/org/apache/spark/util/taskListeners.scala +++ b/core/src/main/scala/org/apache/spark/util/taskListeners.scala @@ -55,14 +55,16 @@ class TaskCompletionListenerException( extends RuntimeException { override def getMessage: String = { - if (errorMessages.size == 1) { - errorMessages.head - } else { - errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") - } + - previousError.map { e => + val listenerErrorMessage = + if (errorMessages.size == 1) { + errorMessages.head + } else { + errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n") + } + val previousErrorMessage = previousError.map { e => "\n\nPrevious exception in task: " + e.getMessage + "\n" + e.getStackTrace.mkString("\t", "\n\t", "") }.getOrElse("") + listenerErrorMessage + previousErrorMessage } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index b22da565d86e..992d3396d203 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -100,7 +100,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark context.addTaskCompletionListener(_ => throw new Exception("blah")) intercept[TaskCompletionListenerException] { - context.markTaskCompleted() + context.markTaskCompleted(None) } verify(listener, times(1)).onTaskCompletion(any()) @@ -231,10 +231,10 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark test("immediately call a completion listener if the context is completed") { var invocations = 0 val context = TaskContext.empty() - context.markTaskCompleted() + context.markTaskCompleted(None) context.addTaskCompletionListener(_ => invocations += 1) assert(invocations == 1) - context.markTaskCompleted() + context.markTaskCompleted(None) assert(invocations == 1) } @@ -254,6 +254,36 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(lastError == error) assert(invocations == 1) } + + test("TaskCompletionListenerException.getMessage should include previousError") { + val listenerErrorMessage = "exception in listener" + val taskErrorMessage = "exception in task" + val e = new TaskCompletionListenerException( + Seq(listenerErrorMessage), + Some(new RuntimeException(taskErrorMessage))) + assert(e.getMessage.contains(listenerErrorMessage) && e.getMessage.contains(taskErrorMessage)) + } + + test("all TaskCompletionListeners should be called even if some fail or a task") { + val context = TaskContext.empty() + val listener = mock(classOf[TaskCompletionListener]) + context.addTaskCompletionListener(_ => throw new Exception("exception in listener1")) + context.addTaskCompletionListener(listener) + context.addTaskCompletionListener(_ => throw new Exception("exception in listener3")) + + val e = intercept[TaskCompletionListenerException] { + context.markTaskCompleted(Some(new Exception("exception in task"))) + } + + // Make sure listener 2 was called. + verify(listener, times(1)).onTaskCompletion(any()) + + // also need to check failure in TaskCompletionListener does not mask earlier exception + assert(e.getMessage.contains("exception in listener1")) + assert(e.getMessage.contains("exception in listener3")) + assert(e.getMessage.contains("exception in task")) + } + } private object TaskContextSuite { diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala index 3050f9a25023..535105379963 100644 --- a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala @@ -145,7 +145,7 @@ class PartiallySerializedBlockSuite try { TaskContext.setTaskContext(TaskContext.empty()) val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2) - TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted() + TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted(None) Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose() Mockito.verifyNoMoreInteractions(memoryStore) } finally { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index e56e440380a5..9900d1edc4cb 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -192,7 +192,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Complete the task; then the 2nd block buffer should be exhausted verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() - taskContext.markTaskCompleted() + taskContext.markTaskCompleted(None) verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release() // The 3rd block should not be retained because the iterator is already in zombie state