From d30859d5130ce3974b633b31d5e74a02a815c772 Mon Sep 17 00:00:00 2001 From: Xi Lyu Date: Mon, 15 Apr 2024 14:33:13 +0200 Subject: [PATCH 1/2] Use asynchronous callback for execution cleanup --- .../execution/ExecuteThreadRunner.scala | 31 ++++++++++++++----- .../sql/connect/service/ExecuteHolder.scala | 16 +++++++--- .../execution/ReattachableExecuteSuite.scala | 22 +++++++++++++ .../planner/SparkConnectServiceSuite.scala | 7 ++++- 4 files changed, 64 insertions(+), 12 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index 62083d4892f7..d503dde3d18c 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.connect.execution +import scala.concurrent.{ExecutionContext, Promise} +import scala.util.Try import scala.util.control.NonFatal import com.google.protobuf.Message @@ -29,7 +31,7 @@ import org.apache.spark.sql.connect.common.ProtoUtils import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag} import org.apache.spark.sql.connect.utils.ErrorUtils -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * This class launches the actual execution in an execution thread. The execution pushes the @@ -37,10 +39,12 @@ import org.apache.spark.util.Utils */ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends Logging { + private val promise: Promise[Unit] = Promise[Unit]() + // The newly created thread will inherit all InheritableThreadLocals used by Spark, // e.g. SparkContext.localProperties. If considering implementing a thread-pool, // forwarding of thread locals needs to be taken into account. - private var executionThread: Thread = new ExecutionThread() + private val executionThread: Thread = new ExecutionThread(promise) private var interrupted: Boolean = false @@ -53,9 +57,11 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends executionThread.start() } - /** Joins the background execution thread after it is finished. */ - def join(): Unit = { - executionThread.join() + /** + * Register a callback that gets executed after completion/interruption of the execution + */ + private[connect] def processOnCompletion(callback: Try[Unit] => Unit): Unit = { + promise.future.onComplete(callback)(ExecuteThreadRunner.namedExecutionContext) } /** @@ -222,10 +228,21 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends .build() } - private class ExecutionThread + private class ExecutionThread(onCompletionPromise: Promise[Unit]) extends Thread(s"SparkConnectExecuteThread_opId=${executeHolder.operationId}") { override def run(): Unit = { - execute() + try { + execute() + onCompletionPromise.success(()) + } catch { + case NonFatal(e) => + onCompletionPromise.failure(e) + } } } } + +private[connect] object ExecuteThreadRunner { + private implicit val namedExecutionContext: ExecutionContext = ExecutionContext + .fromExecutor(ThreadUtils.newDaemonSingleThreadExecutor("SparkConnectExecuteThreadCallback")) +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index 974c13b08e31..ae5e0c29e575 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -114,6 +114,9 @@ private[connect] class ExecuteHolder( : mutable.ArrayBuffer[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]] = new mutable.ArrayBuffer[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]]() + /** For testing. Whether the async completion callback is called. */ + @volatile private[connect] var completionCallbackCalled: Boolean = false + /** * Start the execution. The execution is started in a background thread in ExecuteThreadRunner. * Responses are produced and cached in ExecuteResponseObserver. A GRPC thread consumes the @@ -234,8 +237,15 @@ private[connect] class ExecuteHolder( if (closedTime.isEmpty) { // interrupt execution, if still running. runner.interrupt() - // wait for execution to finish, to make sure no more results get pushed to responseObserver - runner.join() + // Do not wait for the execution to finish, clean up resources immediately. + runner.processOnCompletion { _ => + completionCallbackCalled = true + // The execution may not immediately get interrupted, clean up any remaining resources when + // it does. + responseObserver.removeAll() + // post closed to UI + eventsManager.postClosed() + } // interrupt any attached grpcResponseSenders grpcResponseSenders.foreach(_.interrupt()) // if there were still any grpcResponseSenders, register detach time @@ -245,8 +255,6 @@ private[connect] class ExecuteHolder( } // remove all cached responses from observer responseObserver.removeAll() - // post closed to UI - eventsManager.postClosed() closedTime = Some(System.currentTimeMillis()) } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala index 0e29a07b719a..06cd1a5666b6 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/execution/ReattachableExecuteSuite.scala @@ -355,4 +355,26 @@ class ReattachableExecuteSuite extends SparkConnectServerTest { assertEventuallyNoActiveExecutions() } } + + test("Async cleanup callback gets called after the execution is closed") { + withClient { client => + val query1 = client.execute(buildPlan(MEDIUM_RESULTS_QUERY)) + // just creating the iterator is lazy, trigger query1 to be sent. + query1.hasNext + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(SparkConnectService.executionManager.listExecuteHolders.length == 1) + } + val executeHolder1 = SparkConnectService.executionManager.listExecuteHolders.head + // Close execution + SparkConnectService.executionManager.removeExecuteHolder(executeHolder1.key) + // Check that queries get cancelled + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(SparkConnectService.executionManager.listExecuteHolders.length == 0) + } + // Check the async execute cleanup get called + Eventually.eventually(timeout(eventuallyTimeout)) { + assert(executeHolder1.completionCallbackCalled) + } + } + } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 90c9d13def61..06508bfc6a7c 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -31,6 +31,8 @@ import org.apache.arrow.vector.{BigIntVector, Float8Vector} import org.apache.arrow.vector.ipc.ArrowStreamReader import org.mockito.Mockito.when import org.scalatest.Tag +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar.convertIntToGrainOfTime import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkContext, SparkEnv} @@ -879,8 +881,11 @@ class SparkConnectServiceSuite assert(executeHolder.eventsManager.hasError.isDefined) } def onCompleted(producedRowCount: Option[Long] = None): Unit = { - assert(executeHolder.eventsManager.status == ExecuteStatus.Closed) assert(executeHolder.eventsManager.getProducedRowCount == producedRowCount) + // The eventsManager is closed asynchronously + Eventually.eventually(timeout(1.seconds)) { + assert(executeHolder.eventsManager.status == ExecuteStatus.Closed) + } } def onCanceled(): Unit = { assert(executeHolder.eventsManager.hasCanceled.contains(true)) From 40790a4380e376cee8ea7fc0e9abf47e205f8675 Mon Sep 17 00:00:00 2001 From: Xi Lyu Date: Mon, 15 Apr 2024 14:56:27 +0200 Subject: [PATCH 2/2] Remove join method from ExecuteHolder --- .../apache/spark/sql/connect/service/ExecuteHolder.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index ae5e0c29e575..5cf63c2195ab 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -128,13 +128,6 @@ private[connect] class ExecuteHolder( runner.start() } - /** - * Wait for the execution thread to finish and join it. - */ - def join(): Unit = { - runner.join() - } - /** * Attach an ExecuteGrpcResponseSender that will consume responses from the query and send them * out on the Grpc response stream. The sender will start from the start of the response stream.