Skip to content

Commit 9d93b71

Browse files
allisonwang-dbHyukjinKwon
authored andcommitted
[SPARK-45639][SQL][PYTHON] Support loading Python data sources in DataFrameReader
### What changes were proposed in this pull request? This PR supports `spark.read.format(...).load()` for Python data sources. After this PR, users can use a Python data source directly like this: ```python from pyspark.sql.datasource import DataSource, DataSourceReader class MyReader(DataSourceReader): def read(self, partition): yield (0, 1) class MyDataSource(DataSource): classmethod def name(cls): return "my-source" def schema(self): return "id INT, value INT" def reader(self, schema): return MyReader() spark.dataSource.register(MyDataSource) df = spark.read.format("my-source").load() df.show() +---+-----+ | id|value| +---+-----+ | 0| 1| +---+-----+ ``` ### Why are the changes needed? To support Python data sources. ### Does this PR introduce _any_ user-facing change? Yes. After this PR, users can load a custom Python data source using `spark.read.format(...).load()`. ### How was this patch tested? New unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43630 from allisonwang-db/spark-45639-ds-lookup. Authored-by: allisonwang-db <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent e331de0 commit 9d93b71

File tree

11 files changed

+255
-28
lines changed

11 files changed

+255
-28
lines changed

common/utils/src/main/resources/error/error-classes.json

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,12 @@
850850
],
851851
"sqlState" : "42710"
852852
},
853+
"DATA_SOURCE_NOT_EXIST" : {
854+
"message" : [
855+
"Data source '<provider>' not found. Please make sure the data source is registered."
856+
],
857+
"sqlState" : "42704"
858+
},
853859
"DATA_SOURCE_NOT_FOUND" : {
854860
"message" : [
855861
"Failed to find the data source: <provider>. Please find packages at `https://spark.apache.org/third-party-projects.html`."
@@ -1095,6 +1101,12 @@
10951101
],
10961102
"sqlState" : "42809"
10971103
},
1104+
"FOUND_MULTIPLE_DATA_SOURCES" : {
1105+
"message" : [
1106+
"Detected multiple data sources with the name '<provider>'. Please check the data source isn't simultaneously registered and located in the classpath."
1107+
],
1108+
"sqlState" : "42710"
1109+
},
10981110
"GENERATED_COLUMN_WITH_DEFAULT_VALUE" : {
10991111
"message" : [
11001112
"A column cannot have both a default value and a generation expression but column <colName> has default value: (<defaultValue>) and generation expression: (<genExpr>)."

dev/sparktestsupport/modules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ def __hash__(self):
511511
"pyspark.sql.tests.pandas.test_pandas_udf_window",
512512
"pyspark.sql.tests.pandas.test_converter",
513513
"pyspark.sql.tests.test_pandas_sqlmetrics",
514+
"pyspark.sql.tests.test_python_datasource",
514515
"pyspark.sql.tests.test_readwriter",
515516
"pyspark.sql.tests.test_serde",
516517
"pyspark.sql.tests.test_session",

docs/sql-error-conditions.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,12 @@ DataType `<type>` requires a length parameter, for example `<type>`(10). Please
454454

455455
Data source '`<provider>`' already exists in the registry. Please use a different name for the new data source.
456456

457+
### DATA_SOURCE_NOT_EXIST
458+
459+
[SQLSTATE: 42704](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
460+
461+
Data source '`<provider>`' not found. Please make sure the data source is registered.
462+
457463
### DATA_SOURCE_NOT_FOUND
458464

459465
[SQLSTATE: 42K02](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
@@ -669,6 +675,12 @@ No such struct field `<fieldName>` in `<fields>`.
669675

670676
The operation `<statement>` is not allowed on the `<objectType>`: `<objectName>`.
671677

678+
### FOUND_MULTIPLE_DATA_SOURCES
679+
680+
[SQLSTATE: 42710](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
681+
682+
Detected multiple data sources with the name '`<provider>`'. Please check the data source isn't simultaneously registered and located in the classpath.
683+
672684
### GENERATED_COLUMN_WITH_DEFAULT_VALUE
673685

674686
[SQLSTATE: 42623](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)

python/pyspark/sql/session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,10 @@ def dataSource(self) -> "DataSourceRegistration":
884884
Returns
885885
-------
886886
:class:`DataSourceRegistration`
887+
888+
Notes
889+
-----
890+
This feature is experimental and unstable.
887891
"""
888892
from pyspark.sql.datasource import DataSourceRegistration
889893

python/pyspark/sql/tests/test_python_datasource.py

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
import os
1718
import unittest
1819

1920
from pyspark.sql.datasource import DataSource, DataSourceReader
21+
from pyspark.sql.types import Row
22+
from pyspark.testing import assertDataFrameEqual
2023
from pyspark.testing.sqlutils import ReusedSQLTestCase
24+
from pyspark.testing.utils import SPARK_HOME
2125

2226

2327
class BasePythonDataSourceTestsMixin:
@@ -45,16 +49,93 @@ def read(self, partition):
4549
self.assertEqual(list(reader.partitions()), [None])
4650
self.assertEqual(list(reader.read(None)), [(None,)])
4751

48-
def test_register_data_source(self):
49-
class MyDataSource(DataSource):
50-
...
52+
def test_in_memory_data_source(self):
53+
class InMemDataSourceReader(DataSourceReader):
54+
DEFAULT_NUM_PARTITIONS: int = 3
55+
56+
def __init__(self, paths, options):
57+
self.paths = paths
58+
self.options = options
59+
60+
def partitions(self):
61+
if "num_partitions" in self.options:
62+
num_partitions = int(self.options["num_partitions"])
63+
else:
64+
num_partitions = self.DEFAULT_NUM_PARTITIONS
65+
return range(num_partitions)
66+
67+
def read(self, partition):
68+
yield partition, str(partition)
69+
70+
class InMemoryDataSource(DataSource):
71+
@classmethod
72+
def name(cls):
73+
return "memory"
74+
75+
def schema(self):
76+
return "x INT, y STRING"
77+
78+
def reader(self, schema) -> "DataSourceReader":
79+
return InMemDataSourceReader(self.paths, self.options)
80+
81+
self.spark.dataSource.register(InMemoryDataSource)
82+
df = self.spark.read.format("memory").load()
83+
self.assertEqual(df.rdd.getNumPartitions(), 3)
84+
assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1"), Row(x=2, y="2")])
5185

52-
self.spark.dataSource.register(MyDataSource)
86+
df = self.spark.read.format("memory").option("num_partitions", 2).load()
87+
assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")])
88+
self.assertEqual(df.rdd.getNumPartitions(), 2)
89+
90+
def test_custom_json_data_source(self):
91+
import json
92+
93+
class JsonDataSourceReader(DataSourceReader):
94+
def __init__(self, paths, options):
95+
self.paths = paths
96+
self.options = options
97+
98+
def partitions(self):
99+
return iter(self.paths)
100+
101+
def read(self, path):
102+
with open(path, "r") as file:
103+
for line in file.readlines():
104+
if line.strip():
105+
data = json.loads(line)
106+
yield data.get("name"), data.get("age")
107+
108+
class JsonDataSource(DataSource):
109+
@classmethod
110+
def name(cls):
111+
return "my-json"
112+
113+
def schema(self):
114+
return "name STRING, age INT"
115+
116+
def reader(self, schema) -> "DataSourceReader":
117+
return JsonDataSourceReader(self.paths, self.options)
118+
119+
self.spark.dataSource.register(JsonDataSource)
120+
path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
121+
path2 = os.path.join(SPARK_HOME, "python/test_support/sql/people1.json")
122+
df1 = self.spark.read.format("my-json").load(path1)
123+
self.assertEqual(df1.rdd.getNumPartitions(), 1)
124+
assertDataFrameEqual(
125+
df1,
126+
[Row(name="Michael", age=None), Row(name="Andy", age=30), Row(name="Justin", age=19)],
127+
)
53128

54-
self.assertTrue(
55-
self.spark._jsparkSession.sharedState()
56-
.dataSourceRegistry()
57-
.dataSourceExists("MyDataSource")
129+
df2 = self.spark.read.format("my-json").load([path1, path2])
130+
self.assertEqual(df2.rdd.getNumPartitions(), 2)
131+
assertDataFrameEqual(
132+
df2,
133+
[
134+
Row(name="Michael", age=None),
135+
Row(name="Andy", age=30),
136+
Row(name="Justin", age=19),
137+
Row(name="Jonathan", age=None),
138+
],
58139
)
59140

60141

python/pyspark/sql/worker/create_data_source.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17-
17+
import inspect
1818
import os
1919
import sys
2020
from typing import IO, List
2121

2222
from pyspark.accumulators import _accumulatorRegistry
23-
from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
23+
from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, PySparkTypeError
2424
from pyspark.java_gateway import local_connect_and_auth
2525
from pyspark.serializers import (
2626
read_bool,
@@ -84,8 +84,20 @@ def main(infile: IO, outfile: IO) -> None:
8484
},
8585
)
8686

87+
# Check the name method is a class method.
88+
if not inspect.ismethod(data_source_cls.name):
89+
raise PySparkTypeError(
90+
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
91+
message_parameters={
92+
"expected": "'name()' method to be a classmethod",
93+
"actual": f"'{type(data_source_cls.name).__name__}'",
94+
},
95+
)
96+
8797
# Receive the provider name.
8898
provider = utf8_deserializer.loads(infile)
99+
100+
# Check if the provider name matches the data source's name.
89101
if provider.lower() != data_source_cls.name().lower():
90102
raise PySparkAssertionError(
91103
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",

sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3805,4 +3805,16 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
38053805
errorClass = "DATA_SOURCE_ALREADY_EXISTS",
38063806
messageParameters = Map("provider" -> name))
38073807
}
3808+
3809+
def dataSourceDoesNotExist(name: String): Throwable = {
3810+
new AnalysisException(
3811+
errorClass = "DATA_SOURCE_NOT_EXIST",
3812+
messageParameters = Map("provider" -> name))
3813+
}
3814+
3815+
def foundMultipleDataSources(provider: String): Throwable = {
3816+
new AnalysisException(
3817+
errorClass = "FOUND_MULTIPLE_DATA_SOURCES",
3818+
messageParameters = Map("provider" -> provider))
3819+
}
38083820
}

sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
package org.apache.spark.sql
1919

20-
import java.util.{Locale, Properties}
20+
import java.util.{Locale, Properties, ServiceConfigurationError}
2121

2222
import scala.jdk.CollectionConverters._
23+
import scala.util.{Failure, Success, Try}
2324

24-
import org.apache.spark.Partition
25+
import org.apache.spark.{Partition, SparkClassNotFoundException, SparkThrowable}
2526
import org.apache.spark.annotation.Stable
2627
import org.apache.spark.api.java.JavaRDD
2728
import org.apache.spark.internal.Logging
@@ -208,10 +209,45 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
208209
throw QueryCompilationErrors.pathOptionNotSetCorrectlyWhenReadingError()
209210
}
210211

211-
DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf).flatMap { provider =>
212-
DataSourceV2Utils.loadV2Source(sparkSession, provider, userSpecifiedSchema, extraOptions,
213-
source, paths: _*)
214-
}.getOrElse(loadV1Source(paths: _*))
212+
val isUserDefinedDataSource =
213+
sparkSession.sharedState.dataSourceManager.dataSourceExists(source)
214+
215+
Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf)) match {
216+
case Success(providerOpt) =>
217+
// The source can be successfully loaded as either a V1 or a V2 data source.
218+
// Check if it is also a user-defined data source.
219+
if (isUserDefinedDataSource) {
220+
throw QueryCompilationErrors.foundMultipleDataSources(source)
221+
}
222+
providerOpt.flatMap { provider =>
223+
DataSourceV2Utils.loadV2Source(
224+
sparkSession, provider, userSpecifiedSchema, extraOptions, source, paths: _*)
225+
}.getOrElse(loadV1Source(paths: _*))
226+
case Failure(exception) =>
227+
// Exceptions are thrown while trying to load the data source as a V1 or V2 data source.
228+
// For the following not found exceptions, if the user-defined data source is defined,
229+
// we can instead return the user-defined data source.
230+
val isNotFoundError = exception match {
231+
case _: NoClassDefFoundError | _: SparkClassNotFoundException => true
232+
case e: SparkThrowable => e.getErrorClass == "DATA_SOURCE_NOT_FOUND"
233+
case e: ServiceConfigurationError => e.getCause.isInstanceOf[NoClassDefFoundError]
234+
case _ => false
235+
}
236+
if (isNotFoundError && isUserDefinedDataSource) {
237+
loadUserDefinedDataSource(paths)
238+
} else {
239+
// Throw the original exception.
240+
throw exception
241+
}
242+
}
243+
}
244+
245+
private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
246+
val builder = sparkSession.sharedState.dataSourceManager.lookupDataSource(source)
247+
// Unless the legacy path option behavior is enabled, the extraOptions here
248+
// should not include "path" or "paths" as keys.
249+
val plan = builder(sparkSession, source, paths, userSpecifiedSchema, extraOptions)
250+
Dataset.ofRows(sparkSession, plan)
215251
}
216252

217253
private def loadV1Source(paths: String*) = {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,33 +22,56 @@ import java.util.concurrent.ConcurrentHashMap
2222

2323
import org.apache.spark.sql.SparkSession
2424
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
25+
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
2526
import org.apache.spark.sql.errors.QueryCompilationErrors
2627
import org.apache.spark.sql.types.StructType
27-
import org.apache.spark.sql.util.CaseInsensitiveStringMap
2828

29+
/**
30+
* A manager for user-defined data sources. It is used to register and lookup data sources by
31+
* their short names or fully qualified names.
32+
*/
2933
class DataSourceManager {
3034

3135
private type DataSourceBuilder = (
3236
SparkSession, // Spark session
3337
String, // provider name
3438
Seq[String], // paths
3539
Option[StructType], // user specified schema
36-
CaseInsensitiveStringMap // options
40+
CaseInsensitiveMap[String] // options
3741
) => LogicalPlan
3842

3943
private val dataSourceBuilders = new ConcurrentHashMap[String, DataSourceBuilder]()
4044

4145
private def normalize(name: String): String = name.toLowerCase(Locale.ROOT)
4246

47+
/**
48+
* Register a data source builder for the given provider.
49+
* Note that the provider name is case-insensitive.
50+
*/
4351
def registerDataSource(name: String, builder: DataSourceBuilder): Unit = {
4452
val normalizedName = normalize(name)
4553
if (dataSourceBuilders.containsKey(normalizedName)) {
4654
throw QueryCompilationErrors.dataSourceAlreadyExists(name)
4755
}
48-
// TODO(SPARK-45639): check if the data source is a DSv1 or DSv2 using loadDataSource.
4956
dataSourceBuilders.put(normalizedName, builder)
5057
}
5158

52-
def dataSourceExists(name: String): Boolean =
59+
/**
60+
* Returns a data source builder for the given provider and throw an exception if
61+
* it does not exist.
62+
*/
63+
def lookupDataSource(name: String): DataSourceBuilder = {
64+
if (dataSourceExists(name)) {
65+
dataSourceBuilders.get(normalize(name))
66+
} else {
67+
throw QueryCompilationErrors.dataSourceDoesNotExist(name)
68+
}
69+
}
70+
71+
/**
72+
* Checks if a data source with the specified name exists (case-insensitive).
73+
*/
74+
def dataSourceExists(name: String): Boolean = {
5375
dataSourceBuilders.containsKey(normalize(name))
76+
}
5477
}

0 commit comments

Comments
 (0)