1717
1818package org .apache .spark
1919
20- import java .nio .charset .StandardCharsets .UTF_8
2120import java .util .{Timer , TimerTask }
2221import java .util .concurrent .ConcurrentHashMap
2322import java .util .function .Consumer
2423
2524import scala .collection .mutable .ArrayBuffer
2625
27- import org .json4s .JsonAST ._
28- import org .json4s .JsonDSL ._
29- import org .json4s .jackson .JsonMethods .{compact , render }
30-
3126import org .apache .spark .internal .Logging
3227import org .apache .spark .rpc .{RpcCallContext , RpcEnv , ThreadSafeRpcEndpoint }
3328import org .apache .spark .scheduler .{LiveListenerBus , SparkListener , SparkListenerStageCompleted }
@@ -107,11 +102,13 @@ private[spark] class BarrierCoordinator(
107102 // An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call
108103 private val requesters : ArrayBuffer [RpcCallContext ] = new ArrayBuffer [RpcCallContext ](numTasks)
109104
110- // An Array of allGather messages for barrier tasks that have made a blocking runBarrier() call
111- private val allGatherMessages : ArrayBuffer [String ] = new Array [String ](numTasks).to[ArrayBuffer ]
105+ // Messages from each barrier task that have made a blocking runBarrier() call.
106+ // The messages will be replied to all tasks once sync finished.
107+ private val messages = Array .ofDim[String ](numTasks)
112108
113- // The blocking requestMethod called by tasks to sync up for this stage attempt
114- private var requestMethodToSync : RequestMethod .Value = RequestMethod .BARRIER
109+ // The request method which is called inside this barrier sync. All tasks should make sure
110+ // that they're calling the same method within the same barrier sync phase.
111+ private var requestMethod : RequestMethod .Value = _
115112
116113 // A timer task that ensures we may timeout for a barrier() call.
117114 private var timerTask : TimerTask = null
@@ -140,28 +137,18 @@ private[spark] class BarrierCoordinator(
140137
141138 // Process the global sync request. The barrier() call succeed if collected enough requests
142139 // within a configured time, otherwise fail all the pending requests.
143- def handleRequest (
144- requester : RpcCallContext ,
145- request : RequestToSync
146- ): Unit = synchronized {
140+ def handleRequest (requester : RpcCallContext , request : RequestToSync ): Unit = synchronized {
147141 val taskId = request.taskAttemptId
148142 val epoch = request.barrierEpoch
149- val requestMethod = request.requestMethod
150- val partitionId = request.partitionId
151- val allGatherMessage = request match {
152- case ag : AllGatherRequestToSync => ag.allGatherMessage
153- case _ => " "
154- }
155-
156- if (requesters.size == 0 ) {
157- requestMethodToSync = requestMethod
158- }
143+ val curReqMethod = request.requestMethod
159144
160- if (requestMethodToSync != requestMethod) {
145+ if (requesters.isEmpty) {
146+ requestMethod = curReqMethod
147+ } else if (requestMethod != curReqMethod) {
161148 requesters.foreach(
162149 _.sendFailure(new SparkException (s " $barrierId tried to use requestMethod " +
163- s " ` $requestMethod ` during barrier epoch $barrierEpoch, which does not match " +
164- s " the current synchronized requestMethod ` $requestMethodToSync ` "
150+ s " ` $curReqMethod ` during barrier epoch $barrierEpoch, which does not match " +
151+ s " the current synchronized requestMethod ` $requestMethod ` "
165152 ))
166153 )
167154 cleanupBarrierStage(barrierId)
@@ -186,10 +173,11 @@ private[spark] class BarrierCoordinator(
186173 }
187174 // Add the requester to array of RPCCallContexts pending for reply.
188175 requesters += requester
189- allGatherMessages( partitionId) = allGatherMessage
176+ messages(request. partitionId) = request.message
190177 logInfo(s " Barrier sync epoch $barrierEpoch from $barrierId received update from Task " +
191178 s " $taskId, current progress: ${requesters.size}/ $numTasks. " )
192- if (maybeFinishAllRequesters(requesters, numTasks)) {
179+ if (requesters.size == numTasks) {
180+ requesters.foreach(_.reply(messages))
193181 // Finished current barrier() call successfully, clean up ContextBarrierState and
194182 // increase the barrier epoch.
195183 logInfo(s " Barrier sync epoch $barrierEpoch from $barrierId received all updates from " +
@@ -201,25 +189,6 @@ private[spark] class BarrierCoordinator(
201189 }
202190 }
203191
204- // Finish all the blocking barrier sync requests from a stage attempt successfully if we
205- // have received all the sync requests.
206- private def maybeFinishAllRequesters (
207- requesters : ArrayBuffer [RpcCallContext ],
208- numTasks : Int ): Boolean = {
209- if (requesters.size == numTasks) {
210- requestMethodToSync match {
211- case RequestMethod .BARRIER =>
212- requesters.foreach(_.reply(" " ))
213- case RequestMethod .ALL_GATHER =>
214- val json : String = compact(render(allGatherMessages))
215- requesters.foreach(_.reply(json))
216- }
217- true
218- } else {
219- false
220- }
221- }
222-
223192 // Cleanup the internal state of a barrier stage attempt.
224193 def clear (): Unit = synchronized {
225194 // The global sync fails so the stage is expected to retry another attempt, all sync
@@ -239,11 +208,11 @@ private[spark] class BarrierCoordinator(
239208 }
240209
241210 override def receiveAndReply (context : RpcCallContext ): PartialFunction [Any , Unit ] = {
242- case request : RequestToSync =>
211+ case request @ RequestToSync (numTasks, stageId, stageAttemptId, _, _, _, _, _) =>
243212 // Get or init the ContextBarrierState correspond to the stage attempt.
244- val barrierId = ContextBarrierId (request. stageId, request. stageAttemptId)
213+ val barrierId = ContextBarrierId (stageId, stageAttemptId)
245214 states.computeIfAbsent(barrierId,
246- (key : ContextBarrierId ) => new ContextBarrierState (key, request. numTasks))
215+ (key : ContextBarrierId ) => new ContextBarrierState (key, numTasks))
247216 val barrierState = states.get(barrierId)
248217
249218 barrierState.handleRequest(context, request)
@@ -256,61 +225,28 @@ private[spark] class BarrierCoordinator(
256225
257226private [spark] sealed trait BarrierCoordinatorMessage extends Serializable
258227
259- private [spark] sealed trait RequestToSync extends BarrierCoordinatorMessage {
260- def numTasks : Int
261- def stageId : Int
262- def stageAttemptId : Int
263- def taskAttemptId : Long
264- def barrierEpoch : Int
265- def partitionId : Int
266- def requestMethod : RequestMethod .Value
267- }
268-
269- /**
270- * A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is
271- * identified by stageId + stageAttemptId + barrierEpoch.
272- *
273- * @param numTasks The number of global sync requests the BarrierCoordinator shall receive
274- * @param stageId ID of current stage
275- * @param stageAttemptId ID of current stage attempt
276- * @param taskAttemptId Unique ID of current task
277- * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls
278- * @param partitionId ID of the current partition the task is assigned to
279- * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
280- */
281- private [spark] case class BarrierRequestToSync (
282- numTasks : Int ,
283- stageId : Int ,
284- stageAttemptId : Int ,
285- taskAttemptId : Long ,
286- barrierEpoch : Int ,
287- partitionId : Int ,
288- requestMethod : RequestMethod .Value
289- ) extends RequestToSync
290-
291228/**
292- * A global sync request message from BarrierTaskContext, by `allGather()` call . Each request is
229+ * A global sync request message from BarrierTaskContext. Each request is
293230 * identified by stageId + stageAttemptId + barrierEpoch.
294231 *
295232 * @param numTasks The number of global sync requests the BarrierCoordinator shall receive
296233 * @param stageId ID of current stage
297234 * @param stageAttemptId ID of current stage attempt
298235 * @param taskAttemptId Unique ID of current task
299- * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls
236+ * @param barrierEpoch ID of a runBarrier() call, a task may consist multiple runBarrier() calls
300237 * @param partitionId ID of the current partition the task is assigned to
238+ * @param message Message sent from the BarrierTaskContext
301239 * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
302- * @param allGatherMessage Message sent from the BarrierTaskContext if requestMethod is ALL_GATHER
303240 */
304- private [spark] case class AllGatherRequestToSync (
241+ private [spark] case class RequestToSync (
305242 numTasks : Int ,
306243 stageId : Int ,
307244 stageAttemptId : Int ,
308245 taskAttemptId : Long ,
309246 barrierEpoch : Int ,
310247 partitionId : Int ,
311- requestMethod : RequestMethod .Value ,
312- allGatherMessage : String
313- ) extends RequestToSync
248+ message : String ,
249+ requestMethod : RequestMethod .Value ) extends BarrierCoordinatorMessage
314250
315251private [spark] object RequestMethod extends Enumeration {
316252 val BARRIER, ALL_GATHER = Value
0 commit comments