@@ -28,7 +28,6 @@ import org.apache.spark.internal.Logging
2828import org .apache .spark .sql .AnalysisException
2929import org .apache .spark .sql .catalyst .TableIdentifier
3030import org .apache .spark .sql .catalyst .plans .logical .LogicalPlan
31- import org .apache .spark .sql .classic .SparkSession
3231import org .apache .spark .sql .connect .common .DataTypeProtoConverter
3332import org .apache .spark .sql .connect .service .SessionHolder
3433import org .apache .spark .sql .pipelines .Language .Python
@@ -68,7 +67,7 @@ private[connect] object PipelinesHandler extends Logging {
6867 cmd.getCommandTypeCase match {
6968 case proto.PipelineCommand .CommandTypeCase .CREATE_DATAFLOW_GRAPH =>
7069 val createdGraphId =
71- createDataflowGraph(cmd.getCreateDataflowGraph, sessionHolder.session )
70+ createDataflowGraph(cmd.getCreateDataflowGraph, sessionHolder)
7271 PipelineCommandResult
7372 .newBuilder()
7473 .setCreateDataflowGraphResult(
@@ -78,73 +77,81 @@ private[connect] object PipelinesHandler extends Logging {
7877 .build()
7978 case proto.PipelineCommand .CommandTypeCase .DROP_DATAFLOW_GRAPH =>
8079 logInfo(s " Drop pipeline cmd received: $cmd" )
81- DataflowGraphRegistry .dropDataflowGraph(cmd.getDropDataflowGraph.getDataflowGraphId)
80+ sessionHolder.dataflowGraphRegistry
81+ .dropDataflowGraph(cmd.getDropDataflowGraph.getDataflowGraphId)
8282 defaultResponse
8383 case proto.PipelineCommand .CommandTypeCase .DEFINE_DATASET =>
8484 logInfo(s " Define pipelines dataset cmd received: $cmd" )
85- defineDataset(cmd.getDefineDataset, sessionHolder.session )
85+ defineDataset(cmd.getDefineDataset, sessionHolder)
8686 defaultResponse
8787 case proto.PipelineCommand .CommandTypeCase .DEFINE_FLOW =>
8888 logInfo(s " Define pipelines flow cmd received: $cmd" )
89- defineFlow(cmd.getDefineFlow, transformRelationFunc, sessionHolder.session )
89+ defineFlow(cmd.getDefineFlow, transformRelationFunc, sessionHolder)
9090 defaultResponse
9191 case proto.PipelineCommand .CommandTypeCase .START_RUN =>
9292 logInfo(s " Start pipeline cmd received: $cmd" )
9393 startRun(cmd.getStartRun, responseObserver, sessionHolder)
9494 defaultResponse
9595 case proto.PipelineCommand .CommandTypeCase .DEFINE_SQL_GRAPH_ELEMENTS =>
9696 logInfo(s " Register sql datasets cmd received: $cmd" )
97- defineSqlGraphElements(cmd.getDefineSqlGraphElements, sessionHolder.session )
97+ defineSqlGraphElements(cmd.getDefineSqlGraphElements, sessionHolder)
9898 defaultResponse
9999 case other => throw new UnsupportedOperationException (s " $other not supported " )
100100 }
101101 }
102102
103103 private def createDataflowGraph (
104104 cmd : proto.PipelineCommand .CreateDataflowGraph ,
105- spark : SparkSession ): String = {
105+ sessionHolder : SessionHolder ): String = {
106106 val defaultCatalog = Option
107107 .when(cmd.hasDefaultCatalog)(cmd.getDefaultCatalog)
108108 .getOrElse {
109109 logInfo(s " No default catalog was supplied. Falling back to the current catalog. " )
110- spark .catalog.currentCatalog()
110+ sessionHolder.session .catalog.currentCatalog()
111111 }
112112
113113 val defaultDatabase = Option
114114 .when(cmd.hasDefaultDatabase)(cmd.getDefaultDatabase)
115115 .getOrElse {
116116 logInfo(s " No default database was supplied. Falling back to the current database. " )
117- spark .catalog.currentDatabase
117+ sessionHolder.session .catalog.currentDatabase
118118 }
119119
120120 val defaultSqlConf = cmd.getSqlConfMap.asScala.toMap
121121
122- DataflowGraphRegistry .createDataflowGraph(
122+ sessionHolder.dataflowGraphRegistry .createDataflowGraph(
123123 defaultCatalog = defaultCatalog,
124124 defaultDatabase = defaultDatabase,
125125 defaultSqlConf = defaultSqlConf)
126126 }
127127
128128 private def defineSqlGraphElements (
129129 cmd : proto.PipelineCommand .DefineSqlGraphElements ,
130- session : SparkSession ): Unit = {
130+ sessionHolder : SessionHolder ): Unit = {
131131 val dataflowGraphId = cmd.getDataflowGraphId
132132
133- val graphElementRegistry = DataflowGraphRegistry .getDataflowGraphOrThrow(dataflowGraphId)
133+ val graphElementRegistry =
134+ sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
134135 val sqlGraphElementRegistrationContext = new SqlGraphRegistrationContext (graphElementRegistry)
135- sqlGraphElementRegistrationContext.processSqlFile(cmd.getSqlText, cmd.getSqlFilePath, session)
136+ sqlGraphElementRegistrationContext.processSqlFile(
137+ cmd.getSqlText,
138+ cmd.getSqlFilePath,
139+ sessionHolder.session)
136140 }
137141
138142 private def defineDataset (
139143 dataset : proto.PipelineCommand .DefineDataset ,
140- sparkSession : SparkSession ): Unit = {
144+ sessionHolder : SessionHolder ): Unit = {
141145 val dataflowGraphId = dataset.getDataflowGraphId
142- val graphElementRegistry = DataflowGraphRegistry .getDataflowGraphOrThrow(dataflowGraphId)
146+ val graphElementRegistry =
147+ sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
143148
144149 dataset.getDatasetType match {
145150 case proto.DatasetType .MATERIALIZED_VIEW | proto.DatasetType .TABLE =>
146151 val tableIdentifier =
147- GraphIdentifierManager .parseTableIdentifier(dataset.getDatasetName, sparkSession)
152+ GraphIdentifierManager .parseTableIdentifier(
153+ dataset.getDatasetName,
154+ sessionHolder.session)
148155 graphElementRegistry.registerTable(
149156 Table (
150157 identifier = tableIdentifier,
@@ -165,7 +172,9 @@ private[connect] object PipelinesHandler extends Logging {
165172 isStreamingTable = dataset.getDatasetType == proto.DatasetType .TABLE ))
166173 case proto.DatasetType .TEMPORARY_VIEW =>
167174 val viewIdentifier =
168- GraphIdentifierManager .parseTableIdentifier(dataset.getDatasetName, sparkSession)
175+ GraphIdentifierManager .parseTableIdentifier(
176+ dataset.getDatasetName,
177+ sessionHolder.session)
169178
170179 graphElementRegistry.registerView(
171180 TemporaryView (
@@ -184,14 +193,15 @@ private[connect] object PipelinesHandler extends Logging {
184193 private def defineFlow (
185194 flow : proto.PipelineCommand .DefineFlow ,
186195 transformRelationFunc : Relation => LogicalPlan ,
187- sparkSession : SparkSession ): Unit = {
196+ sessionHolder : SessionHolder ): Unit = {
188197 val dataflowGraphId = flow.getDataflowGraphId
189- val graphElementRegistry = DataflowGraphRegistry .getDataflowGraphOrThrow(dataflowGraphId)
198+ val graphElementRegistry =
199+ sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
190200
191201 val isImplicitFlow = flow.getFlowName == flow.getTargetDatasetName
192202
193203 val flowIdentifier = GraphIdentifierManager
194- .parseTableIdentifier(name = flow.getFlowName, spark = sparkSession )
204+ .parseTableIdentifier(name = flow.getFlowName, spark = sessionHolder.session )
195205
196206 // If the flow is not an implicit flow (i.e. one defined as part of dataset creation), then
197207 // it must be a single-part identifier.
@@ -205,7 +215,7 @@ private[connect] object PipelinesHandler extends Logging {
205215 new UnresolvedFlow (
206216 identifier = flowIdentifier,
207217 destinationIdentifier = GraphIdentifierManager
208- .parseTableIdentifier(name = flow.getTargetDatasetName, spark = sparkSession ),
218+ .parseTableIdentifier(name = flow.getTargetDatasetName, spark = sessionHolder.session ),
209219 func =
210220 FlowAnalysis .createFlowFunctionFromLogicalPlan(transformRelationFunc(flow.getRelation)),
211221 sqlConf = flow.getSqlConfMap.asScala.toMap,
@@ -224,7 +234,8 @@ private[connect] object PipelinesHandler extends Logging {
224234 responseObserver : StreamObserver [ExecutePlanResponse ],
225235 sessionHolder : SessionHolder ): Unit = {
226236 val dataflowGraphId = cmd.getDataflowGraphId
227- val graphElementRegistry = DataflowGraphRegistry .getDataflowGraphOrThrow(dataflowGraphId)
237+ val graphElementRegistry =
238+ sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
228239 val tableFiltersResult = createTableFilters(cmd, graphElementRegistry, sessionHolder)
229240
230241 // We will use this variable to store the run failure event if it occurs. This will be set
0 commit comments