Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -875,15 +875,9 @@
],
"sqlState" : "42710"
},
"DATA_SOURCE_NOT_EXIST" : {
"message" : [
"Data source '<provider>' not found. Please make sure the data source is registered."
],
"sqlState" : "42704"
},
"DATA_SOURCE_NOT_FOUND" : {
"message" : [
"Failed to find the data source: <provider>. Please find packages at `https://spark.apache.org/third-party-projects.html`."
"Failed to find the data source: <provider>. Please find packages at `https://spark.apache.org/third-party-projects.html`, or register it first."
],
"sqlState" : "42K02"
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ case class PythonMapInArrow(
*/
case class PythonDataSource(
dataSource: PythonFunction,
outputSchema: StructType,
override val output: Seq[Attribute]) extends LeafNode {
output: Seq[Attribute]) extends LeafNode {
require(output.forall(_.resolved),
"Unresolved attributes found when constructing PythonDataSource.")
override protected def stringArgs: Iterator[Any] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ object TreePattern extends Enumeration {
val CREATE_NAMED_STRUCT: Value = Value
val CURRENT_LIKE: Value = Value
val DESERIALIZE_TO_OBJECT: Value = Value
val DATA_SOURCE_V2_RELATION: Value = Value
val DYNAMIC_PRUNING_EXPRESSION: Value = Value
val DYNAMIC_PRUNING_SUBQUERY: Value = Value
val EXISTS_SUBQUERY = Value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3844,12 +3844,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
messageParameters = Map("provider" -> name))
}

def dataSourceDoesNotExist(name: String): Throwable = {
new AnalysisException(
errorClass = "DATA_SOURCE_NOT_EXIST",
messageParameters = Map("provider" -> name))
}

def foundMultipleDataSources(provider: String): Throwable = {
new AnalysisException(
errorClass = "FOUND_MULTIPLE_DATA_SOURCES",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, ExposesMetadataColumns, Histogram, HistogramBin, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, truncatedString, CharVarcharUtils}
import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, SupportsMetadataColumns, Table, TableCapability}
Expand Down Expand Up @@ -50,6 +51,8 @@ case class DataSourceV2Relation(

import DataSourceV2Implicits._

final override val nodePatterns: Seq[TreePattern] = Seq(DATA_SOURCE_V2_RELATION)

lazy val funCatalog: Option[FunctionCatalog] = catalog.collect {
case c: FunctionCatalog => c
}
Expand Down
51 changes: 9 additions & 42 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@

package org.apache.spark.sql

import java.util.{Locale, Properties, ServiceConfigurationError}
import java.util.{Locale, Properties}

import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}

import org.apache.spark.{Partition, SparkClassNotFoundException, SparkThrowable}
import org.apache.spark.Partition
import org.apache.spark.annotation.Stable
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -209,45 +208,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
throw QueryCompilationErrors.pathOptionNotSetCorrectlyWhenReadingError()
}

val isUserDefinedDataSource =
sparkSession.sessionState.dataSourceManager.dataSourceExists(source)

Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf)) match {
case Success(providerOpt) =>
// The source can be successfully loaded as either a V1 or a V2 data source.
// Check if it is also a user-defined data source.
if (isUserDefinedDataSource) {
throw QueryCompilationErrors.foundMultipleDataSources(source)
}
providerOpt.flatMap { provider =>
DataSourceV2Utils.loadV2Source(
sparkSession, provider, userSpecifiedSchema, extraOptions, source, paths: _*)
}.getOrElse(loadV1Source(paths: _*))
case Failure(exception) =>
// Exceptions are thrown while trying to load the data source as a V1 or V2 data source.
// For the following not found exceptions, if the user-defined data source is defined,
// we can instead return the user-defined data source.
val isNotFoundError = exception match {
case _: NoClassDefFoundError | _: SparkClassNotFoundException => true
case e: SparkThrowable => e.getErrorClass == "DATA_SOURCE_NOT_FOUND"
case e: ServiceConfigurationError => e.getCause.isInstanceOf[NoClassDefFoundError]
case _ => false
}
if (isNotFoundError && isUserDefinedDataSource) {
loadUserDefinedDataSource(paths)
} else {
// Throw the original exception.
throw exception
}
}
}

private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
val builder = sparkSession.sessionState.dataSourceManager.lookupDataSource(source)
// Add `path` and `paths` options to the extra options if specified.
val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions, paths: _*)
val plan = builder(sparkSession, source, userSpecifiedSchema, optionsWithPath)
Dataset.ofRows(sparkSession, plan)
DataSource.lookupDataSourceV2(
source,
sparkSession.sessionState.conf,
sparkSession.sessionState.dataSourceManager).flatMap { provider =>
DataSourceV2Utils.loadV2Source(sparkSession, provider, userSpecifiedSchema, extraOptions,
source, paths: _*)
}.getOrElse(loadV1Source(paths: _*))
}

private def loadV1Source(paths: String*) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}

private def lookupV2Provider(): Option[TableProvider] = {
DataSource.lookupDataSourceV2(source, df.sparkSession.sessionState.conf) match {
DataSource.lookupDataSourceV2(
source,
df.sparkSession.sessionState.conf,
df.sparkSession.sessionState.dataSourceManager) match {
// TODO(SPARK-28396): File source v2 write path is currently broken.
case Some(_: FileDataSourceV2) => None
case other => other
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@ private[sql] class DataSourceRegistration private[sql] (dataSourceManager: DataS
| pythonExec: ${dataSource.dataSourceCls.pythonExec}
""".stripMargin)

dataSourceManager.registerDataSource(name, dataSource.builder)
dataSourceManager.registerDataSource(name, dataSource.getBuilder)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogV2Util, Lo
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1}
import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1, DataSourceManager}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.internal.connector.V1Function
Expand All @@ -44,7 +44,9 @@ import org.apache.spark.util.ArrayImplicits._
* identifiers to construct the v1 commands, so that v1 commands do not need to qualify identifiers
* again, which may lead to inconsistent behavior if the current database is changed in the middle.
*/
class ResolveSessionCatalog(val catalogManager: CatalogManager)
class ResolveSessionCatalog(
val catalogManager: CatalogManager,
dataSourceManager: DataSourceManager = new DataSourceManager)
extends Rule[LogicalPlan] with LookupCatalog {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.CatalogV2Util._
Expand Down Expand Up @@ -612,7 +614,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
}

private def isV2Provider(provider: String): Boolean = {
DataSourceV2Utils.getTableProvider(provider, conf).isDefined
DataSourceV2Utils.getTableProvider(provider, conf, dataSourceManager).isDefined
}

private object DatabaseInSessionCatalog {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.util.{Failure, Success, Try}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.SparkException
import org.apache.spark.{SparkClassNotFoundException, SparkException}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
Expand All @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, TypeUtils}
import org.apache.spark.sql.connector.catalog.TableProvider
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.command.DataWritingCommand
import org.apache.spark.sql.execution.datasources.UserDefinedDataSourceTableProvider
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
Expand Down Expand Up @@ -705,10 +706,26 @@ object DataSource extends Logging {
* there is no corresponding Data Source V2 implementation, or the provider is configured to
* fallback to Data Source V1 code path.
*/
def lookupDataSourceV2(provider: String, conf: SQLConf): Option[TableProvider] = {
def lookupDataSourceV2(
provider: String,
conf: SQLConf,
dataSourceManager: DataSourceManager): Option[TableProvider] = {
val useV1Sources = conf.getConf(SQLConf.USE_V1_SOURCE_LIST).toLowerCase(Locale.ROOT)
.split(",").map(_.trim)
val cls = lookupDataSource(provider, conf)
val cls = try {
lookupDataSource(provider, conf)
} catch {
case e: SparkClassNotFoundException if e.getErrorClass == "DATA_SOURCE_NOT_FOUND" =>
val registeredDataSourceOpt = dataSourceManager.getDataSource(provider)
if (registeredDataSourceOpt.isDefined) {
return Some(new UserDefinedDataSourceTableProvider(provider, registeredDataSourceOpt.get))
} else {
throw e
}
}
if (dataSourceManager.dataSourceExists(provider)) {
throw QueryCompilationErrors.foundMultipleDataSources(provider)
}
val instance = try {
cls.getDeclaredConstructor().newInstance()
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,24 @@ import java.util.Locale
import java.util.concurrent.ConcurrentHashMap

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
* A manager for user-defined data sources. It is used to register and lookup data sources by
* their short names or fully qualified names.
* A manager for user-defined data sources. It is used to register and lookup data sources by names.
*/
class DataSourceManager extends Logging {

private type DataSourceBuilder = (
SparkSession, // Spark session
String, // provider name
Option[StructType], // user specified schema
CaseInsensitiveMap[String] // options
) => LogicalPlan

private val dataSourceBuilders = new ConcurrentHashMap[String, DataSourceBuilder]()
private val dataSourceBuilders = new ConcurrentHashMap[String, UserDefinedDataSourceBuilder]()

private def normalize(name: String): String = name.toLowerCase(Locale.ROOT)

/**
* Register a data source builder for the given provider.
* Note that the provider name is case-insensitive.
*/
def registerDataSource(name: String, builder: DataSourceBuilder): Unit = {
def registerDataSource(name: String, builder: UserDefinedDataSourceBuilder): Unit = {
val normalizedName = normalize(name)
val previousValue = dataSourceBuilders.put(normalizedName, builder)
if (previousValue != null) {
Expand All @@ -60,12 +50,8 @@ class DataSourceManager extends Logging {
* Returns a data source builder for the given provider and throw an exception if
* it does not exist.
*/
def lookupDataSource(name: String): DataSourceBuilder = {
if (dataSourceExists(name)) {
dataSourceBuilders.get(normalize(name))
} else {
throw QueryCompilationErrors.dataSourceDoesNotExist(name)
}
def getDataSource(name: String): Option[UserDefinedDataSourceBuilder] = {
Option(dataSourceBuilders.get(normalize(name)))
}

/**
Expand All @@ -81,3 +67,16 @@ class DataSourceManager extends Logging {
manager
}
}

trait UserDefinedDataSourceBuilder {
def build(
provider: String,
userSpecifiedSchema: Option[StructType],
options: CaseInsensitiveStringMap): UserDefinedDataSourcePlanBuilder
}

trait UserDefinedDataSourcePlanBuilder {
def schema: StructType

def build(output: Seq[Attribute]): LogicalPlan
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ import org.apache.spark.util.ArrayImplicits._
object PlanPythonDataSourceScan extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning(
_.containsPattern(PYTHON_DATA_SOURCE)) {
case ds @ PythonDataSource(dataSource: PythonFunction, schema, _) =>
val info = new UserDefinedPythonDataSourceReadRunner(dataSource, schema).runInPython()
case ds @ PythonDataSource(dataSource: PythonFunction, _) =>
val info = new UserDefinedPythonDataSourceReadRunner(dataSource, ds.schema).runInPython()

val readerFunc = SimplePythonFunction(
command = info.func.toImmutableArraySeq,
Expand All @@ -69,7 +69,7 @@ object PlanPythonDataSourceScan extends Rule[LogicalPlan] {
val pythonUDTF = PythonUDTF(
name = "python_data_source_read",
func = readerFunc,
elementSchema = schema,
elementSchema = ds.schema,
children = partitionPlan.output,
evalType = PythonEvalType.SQL_TABLE_UDF,
udfDeterministic = false,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation

object RewriteUserDefinedDataSource extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsPattern(DATA_SOURCE_V2_RELATION)) {
case r: DataSourceV2Relation if r.table.isInstanceOf[UserDefinedDataSourceTable] =>
val table = r.table.asInstanceOf[UserDefinedDataSourceTable]
table.builder.build(r.output)
}
}
Loading