@@ -19,9 +19,10 @@ package org.apache.spark.scheduler
1919
2020import java .io .NotSerializableException
2121import java .util .Properties
22- import java .util .concurrent .{LinkedBlockingQueue , TimeUnit }
2322import java .util .concurrent .atomic .AtomicInteger
2423
24+ import akka .actor ._
25+ import akka .util .duration ._
2526import scala .collection .mutable .{ArrayBuffer , HashMap , HashSet , Map }
2627
2728import org .apache .spark ._
@@ -65,12 +66,12 @@ class DAGScheduler(
6566
6667 // Called by TaskScheduler to report task's starting.
6768 def taskStarted (task : Task [_], taskInfo : TaskInfo ) {
68- eventQueue.put( BeginEvent (task, taskInfo) )
69+ eventProcessActor ! BeginEvent (task, taskInfo)
6970 }
7071
7172 // Called to report that a task has completed and results are being fetched remotely.
7273 def taskGettingResult (task : Task [_], taskInfo : TaskInfo ) {
73- eventQueue.put( GettingResultEvent (task, taskInfo) )
74+ eventProcessActor ! GettingResultEvent (task, taskInfo)
7475 }
7576
7677 // Called by TaskScheduler to report task completions or failures.
@@ -81,23 +82,23 @@ class DAGScheduler(
8182 accumUpdates : Map [Long , Any ],
8283 taskInfo : TaskInfo ,
8384 taskMetrics : TaskMetrics ) {
84- eventQueue.put( CompletionEvent (task, reason, result, accumUpdates, taskInfo, taskMetrics) )
85+ eventProcessActor ! CompletionEvent (task, reason, result, accumUpdates, taskInfo, taskMetrics)
8586 }
8687
8788 // Called by TaskScheduler when an executor fails.
8889 def executorLost (execId : String ) {
89- eventQueue.put( ExecutorLost (execId) )
90+ eventProcessActor ! ExecutorLost (execId)
9091 }
9192
9293 // Called by TaskScheduler when a host is added
9394 def executorGained (execId : String , host : String ) {
94- eventQueue.put( ExecutorGained (execId, host) )
95+ eventProcessActor ! ExecutorGained (execId, host)
9596 }
9697
9798 // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
9899 // cancellation of the job itself.
99100 def taskSetFailed (taskSet : TaskSet , reason : String ) {
100- eventQueue.put( TaskSetFailed (taskSet, reason) )
101+ eventProcessActor ! TaskSetFailed (taskSet, reason)
101102 }
102103
103104 // The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
@@ -109,7 +110,30 @@ class DAGScheduler(
109110 // resubmit failed stages
110111 val POLL_TIMEOUT = 10L
111112
112- private val eventQueue = new LinkedBlockingQueue [DAGSchedulerEvent ]
113+ private val eventProcessActor : ActorRef = env.actorSystem.actorOf(Props (new Actor {
114+ override def preStart () {
115+ context.system.scheduler.schedule(RESUBMIT_TIMEOUT milliseconds, RESUBMIT_TIMEOUT milliseconds) {
116+ if (failed.size > 0 ) {
117+ resubmitFailedStages()
118+ }
119+ }
120+ }
121+
122+ /**
123+ * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
124+ * events and responds by launching tasks. This runs in a dedicated thread and receives events
125+ * via the eventQueue.
126+ */
127+ def receive = {
128+ case event : DAGSchedulerEvent =>
129+ logDebug(" Got event of type " + event.getClass.getName)
130+
131+ if (! processEvent(event))
132+ submitWaitingStages()
133+ else
134+ context.stop(self)
135+ }
136+ }))
113137
114138 private [scheduler] val nextJobId = new AtomicInteger (0 )
115139
@@ -150,16 +174,6 @@ class DAGScheduler(
150174
151175 val metadataCleaner = new MetadataCleaner (MetadataCleanerType .DAG_SCHEDULER , this .cleanup)
152176
153- // Start a thread to run the DAGScheduler event loop
154- def start () {
155- new Thread (" DAGScheduler" ) {
156- setDaemon(true )
157- override def run () {
158- DAGScheduler .this .run()
159- }
160- }.start()
161- }
162-
163177 def addSparkListener (listener : SparkListener ) {
164178 listenerBus.addListener(listener)
165179 }
@@ -301,8 +315,7 @@ class DAGScheduler(
301315 assert(partitions.size > 0 )
302316 val func2 = func.asInstanceOf [(TaskContext , Iterator [_]) => _]
303317 val waiter = new JobWaiter (this , jobId, partitions.size, resultHandler)
304- eventQueue.put(JobSubmitted (jobId, rdd, func2, partitions.toArray, allowLocal, callSite,
305- waiter, properties))
318+ eventProcessActor ! JobSubmitted (jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
306319 waiter
307320 }
308321
@@ -337,8 +350,7 @@ class DAGScheduler(
337350 val func2 = func.asInstanceOf [(TaskContext , Iterator [_]) => _]
338351 val partitions = (0 until rdd.partitions.size).toArray
339352 val jobId = nextJobId.getAndIncrement()
340- eventQueue.put(JobSubmitted (jobId, rdd, func2, partitions, allowLocal = false , callSite,
341- listener, properties))
353+ eventProcessActor ! JobSubmitted (jobId, rdd, func2, partitions, allowLocal = false , callSite, listener, properties)
342354 listener.awaitResult() // Will throw an exception if the job fails
343355 }
344356
@@ -347,19 +359,19 @@ class DAGScheduler(
347359 */
348360 def cancelJob (jobId : Int ) {
349361 logInfo(" Asked to cancel job " + jobId)
350- eventQueue.put( JobCancelled (jobId) )
362+ eventProcessActor ! JobCancelled (jobId)
351363 }
352364
353365 def cancelJobGroup (groupId : String ) {
354366 logInfo(" Asked to cancel job group " + groupId)
355- eventQueue.put( JobGroupCancelled (groupId) )
367+ eventProcessActor ! JobGroupCancelled (groupId)
356368 }
357369
358370 /**
359371 * Cancel all jobs that are running or waiting in the queue.
360372 */
361373 def cancelAllJobs () {
362- eventQueue.put( AllJobsCancelled )
374+ eventProcessActor ! AllJobsCancelled
363375 }
364376
365377 /**
@@ -474,42 +486,6 @@ class DAGScheduler(
474486 }
475487 }
476488
477-
478- /**
479- * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
480- * events and responds by launching tasks. This runs in a dedicated thread and receives events
481- * via the eventQueue.
482- */
483- private def run () {
484- SparkEnv .set(env)
485-
486- while (true ) {
487- val event = eventQueue.poll(POLL_TIMEOUT , TimeUnit .MILLISECONDS )
488- if (event != null ) {
489- logDebug(" Got event of type " + event.getClass.getName)
490- }
491- this .synchronized { // needed in case other threads makes calls into methods of this class
492- if (event != null ) {
493- if (processEvent(event)) {
494- return
495- }
496- }
497-
498- val time = System .currentTimeMillis() // TODO: use a pluggable clock for testability
499- // Periodically resubmit failed stages if some map output fetches have failed and we have
500- // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
501- // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
502- // the same time, so we want to make sure we've identified all the reduce tasks that depend
503- // on the failed node.
504- if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT ) {
505- resubmitFailedStages()
506- } else {
507- submitWaitingStages()
508- }
509- }
510- }
511- }
512-
513489 /**
514490 * Run a job on an RDD locally, assuming it has only a single partition and no dependencies.
515491 * We run the operation in a separate thread just in case it takes a bunch of time, so that we
@@ -878,15 +854,15 @@ class DAGScheduler(
878854 // If the RDD has narrow dependencies, pick the first partition of the first narrow dep
879855 // that has any placement preferences. Ideally we would choose based on transfer sizes,
880856 // but this will do for now.
881- rdd.dependencies.foreach(_ match {
857+ rdd.dependencies.foreach {
882858 case n : NarrowDependency [_] =>
883859 for (inPart <- n.getParents(partition)) {
884860 val locs = getPreferredLocs(n.rdd, inPart)
885861 if (locs != Nil )
886862 return locs
887863 }
888864 case _ =>
889- })
865+ }
890866 Nil
891867 }
892868
@@ -909,7 +885,7 @@ class DAGScheduler(
909885 }
910886
911887 def stop () {
912- eventQueue.put( StopDAGScheduler )
888+ eventProcessActor ! StopDAGScheduler
913889 metadataCleaner.cancel()
914890 taskSched.stop()
915891 }
0 commit comments