From 76334c7ddb21fd6b032b6b16ffc1527e25b7ae74 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 8 Dec 2023 19:36:25 -0800 Subject: [PATCH] Support creating table using a Python data source in SQL --- .../main/resources/error/error-classes.json | 8 +- .../logical/pythonLogicalOperators.scala | 3 +- .../sql/catalyst/trees/TreePatterns.scala | 1 + .../sql/errors/QueryCompilationErrors.scala | 6 -- .../datasources/v2/DataSourceV2Relation.scala | 3 + .../apache/spark/sql/DataFrameReader.scala | 51 +++---------- .../apache/spark/sql/DataFrameWriter.scala | 5 +- .../spark/sql/DataSourceRegistration.scala | 2 +- .../analysis/ResolveSessionCatalog.scala | 8 +- .../execution/datasources/DataSource.scala | 23 +++++- .../datasources/DataSourceManager.scala | 41 +++++----- .../PlanPythonDataSourceScan.scala | 6 +- .../RewriteUserDefinedDataSource.scala | 32 ++++++++ .../UserDefinedDataSourceTableProvider.scala | 76 +++++++++++++++++++ .../spark/sql/execution/datasources/ddl.scala | 3 +- .../datasources/v2/DataSourceV2Utils.scala | 9 ++- .../datasources/v2/V2SessionCatalog.scala | 11 ++- .../python/UserDefinedPythonDataSource.scala | 43 ++++++----- .../internal/BaseSessionStateBuilder.scala | 5 +- .../python/PythonDataSourceSuite.scala | 20 +++-- .../sql/hive/HiveSessionStateBuilder.scala | 3 +- 21 files changed, 237 insertions(+), 122 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RewriteUserDefinedDataSource.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/UserDefinedDataSourceTableProvider.scala diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 62d10c0d34cb..bb1dd849f365 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -875,15 +875,9 @@ ], "sqlState" : "42710" }, - "DATA_SOURCE_NOT_EXIST" : { - "message" : [ - "Data source '' not found. Please make sure the data source is registered." - ], - "sqlState" : "42704" - }, "DATA_SOURCE_NOT_FOUND" : { "message" : [ - "Failed to find the data source: . Please find packages at `https://spark.apache.org/third-party-projects.html`." + "Failed to find the data source: . Please find packages at `https://spark.apache.org/third-party-projects.html`, or register it first." ], "sqlState" : "42K02" }, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index d4ed673c3513..82ffdde40ea8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -108,8 +108,7 @@ case class PythonMapInArrow( */ case class PythonDataSource( dataSource: PythonFunction, - outputSchema: StructType, - override val output: Seq[Attribute]) extends LeafNode { + output: Seq[Attribute]) extends LeafNode { require(output.forall(_.resolved), "Unresolved attributes found when constructing PythonDataSource.") override protected def stringArgs: Iterator[Any] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 1f0df8f3b8ab..75b47d5306e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -42,6 +42,7 @@ object TreePattern extends Enumeration { val CREATE_NAMED_STRUCT: Value = Value val CURRENT_LIKE: Value = Value val DESERIALIZE_TO_OBJECT: Value = Value + val DATA_SOURCE_V2_RELATION: Value = Value val DYNAMIC_PRUNING_EXPRESSION: Value = Value val DYNAMIC_PRUNING_SUBQUERY: Value = Value val EXISTS_SUBQUERY = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index b7e10dc194a0..91984da6eba3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3844,12 +3844,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat messageParameters = Map("provider" -> name)) } - def dataSourceDoesNotExist(name: String): Throwable = { - new AnalysisException( - errorClass = "DATA_SOURCE_NOT_EXIST", - messageParameters = Map("provider" -> name)) - } - def foundMultipleDataSources(provider: String): Throwable = { new AnalysisException( errorClass = "FOUND_MULTIPLE_DATA_SOURCES", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 573b0274e958..a7d5947f9f4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, ExposesMetadataColumns, Histogram, HistogramBin, LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, truncatedString, CharVarcharUtils} import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, SupportsMetadataColumns, Table, TableCapability} @@ -50,6 +51,8 @@ case class DataSourceV2Relation( import DataSourceV2Implicits._ + final override val nodePatterns: Seq[TreePattern] = Seq(DATA_SOURCE_V2_RELATION) + lazy val funCatalog: Option[FunctionCatalog] = catalog.collect { case c: FunctionCatalog => c } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index c29ffb329072..5d0c33d14403 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql -import java.util.{Locale, Properties, ServiceConfigurationError} +import java.util.{Locale, Properties} import scala.jdk.CollectionConverters._ -import scala.util.{Failure, Success, Try} -import org.apache.spark.{Partition, SparkClassNotFoundException, SparkThrowable} +import org.apache.spark.Partition import org.apache.spark.annotation.Stable import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging @@ -209,45 +208,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { throw QueryCompilationErrors.pathOptionNotSetCorrectlyWhenReadingError() } - val isUserDefinedDataSource = - sparkSession.sessionState.dataSourceManager.dataSourceExists(source) - - Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf)) match { - case Success(providerOpt) => - // The source can be successfully loaded as either a V1 or a V2 data source. - // Check if it is also a user-defined data source. - if (isUserDefinedDataSource) { - throw QueryCompilationErrors.foundMultipleDataSources(source) - } - providerOpt.flatMap { provider => - DataSourceV2Utils.loadV2Source( - sparkSession, provider, userSpecifiedSchema, extraOptions, source, paths: _*) - }.getOrElse(loadV1Source(paths: _*)) - case Failure(exception) => - // Exceptions are thrown while trying to load the data source as a V1 or V2 data source. - // For the following not found exceptions, if the user-defined data source is defined, - // we can instead return the user-defined data source. - val isNotFoundError = exception match { - case _: NoClassDefFoundError | _: SparkClassNotFoundException => true - case e: SparkThrowable => e.getErrorClass == "DATA_SOURCE_NOT_FOUND" - case e: ServiceConfigurationError => e.getCause.isInstanceOf[NoClassDefFoundError] - case _ => false - } - if (isNotFoundError && isUserDefinedDataSource) { - loadUserDefinedDataSource(paths) - } else { - // Throw the original exception. - throw exception - } - } - } - - private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = { - val builder = sparkSession.sessionState.dataSourceManager.lookupDataSource(source) - // Add `path` and `paths` options to the extra options if specified. - val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions, paths: _*) - val plan = builder(sparkSession, source, userSpecifiedSchema, optionsWithPath) - Dataset.ofRows(sparkSession, plan) + DataSource.lookupDataSourceV2( + source, + sparkSession.sessionState.conf, + sparkSession.sessionState.dataSourceManager).flatMap { provider => + DataSourceV2Utils.loadV2Source(sparkSession, provider, userSpecifiedSchema, extraOptions, + source, paths: _*) + }.getOrElse(loadV1Source(paths: _*)) } private def loadV1Source(paths: String*) = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index c8727146160b..f9ed77b0517e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -894,7 +894,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } private def lookupV2Provider(): Option[TableProvider] = { - DataSource.lookupDataSourceV2(source, df.sparkSession.sessionState.conf) match { + DataSource.lookupDataSourceV2( + source, + df.sparkSession.sessionState.conf, + df.sparkSession.sessionState.dataSourceManager) match { // TODO(SPARK-28396): File source v2 write path is currently broken. case Some(_: FileDataSourceV2) => None case other => other diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala index 15d26418984b..c3dcb9396fd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala @@ -43,6 +43,6 @@ private[sql] class DataSourceRegistration private[sql] (dataSourceManager: DataS | pythonExec: ${dataSource.dataSourceCls.pythonExec} """.stripMargin) - dataSourceManager.registerDataSource(name, dataSource.builder) + dataSourceManager.registerDataSource(name, dataSource.getBuilder) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index d44de0b260b2..45e41feec2e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogV2Util, Lo import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1} +import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1, DataSourceManager} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.internal.connector.V1Function @@ -44,7 +44,9 @@ import org.apache.spark.util.ArrayImplicits._ * identifiers to construct the v1 commands, so that v1 commands do not need to qualify identifiers * again, which may lead to inconsistent behavior if the current database is changed in the middle. */ -class ResolveSessionCatalog(val catalogManager: CatalogManager) +class ResolveSessionCatalog( + val catalogManager: CatalogManager, + dataSourceManager: DataSourceManager = new DataSourceManager) extends Rule[LogicalPlan] with LookupCatalog { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.CatalogV2Util._ @@ -612,7 +614,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) } private def isV2Provider(provider: String): Boolean = { - DataSourceV2Utils.getTableProvider(provider, conf).isDefined + DataSourceV2Utils.getTableProvider(provider, conf, dataSourceManager).isDefined } private object DatabaseInSessionCatalog { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 668d2538e03f..6071ab958b22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -25,7 +25,7 @@ import scala.util.{Failure, Success, Try} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.SparkException +import org.apache.spark.{SparkClassNotFoundException, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql._ @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, TypeUtils} import org.apache.spark.sql.connector.catalog.TableProvider import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.command.DataWritingCommand +import org.apache.spark.sql.execution.datasources.UserDefinedDataSourceTableProvider import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -705,10 +706,26 @@ object DataSource extends Logging { * there is no corresponding Data Source V2 implementation, or the provider is configured to * fallback to Data Source V1 code path. */ - def lookupDataSourceV2(provider: String, conf: SQLConf): Option[TableProvider] = { + def lookupDataSourceV2( + provider: String, + conf: SQLConf, + dataSourceManager: DataSourceManager): Option[TableProvider] = { val useV1Sources = conf.getConf(SQLConf.USE_V1_SOURCE_LIST).toLowerCase(Locale.ROOT) .split(",").map(_.trim) - val cls = lookupDataSource(provider, conf) + val cls = try { + lookupDataSource(provider, conf) + } catch { + case e: SparkClassNotFoundException if e.getErrorClass == "DATA_SOURCE_NOT_FOUND" => + val registeredDataSourceOpt = dataSourceManager.getDataSource(provider) + if (registeredDataSourceOpt.isDefined) { + return Some(new UserDefinedDataSourceTableProvider(provider, registeredDataSourceOpt.get)) + } else { + throw e + } + } + if (dataSourceManager.dataSourceExists(provider)) { + throw QueryCompilationErrors.foundMultipleDataSources(provider) + } val instance = try { cls.getDeclaredConstructor().newInstance() } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala index 1cdc3d9cb69e..d26ff9a9339a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala @@ -21,26 +21,16 @@ import java.util.Locale import java.util.concurrent.ConcurrentHashMap import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** - * A manager for user-defined data sources. It is used to register and lookup data sources by - * their short names or fully qualified names. + * A manager for user-defined data sources. It is used to register and lookup data sources by names. */ class DataSourceManager extends Logging { - - private type DataSourceBuilder = ( - SparkSession, // Spark session - String, // provider name - Option[StructType], // user specified schema - CaseInsensitiveMap[String] // options - ) => LogicalPlan - - private val dataSourceBuilders = new ConcurrentHashMap[String, DataSourceBuilder]() + private val dataSourceBuilders = new ConcurrentHashMap[String, UserDefinedDataSourceBuilder]() private def normalize(name: String): String = name.toLowerCase(Locale.ROOT) @@ -48,7 +38,7 @@ class DataSourceManager extends Logging { * Register a data source builder for the given provider. * Note that the provider name is case-insensitive. */ - def registerDataSource(name: String, builder: DataSourceBuilder): Unit = { + def registerDataSource(name: String, builder: UserDefinedDataSourceBuilder): Unit = { val normalizedName = normalize(name) val previousValue = dataSourceBuilders.put(normalizedName, builder) if (previousValue != null) { @@ -60,12 +50,8 @@ class DataSourceManager extends Logging { * Returns a data source builder for the given provider and throw an exception if * it does not exist. */ - def lookupDataSource(name: String): DataSourceBuilder = { - if (dataSourceExists(name)) { - dataSourceBuilders.get(normalize(name)) - } else { - throw QueryCompilationErrors.dataSourceDoesNotExist(name) - } + def getDataSource(name: String): Option[UserDefinedDataSourceBuilder] = { + Option(dataSourceBuilders.get(normalize(name))) } /** @@ -81,3 +67,16 @@ class DataSourceManager extends Logging { manager } } + +trait UserDefinedDataSourceBuilder { + def build( + provider: String, + userSpecifiedSchema: Option[StructType], + options: CaseInsensitiveStringMap): UserDefinedDataSourcePlanBuilder +} + +trait UserDefinedDataSourcePlanBuilder { + def schema: StructType + + def build(output: Seq[Attribute]): LogicalPlan +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala index ec4c7c188fa0..e2bc61493c04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala @@ -50,8 +50,8 @@ import org.apache.spark.util.ArrayImplicits._ object PlanPythonDataSourceScan extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning( _.containsPattern(PYTHON_DATA_SOURCE)) { - case ds @ PythonDataSource(dataSource: PythonFunction, schema, _) => - val info = new UserDefinedPythonDataSourceReadRunner(dataSource, schema).runInPython() + case ds @ PythonDataSource(dataSource: PythonFunction, _) => + val info = new UserDefinedPythonDataSourceReadRunner(dataSource, ds.schema).runInPython() val readerFunc = SimplePythonFunction( command = info.func.toImmutableArraySeq, @@ -69,7 +69,7 @@ object PlanPythonDataSourceScan extends Rule[LogicalPlan] { val pythonUDTF = PythonUDTF( name = "python_data_source_read", func = readerFunc, - elementSchema = schema, + elementSchema = ds.schema, children = partitionPlan.output, evalType = PythonEvalType.SQL_TABLE_UDF, udfDeterministic = false, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RewriteUserDefinedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RewriteUserDefinedDataSource.scala new file mode 100644 index 000000000000..c2c7cb45f280 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/RewriteUserDefinedDataSource.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation + +object RewriteUserDefinedDataSource extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( + _.containsPattern(DATA_SOURCE_V2_RELATION)) { + case r: DataSourceV2Relation if r.table.isInstanceOf[UserDefinedDataSourceTable] => + val table = r.table.asInstanceOf[UserDefinedDataSourceTable] + table.builder.build(r.output) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/UserDefinedDataSourceTableProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/UserDefinedDataSourceTableProvider.scala new file mode 100644 index 000000000000..d0ac7f26ecd4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/UserDefinedDataSourceTableProvider.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.util + +import org.apache.spark.SparkException +import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class UserDefinedDataSourceTableProvider( + name: String, + builder: UserDefinedDataSourceBuilder) extends TableProvider { + private var planBuilder: UserDefinedDataSourcePlanBuilder = _ + + override def inferSchema(options: CaseInsensitiveStringMap): StructType = { + assert(planBuilder == null) + // When we reach here, it means there is no user-specified schema + planBuilder = builder.build(name, userSpecifiedSchema = None, options = options) + planBuilder.schema + } + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: util.Map[String, String]): Table = { + assert(partitioning.isEmpty) + if (planBuilder == null) { + // When we reach here, it means there is user-specified schema + planBuilder = builder.build( + name, + userSpecifiedSchema = Some(schema), + options = new CaseInsensitiveStringMap(properties)) + } else { + assert(schema == planBuilder.schema) + } + UserDefinedDataSourceTable(name, planBuilder) + } +} + +case class UserDefinedDataSourceTable( + name: String, + builder: UserDefinedDataSourcePlanBuilder) extends Table + with SupportsRead with SupportsWrite { + override def schema(): StructType = builder.schema + override def capabilities(): util.Set[TableCapability] = { + util.EnumSet.of(TableCapability.BATCH_READ, TableCapability.BATCH_WRITE) + } + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + throw SparkException.internalError( + "UserDefinedDataSourceTable.newScanBuilder should not be called.") + } + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + throw SparkException.internalError( + "UserDefinedDataSourceTable.newWriteBuilder should not be called.") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index fc6cba786c4e..df11fee82b28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -99,7 +99,8 @@ case class CreateTempViewUsing( } val catalog = sparkSession.sessionState.catalog - val analyzedPlan = DataSource.lookupDataSourceV2(provider, sparkSession.sessionState.conf) + val analyzedPlan = DataSource.lookupDataSourceV2( + provider, sparkSession.sessionState.conf, sparkSession.sessionState.dataSourceManager) .flatMap { tblProvider => DataSourceV2Utils.loadV2Source(sparkSession, tblProvider, userSpecifiedSchema, CaseInsensitiveMap(options), provider) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index 9ffa0d728ca2..96814982c087 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SessionConfigSuppo import org.apache.spark.sql.connector.catalog.TableCapability.BATCH_READ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceManager} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -156,11 +156,14 @@ private[sql] object DataSourceV2Utils extends Logging { /** * Returns the table provider for the given format, or None if it cannot be found. */ - def getTableProvider(provider: String, conf: SQLConf): Option[TableProvider] = { + def getTableProvider( + provider: String, + conf: SQLConf, + dataSourceManager: DataSourceManager): Option[TableProvider] = { // Return earlier since `lookupDataSourceV2` may fail to resolve provider "hive" to // `HiveFileFormat`, when running tests in sql/core. if (DDLUtils.isHiveTable(Some(provider))) return None - DataSource.lookupDataSourceV2(provider, conf) match { + DataSource.lookupDataSourceV2(provider, conf, dataSourceManager) match { // TODO(SPARK-28396): Currently file source v2 can't work with tables. case Some(p) if !p.isInstanceOf[FileDataSourceV2] => Some(p) case _ => None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index 5b1ff7c67b26..d68fa2434601 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.catalog.NamespaceChange.RemoveProperty import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceManager} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.connector.V1Function import org.apache.spark.sql.types.StructType @@ -43,7 +43,9 @@ import org.apache.spark.util.ArrayImplicits._ /** * A [[TableCatalog]] that translates calls to the v1 SessionCatalog. */ -class V2SessionCatalog(catalog: SessionCatalog) +class V2SessionCatalog( + catalog: SessionCatalog, + dataSourceManager: DataSourceManager = new DataSourceManager) extends TableCatalog with FunctionCatalog with SupportsNamespaces with SQLConfHelper { import V2SessionCatalog._ @@ -85,7 +87,7 @@ class V2SessionCatalog(catalog: SessionCatalog) try { val table = catalog.getTableMetadata(ident.asTableIdentifier) if (table.provider.isDefined) { - DataSourceV2Utils.getTableProvider(table.provider.get, conf) match { + DataSourceV2Utils.getTableProvider(table.provider.get, conf, dataSourceManager) match { case Some(provider) => // Get the table properties during creation and append the path option // to the properties. @@ -173,7 +175,8 @@ class V2SessionCatalog(catalog: SessionCatalog) CatalogTableType.MANAGED } - val (newSchema, newPartitions) = DataSourceV2Utils.getTableProvider(provider, conf) match { + val (newSchema, newPartitions) = DataSourceV2Utils.getTableProvider( + provider, conf, dataSourceManager) match { // If the provider does not support external metadata, users should not be allowed to // specify custom schema when creating the data source table, since the schema will not // be used when loading the table. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index 7044ef65c638..5147bd1d18a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -20,16 +20,19 @@ package org.apache.spark.sql.execution.python import java.io.{DataInputStream, DataOutputStream} import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ import net.razorvine.pickle.Pickler import org.apache.spark.api.python.{PythonFunction, PythonWorkerUtils, SimplePythonFunction, SpecialLengths} import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PythonDataSource} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.datasources.{UserDefinedDataSourceBuilder, UserDefinedDataSourcePlanBuilder} import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ArrayImplicits._ /** @@ -39,12 +42,23 @@ import org.apache.spark.util.ArrayImplicits._ */ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { - def builder( + def getBuilder: PythonDataSourceBuilder = new PythonDataSourceBuilder(dataSourceCls) + + def apply( sparkSession: SparkSession, provider: String, - userSpecifiedSchema: Option[StructType], - options: CaseInsensitiveMap[String]): LogicalPlan = { + userSpecifiedSchema: Option[StructType] = None, + options: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty()): DataFrame = { + val planBuilder = getBuilder.build(provider, userSpecifiedSchema, options) + Dataset.ofRows(sparkSession, planBuilder.build(toAttributes(planBuilder.schema))) + } +} +class PythonDataSourceBuilder(dataSourceCls: PythonFunction) extends UserDefinedDataSourceBuilder { + override def build( + provider: String, + userSpecifiedSchema: Option[StructType], + options: CaseInsensitiveStringMap): UserDefinedDataSourcePlanBuilder = { val runner = new UserDefinedPythonDataSourceRunner( dataSourceCls, provider, userSpecifiedSchema, options) @@ -59,19 +73,14 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { pythonVer = dataSourceCls.pythonVer, broadcastVars = dataSourceCls.broadcastVars, accumulator = dataSourceCls.accumulator) - val schema = result.schema - - PythonDataSource(dataSource, schema, output = toAttributes(schema)) + new PythonDataSourcePlanBuilder(result.schema, dataSource) } +} - def apply( - sparkSession: SparkSession, - provider: String, - userSpecifiedSchema: Option[StructType] = None, - options: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map.empty)): DataFrame = { - val plan = builder(sparkSession, provider, userSpecifiedSchema, options) - Dataset.ofRows(sparkSession, plan) - } +class PythonDataSourcePlanBuilder( + override val schema: StructType, + dataSource: SimplePythonFunction) extends UserDefinedDataSourcePlanBuilder { + override def build(output: Seq[Attribute]): LogicalPlan = PythonDataSource(dataSource, output) } /** @@ -88,7 +97,7 @@ class UserDefinedPythonDataSourceRunner( dataSourceCls: PythonFunction, provider: String, userSpecifiedSchema: Option[StructType], - options: CaseInsensitiveMap[String]) + options: CaseInsensitiveStringMap) extends PythonPlannerRunner[PythonDataSourceCreationResult](dataSourceCls) { override val workerModule = "pyspark.sql.worker.create_data_source" @@ -106,7 +115,7 @@ class UserDefinedPythonDataSourceRunner( // Send the options dataOut.writeInt(options.size) - options.iterator.foreach { case (key, value) => + options.asCaseSensitiveMap().asScala.iterator.foreach { case (key, value) => PythonWorkerUtils.writeUTF(key, dataOut) PythonWorkerUtils.writeUTF(value, dataOut) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 00c72294ca07..ac463dda9743 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -171,7 +171,7 @@ abstract class BaseSessionStateBuilder( catalog } - protected lazy val v2SessionCatalog = new V2SessionCatalog(catalog) + protected lazy val v2SessionCatalog = new V2SessionCatalog(catalog, dataSourceManager) protected lazy val catalogManager = new CatalogManager(v2SessionCatalog, catalog) @@ -199,10 +199,11 @@ abstract class BaseSessionStateBuilder( protected def analyzer: Analyzer = new Analyzer(catalogManager) { override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = new FindDataSourceTable(session) +: + RewriteUserDefinedDataSource +: new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: ResolveEncodersInScalaAgg +: - new ResolveSessionCatalog(this.catalogManager) +: + new ResolveSessionCatalog(this.catalogManager, dataSourceManager) +: ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: customResolutionRules diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index ec2f8c19b02b..558d0b815ff2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.execution.python import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row} import org.apache.spark.sql.catalyst.plans.logical.{BatchEvalPythonUDTF, PythonDataSourcePartitions} -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap class PythonDataSourceSuite extends QueryTest with SharedSparkSession { import IntegratedUDFTestUtils._ @@ -146,9 +147,10 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) spark.dataSource.registerPython(dataSourceName, dataSource) assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName)) - val ds1 = spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName) + val ds1 = spark.sessionState.dataSourceManager.getDataSource(dataSourceName).get + val planBuilder1 = ds1.build(dataSourceName, None, CaseInsensitiveStringMap.empty()) checkAnswer( - ds1(spark, dataSourceName, None, CaseInsensitiveMap(Map.empty)), + planBuilder1.build(toAttributes(planBuilder1.schema)), Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1))) // Should be able to override an already registered data source. @@ -170,9 +172,10 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { spark.dataSource.registerPython(dataSourceName, newDataSource) assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName)) - val ds2 = spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName) + val ds2 = spark.sessionState.dataSourceManager.getDataSource(dataSourceName).get + val planBuilder2 = ds2.build(dataSourceName, None, CaseInsensitiveStringMap.empty()) checkAnswer( - ds2(spark, dataSourceName, None, CaseInsensitiveMap(Map.empty)), + planBuilder2.build(toAttributes(planBuilder2.schema)), Seq(Row(0))) } @@ -219,6 +222,13 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1))) checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1))) checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1), Row("2", 1))) + + // Test SQL + withTable("tblA") { + sql("CREATE TABLE tblA USING test") + // The path will be the actual temp path. + checkAnswer(spark.table("tblA").selectExpr("value"), Seq(Row(1))) + } } test("reader not implemented") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 8a33645853cb..c5e79aef6295 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -87,10 +87,11 @@ class HiveSessionStateBuilder( override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = new ResolveHiveSerdeTable(session) +: new FindDataSourceTable(session) +: + RewriteUserDefinedDataSource +: new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: ResolveEncodersInScalaAgg +: - new ResolveSessionCatalog(catalogManager) +: + new ResolveSessionCatalog(catalogManager, dataSourceManager) +: ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: new DetermineTableStats(session) +: