Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,17 @@ object SparkHadoopWriter extends Logging {
// Try to write all RDD partitions as a Hadoop OutputFormat.
try {
val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => {
// SPARK-24552: Generate a unique "attempt ID" based on the stage and task attempt numbers.
// Assumes that there won't be more than Short.MaxValue attempts, at least not concurrently.
val attemptId = (context.stageAttemptNumber << 16) | context.attemptNumber

executeTask(
context = context,
config = config,
jobTrackerId = jobTrackerId,
commitJobId = commitJobId,
sparkPartitionId = context.partitionId,
sparkAttemptNumber = context.attemptNumber,
sparkAttemptNumber = attemptId,
committer = committer,
iterator = iter)
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ case class KafkaStreamWriterFactory(

override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
taskId: Long,
epochId: Long): DataWriter[InternalRow] = {
new KafkaStreamDataWriter(topic, producerParams, schema.toAttributes)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ public interface DataSourceWriter {
DataWriterFactory<Row> createWriterFactory();

/**
* Returns whether Spark should use the commit coordinator to ensure that at most one attempt for
* each task commits.
* Returns whether Spark should use the commit coordinator to ensure that at most one task for
* each partition commits.
*
* @return true if commit coordinator should be used, false otherwise.
*/
Expand All @@ -90,9 +90,9 @@ default void onDataWriterCommit(WriterCommitMessage message) {}
* is undefined and @{@link #abort(WriterCommitMessage[])} may not be able to deal with it.
*
* Note that speculative execution may cause multiple tasks to run for a partition. By default,
* Spark uses the commit coordinator to allow at most one attempt to commit. Implementations can
* Spark uses the commit coordinator to allow at most one task to commit. Implementations can
* disable this behavior by overriding {@link #useCommitCoordinator()}. If disabled, multiple
* attempts may have committed successfully and one successful commit message per task will be
* tasks may have committed successfully and one successful commit message per task will be
* passed to this commit method. The remaining commit messages are ignored by Spark.
*/
void commit(WriterCommitMessage[] messages);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import org.apache.spark.annotation.InterfaceStability;

/**
* A data writer returned by {@link DataWriterFactory#createDataWriter(int, int, long)} and is
* A data writer returned by {@link DataWriterFactory#createDataWriter(int, long, long)} and is
* responsible for writing data for an input RDD partition.
*
* One Spark task has one exclusive data writer, so there is no thread-safe concern.
Expand All @@ -39,14 +39,14 @@
* {@link DataSourceWriter#commit(WriterCommitMessage[])} with commit messages from other data
* writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an
* exception will be sent to the driver side, and Spark may retry this writing task a few times.
* In each retry, {@link DataWriterFactory#createDataWriter(int, int, long)} will receive a
* different `attemptNumber`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])}
* In each retry, {@link DataWriterFactory#createDataWriter(int, long, long)} will receive a
* different `taskId`. Spark will call {@link DataSourceWriter#abort(WriterCommitMessage[])}
* when the configured number of retries is exhausted.
*
* Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task
* takes too long to finish. Different from retried tasks, which are launched one by one after the
* previous one fails, speculative tasks are running simultaneously. It's possible that one input
* RDD partition has multiple data writers with different `attemptNumber` running at the same time,
* RDD partition has multiple data writers with different `taskId` running at the same time,
* and data sources should guarantee that these data writers don't conflict and can work together.
* Implementations can coordinate with driver during {@link #commit()} to make sure only one of
* these data writers can commit successfully. Or implementations can allow all of them to commit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,12 @@ public interface DataWriterFactory<T> extends Serializable {
* Usually Spark processes many RDD partitions at the same time,
* implementations should use the partition id to distinguish writers for
* different partitions.
* @param attemptNumber Spark may launch multiple tasks with the same task id. For example, a task
* failed, Spark launches a new task wth the same task id but different
* attempt number. Or a task is too slow, Spark launches new tasks wth the
* same task id but different attempt number, which means there are multiple
* tasks with the same task id running at the same time. Implementations can
* use this attempt number to distinguish writers of different task attempts.
* @param taskId A unique identifier for a task that is performing the write of the partition
* data. Spark may run multiple tasks for the same partition (due to speculation
* or task failures, for example).
* @param epochId A monotonically increasing id for streaming queries that are split in to
* discrete periods of execution. For non-streaming queries,
* this ID will always be 0.
*/
DataWriter<T> createDataWriter(int partitionId, int attemptNumber, long epochId);
DataWriter<T> createDataWriter(int partitionId, long taskId, long epochId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming.{MicroBatchExecution, StreamExecution}
import org.apache.spark.sql.execution.streaming.continuous.{CommitPartitionEpoch, ContinuousExecution, EpochCoordinatorRef, SetWriterPartitions}
import org.apache.spark.sql.execution.streaming.MicroBatchExecution
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -111,9 +109,10 @@ object DataWritingSparkTask extends Logging {
val stageId = context.stageId()
val stageAttempt = context.stageAttemptNumber()
val partId = context.partitionId()
val taskId = context.taskAttemptId()
val attemptId = context.attemptNumber()
val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0")
val dataWriter = writeTask.createDataWriter(partId, attemptId, epochId.toLong)
val dataWriter = writeTask.createDataWriter(partId, taskId, epochId.toLong)

// write the data and commit this writer.
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
Expand All @@ -125,12 +124,12 @@ object DataWritingSparkTask extends Logging {
val coordinator = SparkEnv.get.outputCommitCoordinator
val commitAuthorized = coordinator.canCommit(stageId, stageAttempt, partId, attemptId)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a note for the followup: since we decided to use taskId as a unique identifier for write tasks, the output coordinator can also use taskId instead of stage and task attempts.

if (commitAuthorized) {
logInfo(s"Writer for stage $stageId / $stageAttempt, " +
s"task $partId.$attemptId is authorized to commit.")
logInfo(s"Commit authorized for partition $partId (task $taskId, attempt $attemptId" +
s"stage $stageId.$stageAttempt)")
dataWriter.commit()
} else {
val message = s"Stage $stageId / $stageAttempt, " +
s"task $partId.$attemptId: driver did not authorize commit"
val message = s"Commit denied for partition $partId (task $taskId, attempt $attemptId" +
s"stage $stageId.$stageAttempt)"
logInfo(message)
// throwing CommitDeniedException will trigger the catch block for abort
throw new CommitDeniedException(message, stageId, partId, attemptId)
Expand All @@ -141,15 +140,18 @@ object DataWritingSparkTask extends Logging {
dataWriter.commit()
}

logInfo(s"Writer for stage $stageId, task $partId.$attemptId committed.")
logInfo(s"Committed partition $partId (task $taskId, attempt $attemptId" +
s"stage $stageId.$stageAttempt)")

msg

})(catchBlock = {
// If there is an error, abort this writer
logError(s"Writer for stage $stageId, task $partId.$attemptId is aborting.")
logError(s"Aborting commit for partition $partId (task $taskId, attempt $attemptId" +
s"stage $stageId.$stageAttempt)")
dataWriter.abort()
logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.")
logError(s"Aborted commit for partition $partId (task $taskId, attempt $attemptId" +
s"stage $stageId.$stageAttempt)")
})
}
}
Expand All @@ -160,10 +162,10 @@ class InternalRowDataWriterFactory(

override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
taskId: Long,
epochId: Long): DataWriter[InternalRow] = {
new InternalRowDataWriter(
rowWriterFactory.createDataWriter(partitionId, attemptNumber, epochId),
rowWriterFactory.createDataWriter(partitionId, taskId, epochId),
RowEncoder.apply(schema).resolveAndBind())
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor
val dataIterator = prev.compute(split, context)
dataWriter = writeTask.createDataWriter(
context.partitionId(),
context.attemptNumber(),
context.taskAttemptId(),
EpochTracker.getCurrentEpoch.get)
while (dataIterator.hasNext) {
dataWriter.write(dataIterator.next())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ case class ForeachWriterFactory[T](
extends DataWriterFactory[InternalRow] {
override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
taskId: Long,
epochId: Long): ForeachDataWriter[T] = {
new ForeachDataWriter(writer, rowConverter, partitionId, epochId)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, Dat
case object PackedRowWriterFactory extends DataWriterFactory[Row] {
override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
taskId: Long,
epochId: Long): DataWriter[Row] = {
new PackedRowDataWriter()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class MemoryStreamWriter(
case class MemoryWriterFactory(outputMode: OutputMode) extends DataWriterFactory[Row] {
override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
taskId: Long,
epochId: Long): DataWriter[Row] = {
new MemoryDataWriter(partitionId, outputMode)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,10 @@ class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: Serializable

override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
taskId: Long,
epochId: Long): DataWriter[Row] = {
val jobPath = new Path(new Path(path, "_temporary"), jobId)
val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber")
val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId")
val fs = filePath.getFileSystem(conf.value)
new SimpleCSVDataWriter(fs, filePath)
}
Expand Down Expand Up @@ -245,10 +245,10 @@ class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: Seriali

override def createDataWriter(
partitionId: Int,
attemptNumber: Int,
taskId: Long,
epochId: Long): DataWriter[InternalRow] = {
val jobPath = new Path(new Path(path, "_temporary"), jobId)
val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber")
val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId")
val fs = filePath.getFileSystem(conf.value)
new InternalRowCSVDataWriter(fs, filePath)
}
Expand Down