@@ -27,7 +27,11 @@ import org.apache.spark.util.{RpcUtils, ThreadUtils}
2727private sealed trait OutputCommitCoordinationMessage extends Serializable
2828
2929private case object StopCoordinator extends OutputCommitCoordinationMessage
30- private case class AskPermissionToCommitOutput (stage : Int , partition : Int , attemptNumber : Int )
30+ private case class AskPermissionToCommitOutput (
31+ stage : Int ,
32+ stageAttempt : Int ,
33+ partition : Int ,
34+ attemptNumber : Int )
3135
3236/**
3337 * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins"
@@ -45,13 +49,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
4549 // Initialized by SparkEnv
4650 var coordinatorRef : Option [RpcEndpointRef ] = None
4751
48- private type StageId = Int
49- private type PartitionId = Int
50- private type TaskAttemptNumber = Int
51- private val NO_AUTHORIZED_COMMITTER : TaskAttemptNumber = - 1
52+ // Class used to identify a committer. The task ID for a committer is implicitly defined by
53+ // the partition being processed, but the coordinator needs to keep track of both the stage
54+ // attempt and the task attempt, because in some situations the same task may be running
55+ // concurrently in two different attempts of the same stage.
56+ private case class TaskIdentifier (stageAttempt : Int , taskAttempt : Int )
57+
5258 private case class StageState (numPartitions : Int ) {
53- val authorizedCommitters = Array .fill[TaskAttemptNumber ](numPartitions)(NO_AUTHORIZED_COMMITTER )
54- val failures = mutable.Map [PartitionId , mutable.Set [TaskAttemptNumber ]]()
59+ val authorizedCommitters = Array .fill[TaskIdentifier ](numPartitions)(null )
60+ val failures = mutable.Map [Int , mutable.Set [TaskIdentifier ]]()
5561 }
5662
5763 /**
@@ -64,7 +70,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
6470 *
6571 * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance.
6672 */
67- private val stageStates = mutable.Map [StageId , StageState ]()
73+ private val stageStates = mutable.Map [Int , StageState ]()
6874
6975 /**
7076 * Returns whether the OutputCommitCoordinator's internal data structures are all empty.
@@ -87,10 +93,11 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
8793 * @return true if this task is authorized to commit, false otherwise
8894 */
8995 def canCommit (
90- stage : StageId ,
91- partition : PartitionId ,
92- attemptNumber : TaskAttemptNumber ): Boolean = {
93- val msg = AskPermissionToCommitOutput (stage, partition, attemptNumber)
96+ stage : Int ,
97+ stageAttempt : Int ,
98+ partition : Int ,
99+ attemptNumber : Int ): Boolean = {
100+ val msg = AskPermissionToCommitOutput (stage, stageAttempt, partition, attemptNumber)
94101 coordinatorRef match {
95102 case Some (endpointRef) =>
96103 ThreadUtils .awaitResult(endpointRef.ask[Boolean ](msg),
@@ -103,26 +110,35 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
103110 }
104111
105112 /**
106- * Called by the DAGScheduler when a stage starts.
113+ * Called by the DAGScheduler when a stage starts. Initializes the stage's state if it hasn't
114+ * yet been initialized.
107115 *
108116 * @param stage the stage id.
109117 * @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e.
110118 * the maximum possible value of `context.partitionId`).
111119 */
112- private [scheduler] def stageStart (stage : StageId , maxPartitionId : Int ): Unit = synchronized {
113- stageStates(stage) = new StageState (maxPartitionId + 1 )
120+ private [scheduler] def stageStart (stage : Int , maxPartitionId : Int ): Unit = synchronized {
121+ stageStates.get(stage) match {
122+ case Some (state) =>
123+ require(state.authorizedCommitters.length == maxPartitionId + 1 )
124+ logInfo(s " Reusing state from previous attempt of stage $stage. " )
125+
126+ case _ =>
127+ stageStates(stage) = new StageState (maxPartitionId + 1 )
128+ }
114129 }
115130
116131 // Called by DAGScheduler
117- private [scheduler] def stageEnd (stage : StageId ): Unit = synchronized {
132+ private [scheduler] def stageEnd (stage : Int ): Unit = synchronized {
118133 stageStates.remove(stage)
119134 }
120135
121136 // Called by DAGScheduler
122137 private [scheduler] def taskCompleted (
123- stage : StageId ,
124- partition : PartitionId ,
125- attemptNumber : TaskAttemptNumber ,
138+ stage : Int ,
139+ stageAttempt : Int ,
140+ partition : Int ,
141+ attemptNumber : Int ,
126142 reason : TaskEndReason ): Unit = synchronized {
127143 val stageState = stageStates.getOrElse(stage, {
128144 logDebug(s " Ignoring task completion for completed stage " )
@@ -131,16 +147,17 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
131147 reason match {
132148 case Success =>
133149 // The task output has been committed successfully
134- case denied : TaskCommitDenied =>
135- logInfo(s " Task was denied committing, stage: $stage, partition: $partition , " +
136- s " attempt: $attemptNumber" )
137- case otherReason =>
150+ case _ : TaskCommitDenied =>
151+ logInfo(s " Task was denied committing, stage: $stage. $stageAttempt , " +
152+ s " partition: $partition , attempt: $attemptNumber" )
153+ case _ =>
138154 // Mark the attempt as failed to blacklist from future commit protocol
139- stageState.failures.getOrElseUpdate(partition, mutable.Set ()) += attemptNumber
140- if (stageState.authorizedCommitters(partition) == attemptNumber) {
155+ val taskId = TaskIdentifier (stageAttempt, attemptNumber)
156+ stageState.failures.getOrElseUpdate(partition, mutable.Set ()) += taskId
157+ if (stageState.authorizedCommitters(partition) == taskId) {
141158 logDebug(s " Authorized committer (attemptNumber= $attemptNumber, stage= $stage, " +
142159 s " partition= $partition) failed; clearing lock " )
143- stageState.authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER
160+ stageState.authorizedCommitters(partition) = null
144161 }
145162 }
146163 }
@@ -155,47 +172,41 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
155172
156173 // Marked private[scheduler] instead of private so this can be mocked in tests
157174 private [scheduler] def handleAskPermissionToCommit (
158- stage : StageId ,
159- partition : PartitionId ,
160- attemptNumber : TaskAttemptNumber ): Boolean = synchronized {
175+ stage : Int ,
176+ stageAttempt : Int ,
177+ partition : Int ,
178+ attemptNumber : Int ): Boolean = synchronized {
161179 stageStates.get(stage) match {
162- case Some (state) if attemptFailed(state, partition, attemptNumber) =>
163- logInfo(s " Denying attemptNumber= $attemptNumber to commit for stage= $stage, " +
164- s " partition= $partition as task attempt $attemptNumber has already failed. " )
180+ case Some (state) if attemptFailed(state, stageAttempt, partition, attemptNumber) =>
181+ logInfo(s " Commit denied for stage= $stage. $stageAttempt , partition= $partition : " +
182+ s " task attempt $attemptNumber already marked as failed. " )
165183 false
166184 case Some (state) =>
167- state.authorizedCommitters(partition) match {
168- case NO_AUTHORIZED_COMMITTER =>
169- logDebug(s " Authorizing attemptNumber= $attemptNumber to commit for stage= $stage, " +
170- s " partition= $partition" )
171- state.authorizedCommitters(partition) = attemptNumber
172- true
173- case existingCommitter =>
174- // Coordinator should be idempotent when receiving AskPermissionToCommit.
175- if (existingCommitter == attemptNumber) {
176- logWarning(s " Authorizing duplicate request to commit for " +
177- s " attemptNumber= $attemptNumber to commit for stage= $stage, " +
178- s " partition= $partition; existingCommitter = $existingCommitter. " +
179- s " This can indicate dropped network traffic. " )
180- true
181- } else {
182- logDebug(s " Denying attemptNumber= $attemptNumber to commit for stage= $stage, " +
183- s " partition= $partition; existingCommitter = $existingCommitter" )
184- false
185- }
185+ val existing = state.authorizedCommitters(partition)
186+ if (existing == null ) {
187+ logDebug(s " Commit allowed for stage= $stage. $stageAttempt, partition= $partition, " +
188+ s " task attempt $attemptNumber" )
189+ state.authorizedCommitters(partition) = TaskIdentifier (stageAttempt, attemptNumber)
190+ true
191+ } else {
192+ logDebug(s " Commit denied for stage= $stage. $stageAttempt, partition= $partition: " +
193+ s " already committed by $existing" )
194+ false
186195 }
187196 case None =>
188- logDebug(s " Stage $stage has completed, so not allowing " +
189- s " attempt number $attemptNumber of partition $partition to commit " )
197+ logDebug(s " Commit denied for stage= $stage . $stageAttempt , partition= $partition : " +
198+ " stage already marked as completed. " )
190199 false
191200 }
192201 }
193202
194203 private def attemptFailed (
195204 stageState : StageState ,
196- partition : PartitionId ,
197- attempt : TaskAttemptNumber ): Boolean = synchronized {
198- stageState.failures.get(partition).exists(_.contains(attempt))
205+ stageAttempt : Int ,
206+ partition : Int ,
207+ attempt : Int ): Boolean = synchronized {
208+ val failInfo = TaskIdentifier (stageAttempt, attempt)
209+ stageState.failures.get(partition).exists(_.contains(failInfo))
199210 }
200211}
201212
@@ -215,9 +226,10 @@ private[spark] object OutputCommitCoordinator {
215226 }
216227
217228 override def receiveAndReply (context : RpcCallContext ): PartialFunction [Any , Unit ] = {
218- case AskPermissionToCommitOutput (stage, partition, attemptNumber) =>
229+ case AskPermissionToCommitOutput (stage, stageAttempt, partition, attemptNumber) =>
219230 context.reply(
220- outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber))
231+ outputCommitCoordinator.handleAskPermissionToCommit(stage, stageAttempt, partition,
232+ attemptNumber))
221233 }
222234 }
223235}
0 commit comments