Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.connect.execution

import scala.concurrent.{ExecutionContext, Promise}
import scala.util.Try
import scala.util.control.NonFatal

import com.google.protobuf.Message
Expand All @@ -29,18 +31,20 @@ import org.apache.spark.sql.connect.common.ProtoUtils
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag}
import org.apache.spark.sql.connect.utils.ErrorUtils
import org.apache.spark.util.Utils
import org.apache.spark.util.{ThreadUtils, Utils}

/**
* This class launches the actual execution in an execution thread. The execution pushes the
* responses to a ExecuteResponseObserver in executeHolder.
*/
private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends Logging {

private val promise: Promise[Unit] = Promise[Unit]()

// The newly created thread will inherit all InheritableThreadLocals used by Spark,
// e.g. SparkContext.localProperties. If considering implementing a thread-pool,
// forwarding of thread locals needs to be taken into account.
private var executionThread: Thread = new ExecutionThread()
private val executionThread: Thread = new ExecutionThread(promise)

private var interrupted: Boolean = false

Expand All @@ -53,9 +57,11 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
executionThread.start()
}

/** Joins the background execution thread after it is finished. */
def join(): Unit = {
executionThread.join()
/**
* Register a callback that gets executed after completion/interruption of the execution
*/
private[connect] def processOnCompletion(callback: Try[Unit] => Unit): Unit = {
promise.future.onComplete(callback)(ExecuteThreadRunner.namedExecutionContext)
}

/**
Expand Down Expand Up @@ -222,10 +228,21 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
.build()
}

private class ExecutionThread
private class ExecutionThread(onCompletionPromise: Promise[Unit])
extends Thread(s"SparkConnectExecuteThread_opId=${executeHolder.operationId}") {
override def run(): Unit = {
execute()
try {
execute()
onCompletionPromise.success(())
} catch {
case NonFatal(e) =>
onCompletionPromise.failure(e)
}
}
}
}

private[connect] object ExecuteThreadRunner {
private implicit val namedExecutionContext: ExecutionContext = ExecutionContext
.fromExecutor(ThreadUtils.newDaemonSingleThreadExecutor("SparkConnectExecuteThreadCallback"))
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ private[connect] class ExecuteHolder(
: mutable.ArrayBuffer[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]] =
new mutable.ArrayBuffer[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]]()

/** For testing. Whether the async completion callback is called. */
@volatile private[connect] var completionCallbackCalled: Boolean = false

/**
* Start the execution. The execution is started in a background thread in ExecuteThreadRunner.
* Responses are produced and cached in ExecuteResponseObserver. A GRPC thread consumes the
Expand All @@ -125,13 +128,6 @@ private[connect] class ExecuteHolder(
runner.start()
}

/**
* Wait for the execution thread to finish and join it.
*/
def join(): Unit = {
runner.join()
}

/**
* Attach an ExecuteGrpcResponseSender that will consume responses from the query and send them
* out on the Grpc response stream. The sender will start from the start of the response stream.
Expand Down Expand Up @@ -234,8 +230,15 @@ private[connect] class ExecuteHolder(
if (closedTime.isEmpty) {
// interrupt execution, if still running.
runner.interrupt()
// wait for execution to finish, to make sure no more results get pushed to responseObserver
runner.join()
// Do not wait for the execution to finish, clean up resources immediately.
runner.processOnCompletion { _ =>
completionCallbackCalled = true
// The execution may not immediately get interrupted, clean up any remaining resources when
// it does.
responseObserver.removeAll()
// post closed to UI
eventsManager.postClosed()
}
// interrupt any attached grpcResponseSenders
grpcResponseSenders.foreach(_.interrupt())
// if there were still any grpcResponseSenders, register detach time
Expand All @@ -245,8 +248,6 @@ private[connect] class ExecuteHolder(
}
// remove all cached responses from observer
responseObserver.removeAll()
// post closed to UI
eventsManager.postClosed()
closedTime = Some(System.currentTimeMillis())
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,4 +355,26 @@ class ReattachableExecuteSuite extends SparkConnectServerTest {
assertEventuallyNoActiveExecutions()
}
}

test("Async cleanup callback gets called after the execution is closed") {
withClient { client =>
val query1 = client.execute(buildPlan(MEDIUM_RESULTS_QUERY))
// just creating the iterator is lazy, trigger query1 to be sent.
query1.hasNext
Eventually.eventually(timeout(eventuallyTimeout)) {
assert(SparkConnectService.executionManager.listExecuteHolders.length == 1)
}
val executeHolder1 = SparkConnectService.executionManager.listExecuteHolders.head
// Close execution
SparkConnectService.executionManager.removeExecuteHolder(executeHolder1.key)
// Check that queries get cancelled
Eventually.eventually(timeout(eventuallyTimeout)) {
assert(SparkConnectService.executionManager.listExecuteHolders.length == 0)
}
// Check the async execute cleanup get called
Eventually.eventually(timeout(eventuallyTimeout)) {
assert(executeHolder1.completionCallbackCalled)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import org.apache.arrow.vector.{BigIntVector, Float8Vector}
import org.apache.arrow.vector.ipc.ArrowStreamReader
import org.mockito.Mockito.when
import org.scalatest.Tag
import org.scalatest.concurrent.Eventually
import org.scalatest.time.SpanSugar.convertIntToGrainOfTime
import org.scalatestplus.mockito.MockitoSugar

import org.apache.spark.{SparkContext, SparkEnv}
Expand Down Expand Up @@ -879,8 +881,11 @@ class SparkConnectServiceSuite
assert(executeHolder.eventsManager.hasError.isDefined)
}
def onCompleted(producedRowCount: Option[Long] = None): Unit = {
assert(executeHolder.eventsManager.status == ExecuteStatus.Closed)
assert(executeHolder.eventsManager.getProducedRowCount == producedRowCount)
// The eventsManager is closed asynchronously
Eventually.eventually(timeout(1.seconds)) {
assert(executeHolder.eventsManager.status == ExecuteStatus.Closed)
}
}
def onCanceled(): Unit = {
assert(executeHolder.eventsManager.hasCanceled.contains(true))
Expand Down