From 90d59cf7e945c15adadf1fbb3760b816e1305b06 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Thu, 17 Jul 2025 16:17:43 -0700 Subject: [PATCH 1/7] wip --- .../pipelines/DataflowGraphRegistry.scala | 4 +- .../connect/pipelines/PipelinesHandler.scala | 44 ++++++++-------- .../sql/connect/service/SessionHolder.scala | 52 ++++++++++++++++++- .../pipelines/PythonPipelineSuite.scala | 2 +- ...SparkDeclarativePipelinesServerSuite.scala | 9 ++-- .../SparkDeclarativePipelinesServerTest.scala | 15 ++++-- 6 files changed, 89 insertions(+), 37 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala index 4402dde04f3c..a5edfec06531 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.pipelines.graph.GraphRegistrationContext // TODO(SPARK-51727): Currently DataflowGraphRegistry is a singleton, but it should instead be // scoped to a single SparkSession for proper isolation between pipelines that are run on the // same cluster. -object DataflowGraphRegistry { +class DataflowGraphRegistry { private val dataflowGraphs = new ConcurrentHashMap[String, GraphRegistrationContext]() @@ -55,7 +55,7 @@ object DataflowGraphRegistry { /** Retrieves the graph for a given id, and throws if the id could not be found. */ def getDataflowGraphOrThrow(dataflowGraphId: String): GraphRegistrationContext = - DataflowGraphRegistry.getDataflowGraph(dataflowGraphId).getOrElse { + getDataflowGraph(dataflowGraphId).getOrElse { throw new SparkException( errorClass = "DATAFLOW_GRAPH_NOT_FOUND", messageParameters = Map("graphId" -> dataflowGraphId), diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index 7f92aa13944c..a86ad7b9e1a2 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -28,7 +28,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.pipelines.Language.Python @@ -68,7 +67,7 @@ private[connect] object PipelinesHandler extends Logging { cmd.getCommandTypeCase match { case proto.PipelineCommand.CommandTypeCase.CREATE_DATAFLOW_GRAPH => val createdGraphId = - createDataflowGraph(cmd.getCreateDataflowGraph, sessionHolder.session) + createDataflowGraph(cmd.getCreateDataflowGraph, sessionHolder) PipelineCommandResult .newBuilder() .setCreateDataflowGraphResult( @@ -78,15 +77,15 @@ private[connect] object PipelinesHandler extends Logging { .build() case proto.PipelineCommand.CommandTypeCase.DROP_DATAFLOW_GRAPH => logInfo(s"Drop pipeline cmd received: $cmd") - DataflowGraphRegistry.dropDataflowGraph(cmd.getDropDataflowGraph.getDataflowGraphId) + sessionHolder.dropDataflowGraph(cmd.getDropDataflowGraph.getDataflowGraphId) defaultResponse case proto.PipelineCommand.CommandTypeCase.DEFINE_DATASET => logInfo(s"Define pipelines dataset cmd received: $cmd") - defineDataset(cmd.getDefineDataset, sessionHolder.session) + defineDataset(cmd.getDefineDataset, sessionHolder) defaultResponse case proto.PipelineCommand.CommandTypeCase.DEFINE_FLOW => logInfo(s"Define pipelines flow cmd received: $cmd") - defineFlow(cmd.getDefineFlow, transformRelationFunc, sessionHolder.session) + defineFlow(cmd.getDefineFlow, transformRelationFunc, sessionHolder) defaultResponse case proto.PipelineCommand.CommandTypeCase.START_RUN => logInfo(s"Start pipeline cmd received: $cmd") @@ -94,7 +93,7 @@ private[connect] object PipelinesHandler extends Logging { defaultResponse case proto.PipelineCommand.CommandTypeCase.DEFINE_SQL_GRAPH_ELEMENTS => logInfo(s"Register sql datasets cmd received: $cmd") - defineSqlGraphElements(cmd.getDefineSqlGraphElements, sessionHolder.session) + defineSqlGraphElements(cmd.getDefineSqlGraphElements, sessionHolder) defaultResponse case other => throw new UnsupportedOperationException(s"$other not supported") } @@ -102,24 +101,24 @@ private[connect] object PipelinesHandler extends Logging { private def createDataflowGraph( cmd: proto.PipelineCommand.CreateDataflowGraph, - spark: SparkSession): String = { + sessionHolder: SessionHolder): String = { val defaultCatalog = Option .when(cmd.hasDefaultCatalog)(cmd.getDefaultCatalog) .getOrElse { logInfo(s"No default catalog was supplied. Falling back to the current catalog.") - spark.catalog.currentCatalog() + sessionHolder.session.catalog.currentCatalog() } val defaultDatabase = Option .when(cmd.hasDefaultDatabase)(cmd.getDefaultDatabase) .getOrElse { logInfo(s"No default database was supplied. Falling back to the current database.") - spark.catalog.currentDatabase + sessionHolder.session.catalog.currentDatabase } val defaultSqlConf = cmd.getSqlConfMap.asScala.toMap - DataflowGraphRegistry.createDataflowGraph( + sessionHolder.createDataflowGraph( defaultCatalog = defaultCatalog, defaultDatabase = defaultDatabase, defaultSqlConf = defaultSqlConf) @@ -127,24 +126,25 @@ private[connect] object PipelinesHandler extends Logging { private def defineSqlGraphElements( cmd: proto.PipelineCommand.DefineSqlGraphElements, - session: SparkSession): Unit = { + sessionHolder: SessionHolder): Unit = { val dataflowGraphId = cmd.getDataflowGraphId - val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) + val graphElementRegistry = sessionHolder.getDataflowGraphOrThrow(dataflowGraphId) val sqlGraphElementRegistrationContext = new SqlGraphRegistrationContext(graphElementRegistry) - sqlGraphElementRegistrationContext.processSqlFile(cmd.getSqlText, cmd.getSqlFilePath, session) + sqlGraphElementRegistrationContext.processSqlFile( + cmd.getSqlText, cmd.getSqlFilePath, sessionHolder.session) } private def defineDataset( dataset: proto.PipelineCommand.DefineDataset, - sparkSession: SparkSession): Unit = { + sessionHolder: SessionHolder): Unit = { val dataflowGraphId = dataset.getDataflowGraphId - val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) + val graphElementRegistry = sessionHolder.getDataflowGraphOrThrow(dataflowGraphId) dataset.getDatasetType match { case proto.DatasetType.MATERIALIZED_VIEW | proto.DatasetType.TABLE => val tableIdentifier = - GraphIdentifierManager.parseTableIdentifier(dataset.getDatasetName, sparkSession) + GraphIdentifierManager.parseTableIdentifier(dataset.getDatasetName, sessionHolder.session) graphElementRegistry.registerTable( Table( identifier = tableIdentifier, @@ -165,7 +165,7 @@ private[connect] object PipelinesHandler extends Logging { isStreamingTable = dataset.getDatasetType == proto.DatasetType.TABLE)) case proto.DatasetType.TEMPORARY_VIEW => val viewIdentifier = - GraphIdentifierManager.parseTableIdentifier(dataset.getDatasetName, sparkSession) + GraphIdentifierManager.parseTableIdentifier(dataset.getDatasetName, sessionHolder.session) graphElementRegistry.registerView( TemporaryView( @@ -184,14 +184,14 @@ private[connect] object PipelinesHandler extends Logging { private def defineFlow( flow: proto.PipelineCommand.DefineFlow, transformRelationFunc: Relation => LogicalPlan, - sparkSession: SparkSession): Unit = { + sessionHolder: SessionHolder): Unit = { val dataflowGraphId = flow.getDataflowGraphId - val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) + val graphElementRegistry = sessionHolder.getDataflowGraphOrThrow(dataflowGraphId) val isImplicitFlow = flow.getFlowName == flow.getTargetDatasetName val flowIdentifier = GraphIdentifierManager - .parseTableIdentifier(name = flow.getFlowName, spark = sparkSession) + .parseTableIdentifier(name = flow.getFlowName, spark = sessionHolder.session) // If the flow is not an implicit flow (i.e. one defined as part of dataset creation), then // it must be a single-part identifier. @@ -205,7 +205,7 @@ private[connect] object PipelinesHandler extends Logging { new UnresolvedFlow( identifier = flowIdentifier, destinationIdentifier = GraphIdentifierManager - .parseTableIdentifier(name = flow.getTargetDatasetName, spark = sparkSession), + .parseTableIdentifier(name = flow.getTargetDatasetName, spark = sessionHolder.session), func = FlowAnalysis.createFlowFunctionFromLogicalPlan(transformRelationFunc(flow.getRelation)), sqlConf = flow.getSqlConfMap.asScala.toMap, @@ -224,7 +224,7 @@ private[connect] object PipelinesHandler extends Logging { responseObserver: StreamObserver[ExecutePlanResponse], sessionHolder: SessionHolder): Unit = { val dataflowGraphId = cmd.getDataflowGraphId - val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) + val graphElementRegistry = sessionHolder.getDataflowGraphOrThrow(dataflowGraphId) val tableFiltersResult = createTableFilters(cmd, graphElementRegistry, sessionHolder) // We will use this variable to store the run failure event if it occurs. This will be set diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index ada322fd859c..b5070d1b7315 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -38,10 +38,11 @@ import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.ml.MLCache +import org.apache.spark.sql.connect.pipelines.DataflowGraphRegistry import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC} -import org.apache.spark.sql.pipelines.graph.PipelineUpdateContext +import org.apache.spark.sql.pipelines.graph.{GraphRegistrationContext, PipelineUpdateContext} import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.util.{SystemClock, Utils} @@ -125,6 +126,9 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio private lazy val pipelineExecutions = new ConcurrentHashMap[String, PipelineUpdateContext]() + // Registry for dataflow graphs specific to this session + private lazy val dataflowGraphRegistry: DataflowGraphRegistry = new DataflowGraphRegistry() + // Handles Python process clean up for streaming queries. Initialized on first use in a query. private[connect] lazy val streamingForeachBatchRunnerCleanerCache = new StreamingForeachBatchHelper.CleanerCache(this) @@ -320,6 +324,9 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio // Stops all pipeline execution and clears the pipeline execution cache removeAllPipelineExecutions() + // Clean up dataflow graphs + dropAllDataflowGraphs() + // if there is a server side listener, clean up related resources if (streamingServersideListenerHolder.isServerSideListenerRegistered) { streamingServersideListenerHolder.cleanUp() @@ -486,6 +493,49 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio Option(pipelineExecutions.get(graphId)) } + private[connect] def createDataflowGraph( + defaultCatalog: String, + defaultDatabase: String, + defaultSqlConf: Map[String, String]): String = { + dataflowGraphRegistry.createDataflowGraph(defaultCatalog, defaultDatabase, defaultSqlConf) + } + + /** + * Retrieves the dataflow graph for the given graph ID. + */ + private[connect] def getDataflowGraph(graphId: String): Option[GraphRegistrationContext] = { + dataflowGraphRegistry.getDataflowGraph(graphId) + } + + /** + * Retrieves the dataflow graph for the given graph ID, throwing if not found. + */ + private[connect] def getDataflowGraphOrThrow(graphId: String): GraphRegistrationContext = { + dataflowGraphRegistry.getDataflowGraphOrThrow(graphId) + } + + /** + * Removes the dataflow graph with the given ID. + */ + private[connect] def dropDataflowGraph(graphId: String): Unit = { + dataflowGraphRegistry.dropDataflowGraph(graphId) + } + + /** + * Returns all dataflow graphs in this session. + */ + private[connect] def getAllDataflowGraphs: Seq[GraphRegistrationContext] = { + dataflowGraphRegistry.getAllDataflowGraphs + } + + /** + * Removes all dataflow graphs from this session. + * Called during session cleanup. + */ + private[connect] def dropAllDataflowGraphs(): Unit = { + dataflowGraphRegistry.dropAllDataflowGraphs() + } + /** * An accumulator for Python executors. * diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala index a9e8f9b5245b..4e289d9ea7ea 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala @@ -79,7 +79,7 @@ class PythonPipelineSuite s"Python process failed with exit code $exitCode. Output: ${output.mkString("\n")}") } - val dataflowGraphContexts = DataflowGraphRegistry.getAllDataflowGraphs + val dataflowGraphContexts = getSessionHolder.getAllDataflowGraphs assert(dataflowGraphContexts.size == 1) dataflowGraphContexts.head.toDataflowGraph diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala index 3b200c3d08ac..512bb10ff467 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala @@ -41,8 +41,7 @@ class SparkDeclarativePipelinesServerSuite .newBuilder() .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId val definition = - DataflowGraphRegistry - .getDataflowGraphOrThrow(graphId) + getSessionHolder.getDataflowGraphOrThrow(graphId) assert(definition.defaultDatabase == "test_db") } } @@ -115,8 +114,7 @@ class SparkDeclarativePipelinesServerSuite |""".stripMargin) val definition = - DataflowGraphRegistry - .getDataflowGraphOrThrow(graphId) + getSessionHolder.getDataflowGraphOrThrow(graphId) val graph = definition.toDataflowGraph.resolve() @@ -161,8 +159,7 @@ class SparkDeclarativePipelinesServerSuite } val definition = - DataflowGraphRegistry - .getDataflowGraphOrThrow(graphId) + getSessionHolder.getDataflowGraphOrThrow(graphId) registerPipelineDatasets(pipeline) val graph = definition.toDataflowGraph diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala index 003fd30b6075..cbb01a4f755c 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala @@ -23,20 +23,25 @@ import org.apache.spark.connect.{proto => sc} import org.apache.spark.connect.proto.{PipelineCommand, PipelineEvent} import org.apache.spark.sql.connect.{SparkConnectServerTest, SparkConnectTestUtils} import org.apache.spark.sql.connect.planner.SparkConnectPlanner -import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService} +import org.apache.spark.sql.connect.service.{SessionHolder, SessionKey, SparkConnectService} import org.apache.spark.sql.pipelines.utils.PipelineTest class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest { override def afterEach(): Unit = { - SparkConnectService.sessionManager - .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) - .foreach(_.removeAllPipelineExecutions()) - DataflowGraphRegistry.dropAllDataflowGraphs() + getSessionHolder.removeAllPipelineExecutions() + getSessionHolder.dropAllDataflowGraphs() PipelineTest.cleanupMetastore(spark) super.afterEach() } + // Helper method to get the session holder + protected def getSessionHolder: SessionHolder = { + SparkConnectService.sessionManager + .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) + .getOrElse(throw new RuntimeException("Session not found")) + } + def buildPlanFromPipelineCommand(command: sc.PipelineCommand): sc.Plan = { sc.Plan .newBuilder() From b5db9cc9b0133ef9dcc5c55c7c7b91762428e3b7 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Thu, 17 Jul 2025 17:05:58 -0700 Subject: [PATCH 2/7] test --- .../connect/pipelines/PythonPipelineSuite.scala | 16 ++++++++++++++-- .../SparkDeclarativePipelinesServerSuite.scala | 6 +++--- .../SparkDeclarativePipelinesServerTest.scala | 4 +--- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala index 4e289d9ea7ea..9a0ad74bd567 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala @@ -28,6 +28,7 @@ import scala.util.Try import org.apache.spark.api.python.PythonUtils import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.connect.service.SparkConnectService import org.apache.spark.sql.pipelines.graph.DataflowGraph import org.apache.spark.sql.pipelines.utils.{EventVerificationTestHelpers, TestPipelineUpdateContextMixin} @@ -78,8 +79,19 @@ class PythonPipelineSuite throw new RuntimeException( s"Python process failed with exit code $exitCode. Output: ${output.mkString("\n")}") } - - val dataflowGraphContexts = getSessionHolder.getAllDataflowGraphs + val activateSessions = SparkConnectService.sessionManager.listActiveSessions + // there should be only one active session + assert(activateSessions.size == 1) + + // get the session holder for the active session + val sessionHolder = + SparkConnectService.sessionManager.getIsolatedSessionIfPresent(activateSessions.head.key) + .getOrElse { + throw new RuntimeException("Session not found") + } + + // get all dataflow graphs from the session holder + val dataflowGraphContexts = sessionHolder.getAllDataflowGraphs assert(dataflowGraphContexts.size == 1) dataflowGraphContexts.head.toDataflowGraph diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala index 512bb10ff467..f1cc4270bbb9 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala @@ -41,7 +41,7 @@ class SparkDeclarativePipelinesServerSuite .newBuilder() .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId val definition = - getSessionHolder.getDataflowGraphOrThrow(graphId) + getDefaultSessionHolder.getDataflowGraphOrThrow(graphId) assert(definition.defaultDatabase == "test_db") } } @@ -114,7 +114,7 @@ class SparkDeclarativePipelinesServerSuite |""".stripMargin) val definition = - getSessionHolder.getDataflowGraphOrThrow(graphId) + getDefaultSessionHolder.getDataflowGraphOrThrow(graphId) val graph = definition.toDataflowGraph.resolve() @@ -159,7 +159,7 @@ class SparkDeclarativePipelinesServerSuite } val definition = - getSessionHolder.getDataflowGraphOrThrow(graphId) + getDefaultSessionHolder.getDataflowGraphOrThrow(graphId) registerPipelineDatasets(pipeline) val graph = definition.toDataflowGraph diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala index cbb01a4f755c..9e331b2c7720 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala @@ -29,14 +29,12 @@ import org.apache.spark.sql.pipelines.utils.PipelineTest class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest { override def afterEach(): Unit = { - getSessionHolder.removeAllPipelineExecutions() - getSessionHolder.dropAllDataflowGraphs() PipelineTest.cleanupMetastore(spark) super.afterEach() } // Helper method to get the session holder - protected def getSessionHolder: SessionHolder = { + protected def getDefaultSessionHolder: SessionHolder = { SparkConnectService.sessionManager .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) .getOrElse(throw new RuntimeException("Session not found")) From 2ee48016c4b97e72caaadda2ec6e1ed6003017ab Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Thu, 17 Jul 2025 20:01:57 -0700 Subject: [PATCH 3/7] fmt --- .../connect/pipelines/DataflowGraphRegistry.scala | 3 --- .../sql/connect/pipelines/PipelinesHandler.scala | 12 +++++++++--- .../spark/sql/connect/service/SessionHolder.scala | 9 ++++----- .../sql/connect/pipelines/PythonPipelineSuite.scala | 7 ++++--- 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala index a5edfec06531..e0c7beb43001 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala @@ -28,9 +28,6 @@ import org.apache.spark.sql.pipelines.graph.GraphRegistrationContext * PipelinesHandler when CreateDataflowGraph is called, and the PipelinesHandler also supports * attaching flows/datasets to a graph. */ -// TODO(SPARK-51727): Currently DataflowGraphRegistry is a singleton, but it should instead be -// scoped to a single SparkSession for proper isolation between pipelines that are run on the -// same cluster. class DataflowGraphRegistry { private val dataflowGraphs = new ConcurrentHashMap[String, GraphRegistrationContext]() diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index a86ad7b9e1a2..cfc5864befb0 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -132,7 +132,9 @@ private[connect] object PipelinesHandler extends Logging { val graphElementRegistry = sessionHolder.getDataflowGraphOrThrow(dataflowGraphId) val sqlGraphElementRegistrationContext = new SqlGraphRegistrationContext(graphElementRegistry) sqlGraphElementRegistrationContext.processSqlFile( - cmd.getSqlText, cmd.getSqlFilePath, sessionHolder.session) + cmd.getSqlText, + cmd.getSqlFilePath, + sessionHolder.session) } private def defineDataset( @@ -144,7 +146,9 @@ private[connect] object PipelinesHandler extends Logging { dataset.getDatasetType match { case proto.DatasetType.MATERIALIZED_VIEW | proto.DatasetType.TABLE => val tableIdentifier = - GraphIdentifierManager.parseTableIdentifier(dataset.getDatasetName, sessionHolder.session) + GraphIdentifierManager.parseTableIdentifier( + dataset.getDatasetName, + sessionHolder.session) graphElementRegistry.registerTable( Table( identifier = tableIdentifier, @@ -165,7 +169,9 @@ private[connect] object PipelinesHandler extends Logging { isStreamingTable = dataset.getDatasetType == proto.DatasetType.TABLE)) case proto.DatasetType.TEMPORARY_VIEW => val viewIdentifier = - GraphIdentifierManager.parseTableIdentifier(dataset.getDatasetName, sessionHolder.session) + GraphIdentifierManager.parseTableIdentifier( + dataset.getDatasetName, + sessionHolder.session) graphElementRegistry.registerView( TemporaryView( diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index b5070d1b7315..fae3b565edf7 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -494,9 +494,9 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio } private[connect] def createDataflowGraph( - defaultCatalog: String, - defaultDatabase: String, - defaultSqlConf: Map[String, String]): String = { + defaultCatalog: String, + defaultDatabase: String, + defaultSqlConf: Map[String, String]): String = { dataflowGraphRegistry.createDataflowGraph(defaultCatalog, defaultDatabase, defaultSqlConf) } @@ -529,8 +529,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio } /** - * Removes all dataflow graphs from this session. - * Called during session cleanup. + * Removes all dataflow graphs from this session. Called during session cleanup. */ private[connect] def dropAllDataflowGraphs(): Unit = { dataflowGraphRegistry.dropAllDataflowGraphs() diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala index 9a0ad74bd567..efc6e243e890 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala @@ -85,10 +85,11 @@ class PythonPipelineSuite // get the session holder for the active session val sessionHolder = - SparkConnectService.sessionManager.getIsolatedSessionIfPresent(activateSessions.head.key) + SparkConnectService.sessionManager + .getIsolatedSessionIfPresent(activateSessions.head.key) .getOrElse { - throw new RuntimeException("Session not found") - } + throw new RuntimeException("Session not found") + } // get all dataflow graphs from the session holder val dataflowGraphContexts = sessionHolder.getAllDataflowGraphs From 8a9e1d2a43a4fadb887b6433bf6c4a92cc0afe28 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Thu, 17 Jul 2025 20:11:00 -0700 Subject: [PATCH 4/7] nit --- .../pipelines/SparkDeclarativePipelinesServerTest.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala index 9e331b2c7720..b3f38fcecfd5 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala @@ -30,6 +30,12 @@ class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest { override def afterEach(): Unit = { PipelineTest.cleanupMetastore(spark) + SparkConnectService.sessionManager + .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) + .foreach(s => { + s.removeAllPipelineExecutions() + s.dropAllDataflowGraphs() + }) super.afterEach() } From b918491caa6211ec26a7bdee1e2b5dce6d3f1d58 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Thu, 17 Jul 2025 20:12:06 -0700 Subject: [PATCH 5/7] nit --- .../connect/pipelines/SparkDeclarativePipelinesServerTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala index b3f38fcecfd5..2f9b28fc5b32 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala @@ -29,13 +29,13 @@ import org.apache.spark.sql.pipelines.utils.PipelineTest class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest { override def afterEach(): Unit = { - PipelineTest.cleanupMetastore(spark) SparkConnectService.sessionManager .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) .foreach(s => { s.removeAllPipelineExecutions() s.dropAllDataflowGraphs() }) + PipelineTest.cleanupMetastore(spark) super.afterEach() } From 5ed466c20dc3a8afbf73f2b6ae6924a3e37d1bc7 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Thu, 17 Jul 2025 22:34:00 -0700 Subject: [PATCH 6/7] test --- .../pipelines/PythonPipelineSuite.scala | 20 +- ...SparkDeclarativePipelinesServerSuite.scala | 242 ++++++++++++++++++ 2 files changed, 252 insertions(+), 10 deletions(-) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala index efc6e243e890..809ba31033ee 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.pipelines import java.io.{BufferedReader, InputStreamReader} import java.nio.charset.StandardCharsets import java.nio.file.Paths +import java.util.UUID import java.util.concurrent.TimeUnit import scala.collection.mutable.ArrayBuffer @@ -43,6 +44,8 @@ class PythonPipelineSuite def buildGraph(pythonText: String): DataflowGraph = { val indentedPythonText = pythonText.linesIterator.map(" " + _).mkString("\n") + // create a unique identifier to allow identifying the session and dataflow graph + val identifier = UUID.randomUUID().toString val pythonCode = s""" |from pyspark.sql import SparkSession @@ -58,6 +61,7 @@ class PythonPipelineSuite |spark = SparkSession.builder \\ | .remote("sc://localhost:$serverPort") \\ | .config("spark.connect.grpc.channel.timeout", "5s") \\ + | .config("spark.custom.identifier", "$identifier") \\ | .create() | |dataflow_graph_id = create_dataflow_graph( @@ -80,16 +84,12 @@ class PythonPipelineSuite s"Python process failed with exit code $exitCode. Output: ${output.mkString("\n")}") } val activateSessions = SparkConnectService.sessionManager.listActiveSessions - // there should be only one active session - assert(activateSessions.size == 1) - - // get the session holder for the active session - val sessionHolder = - SparkConnectService.sessionManager - .getIsolatedSessionIfPresent(activateSessions.head.key) - .getOrElse { - throw new RuntimeException("Session not found") - } + + // get the session holder by finding the session with the custom UUID set in the conf + val sessionHolder = activateSessions + .map(info => SparkConnectService.sessionManager.getIsolatedSessionIfPresent(info.key).get) + .find(_.session.conf.get("spark.custom.identifier") == identifier) + .getOrElse(throw new RuntimeException(s"Session with app name $identifier not found")) // get all dataflow graphs from the session holder val dataflowGraphContexts = sessionHolder.getAllDataflowGraphs diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala index f1cc4270bbb9..676431849393 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.connect.pipelines +import java.util.UUID + import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{DatasetType, Expression, PipelineCommand, Relation, UnresolvedTableValuedFunction} import org.apache.spark.connect.proto.PipelineCommand.{DefineDataset, DefineFlow} import org.apache.spark.internal.Logging +import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService} class SparkDeclarativePipelinesServerSuite extends SparkDeclarativePipelinesServerTest @@ -248,4 +251,243 @@ class SparkDeclarativePipelinesServerSuite assert(spark.table("spark_catalog.other.tableD").count() == 5) } } + + test("dataflow graphs are session-specific") { + withRawBlockingStub { implicit stub => + // Create a dataflow graph in the default session + val graphId1 = createDataflowGraph + + // Register a dataset in the default session + sendPlan( + buildPlanFromPipelineCommand( + PipelineCommand + .newBuilder() + .setDefineDataset( + DefineDataset + .newBuilder() + .setDataflowGraphId(graphId1) + .setDatasetName("session1_table") + .setDatasetType(DatasetType.MATERIALIZED_VIEW)) + .build())) + + // Verify the graph exists in the default session + assert(getDefaultSessionHolder.getAllDataflowGraphs.size == 1) + } + + // Create a second session with different user/session ID + val newSessionId = UUID.randomUUID().toString + val newSessionUserId = "session2_user" + + withRawBlockingStub { implicit stub => + // Override the test context to use different session + val newSessionExecuteRequest = buildExecutePlanRequest( + buildCreateDataflowGraphPlan( + proto.PipelineCommand.CreateDataflowGraph + .newBuilder() + .setDefaultCatalog("spark_catalog") + .setDefaultDatabase("default") + .build())).toBuilder + .setUserContext(proto.UserContext + .newBuilder() + .setUserId(newSessionUserId) + .build()) + .setSessionId(newSessionId) + .build() + + val response = stub.executePlan(newSessionExecuteRequest) + val graphId2 = + response.next().getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId + + // Register a different dataset in second session + val session2DefineRequest = buildExecutePlanRequest( + buildPlanFromPipelineCommand( + PipelineCommand + .newBuilder() + .setDefineDataset( + DefineDataset + .newBuilder() + .setDataflowGraphId(graphId2) + .setDatasetName("session2_table") + .setDatasetType(DatasetType.MATERIALIZED_VIEW)) + .build())).toBuilder + .setUserContext(proto.UserContext + .newBuilder() + .setUserId(newSessionUserId) + .build()) + .setSessionId(newSessionId) + .build() + + stub.executePlan(session2DefineRequest).next() + + // Verify session isolation - each session should only see its own graphs + val newSessionHolder = SparkConnectService.sessionManager + .getIsolatedSessionIfPresent(SessionKey(newSessionUserId, newSessionId)) + .getOrElse(throw new RuntimeException("New session not found")) + + val defaultSessionGraphs = getDefaultSessionHolder.getAllDataflowGraphs + val newSessionGraphs = newSessionHolder.getAllDataflowGraphs + + assert(defaultSessionGraphs.size == 1) + assert(newSessionGraphs.size == 1) + + assert( + defaultSessionGraphs.head.toDataflowGraph.tables + .exists(_.identifier.table == "session1_table"), + "Session 1 should have its own table") + assert( + newSessionGraphs.head.toDataflowGraph.tables + .exists(_.identifier.table == "session2_table"), + "Session 2 should have its own table") + } + } + + test("dataflow graphs are cleaned up when session is closed") { + val testUserId = "test_user" + val testSessionId = UUID.randomUUID().toString + + // Create a session and dataflow graph + withRawBlockingStub { implicit stub => + val createGraphRequest = buildExecutePlanRequest( + buildCreateDataflowGraphPlan( + proto.PipelineCommand.CreateDataflowGraph + .newBuilder() + .setDefaultCatalog("spark_catalog") + .setDefaultDatabase("default") + .build())).toBuilder + .setUserContext(proto.UserContext + .newBuilder() + .setUserId(testUserId) + .build()) + .setSessionId(testSessionId) + .build() + + val response = stub.executePlan(createGraphRequest) + val graphId = + response.next().getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId + + // Register a dataset + val defineRequest = buildExecutePlanRequest( + buildPlanFromPipelineCommand( + PipelineCommand + .newBuilder() + .setDefineDataset( + DefineDataset + .newBuilder() + .setDataflowGraphId(graphId) + .setDatasetName("test_table") + .setDatasetType(DatasetType.MATERIALIZED_VIEW)) + .build())).toBuilder + .setUserContext(proto.UserContext + .newBuilder() + .setUserId(testUserId) + .build()) + .setSessionId(testSessionId) + .build() + + stub.executePlan(defineRequest).next() + + // Verify the graph exists + val sessionHolder = SparkConnectService.sessionManager + .getIsolatedSessionIfPresent(SessionKey(testUserId, testSessionId)) + .get + + val graphsBefore = sessionHolder.getAllDataflowGraphs + assert(graphsBefore.size == 1) + + // Close the session + SparkConnectService.sessionManager.closeSession(SessionKey(testUserId, testSessionId)) + + // Verify the session is no longer available + val sessionAfterClose = SparkConnectService.sessionManager + .getIsolatedSessionIfPresent(SessionKey(testUserId, testSessionId)) + + assert(sessionAfterClose.isEmpty, "Session should be cleaned up after close") + // Verify the graph is removed + val graphsAfter = sessionHolder.getAllDataflowGraphs + assert(graphsAfter.isEmpty, "Graph should be removed after session close") + } + } + + test("multiple dataflow graphs can exist in the same session") { + withRawBlockingStub { implicit stub => + // Create two dataflow graphs in the same session + val graphId1 = createDataflowGraph + val graphId2 = createDataflowGraph + + // Register datasets in both graphs + sendPlan( + buildPlanFromPipelineCommand( + PipelineCommand + .newBuilder() + .setDefineDataset( + DefineDataset + .newBuilder() + .setDataflowGraphId(graphId1) + .setDatasetName("graph1_table") + .setDatasetType(DatasetType.MATERIALIZED_VIEW)) + .build())) + + sendPlan( + buildPlanFromPipelineCommand( + PipelineCommand + .newBuilder() + .setDefineDataset( + DefineDataset + .newBuilder() + .setDataflowGraphId(graphId2) + .setDatasetName("graph2_table") + .setDatasetType(DatasetType.MATERIALIZED_VIEW)) + .build())) + + // Verify both graphs exist in the session + val sessionHolder = getDefaultSessionHolder + val graph1 = sessionHolder.getDataflowGraph(graphId1).getOrElse { + fail(s"Graph with ID $graphId1 not found in session") + } + val graph2 = sessionHolder.getDataflowGraph(graphId2).getOrElse { + fail(s"Graph with ID $graphId2 not found in session") + } + // Check that both graphs have their datasets registered + assert(graph1.toDataflowGraph.tables.exists(_.identifier.table == "graph1_table")) + assert(graph2.toDataflowGraph.tables.exists(_.identifier.table == "graph2_table")) + } + } + + test("dropping a dataflow graph removes it from session") { + withRawBlockingStub { implicit stub => + val graphId = createDataflowGraph + + // Register a dataset + sendPlan( + buildPlanFromPipelineCommand( + PipelineCommand + .newBuilder() + .setDefineDataset( + DefineDataset + .newBuilder() + .setDataflowGraphId(graphId) + .setDatasetName("test_table") + .setDatasetType(DatasetType.MATERIALIZED_VIEW)) + .build())) + + // Verify the graph exists + val sessionHolder = getDefaultSessionHolder + val graphsBefore = sessionHolder.getAllDataflowGraphs + assert(graphsBefore.size == 1) + + // Drop the graph + sendPlan( + buildPlanFromPipelineCommand( + PipelineCommand + .newBuilder() + .setDropDataflowGraph(PipelineCommand.DropDataflowGraph + .newBuilder() + .setDataflowGraphId(graphId)) + .build())) + + // Verify the graph is removed + val graphsAfter = sessionHolder.getAllDataflowGraphs + assert(graphsAfter.isEmpty, "Graph should be removed after drop") + } + } } From f8aba9e5c02cbde4fcd5c0ca3120315251989fe4 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Fri, 18 Jul 2025 13:18:06 -0700 Subject: [PATCH 7/7] address feedback --- .../connect/pipelines/PipelinesHandler.scala | 17 ++++--- .../sql/connect/service/SessionHolder.scala | 48 ++----------------- .../pipelines/PythonPipelineSuite.scala | 17 +++---- ...SparkDeclarativePipelinesServerSuite.scala | 32 ++++++------- .../SparkDeclarativePipelinesServerTest.scala | 5 +- 5 files changed, 39 insertions(+), 80 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index cfc5864befb0..9ecc22cda13f 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -77,7 +77,8 @@ private[connect] object PipelinesHandler extends Logging { .build() case proto.PipelineCommand.CommandTypeCase.DROP_DATAFLOW_GRAPH => logInfo(s"Drop pipeline cmd received: $cmd") - sessionHolder.dropDataflowGraph(cmd.getDropDataflowGraph.getDataflowGraphId) + sessionHolder.dataflowGraphRegistry + .dropDataflowGraph(cmd.getDropDataflowGraph.getDataflowGraphId) defaultResponse case proto.PipelineCommand.CommandTypeCase.DEFINE_DATASET => logInfo(s"Define pipelines dataset cmd received: $cmd") @@ -118,7 +119,7 @@ private[connect] object PipelinesHandler extends Logging { val defaultSqlConf = cmd.getSqlConfMap.asScala.toMap - sessionHolder.createDataflowGraph( + sessionHolder.dataflowGraphRegistry.createDataflowGraph( defaultCatalog = defaultCatalog, defaultDatabase = defaultDatabase, defaultSqlConf = defaultSqlConf) @@ -129,7 +130,8 @@ private[connect] object PipelinesHandler extends Logging { sessionHolder: SessionHolder): Unit = { val dataflowGraphId = cmd.getDataflowGraphId - val graphElementRegistry = sessionHolder.getDataflowGraphOrThrow(dataflowGraphId) + val graphElementRegistry = + sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) val sqlGraphElementRegistrationContext = new SqlGraphRegistrationContext(graphElementRegistry) sqlGraphElementRegistrationContext.processSqlFile( cmd.getSqlText, @@ -141,7 +143,8 @@ private[connect] object PipelinesHandler extends Logging { dataset: proto.PipelineCommand.DefineDataset, sessionHolder: SessionHolder): Unit = { val dataflowGraphId = dataset.getDataflowGraphId - val graphElementRegistry = sessionHolder.getDataflowGraphOrThrow(dataflowGraphId) + val graphElementRegistry = + sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) dataset.getDatasetType match { case proto.DatasetType.MATERIALIZED_VIEW | proto.DatasetType.TABLE => @@ -192,7 +195,8 @@ private[connect] object PipelinesHandler extends Logging { transformRelationFunc: Relation => LogicalPlan, sessionHolder: SessionHolder): Unit = { val dataflowGraphId = flow.getDataflowGraphId - val graphElementRegistry = sessionHolder.getDataflowGraphOrThrow(dataflowGraphId) + val graphElementRegistry = + sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) val isImplicitFlow = flow.getFlowName == flow.getTargetDatasetName @@ -230,7 +234,8 @@ private[connect] object PipelinesHandler extends Logging { responseObserver: StreamObserver[ExecutePlanResponse], sessionHolder: SessionHolder): Unit = { val dataflowGraphId = cmd.getDataflowGraphId - val graphElementRegistry = sessionHolder.getDataflowGraphOrThrow(dataflowGraphId) + val graphElementRegistry = + sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) val tableFiltersResult = createTableFilters(cmd, graphElementRegistry, sessionHolder) // We will use this variable to store the run failure event if it occurs. This will be set diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index fae3b565edf7..1b43ea529ec0 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.connect.pipelines.DataflowGraphRegistry import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC} -import org.apache.spark.sql.pipelines.graph.{GraphRegistrationContext, PipelineUpdateContext} +import org.apache.spark.sql.pipelines.graph.PipelineUpdateContext import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.util.{SystemClock, Utils} @@ -127,7 +127,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio new ConcurrentHashMap[String, PipelineUpdateContext]() // Registry for dataflow graphs specific to this session - private lazy val dataflowGraphRegistry: DataflowGraphRegistry = new DataflowGraphRegistry() + private[connect] lazy val dataflowGraphRegistry = new DataflowGraphRegistry() // Handles Python process clean up for streaming queries. Initialized on first use in a query. private[connect] lazy val streamingForeachBatchRunnerCleanerCache = @@ -325,7 +325,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio removeAllPipelineExecutions() // Clean up dataflow graphs - dropAllDataflowGraphs() + dataflowGraphRegistry.dropAllDataflowGraphs() // if there is a server side listener, clean up related resources if (streamingServersideListenerHolder.isServerSideListenerRegistered) { @@ -493,48 +493,6 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio Option(pipelineExecutions.get(graphId)) } - private[connect] def createDataflowGraph( - defaultCatalog: String, - defaultDatabase: String, - defaultSqlConf: Map[String, String]): String = { - dataflowGraphRegistry.createDataflowGraph(defaultCatalog, defaultDatabase, defaultSqlConf) - } - - /** - * Retrieves the dataflow graph for the given graph ID. - */ - private[connect] def getDataflowGraph(graphId: String): Option[GraphRegistrationContext] = { - dataflowGraphRegistry.getDataflowGraph(graphId) - } - - /** - * Retrieves the dataflow graph for the given graph ID, throwing if not found. - */ - private[connect] def getDataflowGraphOrThrow(graphId: String): GraphRegistrationContext = { - dataflowGraphRegistry.getDataflowGraphOrThrow(graphId) - } - - /** - * Removes the dataflow graph with the given ID. - */ - private[connect] def dropDataflowGraph(graphId: String): Unit = { - dataflowGraphRegistry.dropDataflowGraph(graphId) - } - - /** - * Returns all dataflow graphs in this session. - */ - private[connect] def getAllDataflowGraphs: Seq[GraphRegistrationContext] = { - dataflowGraphRegistry.getAllDataflowGraphs - } - - /** - * Removes all dataflow graphs from this session. Called during session cleanup. - */ - private[connect] def dropAllDataflowGraphs(): Unit = { - dataflowGraphRegistry.dropAllDataflowGraphs() - } - /** * An accumulator for Python executors. * diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala index 809ba31033ee..1bc2172d86e5 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala @@ -45,7 +45,7 @@ class PythonPipelineSuite def buildGraph(pythonText: String): DataflowGraph = { val indentedPythonText = pythonText.linesIterator.map(" " + _).mkString("\n") // create a unique identifier to allow identifying the session and dataflow graph - val identifier = UUID.randomUUID().toString + val customSessionIdentifier = UUID.randomUUID().toString val pythonCode = s""" |from pyspark.sql import SparkSession @@ -61,7 +61,7 @@ class PythonPipelineSuite |spark = SparkSession.builder \\ | .remote("sc://localhost:$serverPort") \\ | .config("spark.connect.grpc.channel.timeout", "5s") \\ - | .config("spark.custom.identifier", "$identifier") \\ + | .config("spark.custom.identifier", "$customSessionIdentifier") \\ | .create() | |dataflow_graph_id = create_dataflow_graph( @@ -83,16 +83,17 @@ class PythonPipelineSuite throw new RuntimeException( s"Python process failed with exit code $exitCode. Output: ${output.mkString("\n")}") } - val activateSessions = SparkConnectService.sessionManager.listActiveSessions + val activeSessions = SparkConnectService.sessionManager.listActiveSessions // get the session holder by finding the session with the custom UUID set in the conf - val sessionHolder = activateSessions - .map(info => SparkConnectService.sessionManager.getIsolatedSessionIfPresent(info.key).get) - .find(_.session.conf.get("spark.custom.identifier") == identifier) - .getOrElse(throw new RuntimeException(s"Session with app name $identifier not found")) + val sessionHolder = activeSessions + .map(info => SparkConnectService.sessionManager.getIsolatedSession(info.key, None)) + .find(_.session.conf.get("spark.custom.identifier") == customSessionIdentifier) + .getOrElse( + throw new RuntimeException(s"Session with identifier $customSessionIdentifier not found")) // get all dataflow graphs from the session holder - val dataflowGraphContexts = sessionHolder.getAllDataflowGraphs + val dataflowGraphContexts = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs assert(dataflowGraphContexts.size == 1) dataflowGraphContexts.head.toDataflowGraph diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala index 676431849393..ef5da0c014ee 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala @@ -44,7 +44,7 @@ class SparkDeclarativePipelinesServerSuite .newBuilder() .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId val definition = - getDefaultSessionHolder.getDataflowGraphOrThrow(graphId) + getDefaultSessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(graphId) assert(definition.defaultDatabase == "test_db") } } @@ -117,7 +117,7 @@ class SparkDeclarativePipelinesServerSuite |""".stripMargin) val definition = - getDefaultSessionHolder.getDataflowGraphOrThrow(graphId) + getDefaultSessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(graphId) val graph = definition.toDataflowGraph.resolve() @@ -162,7 +162,7 @@ class SparkDeclarativePipelinesServerSuite } val definition = - getDefaultSessionHolder.getDataflowGraphOrThrow(graphId) + getDefaultSessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(graphId) registerPipelineDatasets(pipeline) val graph = definition.toDataflowGraph @@ -271,7 +271,7 @@ class SparkDeclarativePipelinesServerSuite .build())) // Verify the graph exists in the default session - assert(getDefaultSessionHolder.getAllDataflowGraphs.size == 1) + assert(getDefaultSessionHolder.dataflowGraphRegistry.getAllDataflowGraphs.size == 1) } // Create a second session with different user/session ID @@ -321,11 +321,11 @@ class SparkDeclarativePipelinesServerSuite // Verify session isolation - each session should only see its own graphs val newSessionHolder = SparkConnectService.sessionManager - .getIsolatedSessionIfPresent(SessionKey(newSessionUserId, newSessionId)) - .getOrElse(throw new RuntimeException("New session not found")) + .getIsolatedSession(SessionKey(newSessionUserId, newSessionId), None) - val defaultSessionGraphs = getDefaultSessionHolder.getAllDataflowGraphs - val newSessionGraphs = newSessionHolder.getAllDataflowGraphs + val defaultSessionGraphs = + getDefaultSessionHolder.dataflowGraphRegistry.getAllDataflowGraphs + val newSessionGraphs = newSessionHolder.dataflowGraphRegistry.getAllDataflowGraphs assert(defaultSessionGraphs.size == 1) assert(newSessionGraphs.size == 1) @@ -391,7 +391,7 @@ class SparkDeclarativePipelinesServerSuite .getIsolatedSessionIfPresent(SessionKey(testUserId, testSessionId)) .get - val graphsBefore = sessionHolder.getAllDataflowGraphs + val graphsBefore = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs assert(graphsBefore.size == 1) // Close the session @@ -403,7 +403,7 @@ class SparkDeclarativePipelinesServerSuite assert(sessionAfterClose.isEmpty, "Session should be cleaned up after close") // Verify the graph is removed - val graphsAfter = sessionHolder.getAllDataflowGraphs + val graphsAfter = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs assert(graphsAfter.isEmpty, "Graph should be removed after session close") } } @@ -441,12 +441,8 @@ class SparkDeclarativePipelinesServerSuite // Verify both graphs exist in the session val sessionHolder = getDefaultSessionHolder - val graph1 = sessionHolder.getDataflowGraph(graphId1).getOrElse { - fail(s"Graph with ID $graphId1 not found in session") - } - val graph2 = sessionHolder.getDataflowGraph(graphId2).getOrElse { - fail(s"Graph with ID $graphId2 not found in session") - } + val graph1 = sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(graphId1) + val graph2 = sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(graphId2) // Check that both graphs have their datasets registered assert(graph1.toDataflowGraph.tables.exists(_.identifier.table == "graph1_table")) assert(graph2.toDataflowGraph.tables.exists(_.identifier.table == "graph2_table")) @@ -472,7 +468,7 @@ class SparkDeclarativePipelinesServerSuite // Verify the graph exists val sessionHolder = getDefaultSessionHolder - val graphsBefore = sessionHolder.getAllDataflowGraphs + val graphsBefore = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs assert(graphsBefore.size == 1) // Drop the graph @@ -486,7 +482,7 @@ class SparkDeclarativePipelinesServerSuite .build())) // Verify the graph is removed - val graphsAfter = sessionHolder.getAllDataflowGraphs + val graphsAfter = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs assert(graphsAfter.isEmpty, "Graph should be removed after drop") } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala index 2f9b28fc5b32..a31883677f92 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala @@ -33,7 +33,7 @@ class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest { .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) .foreach(s => { s.removeAllPipelineExecutions() - s.dropAllDataflowGraphs() + s.dataflowGraphRegistry.dropAllDataflowGraphs() }) PipelineTest.cleanupMetastore(spark) super.afterEach() @@ -42,8 +42,7 @@ class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest { // Helper method to get the session holder protected def getDefaultSessionHolder: SessionHolder = { SparkConnectService.sessionManager - .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) - .getOrElse(throw new RuntimeException("Session not found")) + .getIsolatedSession(SessionKey(defaultUserId, defaultSessionId), None) } def buildPlanFromPipelineCommand(command: sc.PipelineCommand): sc.Plan = {