Skip to content

Commit 0177265

Browse files
jackywang-dbsryza
authored andcommitted
[SPARK-52432][SDP][SQL] Scope DataflowGraphRegistry to Session
### What changes were proposed in this pull request? Scope `DataflowGraphRegistry` to spark connect session. This is done by adding it as a member to the spark connect [SessionHolder](https://github.com/apache/spark/blob/master/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala#L54). This is added here because pipeline executions are also [scoped](https://github.com/apache/spark/blob/master/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala#L125) to this class. Added getter/setter methods to access dataflow graphs for the session. Added logic to drop all dataflow graphs when session is closed. ### Why are the changes needed? Currently `DataflowGraphRegistry` is a singleton, but it should instead be scoped to a single SparkSession for proper isolation between pipelines that are run on the same cluster. This allows proper cleanup of pipeline resources when session is closed. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added new testcases to test data flow graph session isolation and proper clean up. ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#51544 from JiaqiWang18/SPARK-52432-session-graphRegistry. Authored-by: Jacky Wang <[email protected]> Signed-off-by: Sandy Ryza <[email protected]>
1 parent 689e458 commit 0177265

File tree

6 files changed

+309
-37
lines changed

6 files changed

+309
-37
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,7 @@ import org.apache.spark.sql.pipelines.graph.GraphRegistrationContext
2828
* PipelinesHandler when CreateDataflowGraph is called, and the PipelinesHandler also supports
2929
* attaching flows/datasets to a graph.
3030
*/
31-
// TODO(SPARK-51727): Currently DataflowGraphRegistry is a singleton, but it should instead be
32-
// scoped to a single SparkSession for proper isolation between pipelines that are run on the
33-
// same cluster.
34-
object DataflowGraphRegistry {
31+
class DataflowGraphRegistry {
3532

3633
private val dataflowGraphs = new ConcurrentHashMap[String, GraphRegistrationContext]()
3734

@@ -55,7 +52,7 @@ object DataflowGraphRegistry {
5552

5653
/** Retrieves the graph for a given id, and throws if the id could not be found. */
5754
def getDataflowGraphOrThrow(dataflowGraphId: String): GraphRegistrationContext =
58-
DataflowGraphRegistry.getDataflowGraph(dataflowGraphId).getOrElse {
55+
getDataflowGraph(dataflowGraphId).getOrElse {
5956
throw new SparkException(
6057
errorClass = "DATAFLOW_GRAPH_NOT_FOUND",
6158
messageParameters = Map("graphId" -> dataflowGraphId),

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ import org.apache.spark.internal.Logging
2828
import org.apache.spark.sql.AnalysisException
2929
import org.apache.spark.sql.catalyst.TableIdentifier
3030
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
31-
import org.apache.spark.sql.classic.SparkSession
3231
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
3332
import org.apache.spark.sql.connect.service.SessionHolder
3433
import org.apache.spark.sql.pipelines.Language.Python
@@ -68,7 +67,7 @@ private[connect] object PipelinesHandler extends Logging {
6867
cmd.getCommandTypeCase match {
6968
case proto.PipelineCommand.CommandTypeCase.CREATE_DATAFLOW_GRAPH =>
7069
val createdGraphId =
71-
createDataflowGraph(cmd.getCreateDataflowGraph, sessionHolder.session)
70+
createDataflowGraph(cmd.getCreateDataflowGraph, sessionHolder)
7271
PipelineCommandResult
7372
.newBuilder()
7473
.setCreateDataflowGraphResult(
@@ -78,73 +77,81 @@ private[connect] object PipelinesHandler extends Logging {
7877
.build()
7978
case proto.PipelineCommand.CommandTypeCase.DROP_DATAFLOW_GRAPH =>
8079
logInfo(s"Drop pipeline cmd received: $cmd")
81-
DataflowGraphRegistry.dropDataflowGraph(cmd.getDropDataflowGraph.getDataflowGraphId)
80+
sessionHolder.dataflowGraphRegistry
81+
.dropDataflowGraph(cmd.getDropDataflowGraph.getDataflowGraphId)
8282
defaultResponse
8383
case proto.PipelineCommand.CommandTypeCase.DEFINE_DATASET =>
8484
logInfo(s"Define pipelines dataset cmd received: $cmd")
85-
defineDataset(cmd.getDefineDataset, sessionHolder.session)
85+
defineDataset(cmd.getDefineDataset, sessionHolder)
8686
defaultResponse
8787
case proto.PipelineCommand.CommandTypeCase.DEFINE_FLOW =>
8888
logInfo(s"Define pipelines flow cmd received: $cmd")
89-
defineFlow(cmd.getDefineFlow, transformRelationFunc, sessionHolder.session)
89+
defineFlow(cmd.getDefineFlow, transformRelationFunc, sessionHolder)
9090
defaultResponse
9191
case proto.PipelineCommand.CommandTypeCase.START_RUN =>
9292
logInfo(s"Start pipeline cmd received: $cmd")
9393
startRun(cmd.getStartRun, responseObserver, sessionHolder)
9494
defaultResponse
9595
case proto.PipelineCommand.CommandTypeCase.DEFINE_SQL_GRAPH_ELEMENTS =>
9696
logInfo(s"Register sql datasets cmd received: $cmd")
97-
defineSqlGraphElements(cmd.getDefineSqlGraphElements, sessionHolder.session)
97+
defineSqlGraphElements(cmd.getDefineSqlGraphElements, sessionHolder)
9898
defaultResponse
9999
case other => throw new UnsupportedOperationException(s"$other not supported")
100100
}
101101
}
102102

103103
private def createDataflowGraph(
104104
cmd: proto.PipelineCommand.CreateDataflowGraph,
105-
spark: SparkSession): String = {
105+
sessionHolder: SessionHolder): String = {
106106
val defaultCatalog = Option
107107
.when(cmd.hasDefaultCatalog)(cmd.getDefaultCatalog)
108108
.getOrElse {
109109
logInfo(s"No default catalog was supplied. Falling back to the current catalog.")
110-
spark.catalog.currentCatalog()
110+
sessionHolder.session.catalog.currentCatalog()
111111
}
112112

113113
val defaultDatabase = Option
114114
.when(cmd.hasDefaultDatabase)(cmd.getDefaultDatabase)
115115
.getOrElse {
116116
logInfo(s"No default database was supplied. Falling back to the current database.")
117-
spark.catalog.currentDatabase
117+
sessionHolder.session.catalog.currentDatabase
118118
}
119119

120120
val defaultSqlConf = cmd.getSqlConfMap.asScala.toMap
121121

122-
DataflowGraphRegistry.createDataflowGraph(
122+
sessionHolder.dataflowGraphRegistry.createDataflowGraph(
123123
defaultCatalog = defaultCatalog,
124124
defaultDatabase = defaultDatabase,
125125
defaultSqlConf = defaultSqlConf)
126126
}
127127

128128
private def defineSqlGraphElements(
129129
cmd: proto.PipelineCommand.DefineSqlGraphElements,
130-
session: SparkSession): Unit = {
130+
sessionHolder: SessionHolder): Unit = {
131131
val dataflowGraphId = cmd.getDataflowGraphId
132132

133-
val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
133+
val graphElementRegistry =
134+
sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
134135
val sqlGraphElementRegistrationContext = new SqlGraphRegistrationContext(graphElementRegistry)
135-
sqlGraphElementRegistrationContext.processSqlFile(cmd.getSqlText, cmd.getSqlFilePath, session)
136+
sqlGraphElementRegistrationContext.processSqlFile(
137+
cmd.getSqlText,
138+
cmd.getSqlFilePath,
139+
sessionHolder.session)
136140
}
137141

138142
private def defineDataset(
139143
dataset: proto.PipelineCommand.DefineDataset,
140-
sparkSession: SparkSession): Unit = {
144+
sessionHolder: SessionHolder): Unit = {
141145
val dataflowGraphId = dataset.getDataflowGraphId
142-
val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
146+
val graphElementRegistry =
147+
sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
143148

144149
dataset.getDatasetType match {
145150
case proto.DatasetType.MATERIALIZED_VIEW | proto.DatasetType.TABLE =>
146151
val tableIdentifier =
147-
GraphIdentifierManager.parseTableIdentifier(dataset.getDatasetName, sparkSession)
152+
GraphIdentifierManager.parseTableIdentifier(
153+
dataset.getDatasetName,
154+
sessionHolder.session)
148155
graphElementRegistry.registerTable(
149156
Table(
150157
identifier = tableIdentifier,
@@ -165,7 +172,9 @@ private[connect] object PipelinesHandler extends Logging {
165172
isStreamingTable = dataset.getDatasetType == proto.DatasetType.TABLE))
166173
case proto.DatasetType.TEMPORARY_VIEW =>
167174
val viewIdentifier =
168-
GraphIdentifierManager.parseTableIdentifier(dataset.getDatasetName, sparkSession)
175+
GraphIdentifierManager.parseTableIdentifier(
176+
dataset.getDatasetName,
177+
sessionHolder.session)
169178

170179
graphElementRegistry.registerView(
171180
TemporaryView(
@@ -184,14 +193,15 @@ private[connect] object PipelinesHandler extends Logging {
184193
private def defineFlow(
185194
flow: proto.PipelineCommand.DefineFlow,
186195
transformRelationFunc: Relation => LogicalPlan,
187-
sparkSession: SparkSession): Unit = {
196+
sessionHolder: SessionHolder): Unit = {
188197
val dataflowGraphId = flow.getDataflowGraphId
189-
val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
198+
val graphElementRegistry =
199+
sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
190200

191201
val isImplicitFlow = flow.getFlowName == flow.getTargetDatasetName
192202

193203
val flowIdentifier = GraphIdentifierManager
194-
.parseTableIdentifier(name = flow.getFlowName, spark = sparkSession)
204+
.parseTableIdentifier(name = flow.getFlowName, spark = sessionHolder.session)
195205

196206
// If the flow is not an implicit flow (i.e. one defined as part of dataset creation), then
197207
// it must be a single-part identifier.
@@ -205,7 +215,7 @@ private[connect] object PipelinesHandler extends Logging {
205215
new UnresolvedFlow(
206216
identifier = flowIdentifier,
207217
destinationIdentifier = GraphIdentifierManager
208-
.parseTableIdentifier(name = flow.getTargetDatasetName, spark = sparkSession),
218+
.parseTableIdentifier(name = flow.getTargetDatasetName, spark = sessionHolder.session),
209219
func =
210220
FlowAnalysis.createFlowFunctionFromLogicalPlan(transformRelationFunc(flow.getRelation)),
211221
sqlConf = flow.getSqlConfMap.asScala.toMap,
@@ -224,7 +234,8 @@ private[connect] object PipelinesHandler extends Logging {
224234
responseObserver: StreamObserver[ExecutePlanResponse],
225235
sessionHolder: SessionHolder): Unit = {
226236
val dataflowGraphId = cmd.getDataflowGraphId
227-
val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
237+
val graphElementRegistry =
238+
sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
228239
val tableFiltersResult = createTableFilters(cmd, graphElementRegistry, sessionHolder)
229240

230241
// We will use this variable to store the run failure event if it occurs. This will be set

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import org.apache.spark.sql.classic.SparkSession
3838
import org.apache.spark.sql.connect.common.InvalidPlanInput
3939
import org.apache.spark.sql.connect.config.Connect
4040
import org.apache.spark.sql.connect.ml.MLCache
41+
import org.apache.spark.sql.connect.pipelines.DataflowGraphRegistry
4142
import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener
4243
import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper
4344
import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC}
@@ -125,6 +126,9 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
125126
private lazy val pipelineExecutions =
126127
new ConcurrentHashMap[String, PipelineUpdateContext]()
127128

129+
// Registry for dataflow graphs specific to this session
130+
private[connect] lazy val dataflowGraphRegistry = new DataflowGraphRegistry()
131+
128132
// Handles Python process clean up for streaming queries. Initialized on first use in a query.
129133
private[connect] lazy val streamingForeachBatchRunnerCleanerCache =
130134
new StreamingForeachBatchHelper.CleanerCache(this)
@@ -320,6 +324,9 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
320324
// Stops all pipeline execution and clears the pipeline execution cache
321325
removeAllPipelineExecutions()
322326

327+
// Clean up dataflow graphs
328+
dataflowGraphRegistry.dropAllDataflowGraphs()
329+
323330
// if there is a server side listener, clean up related resources
324331
if (streamingServersideListenerHolder.isServerSideListenerRegistered) {
325332
streamingServersideListenerHolder.cleanUp()

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.pipelines
2020
import java.io.{BufferedReader, InputStreamReader}
2121
import java.nio.charset.StandardCharsets
2222
import java.nio.file.Paths
23+
import java.util.UUID
2324
import java.util.concurrent.TimeUnit
2425

2526
import scala.collection.mutable.ArrayBuffer
@@ -28,6 +29,7 @@ import scala.util.Try
2829
import org.apache.spark.api.python.PythonUtils
2930
import org.apache.spark.sql.AnalysisException
3031
import org.apache.spark.sql.catalyst.TableIdentifier
32+
import org.apache.spark.sql.connect.service.SparkConnectService
3133
import org.apache.spark.sql.pipelines.graph.DataflowGraph
3234
import org.apache.spark.sql.pipelines.utils.{EventVerificationTestHelpers, TestPipelineUpdateContextMixin}
3335

@@ -42,6 +44,8 @@ class PythonPipelineSuite
4244

4345
def buildGraph(pythonText: String): DataflowGraph = {
4446
val indentedPythonText = pythonText.linesIterator.map(" " + _).mkString("\n")
47+
// create a unique identifier to allow identifying the session and dataflow graph
48+
val customSessionIdentifier = UUID.randomUUID().toString
4549
val pythonCode =
4650
s"""
4751
|from pyspark.sql import SparkSession
@@ -57,6 +61,7 @@ class PythonPipelineSuite
5761
|spark = SparkSession.builder \\
5862
| .remote("sc://localhost:$serverPort") \\
5963
| .config("spark.connect.grpc.channel.timeout", "5s") \\
64+
| .config("spark.custom.identifier", "$customSessionIdentifier") \\
6065
| .create()
6166
|
6267
|dataflow_graph_id = create_dataflow_graph(
@@ -78,8 +83,17 @@ class PythonPipelineSuite
7883
throw new RuntimeException(
7984
s"Python process failed with exit code $exitCode. Output: ${output.mkString("\n")}")
8085
}
86+
val activeSessions = SparkConnectService.sessionManager.listActiveSessions
8187

82-
val dataflowGraphContexts = DataflowGraphRegistry.getAllDataflowGraphs
88+
// get the session holder by finding the session with the custom UUID set in the conf
89+
val sessionHolder = activeSessions
90+
.map(info => SparkConnectService.sessionManager.getIsolatedSession(info.key, None))
91+
.find(_.session.conf.get("spark.custom.identifier") == customSessionIdentifier)
92+
.getOrElse(
93+
throw new RuntimeException(s"Session with identifier $customSessionIdentifier not found"))
94+
95+
// get all dataflow graphs from the session holder
96+
val dataflowGraphContexts = sessionHolder.dataflowGraphRegistry.getAllDataflowGraphs
8397
assert(dataflowGraphContexts.size == 1)
8498

8599
dataflowGraphContexts.head.toDataflowGraph

0 commit comments

Comments
 (0)