Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 36 additions & 69 deletions core/src/main/scala/org/apache/spark/FutureAction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -154,88 +154,33 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
def jobId = jobWaiter.jobId
}


/**
* :: Experimental ::
* A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take,
* takeSample. Cancellation works by setting the cancelled flag to true and interrupting the
* action thread if it is being blocked by a job.
* This is an extension of the Scala Future interface to support cancellation.
*/
@Experimental
class ComplexFutureAction[T] extends FutureAction[T] {
class RunAsyncResult[T] private[spark] (jobGroupId: String,
jobGroupDescription: String,
sc: SparkContext,
func: => T) extends FutureAction[T] {

// Pointer to the thread that is executing the action. It is set when the action is run.
// Pointer to the thread that is executing the action; it is set when the action is run.
@volatile private var thread: Thread = _

// A flag indicating whether the future has been cancelled. This is used in case the future
// is cancelled before the action was even run (and thus we have no thread to interrupt).
@volatile private var _cancelled: Boolean = false

// A promise used to signal the future.
private val p = promise[T]()

override def cancel(): Unit = this.synchronized {
_cancelled = true
if (thread != null) {
thread.interrupt()
}
}

/**
* Executes some action enclosed in the closure. To properly enable cancellation, the closure
* should use runJob implementation in this promise. See takeAsync for example.
*/
def run(func: => T)(implicit executor: ExecutionContext): this.type = {
scala.concurrent.future {
thread = Thread.currentThread
try {
p.success(func)
} catch {
case e: Exception => p.failure(e)
} finally {
thread = null
}
}
this
}

/**
* Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext
* to enable cancellation.
* Cancel this Future and any Spark jobs launched from it. The cancellation of Spark jobs is
* performed asynchronously.
*/
def runJob[T, U, R](
rdd: RDD[T],
processPartition: Iterator[T] => U,
partitions: Seq[Int],
resultHandler: (Int, U) => Unit,
resultFunc: => R) {
// If the action hasn't been cancelled yet, submit the job. The check and the submitJob
// command need to be in an atomic block.
val job = this.synchronized {
if (!cancelled) {
rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc)
} else {
throw new SparkException("Action has been cancelled")
}
}

// Wait for the job to complete. If the action is cancelled (with an interrupt),
// cancel the job and stop the execution. This is not in a synchronized block because
// Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
try {
Await.ready(job, Duration.Inf)
} catch {
case e: InterruptedException =>
job.cancel()
throw new SparkException("Action has been cancelled")
def cancel(): Unit = this.synchronized {
if (thread != null) {
thread.interrupt()
sc.cancelJobGroup(jobGroupId)
thread.join()
thread = null
}
}

/**
* Returns whether the promise has been cancelled.
*/
def cancelled: Boolean = _cancelled

@throws(classOf[InterruptedException])
@throws(classOf[scala.concurrent.TimeoutException])
override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = {
Expand All @@ -255,4 +200,26 @@ class ComplexFutureAction[T] extends FutureAction[T] {
override def isCompleted: Boolean = p.isCompleted

override def value: Option[Try[T]] = p.future.value

private def run() {
thread = new Thread(s"RunAsync for job group $jobGroupId") {
override def run() {
try {
sc.setJobGroup(jobGroupId, jobGroupDescription, interruptOnCancel = true)
val result: T = func // Force evaluation
p.success(result)
} catch {
case e: InterruptedException =>
p.failure(new SparkException("runAsync has been cancelled"))
case t: Throwable =>
p.failure(t)
} finally {
sc.clearJobGroup()
}
}
}
thread.start()
}

run()
}
32 changes: 32 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,38 @@ class SparkContext(config: SparkConf) extends Logging {
runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler)
}

/**
* Start an asynchronous computation that may launch Spark jobs.
* Returns a `FutureAction` object that contains the result of the computation and allows the
* computation to be cancelled.
*
* @tparam T the type of the result
* @param jobGroupId the job group for Spark jobs launched by this computation
* @param description a description of the job group
* @param body the asynchronous computation
* @return the `FutureAction` holding the result of the computation
*/
def runAsync[T](jobGroupId: String,
description: String)
(body: => T): FutureAction[T] = {
new RunAsyncResult[T](jobGroupId, description, this, body)
}

/**
* Start an asynchronous computation that may launch Spark jobs.
* Returns a `FutureAction` object that contains the result of the computation and allows the
* computation to be cancelled.
*
* All Spark jobs launched by this computation will be created in the same (anonymous) job group.
*
* @tparam T the type of the result
* @param body the asynchronous computation
* @return the `FutureAction` holding the result of the computation
*/
def runAsync[T](body: => T): FutureAction[T] = {
runAsync(s"RunAsync${UUID.randomUUID().toString}", "Anonymous runAsync job")(body)
}

/**
* :: DeveloperApi ::
* Run a job that can return approximate results.
Expand Down
80 changes: 15 additions & 65 deletions core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,9 @@

package org.apache.spark.rdd

import java.util.concurrent.atomic.AtomicLong

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext.Implicits.global
import scala.reflect.ClassTag

import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
import org.apache.spark.{FutureAction, Logging}
import org.apache.spark.annotation.Experimental

/**
Expand All @@ -38,90 +34,44 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi
* Returns a future for counting the number of elements in the RDD.
*/
def countAsync(): FutureAction[Long] = {
val totalCount = new AtomicLong
self.context.submitJob(
self,
(iter: Iterator[T]) => {
var result = 0L
while (iter.hasNext) {
result += 1L
iter.next()
}
result
},
Range(0, self.partitions.size),
(index: Int, data: Long) => totalCount.addAndGet(data),
totalCount.get())
self.sparkContext.runAsync {
self.count()
}
}

/**
* Returns a future for retrieving all elements of this RDD.
*/
def collectAsync(): FutureAction[Seq[T]] = {
val results = new Array[Array[T]](self.partitions.size)
self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size),
(index, data) => results(index) = data, results.flatten.toSeq)
self.sparkContext.runAsync {
self.collect()
}
}

/**
* Returns a future for retrieving the first num elements of the RDD.
*/
def takeAsync(num: Int): FutureAction[Seq[T]] = {
val f = new ComplexFutureAction[Seq[T]]

f.run {
val results = new ArrayBuffer[T](num)
val totalParts = self.partitions.length
var partsScanned = 0
while (results.size < num && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
if (partsScanned > 0) {
// If we didn't find any rows after the first iteration, just try all partitions next.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
// by 50%.
if (results.size == 0) {
numPartsToTry = totalParts - 1
} else {
numPartsToTry = (1.5 * num * partsScanned / results.size).toInt
}
}
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions

val left = num - results.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)

val buf = new Array[Array[T]](p.size)
f.runJob(self,
(it: Iterator[T]) => it.take(left).toArray,
p,
(index: Int, data: Array[T]) => buf(index) = data,
Unit)

buf.foreach(results ++= _.take(num - results.size))
partsScanned += numPartsToTry
}
results.toSeq
self.sparkContext.runAsync {
self.take(num)
}

f
}

/**
* Applies a function f to all elements of this RDD.
*/
def foreachAsync(f: T => Unit): FutureAction[Unit] = {
val cleanF = self.context.clean(f)
self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.size),
(index, data) => Unit, Unit)
self.sparkContext.runAsync {
self.foreach(f)
}
}

/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = {
self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size),
(index, data) => Unit, Unit)
self.sparkContext.runAsync {
self.foreachPartition(f)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,10 @@ class DAGScheduler(
resultHandler: (Int, U) => Unit,
properties: Properties = null): JobWaiter[U] =
{
if (Thread.currentThread().isInterrupted) {
throw new SparkException(
"Shouldn't submit jobs from interrupted threads (was the job cancelled?)")
}
// Check to make sure we are not launching a task on a partition that does not exist.
val maxPartitions = rdd.partitions.length
partitions.find(p => p >= maxPartitions || p < 0).foreach { p =>
Expand All @@ -493,8 +497,11 @@ class DAGScheduler(
assert(partitions.size > 0)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
// Make a defensive copy of the mutable properties object (this fixes a race condition that
// could occur during cancellation).
val propertiesCopy = Option(properties).map(_.clone().asInstanceOf[Properties]).orNull
eventProcessActor ! JobSubmitted(
jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, propertiesCopy)
waiter
}

Expand Down Expand Up @@ -675,7 +682,7 @@ class DAGScheduler(
// Cancel all jobs belonging to this job group.
// First finds all active jobs with this group id, and then kill stages for them.
val activeInGroup = activeJobs.filter(activeJob =>
groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
groupId == Option(activeJob.properties).map(_.get(SparkContext.SPARK_JOB_GROUP_ID)).orNull)
val jobIds = activeInGroup.map(_.jobId)
jobIds.foreach(handleJobCancellation(_, "part of cancelled job group %s".format(groupId)))
submitWaitingStages()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,9 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
sc.addSparkListener(listener)

val numTasks = 10
val f = sc.parallelize(1 to 10000, numTasks).map { i => Thread.sleep(10); i }.countAsync()
val f = sc.runAsync {
sc.parallelize(1 to 10000, numTasks).map { i => Thread.sleep(10); i }.count()
}
// Wait until one task has started (because we want to make sure that any tasks that are started
// have corresponding end events sent to the listener).
var finishTime = System.currentTimeMillis + WAIT_TIMEOUT_MILLIS
Expand Down