Skip to content

Commit cc76142

Browse files
committed
Add JobWaiter.toFuture to avoid blocking threads
1 parent d9a3e72 commit cc76142

File tree

5 files changed

+32
-12
lines changed

5 files changed

+32
-12
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark
1919

20-
import scala.language.implicitConversions
21-
2220
import java.io._
2321
import java.lang.reflect.Constructor
2422
import java.net.URI
@@ -30,6 +28,8 @@ import scala.collection.{Map, Set}
3028
import scala.collection.JavaConversions._
3129
import scala.collection.generic.Growable
3230
import scala.collection.mutable.HashMap
31+
import scala.concurrent.Future
32+
import scala.language.implicitConversions
3333
import scala.reflect.{ClassTag, classTag}
3434
import scala.util.control.NonFatal
3535

@@ -1860,13 +1860,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
18601860
}
18611861

18621862
/**
1863-
* Submit a job for execution and return a FutureJob holding the result.
1863+
* Submit a job for execution and return a FutureJob holding the result. Return a Future for
1864+
* monitoring the job success or failure event.
18641865
*/
18651866
private[spark] def submitAsyncJob[T, U, R](
18661867
rdd: RDD[T],
18671868
processPartition: (TaskContext, Iterator[T]) => U,
18681869
resultHandler: (Int, U) => Unit,
1869-
resultFunc: => R): SimpleFutureAction[R] =
1870+
resultFunc: => R): Future[Unit] =
18701871
{
18711872
assertNotStopped()
18721873
val cleanF = clean(processPartition)
@@ -1879,7 +1880,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
18791880
allowLocal = false,
18801881
resultHandler,
18811882
localProperties.get)
1882-
new SimpleFutureAction(waiter, resultFunc)
1883+
waiter.toFuture
18831884
}
18841885

18851886
/**

core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.scheduler
1919

20+
import scala.concurrent.{Future, Promise}
21+
2022
/**
2123
* An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
2224
* results to the given handler function.
@@ -28,6 +30,8 @@ private[spark] class JobWaiter[T](
2830
resultHandler: (Int, T) => Unit)
2931
extends JobListener {
3032

33+
private val promise = Promise[Unit]
34+
3135
private var finishedTasks = 0
3236

3337
// Is the job as a whole finished (succeeded or failed)?
@@ -58,13 +62,15 @@ private[spark] class JobWaiter[T](
5862
if (finishedTasks == totalTasks) {
5963
_jobFinished = true
6064
jobResult = JobSucceeded
65+
promise.success()
6166
this.notifyAll()
6267
}
6368
}
6469

6570
override def jobFailed(exception: Exception): Unit = synchronized {
6671
_jobFinished = true
6772
jobResult = JobFailed(exception)
73+
promise.failure(exception)
6874
this.notifyAll()
6975
}
7076

@@ -74,4 +80,10 @@ private[spark] class JobWaiter[T](
7480
}
7581
return jobResult
7682
}
83+
84+
/**
85+
* Return a Future to monitoring the job success or failure event. You can use this method to
86+
* avoid blocking your thread.
87+
*/
88+
def toFuture: Future[Unit] = promise.future
7789
}

streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ private[streaming] abstract class ReceiverSupervisor(
163163
stopReceiver("Restarting receiver with delay " + delay + "ms: " + message, error)
164164
logDebug("Sleeping for " + delay)
165165
Thread.sleep(delay)
166-
if (rescheduleReceiver().contains(host)) {
166+
val scheduledLocations = rescheduleReceiver()
167+
if (scheduledLocations.isEmpty || scheduledLocations.contains(host)) {
167168
logInfo("Starting receiver again")
168169
startReceiver()
169170
logInfo("Receiver started again")
@@ -174,7 +175,7 @@ private[streaming] abstract class ReceiverSupervisor(
174175
}
175176

176177
/** Reschedule this receiver and return a candidate executor list */
177-
def rescheduleReceiver(): Seq[String]
178+
def rescheduleReceiver(): Seq[String] = Seq.empty
178179

179180
/** Check if receiver has been marked for stopping */
180181
def isReceiverStarted(): Boolean = {

streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverScheduler.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.streaming.scheduler
2020
import scala.collection.mutable
2121
import scala.util.Random
2222

23+
import org.apache.spark.streaming.StreamingContext
2324
import org.apache.spark.streaming.scheduler.ReceiverState._
2425

2526
private[streaming] case class ReceiverTrackingInfo(
@@ -31,6 +32,10 @@ private[streaming] case class ReceiverTrackingInfo(
3132

3233
private[streaming] trait ReceiverScheduler {
3334

35+
/**
36+
* Return a candidate executor list to run the receiver. If the list is empty, the caller can run
37+
* this receiver in arbitrary executor.
38+
*/
3439
def scheduleReceiver(
3540
receiverId: Int,
3641
preferredLocation: Option[String],

streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
333333
}
334334

335335
private val submitJobThread = ExecutionContext.fromExecutorService(
336-
ThreadUtils.newDaemonCachedThreadPool("streaming-submit-job"))
336+
ThreadUtils.newDaemonSingleThreadExecutor("streaming-submit-job"))
337337

338338
/**
339339
* Get the receivers from the ReceiverInputDStreams, distributes them to the
@@ -390,8 +390,9 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
390390
}
391391

392392
val self = this
393+
val receiverId = receiver.streamId
393394
val scheduledLocations = scheduler.scheduleReceiver(
394-
receiver.streamId,
395+
receiverId,
395396
receiver.preferredLocation,
396397
getReceiverTrackingInfoMap(),
397398
getExecutors(ssc))
@@ -411,18 +412,18 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
411412
if (stopping) {
412413
receiverExitLatch.countDown()
413414
} else {
414-
logInfo(s"Restarting Receiver ${receiver.streamId}")
415+
logInfo(s"Restarting Receiver $receiverId")
415416
submitJobThread.execute(self)
416417
}
417418
case Failure(e) =>
418419
if (stopping) {
419420
receiverExitLatch.countDown()
420421
} else {
421422
logError("Receiver has been stopped. Try to restart it.", e)
422-
logInfo(s"Restarting Receiver ${receiver.streamId}")
423+
logInfo(s"Restarting Receiver $receiverId")
423424
submitJobThread.execute(self)
424425
}
425-
}(submitJobThread)
426+
}(ThreadUtils.sameThread)
426427
logInfo(s"Receiver ${receiver.streamId} started")
427428
}
428429
})

0 commit comments

Comments
 (0)