Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,12 @@
],
"sqlState" : "42710"
},
"DATA_SOURCE_NOT_EXIST" : {
"message" : [
"Data source '<provider>' not found. Please make sure the data source is registered."
],
"sqlState" : "42704"
},
"DATA_SOURCE_NOT_FOUND" : {
"message" : [
"Failed to find the data source: <provider>. Please find packages at `https://spark.apache.org/third-party-projects.html`."
Expand Down Expand Up @@ -1063,6 +1069,12 @@
],
"sqlState" : "42809"
},
"FOUND_MULTIPLE_DATA_SOURCES" : {
"message" : [
"Detected multiple data sources with the name '<provider>'. Please check the data source isn't simultaneously registered and located in the classpath."
],
"sqlState" : "42710"
},
"GENERATED_COLUMN_WITH_DEFAULT_VALUE" : {
"message" : [
"A column cannot have both a default value and a generation expression but column <colName> has default value: (<defaultValue>) and generation expression: (<genExpr>)."
Expand Down
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,7 @@ def __hash__(self):
"pyspark.sql.tests.pandas.test_pandas_udf_window",
"pyspark.sql.tests.pandas.test_converter",
"pyspark.sql.tests.test_pandas_sqlmetrics",
"pyspark.sql.tests.test_python_datasource",
"pyspark.sql.tests.test_readwriter",
"pyspark.sql.tests.test_serde",
"pyspark.sql.tests.test_session",
Expand Down
12 changes: 12 additions & 0 deletions docs/sql-error-conditions.md
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,12 @@ DataType `<type>` requires a length parameter, for example `<type>`(10). Please

Data source '`<provider>`' already exists in the registry. Please use a different name for the new data source.

### DATA_SOURCE_NOT_EXIST

[SQLSTATE: 42704](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)

Data source '`<provider>`' not found. Please make sure the data source is registered.

### DATA_SOURCE_NOT_FOUND

[SQLSTATE: 42K02](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
Expand Down Expand Up @@ -652,6 +658,12 @@ No such struct field `<fieldName>` in `<fields>`.

The operation `<statement>` is not allowed on the `<objectType>`: `<objectName>`.

### FOUND_MULTIPLE_DATA_SOURCES

[SQLSTATE: 42710](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)

Detected multiple data sources with the name '`<provider>`'. Please check the data source isn't simultaneously registered and located in the classpath.

### GENERATED_COLUMN_WITH_DEFAULT_VALUE

[SQLSTATE: 42623](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,10 @@ def dataSource(self) -> "DataSourceRegistration":
Returns
-------
:class:`DataSourceRegistration`

Notes
-----
This feature is experimental and unstable.
"""
from pyspark.sql.datasource import DataSourceRegistration

Expand Down
97 changes: 89 additions & 8 deletions python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import unittest

from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import Row
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.testing.utils import SPARK_HOME


class BasePythonDataSourceTestsMixin:
Expand Down Expand Up @@ -45,16 +49,93 @@ def read(self, partition):
self.assertEqual(list(reader.partitions()), [None])
self.assertEqual(list(reader.read(None)), [(None,)])

def test_register_data_source(self):
class MyDataSource(DataSource):
...
def test_in_memory_data_source(self):
class InMemDataSourceReader(DataSourceReader):
DEFAULT_NUM_PARTITIONS: int = 3

def __init__(self, paths, options):
self.paths = paths
self.options = options

def partitions(self):
if "num_partitions" in self.options:
num_partitions = int(self.options["num_partitions"])
else:
num_partitions = self.DEFAULT_NUM_PARTITIONS
return range(num_partitions)

def read(self, partition):
yield partition, str(partition)

class InMemoryDataSource(DataSource):
@classmethod
def name(cls):
return "memory"

def schema(self):
return "x INT, y STRING"

def reader(self, schema) -> "DataSourceReader":
return InMemDataSourceReader(self.paths, self.options)

self.spark.dataSource.register(InMemoryDataSource)
df = self.spark.read.format("memory").load()
self.assertEqual(df.rdd.getNumPartitions(), 3)
assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1"), Row(x=2, y="2")])

self.spark.dataSource.register(MyDataSource)
df = self.spark.read.format("memory").option("num_partitions", 2).load()
assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")])
self.assertEqual(df.rdd.getNumPartitions(), 2)

def test_custom_json_data_source(self):
import json

class JsonDataSourceReader(DataSourceReader):
def __init__(self, paths, options):
self.paths = paths
self.options = options

def partitions(self):
return iter(self.paths)

def read(self, path):
with open(path, "r") as file:
for line in file.readlines():
if line.strip():
data = json.loads(line)
yield data.get("name"), data.get("age")

class JsonDataSource(DataSource):
@classmethod
def name(cls):
return "my-json"

def schema(self):
return "name STRING, age INT"

def reader(self, schema) -> "DataSourceReader":
return JsonDataSourceReader(self.paths, self.options)

self.spark.dataSource.register(JsonDataSource)
path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
path2 = os.path.join(SPARK_HOME, "python/test_support/sql/people1.json")
df1 = self.spark.read.format("my-json").load(path1)
self.assertEqual(df1.rdd.getNumPartitions(), 1)
assertDataFrameEqual(
df1,
[Row(name="Michael", age=None), Row(name="Andy", age=30), Row(name="Justin", age=19)],
)

self.assertTrue(
self.spark._jsparkSession.sharedState()
.dataSourceRegistry()
.dataSourceExists("MyDataSource")
df2 = self.spark.read.format("my-json").load([path1, path2])
self.assertEqual(df2.rdd.getNumPartitions(), 2)
assertDataFrameEqual(
df2,
[
Row(name="Michael", age=None),
Row(name="Andy", age=30),
Row(name="Justin", age=19),
Row(name="Jonathan", age=None),
],
)


Expand Down
16 changes: 14 additions & 2 deletions python/pyspark/sql/worker/create_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import inspect
import os
import sys
from typing import IO, List

from pyspark.accumulators import _accumulatorRegistry
from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, PySparkTypeError
from pyspark.java_gateway import local_connect_and_auth
from pyspark.serializers import (
read_bool,
Expand Down Expand Up @@ -84,8 +84,20 @@ def main(infile: IO, outfile: IO) -> None:
},
)

# Check the name method is a class method.
if not inspect.ismethod(data_source_cls.name):
raise PySparkTypeError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
message_parameters={
"expected": "'name()' method to be a classmethod",
"actual": f"'{type(data_source_cls.name).__name__}'",
},
)

# Receive the provider name.
provider = utf8_deserializer.loads(infile)

# Check if the provider name matches the data source's name.
if provider.lower() != data_source_cls.name().lower():
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3811,4 +3811,16 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
errorClass = "DATA_SOURCE_ALREADY_EXISTS",
messageParameters = Map("provider" -> name))
}

def dataSourceDoesNotExist(name: String): Throwable = {
new AnalysisException(
errorClass = "DATA_SOURCE_NOT_EXIST",
messageParameters = Map("provider" -> name))
}

def foundMultipleDataSources(provider: String): Throwable = {
new AnalysisException(
errorClass = "FOUND_MULTIPLE_DATA_SOURCES",
messageParameters = Map("provider" -> provider))
}
}
48 changes: 42 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

package org.apache.spark.sql

import java.util.{Locale, Properties}
import java.util.{Locale, Properties, ServiceConfigurationError}

import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}

import org.apache.spark.Partition
import org.apache.spark.{Partition, SparkClassNotFoundException, SparkThrowable}
import org.apache.spark.annotation.Stable
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -208,10 +209,45 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
throw QueryCompilationErrors.pathOptionNotSetCorrectlyWhenReadingError()
}

DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf).flatMap { provider =>
DataSourceV2Utils.loadV2Source(sparkSession, provider, userSpecifiedSchema, extraOptions,
source, paths: _*)
}.getOrElse(loadV1Source(paths: _*))
val isUserDefinedDataSource =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan @allisonwang-db do we want to support this datasource via USING syntax unlike DSv2, right?

If that's the case, the logics of loading DataSource has to be within DataSource.lookupDataSource and/or DataSource.providingInstance. I don't think we should mix the logics here with DSv2.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's at least separate the logic into a separate function if possible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately DS v2 TableProvider does not support USING yet. That's why the code is a bit messy here as it's not shared with the SQL USING path. We should support it though...

sparkSession.sharedState.dataSourceManager.dataSourceExists(source)

Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf)) match {
case Success(providerOpt) =>
// The source can be successfully loaded as either a V1 or a V2 data source.
// Check if it is also a user-defined data source.
if (isUserDefinedDataSource) {
throw QueryCompilationErrors.foundMultipleDataSources(source)
}
providerOpt.flatMap { provider =>
DataSourceV2Utils.loadV2Source(
sparkSession, provider, userSpecifiedSchema, extraOptions, source, paths: _*)
}.getOrElse(loadV1Source(paths: _*))
case Failure(exception) =>
// Exceptions are thrown while trying to load the data source as a V1 or V2 data source.
// For the following not found exceptions, if the user-defined data source is defined,
// we can instead return the user-defined data source.
val isNotFoundError = exception match {
case _: NoClassDefFoundError | _: SparkClassNotFoundException => true
case e: SparkThrowable => e.getErrorClass == "DATA_SOURCE_NOT_FOUND"
case e: ServiceConfigurationError => e.getCause.isInstanceOf[NoClassDefFoundError]
case _ => false
}
if (isNotFoundError && isUserDefinedDataSource) {
loadUserDefinedDataSource(paths)
} else {
// Throw the original exception.
throw exception
}
}
}

private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
val builder = sparkSession.sharedState.dataSourceManager.lookupDataSource(source)
// Unless the legacy path option behavior is enabled, the extraOptions here
// should not include "path" or "paths" as keys.
val plan = builder(sparkSession, source, paths, userSpecifiedSchema, extraOptions)
Dataset.ofRows(sparkSession, plan)
}

private def loadV1Source(paths: String*) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,56 @@ import java.util.concurrent.ConcurrentHashMap

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
* A manager for user-defined data sources. It is used to register and lookup data sources by
* their short names or fully qualified names.
*/
class DataSourceManager {

private type DataSourceBuilder = (
SparkSession, // Spark session
String, // provider name
Seq[String], // paths
Option[StructType], // user specified schema
CaseInsensitiveStringMap // options
CaseInsensitiveMap[String] // options
) => LogicalPlan

private val dataSourceBuilders = new ConcurrentHashMap[String, DataSourceBuilder]()

private def normalize(name: String): String = name.toLowerCase(Locale.ROOT)

/**
* Register a data source builder for the given provider.
* Note that the provider name is case-insensitive.
*/
def registerDataSource(name: String, builder: DataSourceBuilder): Unit = {
val normalizedName = normalize(name)
if (dataSourceBuilders.containsKey(normalizedName)) {
throw QueryCompilationErrors.dataSourceAlreadyExists(name)
}
// TODO(SPARK-45639): check if the data source is a DSv1 or DSv2 using loadDataSource.
dataSourceBuilders.put(normalizedName, builder)
}

def dataSourceExists(name: String): Boolean =
/**
* Returns a data source builder for the given provider and throw an exception if
* it does not exist.
*/
def lookupDataSource(name: String): DataSourceBuilder = {
if (dataSourceExists(name)) {
dataSourceBuilders.get(normalize(name))
} else {
throw QueryCompilationErrors.dataSourceDoesNotExist(name)
}
}

/**
* Checks if a data source with the specified name exists (case-insensitive).
*/
def dataSourceExists(name: String): Boolean = {
dataSourceBuilders.containsKey(normalize(name))
}
}
Loading