Skip to content

Commit b2e9e17

Browse files
Ngone51jiangxb1987
authored andcommitted
[SPARK-31344][CORE] Polish implementation of barrier() and allGather()
### What changes were proposed in this pull request? 1. Combine `BarrierRequestToSync` and `AllGatherRequestToSync` into `RequestToSync`, which is distinguished by `RequestMethod` type. 2. Remove unnecessary Json serialization/deserialization 3. Clean up some codes to make runBarrier() and `BarrierCoordinator` more general 4. Remove unused imports. ### Why are the changes needed? To make codes simpler for better maintain in the future. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? This is pure code refactor, so should be covered by existed tests. Closes apache#28117 from Ngone51/refactor_barrier. Authored-by: yi.wu <[email protected]> Signed-off-by: Xingbo Jiang <[email protected]>
1 parent fab4ca5 commit b2e9e17

File tree

5 files changed

+46
-160
lines changed

5 files changed

+46
-160
lines changed

core/src/main/scala/org/apache/spark/BarrierCoordinator.scala

Lines changed: 25 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,12 @@
1717

1818
package org.apache.spark
1919

20-
import java.nio.charset.StandardCharsets.UTF_8
2120
import java.util.{Timer, TimerTask}
2221
import java.util.concurrent.ConcurrentHashMap
2322
import java.util.function.Consumer
2423

2524
import scala.collection.mutable.ArrayBuffer
2625

27-
import org.json4s.JsonAST._
28-
import org.json4s.JsonDSL._
29-
import org.json4s.jackson.JsonMethods.{compact, render}
30-
3126
import org.apache.spark.internal.Logging
3227
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
3328
import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted}
@@ -107,11 +102,13 @@ private[spark] class BarrierCoordinator(
107102
// An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call
108103
private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks)
109104

110-
// An Array of allGather messages for barrier tasks that have made a blocking runBarrier() call
111-
private val allGatherMessages: ArrayBuffer[String] = new Array[String](numTasks).to[ArrayBuffer]
105+
// Messages from each barrier task that have made a blocking runBarrier() call.
106+
// The messages will be replied to all tasks once sync finished.
107+
private val messages = Array.ofDim[String](numTasks)
112108

113-
// The blocking requestMethod called by tasks to sync up for this stage attempt
114-
private var requestMethodToSync: RequestMethod.Value = RequestMethod.BARRIER
109+
// The request method which is called inside this barrier sync. All tasks should make sure
110+
// that they're calling the same method within the same barrier sync phase.
111+
private var requestMethod: RequestMethod.Value = _
115112

116113
// A timer task that ensures we may timeout for a barrier() call.
117114
private var timerTask: TimerTask = null
@@ -140,28 +137,18 @@ private[spark] class BarrierCoordinator(
140137

141138
// Process the global sync request. The barrier() call succeed if collected enough requests
142139
// within a configured time, otherwise fail all the pending requests.
143-
def handleRequest(
144-
requester: RpcCallContext,
145-
request: RequestToSync
146-
): Unit = synchronized {
140+
def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized {
147141
val taskId = request.taskAttemptId
148142
val epoch = request.barrierEpoch
149-
val requestMethod = request.requestMethod
150-
val partitionId = request.partitionId
151-
val allGatherMessage = request match {
152-
case ag: AllGatherRequestToSync => ag.allGatherMessage
153-
case _ => ""
154-
}
155-
156-
if (requesters.size == 0) {
157-
requestMethodToSync = requestMethod
158-
}
143+
val curReqMethod = request.requestMethod
159144

160-
if (requestMethodToSync != requestMethod) {
145+
if (requesters.isEmpty) {
146+
requestMethod = curReqMethod
147+
} else if (requestMethod != curReqMethod) {
161148
requesters.foreach(
162149
_.sendFailure(new SparkException(s"$barrierId tried to use requestMethod " +
163-
s"`$requestMethod` during barrier epoch $barrierEpoch, which does not match " +
164-
s"the current synchronized requestMethod `$requestMethodToSync`"
150+
s"`$curReqMethod` during barrier epoch $barrierEpoch, which does not match " +
151+
s"the current synchronized requestMethod `$requestMethod`"
165152
))
166153
)
167154
cleanupBarrierStage(barrierId)
@@ -186,10 +173,11 @@ private[spark] class BarrierCoordinator(
186173
}
187174
// Add the requester to array of RPCCallContexts pending for reply.
188175
requesters += requester
189-
allGatherMessages(partitionId) = allGatherMessage
176+
messages(request.partitionId) = request.message
190177
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " +
191178
s"$taskId, current progress: ${requesters.size}/$numTasks.")
192-
if (maybeFinishAllRequesters(requesters, numTasks)) {
179+
if (requesters.size == numTasks) {
180+
requesters.foreach(_.reply(messages))
193181
// Finished current barrier() call successfully, clean up ContextBarrierState and
194182
// increase the barrier epoch.
195183
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " +
@@ -201,25 +189,6 @@ private[spark] class BarrierCoordinator(
201189
}
202190
}
203191

204-
// Finish all the blocking barrier sync requests from a stage attempt successfully if we
205-
// have received all the sync requests.
206-
private def maybeFinishAllRequesters(
207-
requesters: ArrayBuffer[RpcCallContext],
208-
numTasks: Int): Boolean = {
209-
if (requesters.size == numTasks) {
210-
requestMethodToSync match {
211-
case RequestMethod.BARRIER =>
212-
requesters.foreach(_.reply(""))
213-
case RequestMethod.ALL_GATHER =>
214-
val json: String = compact(render(allGatherMessages))
215-
requesters.foreach(_.reply(json))
216-
}
217-
true
218-
} else {
219-
false
220-
}
221-
}
222-
223192
// Cleanup the internal state of a barrier stage attempt.
224193
def clear(): Unit = synchronized {
225194
// The global sync fails so the stage is expected to retry another attempt, all sync
@@ -239,11 +208,11 @@ private[spark] class BarrierCoordinator(
239208
}
240209

241210
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
242-
case request: RequestToSync =>
211+
case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _, _, _, _) =>
243212
// Get or init the ContextBarrierState correspond to the stage attempt.
244-
val barrierId = ContextBarrierId(request.stageId, request.stageAttemptId)
213+
val barrierId = ContextBarrierId(stageId, stageAttemptId)
245214
states.computeIfAbsent(barrierId,
246-
(key: ContextBarrierId) => new ContextBarrierState(key, request.numTasks))
215+
(key: ContextBarrierId) => new ContextBarrierState(key, numTasks))
247216
val barrierState = states.get(barrierId)
248217

249218
barrierState.handleRequest(context, request)
@@ -256,61 +225,28 @@ private[spark] class BarrierCoordinator(
256225

257226
private[spark] sealed trait BarrierCoordinatorMessage extends Serializable
258227

259-
private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage {
260-
def numTasks: Int
261-
def stageId: Int
262-
def stageAttemptId: Int
263-
def taskAttemptId: Long
264-
def barrierEpoch: Int
265-
def partitionId: Int
266-
def requestMethod: RequestMethod.Value
267-
}
268-
269-
/**
270-
* A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is
271-
* identified by stageId + stageAttemptId + barrierEpoch.
272-
*
273-
* @param numTasks The number of global sync requests the BarrierCoordinator shall receive
274-
* @param stageId ID of current stage
275-
* @param stageAttemptId ID of current stage attempt
276-
* @param taskAttemptId Unique ID of current task
277-
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls
278-
* @param partitionId ID of the current partition the task is assigned to
279-
* @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
280-
*/
281-
private[spark] case class BarrierRequestToSync(
282-
numTasks: Int,
283-
stageId: Int,
284-
stageAttemptId: Int,
285-
taskAttemptId: Long,
286-
barrierEpoch: Int,
287-
partitionId: Int,
288-
requestMethod: RequestMethod.Value
289-
) extends RequestToSync
290-
291228
/**
292-
* A global sync request message from BarrierTaskContext, by `allGather()` call. Each request is
229+
* A global sync request message from BarrierTaskContext. Each request is
293230
* identified by stageId + stageAttemptId + barrierEpoch.
294231
*
295232
* @param numTasks The number of global sync requests the BarrierCoordinator shall receive
296233
* @param stageId ID of current stage
297234
* @param stageAttemptId ID of current stage attempt
298235
* @param taskAttemptId Unique ID of current task
299-
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls
236+
* @param barrierEpoch ID of a runBarrier() call, a task may consist multiple runBarrier() calls
300237
* @param partitionId ID of the current partition the task is assigned to
238+
* @param message Message sent from the BarrierTaskContext
301239
* @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
302-
* @param allGatherMessage Message sent from the BarrierTaskContext if requestMethod is ALL_GATHER
303240
*/
304-
private[spark] case class AllGatherRequestToSync(
241+
private[spark] case class RequestToSync(
305242
numTasks: Int,
306243
stageId: Int,
307244
stageAttemptId: Int,
308245
taskAttemptId: Long,
309246
barrierEpoch: Int,
310247
partitionId: Int,
311-
requestMethod: RequestMethod.Value,
312-
allGatherMessage: String
313-
) extends RequestToSync
248+
message: String,
249+
requestMethod: RequestMethod.Value) extends BarrierCoordinatorMessage
314250

315251
private[spark] object RequestMethod extends Enumeration {
316252
val BARRIER, ALL_GATHER = Value

core/src/main/scala/org/apache/spark/BarrierTaskContext.scala

Lines changed: 10 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,13 @@
1717

1818
package org.apache.spark
1919

20-
import java.nio.charset.StandardCharsets.UTF_8
2120
import java.util.{Properties, Timer, TimerTask}
2221

2322
import scala.collection.JavaConverters._
24-
import scala.collection.mutable.ArrayBuffer
2523
import scala.concurrent.TimeoutException
2624
import scala.concurrent.duration._
2725
import scala.language.postfixOps
2826

29-
import org.json4s.DefaultFormats
30-
import org.json4s.JsonAST._
31-
import org.json4s.JsonDSL._
32-
import org.json4s.jackson.JsonMethods.parse
33-
3427
import org.apache.spark.annotation.{Experimental, Since}
3528
import org.apache.spark.executor.TaskMetrics
3629
import org.apache.spark.internal.Logging
@@ -67,31 +60,7 @@ class BarrierTaskContext private[spark] (
6760
// from different tasks within the same barrier stage attempt to succeed.
6861
private lazy val numTasks = getTaskInfos().size
6962

70-
private def getRequestToSync(
71-
numTasks: Int,
72-
stageId: Int,
73-
stageAttemptNumber: Int,
74-
taskAttemptId: Long,
75-
barrierEpoch: Int,
76-
partitionId: Int,
77-
requestMethod: RequestMethod.Value,
78-
allGatherMessage: String
79-
): RequestToSync = {
80-
requestMethod match {
81-
case RequestMethod.BARRIER =>
82-
BarrierRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
83-
barrierEpoch, partitionId, requestMethod)
84-
case RequestMethod.ALL_GATHER =>
85-
AllGatherRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
86-
barrierEpoch, partitionId, requestMethod, allGatherMessage)
87-
}
88-
}
89-
90-
private def runBarrier(
91-
requestMethod: RequestMethod.Value,
92-
allGatherMessage: String = ""
93-
): String = {
94-
63+
private def runBarrier(message: String, requestMethod: RequestMethod.Value): Array[String] = {
9564
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " +
9665
s"the global sync, current barrier epoch is $barrierEpoch.")
9766
logTrace("Current callSite: " + Utils.getCallSite())
@@ -108,24 +77,24 @@ class BarrierTaskContext private[spark] (
10877
// Log the update of global sync every 60 seconds.
10978
timer.schedule(timerTask, 60000, 60000)
11079

111-
var json: String = ""
112-
11380
try {
114-
val abortableRpcFuture = barrierCoordinator.askAbortable[String](
115-
message = getRequestToSync(numTasks, stageId, stageAttemptNumber,
116-
taskAttemptId, barrierEpoch, partitionId, requestMethod, allGatherMessage),
81+
val abortableRpcFuture = barrierCoordinator.askAbortable[Array[String]](
82+
message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
83+
barrierEpoch, partitionId, message, requestMethod),
11784
// Set a fixed timeout for RPC here, so users shall get a SparkException thrown by
11885
// BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework.
11986
timeout = new RpcTimeout(365.days, "barrierTimeout"))
12087

88+
// messages which consist of all barrier tasks' messages
89+
var messages: Array[String] = null
12190
// Wait the RPC future to be completed, but every 1 second it will jump out waiting
12291
// and check whether current spark task is killed. If killed, then throw
12392
// a `TaskKilledException`, otherwise continue wait RPC until it completes.
12493
try {
12594
while (!abortableRpcFuture.toFuture.isCompleted) {
12695
// wait RPC future for at most 1 second
12796
try {
128-
json = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
97+
messages = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
12998
} catch {
13099
case _: TimeoutException | _: InterruptedException =>
131100
// If `TimeoutException` thrown, waiting RPC future reach 1 second.
@@ -144,6 +113,7 @@ class BarrierTaskContext private[spark] (
144113
"global sync successfully, waited for " +
145114
s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
146115
s"current barrier epoch is $barrierEpoch.")
116+
messages
147117
} catch {
148118
case e: SparkException =>
149119
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " +
@@ -155,7 +125,6 @@ class BarrierTaskContext private[spark] (
155125
timerTask.cancel()
156126
timer.purge()
157127
}
158-
json
159128
}
160129

161130
/**
@@ -200,10 +169,7 @@ class BarrierTaskContext private[spark] (
200169
*/
201170
@Experimental
202171
@Since("2.4.0")
203-
def barrier(): Unit = {
204-
runBarrier(RequestMethod.BARRIER)
205-
()
206-
}
172+
def barrier(): Unit = runBarrier("", RequestMethod.BARRIER)
207173

208174
/**
209175
* :: Experimental ::
@@ -217,12 +183,7 @@ class BarrierTaskContext private[spark] (
217183
*/
218184
@Experimental
219185
@Since("3.0.0")
220-
def allGather(message: String): Array[String] = {
221-
val json = runBarrier(RequestMethod.ALL_GATHER, message)
222-
val jsonArray = parse(json)
223-
implicit val formats = DefaultFormats
224-
jsonArray.extract[Array[String]]
225-
}
186+
def allGather(message: String): Array[String] = runBarrier(message, RequestMethod.ALL_GATHER)
226187

227188
/**
228189
* :: Experimental ::

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -414,22 +414,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
414414
)
415415
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
416416
try {
417-
var result: String = ""
418-
requestMethod match {
417+
val messages = requestMethod match {
419418
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
420419
context.asInstanceOf[BarrierTaskContext].barrier()
421-
result = BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS
420+
Array(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS)
422421
case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
423-
val messages: Array[String] = context.asInstanceOf[BarrierTaskContext].allGather(
424-
message
425-
)
426-
result = compact(render(JArray(
427-
messages.map(
428-
(message) => JString(message)
429-
).toList
430-
)))
422+
context.asInstanceOf[BarrierTaskContext].allGather(message)
431423
}
432-
writeUTF(result, out)
424+
out.writeInt(messages.length)
425+
messages.foreach(writeUTF(_, out))
433426
} catch {
434427
case e: SparkException =>
435428
writeUTF(e.getMessage, out)

core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package org.apache.spark.scheduler
1919

2020
import java.io.File
2121

22-
import scala.collection.mutable.ArrayBuffer
2322
import scala.util.Random
2423

2524
import org.apache.spark._

python/pyspark/taskcontext.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import json
2020

2121
from pyspark.java_gateway import local_connect_and_auth
22-
from pyspark.serializers import write_int, write_with_length, UTF8Deserializer
22+
from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer
2323

2424

2525
class TaskContext(object):
@@ -133,7 +133,10 @@ def _load_from_socket(port, auth_secret, function, all_gather_message=None):
133133
sockfile.flush()
134134

135135
# Collect result.
136-
res = UTF8Deserializer().loads(sockfile)
136+
len = read_int(sockfile)
137+
res = []
138+
for i in range(len):
139+
res.append(UTF8Deserializer().loads(sockfile))
137140

138141
# Release resources.
139142
sockfile.close()
@@ -232,13 +235,7 @@ def allGather(self, message=""):
232235
raise Exception("Not supported to call barrier() before initialize " +
233236
"BarrierTaskContext.")
234237
else:
235-
gathered_items = _load_from_socket(
236-
self._port,
237-
self._secret,
238-
ALL_GATHER_FUNCTION,
239-
message,
240-
)
241-
return [e for e in json.loads(gathered_items)]
238+
return _load_from_socket(self._port, self._secret, ALL_GATHER_FUNCTION, message)
242239

243240
def getTaskInfos(self):
244241
"""

0 commit comments

Comments
 (0)