Skip to content

Commit 42d86e1

Browse files
authored
Merge branch 'master' into SPARK-24251-add-append-data
2 parents e81790d + 1a29fec commit 42d86e1

File tree

69 files changed

+1349
-459
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+1349
-459
lines changed
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark
19+
20+
import java.util.{Timer, TimerTask}
21+
import java.util.concurrent.ConcurrentHashMap
22+
import java.util.function.{Consumer, Function}
23+
24+
import scala.collection.mutable.ArrayBuffer
25+
26+
import org.apache.spark.internal.Logging
27+
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
28+
import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted}
29+
30+
/**
31+
* For each barrier stage attempt, only at most one barrier() call can be active at any time, thus
32+
* we can use (stageId, stageAttemptId) to identify the stage attempt where the barrier() call is
33+
* from.
34+
*/
35+
private case class ContextBarrierId(stageId: Int, stageAttemptId: Int) {
36+
override def toString: String = s"Stage $stageId (Attempt $stageAttemptId)"
37+
}
38+
39+
/**
40+
* A coordinator that handles all global sync requests from BarrierTaskContext. Each global sync
41+
* request is generated by `BarrierTaskContext.barrier()`, and identified by
42+
* stageId + stageAttemptId + barrierEpoch. Reply all the blocking global sync requests upon
43+
* all the requests for a group of `barrier()` calls are received. If the coordinator is unable to
44+
* collect enough global sync requests within a configured time, fail all the requests and return
45+
* an Exception with timeout message.
46+
*/
47+
private[spark] class BarrierCoordinator(
48+
timeoutInSecs: Long,
49+
listenerBus: LiveListenerBus,
50+
override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging {
51+
52+
// TODO SPARK-25030 Create a Timer() in the mainClass submitted to SparkSubmit makes it unable to
53+
// fetch result, we shall fix the issue.
54+
private lazy val timer = new Timer("BarrierCoordinator barrier epoch increment timer")
55+
56+
// Listen to StageCompleted event, clear corresponding ContextBarrierState.
57+
private val listener = new SparkListener {
58+
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
59+
val stageInfo = stageCompleted.stageInfo
60+
val barrierId = ContextBarrierId(stageInfo.stageId, stageInfo.attemptNumber)
61+
// Clear ContextBarrierState from a finished stage attempt.
62+
cleanupBarrierStage(barrierId)
63+
}
64+
}
65+
66+
// Record all active stage attempts that make barrier() call(s), and the corresponding internal
67+
// state.
68+
private val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState]
69+
70+
override def onStart(): Unit = {
71+
super.onStart()
72+
listenerBus.addToStatusQueue(listener)
73+
}
74+
75+
override def onStop(): Unit = {
76+
try {
77+
states.forEachValue(1, clearStateConsumer)
78+
states.clear()
79+
listenerBus.removeListener(listener)
80+
} finally {
81+
super.onStop()
82+
}
83+
}
84+
85+
/**
86+
* Provide the current state of a barrier() call. A state is created when a new stage attempt
87+
* sends out a barrier() call, and recycled on stage completed.
88+
*
89+
* @param barrierId Identifier of the barrier stage that make a barrier() call.
90+
* @param numTasks Number of tasks of the barrier stage, all barrier() calls from the stage shall
91+
* collect `numTasks` requests to succeed.
92+
*/
93+
private class ContextBarrierState(
94+
val barrierId: ContextBarrierId,
95+
val numTasks: Int) {
96+
97+
// There may be multiple barrier() calls from a barrier stage attempt, `barrierEpoch` is used
98+
// to identify each barrier() call. It shall get increased when a barrier() call succeeds, or
99+
// reset when a barrier() call fails due to timeout.
100+
private var barrierEpoch: Int = 0
101+
102+
// An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier()
103+
// call.
104+
private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks)
105+
106+
// A timer task that ensures we may timeout for a barrier() call.
107+
private var timerTask: TimerTask = null
108+
109+
// Init a TimerTask for a barrier() call.
110+
private def initTimerTask(): Unit = {
111+
timerTask = new TimerTask {
112+
override def run(): Unit = synchronized {
113+
// Timeout current barrier() call, fail all the sync requests.
114+
requesters.foreach(_.sendFailure(new SparkException("The coordinator didn't get all " +
115+
s"barrier sync requests for barrier epoch $barrierEpoch from $barrierId within " +
116+
s"$timeoutInSecs second(s).")))
117+
cleanupBarrierStage(barrierId)
118+
}
119+
}
120+
}
121+
122+
// Cancel the current active TimerTask and release resources.
123+
private def cancelTimerTask(): Unit = {
124+
if (timerTask != null) {
125+
timerTask.cancel()
126+
timerTask = null
127+
}
128+
}
129+
130+
// Process the global sync request. The barrier() call succeed if collected enough requests
131+
// within a configured time, otherwise fail all the pending requests.
132+
def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized {
133+
val taskId = request.taskAttemptId
134+
val epoch = request.barrierEpoch
135+
136+
// Require the number of tasks is correctly set from the BarrierTaskContext.
137+
require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " +
138+
s"${request.numTasks} from Task $taskId, previously it was $numTasks.")
139+
140+
// Check whether the epoch from the barrier tasks matches current barrierEpoch.
141+
logInfo(s"Current barrier epoch for $barrierId is $barrierEpoch.")
142+
if (epoch != barrierEpoch) {
143+
requester.sendFailure(new SparkException(s"The request to sync of $barrierId with " +
144+
s"barrier epoch $barrierEpoch has already finished. Maybe task $taskId is not " +
145+
"properly killed."))
146+
} else {
147+
// If this is the first sync message received for a barrier() call, start timer to ensure
148+
// we may timeout for the sync.
149+
if (requesters.isEmpty) {
150+
initTimerTask()
151+
timer.schedule(timerTask, timeoutInSecs * 1000)
152+
}
153+
// Add the requester to array of RPCCallContexts pending for reply.
154+
requesters += requester
155+
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " +
156+
s"$taskId, current progress: ${requesters.size}/$numTasks.")
157+
if (maybeFinishAllRequesters(requesters, numTasks)) {
158+
// Finished current barrier() call successfully, clean up ContextBarrierState and
159+
// increase the barrier epoch.
160+
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " +
161+
s"tasks, finished successfully.")
162+
barrierEpoch += 1
163+
requesters.clear()
164+
cancelTimerTask()
165+
}
166+
}
167+
}
168+
169+
// Finish all the blocking barrier sync requests from a stage attempt successfully if we
170+
// have received all the sync requests.
171+
private def maybeFinishAllRequesters(
172+
requesters: ArrayBuffer[RpcCallContext],
173+
numTasks: Int): Boolean = {
174+
if (requesters.size == numTasks) {
175+
requesters.foreach(_.reply(()))
176+
true
177+
} else {
178+
false
179+
}
180+
}
181+
182+
// Cleanup the internal state of a barrier stage attempt.
183+
def clear(): Unit = synchronized {
184+
// The global sync fails so the stage is expected to retry another attempt, all sync
185+
// messages come from current stage attempt shall fail.
186+
barrierEpoch = -1
187+
requesters.clear()
188+
cancelTimerTask()
189+
}
190+
}
191+
192+
// Clean up the [[ContextBarrierState]] that correspond to a specific stage attempt.
193+
private def cleanupBarrierStage(barrierId: ContextBarrierId): Unit = {
194+
val barrierState = states.remove(barrierId)
195+
if (barrierState != null) {
196+
barrierState.clear()
197+
}
198+
}
199+
200+
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
201+
case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) =>
202+
// Get or init the ContextBarrierState correspond to the stage attempt.
203+
val barrierId = ContextBarrierId(stageId, stageAttemptId)
204+
states.computeIfAbsent(barrierId, new Function[ContextBarrierId, ContextBarrierState] {
205+
override def apply(key: ContextBarrierId): ContextBarrierState =
206+
new ContextBarrierState(key, numTasks)
207+
})
208+
val barrierState = states.get(barrierId)
209+
210+
barrierState.handleRequest(context, request)
211+
}
212+
213+
private val clearStateConsumer = new Consumer[ContextBarrierState] {
214+
override def accept(state: ContextBarrierState) = state.clear()
215+
}
216+
}
217+
218+
private[spark] sealed trait BarrierCoordinatorMessage extends Serializable
219+
220+
/**
221+
* A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is
222+
* identified by stageId + stageAttemptId + barrierEpoch.
223+
*
224+
* @param numTasks The number of global sync requests the BarrierCoordinator shall receive
225+
* @param stageId ID of current stage
226+
* @param stageAttemptId ID of current stage attempt
227+
* @param taskAttemptId Unique ID of current task
228+
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls.
229+
*/
230+
private[spark] case class RequestToSync(
231+
numTasks: Int,
232+
stageId: Int,
233+
stageAttemptId: Int,
234+
taskAttemptId: Long,
235+
barrierEpoch: Int) extends BarrierCoordinatorMessage

core/src/main/scala/org/apache/spark/BarrierTaskContext.scala

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,17 @@
1717

1818
package org.apache.spark
1919

20-
import java.util.Properties
20+
import java.util.{Properties, Timer, TimerTask}
21+
22+
import scala.concurrent.duration._
23+
import scala.language.postfixOps
2124

2225
import org.apache.spark.annotation.{Experimental, Since}
2326
import org.apache.spark.executor.TaskMetrics
2427
import org.apache.spark.memory.TaskMemoryManager
2528
import org.apache.spark.metrics.MetricsSystem
29+
import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout}
30+
import org.apache.spark.util.{RpcUtils, Utils}
2631

2732
/** A [[TaskContext]] with extra info and tooling for a barrier stage. */
2833
class BarrierTaskContext(
@@ -39,6 +44,22 @@ class BarrierTaskContext(
3944
extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber,
4045
taskMemoryManager, localProperties, metricsSystem, taskMetrics) {
4146

47+
// Find the driver side RPCEndpointRef of the coordinator that handles all the barrier() calls.
48+
private val barrierCoordinator: RpcEndpointRef = {
49+
val env = SparkEnv.get
50+
RpcUtils.makeDriverRef("barrierSync", env.conf, env.rpcEnv)
51+
}
52+
53+
private val timer = new Timer("Barrier task timer for barrier() calls.")
54+
55+
// Local barrierEpoch that identify a barrier() call from current task, it shall be identical
56+
// with the driver side epoch.
57+
private var barrierEpoch = 0
58+
59+
// Number of tasks of the current barrier stage, a barrier() call must collect enough requests
60+
// from different tasks within the same barrier stage attempt to succeed.
61+
private lazy val numTasks = getTaskInfos().size
62+
4263
/**
4364
* :: Experimental ::
4465
* Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
@@ -80,7 +101,44 @@ class BarrierTaskContext(
80101
@Experimental
81102
@Since("2.4.0")
82103
def barrier(): Unit = {
83-
// TODO SPARK-24817 implement global barrier.
104+
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " +
105+
s"the global sync, current barrier epoch is $barrierEpoch.")
106+
logTrace("Current callSite: " + Utils.getCallSite())
107+
108+
val startTime = System.currentTimeMillis()
109+
val timerTask = new TimerTask {
110+
override def run(): Unit = {
111+
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " +
112+
s"under the global sync since $startTime, has been waiting for " +
113+
s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " +
114+
s"is $barrierEpoch.")
115+
}
116+
}
117+
// Log the update of global sync every 60 seconds.
118+
timer.schedule(timerTask, 60000, 60000)
119+
120+
try {
121+
barrierCoordinator.askSync[Unit](
122+
message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
123+
barrierEpoch),
124+
// Set a fixed timeout for RPC here, so users shall get a SparkException thrown by
125+
// BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework.
126+
timeout = new RpcTimeout(31536000 /* = 3600 * 24 * 365 */ seconds, "barrierTimeout"))
127+
barrierEpoch += 1
128+
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " +
129+
"global sync successfully, waited for " +
130+
s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch is " +
131+
s"$barrierEpoch.")
132+
} catch {
133+
case e: SparkException =>
134+
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " +
135+
"to perform global sync, waited for " +
136+
s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " +
137+
s"is $barrierEpoch.")
138+
throw e
139+
} finally {
140+
timerTask.cancel()
141+
}
84142
}
85143

86144
/**

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1935,6 +1935,12 @@ class SparkContext(config: SparkConf) extends Logging {
19351935
Utils.tryLogNonFatalError {
19361936
_executorAllocationManager.foreach(_.stop())
19371937
}
1938+
if (_dagScheduler != null) {
1939+
Utils.tryLogNonFatalError {
1940+
_dagScheduler.stop()
1941+
}
1942+
_dagScheduler = null
1943+
}
19381944
if (_listenerBusStarted) {
19391945
Utils.tryLogNonFatalError {
19401946
listenerBus.stop()
@@ -1944,12 +1950,6 @@ class SparkContext(config: SparkConf) extends Logging {
19441950
Utils.tryLogNonFatalError {
19451951
_eventLogger.foreach(_.stop())
19461952
}
1947-
if (_dagScheduler != null) {
1948-
Utils.tryLogNonFatalError {
1949-
_dagScheduler.stop()
1950-
}
1951-
_dagScheduler = null
1952-
}
19531953
if (env != null && _heartbeatReceiver != null) {
19541954
Utils.tryLogNonFatalError {
19551955
env.rpcEnv.stop(_heartbeatReceiver)

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,4 +567,14 @@ package object config {
567567
.intConf
568568
.checkValue(v => v > 0, "The value should be a positive integer.")
569569
.createWithDefault(2000)
570+
571+
private[spark] val BARRIER_SYNC_TIMEOUT =
572+
ConfigBuilder("spark.barrier.sync.timeout")
573+
.doc("The timeout in seconds for each barrier() call from a barrier task. If the " +
574+
"coordinator didn't receive all the sync messages from barrier tasks within the " +
575+
"configed time, throw a SparkException to fail all the tasks. The default value is set " +
576+
"to 31536000(3600 * 24 * 365) so the barrier() call shall wait for one year.")
577+
.timeConf(TimeUnit.SECONDS)
578+
.checkValue(v => v > 0, "The value should be a positive time value.")
579+
.createWithDefaultString("365d")
570580
}

0 commit comments

Comments
 (0)