Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ private[spark] class TaskContextImpl(

/** Marks the task as completed and triggers the completion listeners. */
@GuardedBy("this")
private[spark] def markTaskCompleted(): Unit = synchronized {
private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = synchronized {
if (completed) return
completed = true
invokeListeners(onCompleteCallbacks, "TaskCompletionListener", None) {
invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) {
_.onTaskCompletion(this)
}
}
Expand Down
39 changes: 23 additions & 16 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -115,26 +115,33 @@ private[spark] abstract class Task[T](
case t: Throwable =>
e.addSuppressed(t)
}
context.markTaskCompleted(Some(e))
throw e
} finally {
// Call the task completion callbacks.
context.markTaskCompleted()
try {
Utils.tryLogNonFatalError {
// Release memory used by this thread for unrolling blocks
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP)
// Notify any tasks waiting for execution memory to be freed to wake up and try to
// acquire memory again. This makes impossible the scenario where a task sleeps forever
// because there are no other tasks left to notify it. Since this is safe to do but may
// not be strictly necessary, we should revisit whether we can remove this in the future.
val memoryManager = SparkEnv.get.memoryManager
memoryManager.synchronized { memoryManager.notifyAll() }
}
// Call the task completion callbacks. If "markTaskCompleted" is called twice, the second
// one is no-op.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed this comment.
LGTM. Thanks for clarifying @zsxwing

context.markTaskCompleted(None)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just add try...finally to wrap this line and fix the style

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We end up calling markTaskCompleted twice when there is an exception thrown, right ?
Perhaps do this one when no Throwable is thrown.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We end up calling markTaskCompleted twice when there is an exception thrown, right ?

Yes.

Perhaps do this one when no Throwable is thrown.

Then if context.markTaskCompleted(None) throws an exception, context.markTaskFailed(e) will be called, so TaskFailureListener may be called after TaskCompletionListener. This is a slight behavior change. Not sure if it's safe. Someone may depend on the order of calling listeners?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I meant was, when there is an exception is throw, there will be two invocations of context.markTaskCompleted.
One with Throwable passed in, and another with None.

This would be confusing to the listeners - no ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mridulm there is a completed flag in markTaskCompleted.

} finally {
// Though we unset the ThreadLocal here, the context member variable itself is still queried
// directly in the TaskRunner to check for FetchFailedExceptions.
TaskContext.unset()
try {
Utils.tryLogNonFatalError {
// Release memory used by this thread for unrolling blocks
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
MemoryMode.OFF_HEAP)
// Notify any tasks waiting for execution memory to be freed to wake up and try to
// acquire memory again. This makes impossible the scenario where a task sleeps forever
// because there are no other tasks left to notify it. Since this is safe to do but may
// not be strictly necessary, we should revisit whether we can remove this in the
// future.
val memoryManager = SparkEnv.get.memoryManager
memoryManager.synchronized { memoryManager.notifyAll() }
}
} finally {
// Though we unset the ThreadLocal here, the context member variable itself is still
// queried directly in the TaskRunner to check for FetchFailedExceptions.
TaskContext.unset()
}
}
}
}
Expand Down
14 changes: 8 additions & 6 deletions core/src/main/scala/org/apache/spark/util/taskListeners.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,16 @@ class TaskCompletionListenerException(
extends RuntimeException {

override def getMessage: String = {
if (errorMessages.size == 1) {
Copy link
Member Author

@zsxwing zsxwing May 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem here is

if (...) { ... } else {...} +
{...}

is equivalent to

if (...) { ... } else {   {...} + {...}  }

which is not the intention.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch !
Btw, file name should probably have been taskListeners.scala (unrelated to this PR) ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's because it's not a class name. cc @rxin since you added this file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside: Was just curious about the naming - interesting. Is this common pattern in spark code ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a common pattern in scala.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for clarifying !

errorMessages.head
} else {
errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
} +
previousError.map { e =>
val listenerErrorMessage =
if (errorMessages.size == 1) {
errorMessages.head
} else {
errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
}
val previousErrorMessage = previousError.map { e =>
"\n\nPrevious exception in task: " + e.getMessage + "\n" +
e.getStackTrace.mkString("\t", "\n\t", "")
}.getOrElse("")
listenerErrorMessage + previousErrorMessage
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
context.addTaskCompletionListener(_ => throw new Exception("blah"))

intercept[TaskCompletionListenerException] {
context.markTaskCompleted()
context.markTaskCompleted(None)
}

verify(listener, times(1)).onTaskCompletion(any())
Expand Down Expand Up @@ -231,10 +231,10 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
test("immediately call a completion listener if the context is completed") {
var invocations = 0
val context = TaskContext.empty()
context.markTaskCompleted()
context.markTaskCompleted(None)
context.addTaskCompletionListener(_ => invocations += 1)
assert(invocations == 1)
context.markTaskCompleted()
context.markTaskCompleted(None)
assert(invocations == 1)
}

Expand All @@ -254,6 +254,36 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
assert(lastError == error)
assert(invocations == 1)
}

test("TaskCompletionListenerException.getMessage should include previousError") {
val listenerErrorMessage = "exception in listener"
val taskErrorMessage = "exception in task"
val e = new TaskCompletionListenerException(
Seq(listenerErrorMessage),
Some(new RuntimeException(taskErrorMessage)))
assert(e.getMessage.contains(listenerErrorMessage) && e.getMessage.contains(taskErrorMessage))
}

test("all TaskCompletionListeners should be called even if some fail or a task") {
val context = TaskContext.empty()
val listener = mock(classOf[TaskCompletionListener])
context.addTaskCompletionListener(_ => throw new Exception("exception in listener1"))
context.addTaskCompletionListener(listener)
context.addTaskCompletionListener(_ => throw new Exception("exception in listener3"))

val e = intercept[TaskCompletionListenerException] {
context.markTaskCompleted(Some(new Exception("exception in task")))
}

// Make sure listener 2 was called.
verify(listener, times(1)).onTaskCompletion(any())

// also need to check failure in TaskCompletionListener does not mask earlier exception
assert(e.getMessage.contains("exception in listener1"))
assert(e.getMessage.contains("exception in listener3"))
assert(e.getMessage.contains("exception in task"))
}

}

private object TaskContextSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class PartiallySerializedBlockSuite
try {
TaskContext.setTaskContext(TaskContext.empty())
val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted()
TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted(None)
Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose()
Mockito.verifyNoMoreInteractions(memoryStore)
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT

// Complete the task; then the 2nd block buffer should be exhausted
verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release()
taskContext.markTaskCompleted()
taskContext.markTaskCompleted(None)
verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release()

// The 3rd block should not be retained because the iterator is already in zombie state
Expand Down