|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package org.apache.spark |
| 19 | + |
| 20 | +import java.util.{Timer, TimerTask} |
| 21 | +import java.util.concurrent.ConcurrentHashMap |
| 22 | +import java.util.function.{Consumer, Function} |
| 23 | + |
| 24 | +import scala.collection.mutable.ArrayBuffer |
| 25 | + |
| 26 | +import org.apache.spark.internal.Logging |
| 27 | +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} |
| 28 | +import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted} |
| 29 | + |
| 30 | +/** |
| 31 | + * For each barrier stage attempt, only at most one barrier() call can be active at any time, thus |
| 32 | + * we can use (stageId, stageAttemptId) to identify the stage attempt where the barrier() call is |
| 33 | + * from. |
| 34 | + */ |
| 35 | +private case class ContextBarrierId(stageId: Int, stageAttemptId: Int) { |
| 36 | + override def toString: String = s"Stage $stageId (Attempt $stageAttemptId)" |
| 37 | +} |
| 38 | + |
| 39 | +/** |
| 40 | + * A coordinator that handles all global sync requests from BarrierTaskContext. Each global sync |
| 41 | + * request is generated by `BarrierTaskContext.barrier()`, and identified by |
| 42 | + * stageId + stageAttemptId + barrierEpoch. Reply all the blocking global sync requests upon |
| 43 | + * all the requests for a group of `barrier()` calls are received. If the coordinator is unable to |
| 44 | + * collect enough global sync requests within a configured time, fail all the requests and return |
| 45 | + * an Exception with timeout message. |
| 46 | + */ |
| 47 | +private[spark] class BarrierCoordinator( |
| 48 | + timeoutInSecs: Long, |
| 49 | + listenerBus: LiveListenerBus, |
| 50 | + override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { |
| 51 | + |
| 52 | + // TODO SPARK-25030 Create a Timer() in the mainClass submitted to SparkSubmit makes it unable to |
| 53 | + // fetch result, we shall fix the issue. |
| 54 | + private lazy val timer = new Timer("BarrierCoordinator barrier epoch increment timer") |
| 55 | + |
| 56 | + // Listen to StageCompleted event, clear corresponding ContextBarrierState. |
| 57 | + private val listener = new SparkListener { |
| 58 | + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { |
| 59 | + val stageInfo = stageCompleted.stageInfo |
| 60 | + val barrierId = ContextBarrierId(stageInfo.stageId, stageInfo.attemptNumber) |
| 61 | + // Clear ContextBarrierState from a finished stage attempt. |
| 62 | + cleanupBarrierStage(barrierId) |
| 63 | + } |
| 64 | + } |
| 65 | + |
| 66 | + // Record all active stage attempts that make barrier() call(s), and the corresponding internal |
| 67 | + // state. |
| 68 | + private val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState] |
| 69 | + |
| 70 | + override def onStart(): Unit = { |
| 71 | + super.onStart() |
| 72 | + listenerBus.addToStatusQueue(listener) |
| 73 | + } |
| 74 | + |
| 75 | + override def onStop(): Unit = { |
| 76 | + try { |
| 77 | + states.forEachValue(1, clearStateConsumer) |
| 78 | + states.clear() |
| 79 | + listenerBus.removeListener(listener) |
| 80 | + } finally { |
| 81 | + super.onStop() |
| 82 | + } |
| 83 | + } |
| 84 | + |
| 85 | + /** |
| 86 | + * Provide the current state of a barrier() call. A state is created when a new stage attempt |
| 87 | + * sends out a barrier() call, and recycled on stage completed. |
| 88 | + * |
| 89 | + * @param barrierId Identifier of the barrier stage that make a barrier() call. |
| 90 | + * @param numTasks Number of tasks of the barrier stage, all barrier() calls from the stage shall |
| 91 | + * collect `numTasks` requests to succeed. |
| 92 | + */ |
| 93 | + private class ContextBarrierState( |
| 94 | + val barrierId: ContextBarrierId, |
| 95 | + val numTasks: Int) { |
| 96 | + |
| 97 | + // There may be multiple barrier() calls from a barrier stage attempt, `barrierEpoch` is used |
| 98 | + // to identify each barrier() call. It shall get increased when a barrier() call succeeds, or |
| 99 | + // reset when a barrier() call fails due to timeout. |
| 100 | + private var barrierEpoch: Int = 0 |
| 101 | + |
| 102 | + // An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier() |
| 103 | + // call. |
| 104 | + private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks) |
| 105 | + |
| 106 | + // A timer task that ensures we may timeout for a barrier() call. |
| 107 | + private var timerTask: TimerTask = null |
| 108 | + |
| 109 | + // Init a TimerTask for a barrier() call. |
| 110 | + private def initTimerTask(): Unit = { |
| 111 | + timerTask = new TimerTask { |
| 112 | + override def run(): Unit = synchronized { |
| 113 | + // Timeout current barrier() call, fail all the sync requests. |
| 114 | + requesters.foreach(_.sendFailure(new SparkException("The coordinator didn't get all " + |
| 115 | + s"barrier sync requests for barrier epoch $barrierEpoch from $barrierId within " + |
| 116 | + s"$timeoutInSecs second(s)."))) |
| 117 | + cleanupBarrierStage(barrierId) |
| 118 | + } |
| 119 | + } |
| 120 | + } |
| 121 | + |
| 122 | + // Cancel the current active TimerTask and release resources. |
| 123 | + private def cancelTimerTask(): Unit = { |
| 124 | + if (timerTask != null) { |
| 125 | + timerTask.cancel() |
| 126 | + timerTask = null |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + // Process the global sync request. The barrier() call succeed if collected enough requests |
| 131 | + // within a configured time, otherwise fail all the pending requests. |
| 132 | + def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized { |
| 133 | + val taskId = request.taskAttemptId |
| 134 | + val epoch = request.barrierEpoch |
| 135 | + |
| 136 | + // Require the number of tasks is correctly set from the BarrierTaskContext. |
| 137 | + require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " + |
| 138 | + s"${request.numTasks} from Task $taskId, previously it was $numTasks.") |
| 139 | + |
| 140 | + // Check whether the epoch from the barrier tasks matches current barrierEpoch. |
| 141 | + logInfo(s"Current barrier epoch for $barrierId is $barrierEpoch.") |
| 142 | + if (epoch != barrierEpoch) { |
| 143 | + requester.sendFailure(new SparkException(s"The request to sync of $barrierId with " + |
| 144 | + s"barrier epoch $barrierEpoch has already finished. Maybe task $taskId is not " + |
| 145 | + "properly killed.")) |
| 146 | + } else { |
| 147 | + // If this is the first sync message received for a barrier() call, start timer to ensure |
| 148 | + // we may timeout for the sync. |
| 149 | + if (requesters.isEmpty) { |
| 150 | + initTimerTask() |
| 151 | + timer.schedule(timerTask, timeoutInSecs * 1000) |
| 152 | + } |
| 153 | + // Add the requester to array of RPCCallContexts pending for reply. |
| 154 | + requesters += requester |
| 155 | + logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + |
| 156 | + s"$taskId, current progress: ${requesters.size}/$numTasks.") |
| 157 | + if (maybeFinishAllRequesters(requesters, numTasks)) { |
| 158 | + // Finished current barrier() call successfully, clean up ContextBarrierState and |
| 159 | + // increase the barrier epoch. |
| 160 | + logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " + |
| 161 | + s"tasks, finished successfully.") |
| 162 | + barrierEpoch += 1 |
| 163 | + requesters.clear() |
| 164 | + cancelTimerTask() |
| 165 | + } |
| 166 | + } |
| 167 | + } |
| 168 | + |
| 169 | + // Finish all the blocking barrier sync requests from a stage attempt successfully if we |
| 170 | + // have received all the sync requests. |
| 171 | + private def maybeFinishAllRequesters( |
| 172 | + requesters: ArrayBuffer[RpcCallContext], |
| 173 | + numTasks: Int): Boolean = { |
| 174 | + if (requesters.size == numTasks) { |
| 175 | + requesters.foreach(_.reply(())) |
| 176 | + true |
| 177 | + } else { |
| 178 | + false |
| 179 | + } |
| 180 | + } |
| 181 | + |
| 182 | + // Cleanup the internal state of a barrier stage attempt. |
| 183 | + def clear(): Unit = synchronized { |
| 184 | + // The global sync fails so the stage is expected to retry another attempt, all sync |
| 185 | + // messages come from current stage attempt shall fail. |
| 186 | + barrierEpoch = -1 |
| 187 | + requesters.clear() |
| 188 | + cancelTimerTask() |
| 189 | + } |
| 190 | + } |
| 191 | + |
| 192 | + // Clean up the [[ContextBarrierState]] that correspond to a specific stage attempt. |
| 193 | + private def cleanupBarrierStage(barrierId: ContextBarrierId): Unit = { |
| 194 | + val barrierState = states.remove(barrierId) |
| 195 | + if (barrierState != null) { |
| 196 | + barrierState.clear() |
| 197 | + } |
| 198 | + } |
| 199 | + |
| 200 | + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { |
| 201 | + case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) => |
| 202 | + // Get or init the ContextBarrierState correspond to the stage attempt. |
| 203 | + val barrierId = ContextBarrierId(stageId, stageAttemptId) |
| 204 | + states.computeIfAbsent(barrierId, new Function[ContextBarrierId, ContextBarrierState] { |
| 205 | + override def apply(key: ContextBarrierId): ContextBarrierState = |
| 206 | + new ContextBarrierState(key, numTasks) |
| 207 | + }) |
| 208 | + val barrierState = states.get(barrierId) |
| 209 | + |
| 210 | + barrierState.handleRequest(context, request) |
| 211 | + } |
| 212 | + |
| 213 | + private val clearStateConsumer = new Consumer[ContextBarrierState] { |
| 214 | + override def accept(state: ContextBarrierState) = state.clear() |
| 215 | + } |
| 216 | +} |
| 217 | + |
| 218 | +private[spark] sealed trait BarrierCoordinatorMessage extends Serializable |
| 219 | + |
| 220 | +/** |
| 221 | + * A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is |
| 222 | + * identified by stageId + stageAttemptId + barrierEpoch. |
| 223 | + * |
| 224 | + * @param numTasks The number of global sync requests the BarrierCoordinator shall receive |
| 225 | + * @param stageId ID of current stage |
| 226 | + * @param stageAttemptId ID of current stage attempt |
| 227 | + * @param taskAttemptId Unique ID of current task |
| 228 | + * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls. |
| 229 | + */ |
| 230 | +private[spark] case class RequestToSync( |
| 231 | + numTasks: Int, |
| 232 | + stageId: Int, |
| 233 | + stageAttemptId: Int, |
| 234 | + taskAttemptId: Long, |
| 235 | + barrierEpoch: Int) extends BarrierCoordinatorMessage |
0 commit comments