From 48820826b985b5ff131c383fd2e286254256e0b7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 21 Sep 2014 15:15:22 -0700 Subject: [PATCH] Add SparkContext.runAsync and use it to re-implement AsyncRDDActions. --- .../scala/org/apache/spark/FutureAction.scala | 105 ++++++------------ .../scala/org/apache/spark/SparkContext.scala | 32 ++++++ .../apache/spark/rdd/AsyncRDDActions.scala | 80 +++---------- .../apache/spark/scheduler/DAGScheduler.scala | 11 +- .../spark/scheduler/SparkListenerSuite.scala | 4 +- 5 files changed, 95 insertions(+), 137 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 75ea535f2f57..4e3d23c1b0fa 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -154,88 +154,33 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: def jobId = jobWaiter.jobId } - /** - * :: Experimental :: - * A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take, - * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the - * action thread if it is being blocked by a job. + * This is an extension of the Scala Future interface to support cancellation. */ -@Experimental -class ComplexFutureAction[T] extends FutureAction[T] { +class RunAsyncResult[T] private[spark] (jobGroupId: String, + jobGroupDescription: String, + sc: SparkContext, + func: => T) extends FutureAction[T] { - // Pointer to the thread that is executing the action. It is set when the action is run. + // Pointer to the thread that is executing the action; it is set when the action is run. @volatile private var thread: Thread = _ - // A flag indicating whether the future has been cancelled. This is used in case the future - // is cancelled before the action was even run (and thus we have no thread to interrupt). - @volatile private var _cancelled: Boolean = false - // A promise used to signal the future. private val p = promise[T]() - override def cancel(): Unit = this.synchronized { - _cancelled = true - if (thread != null) { - thread.interrupt() - } - } - - /** - * Executes some action enclosed in the closure. To properly enable cancellation, the closure - * should use runJob implementation in this promise. See takeAsync for example. - */ - def run(func: => T)(implicit executor: ExecutionContext): this.type = { - scala.concurrent.future { - thread = Thread.currentThread - try { - p.success(func) - } catch { - case e: Exception => p.failure(e) - } finally { - thread = null - } - } - this - } - /** - * Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext - * to enable cancellation. + * Cancel this Future and any Spark jobs launched from it. The cancellation of Spark jobs is + * performed asynchronously. */ - def runJob[T, U, R]( - rdd: RDD[T], - processPartition: Iterator[T] => U, - partitions: Seq[Int], - resultHandler: (Int, U) => Unit, - resultFunc: => R) { - // If the action hasn't been cancelled yet, submit the job. The check and the submitJob - // command need to be in an atomic block. - val job = this.synchronized { - if (!cancelled) { - rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc) - } else { - throw new SparkException("Action has been cancelled") - } - } - - // Wait for the job to complete. If the action is cancelled (with an interrupt), - // cancel the job and stop the execution. This is not in a synchronized block because - // Await.ready eventually waits on the monitor in FutureJob.jobWaiter. - try { - Await.ready(job, Duration.Inf) - } catch { - case e: InterruptedException => - job.cancel() - throw new SparkException("Action has been cancelled") + def cancel(): Unit = this.synchronized { + if (thread != null) { + thread.interrupt() + sc.cancelJobGroup(jobGroupId) + thread.join() + thread = null } } - /** - * Returns whether the promise has been cancelled. - */ - def cancelled: Boolean = _cancelled - @throws(classOf[InterruptedException]) @throws(classOf[scala.concurrent.TimeoutException]) override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = { @@ -255,4 +200,26 @@ class ComplexFutureAction[T] extends FutureAction[T] { override def isCompleted: Boolean = p.isCompleted override def value: Option[Try[T]] = p.future.value + + private def run() { + thread = new Thread(s"RunAsync for job group $jobGroupId") { + override def run() { + try { + sc.setJobGroup(jobGroupId, jobGroupDescription, interruptOnCancel = true) + val result: T = func // Force evaluation + p.success(result) + } catch { + case e: InterruptedException => + p.failure(new SparkException("runAsync has been cancelled")) + case t: Throwable => + p.failure(t) + } finally { + sc.clearJobGroup() + } + } + } + thread.start() + } + + run() } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 428f019b02a2..8cc2672245c0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1143,6 +1143,38 @@ class SparkContext(config: SparkConf) extends Logging { runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler) } + /** + * Start an asynchronous computation that may launch Spark jobs. + * Returns a `FutureAction` object that contains the result of the computation and allows the + * computation to be cancelled. + * + * @tparam T the type of the result + * @param jobGroupId the job group for Spark jobs launched by this computation + * @param description a description of the job group + * @param body the asynchronous computation + * @return the `FutureAction` holding the result of the computation + */ + def runAsync[T](jobGroupId: String, + description: String) + (body: => T): FutureAction[T] = { + new RunAsyncResult[T](jobGroupId, description, this, body) + } + + /** + * Start an asynchronous computation that may launch Spark jobs. + * Returns a `FutureAction` object that contains the result of the computation and allows the + * computation to be cancelled. + * + * All Spark jobs launched by this computation will be created in the same (anonymous) job group. + * + * @tparam T the type of the result + * @param body the asynchronous computation + * @return the `FutureAction` holding the result of the computation + */ + def runAsync[T](body: => T): FutureAction[T] = { + runAsync(s"RunAsync${UUID.randomUUID().toString}", "Anonymous runAsync job")(body) + } + /** * :: DeveloperApi :: * Run a job that can return approximate results. diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index b62f3fbdc4a1..c58ec8ddbbb1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -17,13 +17,9 @@ package org.apache.spark.rdd -import java.util.concurrent.atomic.AtomicLong - -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global import scala.reflect.ClassTag -import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} +import org.apache.spark.{FutureAction, Logging} import org.apache.spark.annotation.Experimental /** @@ -38,90 +34,44 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Returns a future for counting the number of elements in the RDD. */ def countAsync(): FutureAction[Long] = { - val totalCount = new AtomicLong - self.context.submitJob( - self, - (iter: Iterator[T]) => { - var result = 0L - while (iter.hasNext) { - result += 1L - iter.next() - } - result - }, - Range(0, self.partitions.size), - (index: Int, data: Long) => totalCount.addAndGet(data), - totalCount.get()) + self.sparkContext.runAsync { + self.count() + } } /** * Returns a future for retrieving all elements of this RDD. */ def collectAsync(): FutureAction[Seq[T]] = { - val results = new Array[Array[T]](self.partitions.size) - self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size), - (index, data) => results(index) = data, results.flatten.toSeq) + self.sparkContext.runAsync { + self.collect() + } } /** * Returns a future for retrieving the first num elements of the RDD. */ def takeAsync(num: Int): FutureAction[Seq[T]] = { - val f = new ComplexFutureAction[Seq[T]] - - f.run { - val results = new ArrayBuffer[T](num) - val totalParts = self.partitions.length - var partsScanned = 0 - while (results.size < num && partsScanned < totalParts) { - // The number of partitions to try in this iteration. It is ok for this number to be - // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 - if (partsScanned > 0) { - // If we didn't find any rows after the first iteration, just try all partitions next. - // Otherwise, interpolate the number of partitions we need to try, but overestimate it - // by 50%. - if (results.size == 0) { - numPartsToTry = totalParts - 1 - } else { - numPartsToTry = (1.5 * num * partsScanned / results.size).toInt - } - } - numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions - - val left = num - results.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) - - val buf = new Array[Array[T]](p.size) - f.runJob(self, - (it: Iterator[T]) => it.take(left).toArray, - p, - (index: Int, data: Array[T]) => buf(index) = data, - Unit) - - buf.foreach(results ++= _.take(num - results.size)) - partsScanned += numPartsToTry - } - results.toSeq + self.sparkContext.runAsync { + self.take(num) } - - f } /** * Applies a function f to all elements of this RDD. */ def foreachAsync(f: T => Unit): FutureAction[Unit] = { - val cleanF = self.context.clean(f) - self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.size), - (index, data) => Unit, Unit) + self.sparkContext.runAsync { + self.foreach(f) + } } /** * Applies a function f to each partition of this RDD. */ def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = { - self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size), - (index, data) => Unit, Unit) + self.sparkContext.runAsync { + self.foreachPartition(f) + } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b2774dfc4755..b3afa18455a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -477,6 +477,10 @@ class DAGScheduler( resultHandler: (Int, U) => Unit, properties: Properties = null): JobWaiter[U] = { + if (Thread.currentThread().isInterrupted) { + throw new SparkException( + "Shouldn't submit jobs from interrupted threads (was the job cancelled?)") + } // Check to make sure we are not launching a task on a partition that does not exist. val maxPartitions = rdd.partitions.length partitions.find(p => p >= maxPartitions || p < 0).foreach { p => @@ -493,8 +497,11 @@ class DAGScheduler( assert(partitions.size > 0) val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _] val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler) + // Make a defensive copy of the mutable properties object (this fixes a race condition that + // could occur during cancellation). + val propertiesCopy = Option(properties).map(_.clone().asInstanceOf[Properties]).orNull eventProcessActor ! JobSubmitted( - jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties) + jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, propertiesCopy) waiter } @@ -675,7 +682,7 @@ class DAGScheduler( // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. val activeInGroup = activeJobs.filter(activeJob => - groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + groupId == Option(activeJob.properties).map(_.get(SparkContext.SPARK_JOB_GROUP_ID)).orNull) val jobIds = activeInGroup.map(_.jobId) jobIds.foreach(handleJobCancellation(_, "part of cancelled job group %s".format(groupId))) submitWaitingStages() diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index ab35e8edc4eb..8db6b1df4cea 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -308,7 +308,9 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers sc.addSparkListener(listener) val numTasks = 10 - val f = sc.parallelize(1 to 10000, numTasks).map { i => Thread.sleep(10); i }.countAsync() + val f = sc.runAsync { + sc.parallelize(1 to 10000, numTasks).map { i => Thread.sleep(10); i }.count() + } // Wait until one task has started (because we want to make sure that any tasks that are started // have corresponding end events sent to the listener). var finishTime = System.currentTimeMillis + WAIT_TIMEOUT_MILLIS