Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1111,9 +1111,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten)
recordsWritten += 1
}
} {
writer.close(hadoopContext)
}
}(finallyBlock = writer.close(hadoopContext))
committer.commitTask(hadoopContext)
outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) =>
om.setBytesWritten(callback())
Expand Down Expand Up @@ -1200,9 +1198,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten)
recordsWritten += 1
}
} {
writer.close()
}
}(finallyBlock = writer.close())
writer.commit()
outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) =>
om.setBytesWritten(callback())
Expand Down
14 changes: 10 additions & 4 deletions core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,16 @@ private[spark] abstract class Task[T](
}
try {
runTask(context)
} catch { case e: Throwable =>
// Catch all errors; run task failure callbacks, and rethrow the exception.
context.markTaskFailed(e)
throw e
} catch {
case e: Throwable =>
// Catch all errors; run task failure callbacks, and rethrow the exception.
try {
context.markTaskFailed(e)
} catch {
case t: Throwable =>
e.addSuppressed(t)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice - didn't know this existed

}
throw e
} finally {
// Call the task completion callbacks.
context.markTaskCompleted()
Expand Down
29 changes: 19 additions & 10 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1260,26 +1260,35 @@ private[spark] object Utils extends Logging {
}

/**
* Execute a block of code, call the failure callbacks before finally block if there is any
* exceptions happen. But if exceptions happen in the finally block, do not suppress the original
* exception.
* Execute a block of code and call the failure callbacks in the catch block. If exceptions occur
* in either the catch or the finally block, they are appended to the list of suppressed
* exceptions in original exception which is then rethrown.
*
* This is primarily an issue with `finally { out.close() }` blocks, where
* close needs to be called to clean up `out`, but if an exception happened
* in `out.write`, it's likely `out` may be corrupted and `out.close` will
* This is primarily an issue with `catch { abort() }` or `finally { out.close() }` blocks,
* where the abort/close needs to be called to clean up `out`, but if an exception happened
* in `out.write`, it's likely `out` may be corrupted and `abort` or `out.close` will
* fail as well. This would then suppress the original/likely more meaningful
* exception from the original `out.write` call.
*/
def tryWithSafeFinallyAndFailureCallbacks[T](block: => T)(finallyBlock: => Unit): T = {
def tryWithSafeFinallyAndFailureCallbacks[T](block: => T)
(catchBlock: => Unit = (), finallyBlock: => Unit = ()): T = {
var originalThrowable: Throwable = null
try {
block
} catch {
case t: Throwable =>
case cause: Throwable =>
// Purposefully not using NonFatal, because even fatal exceptions
// we don't want to have our finallyBlock suppress
originalThrowable = t
TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(t)
originalThrowable = cause
try {
logError("Aborting task", originalThrowable)
TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(originalThrowable)
catchBlock
} catch {
case t: Throwable =>
originalThrowable.addSuppressed(t)
logWarning(s"Suppressing exception in catch: " + t.getMessage, t)
}
throw originalThrowable
} finally {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.{SerializableConfiguration, Utils}

/** A container for all the details required when writing to a table. */
case class WriteRelation(
Expand Down Expand Up @@ -255,19 +255,16 @@ private[sql] class DefaultWriterContainer(

// If anything below fails, we should abort the task.
try {
while (iterator.hasNext) {
val internalRow = iterator.next()
writer.writeInternal(internalRow)
}

commitTask()
Utils.tryWithSafeFinallyAndFailureCallbacks {
while (iterator.hasNext) {
val internalRow = iterator.next()
writer.writeInternal(internalRow)
}
commitTask()
}(catchBlock = abortTask())
} catch {
case cause: Throwable =>
logError("Aborting task.", cause)
// call failure callbacks first, so we could have a chance to cleanup the writer.
TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause)
abortTask()
throw new SparkException("Task failed while writing rows.", cause)
case t: Throwable =>
throw new SparkException("Task failed while writing rows", t)
}

def commitTask(): Unit = {
Expand Down Expand Up @@ -421,37 +418,37 @@ private[sql] class DynamicPartitionWriterContainer(
// If anything below fails, we should abort the task.
var currentWriter: OutputWriter = null
try {
var currentKey: UnsafeRow = null
while (sortedIterator.next()) {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
Utils.tryWithSafeFinallyAndFailureCallbacks {
var currentKey: UnsafeRow = null
while (sortedIterator.next()) {
val nextKey = getBucketingKey(sortedIterator.getKey).asInstanceOf[UnsafeRow]
if (currentKey != nextKey) {
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
currentKey = nextKey.copy()
logDebug(s"Writing partition: $currentKey")

currentWriter = newOutputWriter(currentKey, getPartitionString)
}
currentKey = nextKey.copy()
logDebug(s"Writing partition: $currentKey")

currentWriter = newOutputWriter(currentKey, getPartitionString)
currentWriter.writeInternal(sortedIterator.getValue)
}
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}
currentWriter.writeInternal(sortedIterator.getValue)
}
if (currentWriter != null) {
currentWriter.close()
currentWriter = null
}

commitTask()
} catch {
case cause: Throwable =>
logError("Aborting task.", cause)
// call failure callbacks first, so we could have a chance to cleanup the writer.
TaskContext.get().asInstanceOf[TaskContextImpl].markTaskFailed(cause)
commitTask()
}(catchBlock = {
if (currentWriter != null) {
currentWriter.close()
}
abortTask()
throw new SparkException("Task failed while writing rows.", cause)
})
} catch {
case t: Throwable =>
throw new SparkException("Task failed while writing rows", t)
}
}
}