Skip to content

Commit 25ee62e

Browse files
allisonwang-dbdongjoon-hyun
authored andcommitted
[SPARK-45927][PYTHON] Update path handling for Python data source
### What changes were proposed in this pull request? This PR updates how to handle `path` values from the `load()` method. It changes the DataSource class constructor and add `path` as a key-value pair in the options field. Also, this PR blocks loading multiple paths. ### Why are the changes needed? To make the behavior consistent with the existing data source APIs. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43809 from allisonwang-db/spark-45927-path. Authored-by: allisonwang-db <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 5ded567 commit 25ee62e

File tree

8 files changed

+34
-70
lines changed

8 files changed

+34
-70
lines changed

python/pyspark/sql/datasource.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
#
1717
from abc import ABC, abstractmethod
18-
from typing import final, Any, Dict, Iterator, List, Optional, Tuple, Type, Union, TYPE_CHECKING
18+
from typing import final, Any, Dict, Iterator, List, Tuple, Type, Union, TYPE_CHECKING
1919

2020
from pyspark import since
2121
from pyspark.sql import Row
@@ -45,30 +45,19 @@ class DataSource(ABC):
4545
"""
4646

4747
@final
48-
def __init__(
49-
self,
50-
paths: List[str],
51-
userSpecifiedSchema: Optional[StructType],
52-
options: Dict[str, "OptionalPrimitiveType"],
53-
) -> None:
48+
def __init__(self, options: Dict[str, "OptionalPrimitiveType"]) -> None:
5449
"""
55-
Initializes the data source with user-provided information.
50+
Initializes the data source with user-provided options.
5651
5752
Parameters
5853
----------
59-
paths : list
60-
A list of paths to the data source.
61-
userSpecifiedSchema : StructType, optional
62-
The user-specified schema of the data source.
6354
options : dict
6455
A dictionary representing the options for this data source.
6556
6657
Notes
6758
-----
6859
This method should not be overridden.
6960
"""
70-
self.paths = paths
71-
self.userSpecifiedSchema = userSpecifiedSchema
7261
self.options = options
7362

7463
@classmethod

python/pyspark/sql/tests/test_python_datasource.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class MyDataSource(DataSource):
3030
...
3131

3232
options = dict(a=1, b=2)
33-
ds = MyDataSource(paths=[], userSpecifiedSchema=None, options=options)
33+
ds = MyDataSource(options=options)
3434
self.assertEqual(ds.options, options)
3535
self.assertEqual(ds.name(), "MyDataSource")
3636
with self.assertRaises(NotImplementedError):
@@ -53,8 +53,7 @@ def test_in_memory_data_source(self):
5353
class InMemDataSourceReader(DataSourceReader):
5454
DEFAULT_NUM_PARTITIONS: int = 3
5555

56-
def __init__(self, paths, options):
57-
self.paths = paths
56+
def __init__(self, options):
5857
self.options = options
5958

6059
def partitions(self):
@@ -76,7 +75,7 @@ def schema(self):
7675
return "x INT, y STRING"
7776

7877
def reader(self, schema) -> "DataSourceReader":
79-
return InMemDataSourceReader(self.paths, self.options)
78+
return InMemDataSourceReader(self.options)
8079

8180
self.spark.dataSource.register(InMemoryDataSource)
8281
df = self.spark.read.format("memory").load()
@@ -91,14 +90,13 @@ def test_custom_json_data_source(self):
9190
import json
9291

9392
class JsonDataSourceReader(DataSourceReader):
94-
def __init__(self, paths, options):
95-
self.paths = paths
93+
def __init__(self, options):
9694
self.options = options
9795

98-
def partitions(self):
99-
return iter(self.paths)
100-
101-
def read(self, path):
96+
def read(self, partition):
97+
path = self.options.get("path")
98+
if path is None:
99+
raise Exception("path is not specified")
102100
with open(path, "r") as file:
103101
for line in file.readlines():
104102
if line.strip():
@@ -114,28 +112,18 @@ def schema(self):
114112
return "name STRING, age INT"
115113

116114
def reader(self, schema) -> "DataSourceReader":
117-
return JsonDataSourceReader(self.paths, self.options)
115+
return JsonDataSourceReader(self.options)
118116

119117
self.spark.dataSource.register(JsonDataSource)
120118
path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
121119
path2 = os.path.join(SPARK_HOME, "python/test_support/sql/people1.json")
122-
df1 = self.spark.read.format("my-json").load(path1)
123-
self.assertEqual(df1.rdd.getNumPartitions(), 1)
124120
assertDataFrameEqual(
125-
df1,
121+
self.spark.read.format("my-json").load(path1),
126122
[Row(name="Michael", age=None), Row(name="Andy", age=30), Row(name="Justin", age=19)],
127123
)
128-
129-
df2 = self.spark.read.format("my-json").load([path1, path2])
130-
self.assertEqual(df2.rdd.getNumPartitions(), 2)
131124
assertDataFrameEqual(
132-
df2,
133-
[
134-
Row(name="Michael", age=None),
135-
Row(name="Andy", age=30),
136-
Row(name="Justin", age=19),
137-
Row(name="Jonathan", age=None),
138-
],
125+
self.spark.read.format("my-json").load(path2),
126+
[Row(name="Jonathan", age=None)],
139127
)
140128

141129

python/pyspark/sql/worker/create_data_source.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import inspect
1818
import os
1919
import sys
20-
from typing import IO, List
20+
from typing import IO
2121

2222
from pyspark.accumulators import _accumulatorRegistry
2323
from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, PySparkTypeError
@@ -55,7 +55,6 @@ def main(infile: IO, outfile: IO) -> None:
5555
The JVM sends the following information to this process:
5656
- a `DataSource` class representing the data source to be created.
5757
- a provider name in string.
58-
- a list of paths in string.
5958
- an optional user-specified schema in json string.
6059
- a dictionary of options in string.
6160
@@ -107,12 +106,6 @@ def main(infile: IO, outfile: IO) -> None:
107106
},
108107
)
109108

110-
# Receive the paths.
111-
num_paths = read_int(infile)
112-
paths: List[str] = []
113-
for _ in range(num_paths):
114-
paths.append(utf8_deserializer.loads(infile))
115-
116109
# Receive the user-specified schema
117110
user_specified_schema = None
118111
if read_bool(infile):
@@ -136,11 +129,7 @@ def main(infile: IO, outfile: IO) -> None:
136129

137130
# Instantiate a data source.
138131
try:
139-
data_source = data_source_cls(
140-
paths=paths,
141-
userSpecifiedSchema=user_specified_schema, # type: ignore
142-
options=options,
143-
)
132+
data_source = data_source_cls(options=options)
144133
except Exception as e:
145134
raise PySparkRuntimeError(
146135
error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",

sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
244244

245245
private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
246246
val builder = sparkSession.sharedState.dataSourceManager.lookupDataSource(source)
247-
// Unless the legacy path option behavior is enabled, the extraOptions here
248-
// should not include "path" or "paths" as keys.
249-
val plan = builder(sparkSession, source, paths, userSpecifiedSchema, extraOptions)
247+
// Add `path` and `paths` options to the extra options if specified.
248+
val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions, paths: _*)
249+
val plan = builder(sparkSession, source, userSpecifiedSchema, optionsWithPath)
250250
Dataset.ofRows(sparkSession, plan)
251251
}
252252

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class DataSourceManager {
3535
private type DataSourceBuilder = (
3636
SparkSession, // Spark session
3737
String, // provider name
38-
Seq[String], // paths
3938
Option[StructType], // user specified schema
4039
CaseInsensitiveMap[String] // options
4140
) => LogicalPlan

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ private[sql] object DataSourceV2Utils extends Logging {
152152
}
153153

154154
private lazy val objectMapper = new ObjectMapper()
155-
private def getOptionsWithPaths(
155+
def getOptionsWithPaths(
156156
extraOptions: CaseInsensitiveMap[String],
157157
paths: String*): CaseInsensitiveMap[String] = {
158158
if (paths.isEmpty) {

sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,11 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
4242
def builder(
4343
sparkSession: SparkSession,
4444
provider: String,
45-
paths: Seq[String],
4645
userSpecifiedSchema: Option[StructType],
4746
options: CaseInsensitiveMap[String]): LogicalPlan = {
4847

4948
val runner = new UserDefinedPythonDataSourceRunner(
50-
dataSourceCls, provider, paths, userSpecifiedSchema, options)
49+
dataSourceCls, provider, userSpecifiedSchema, options)
5150

5251
val result = runner.runInPython()
5352
val pickledDataSourceInstance = result.dataSource
@@ -68,10 +67,9 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
6867
def apply(
6968
sparkSession: SparkSession,
7069
provider: String,
71-
paths: Seq[String] = Seq.empty,
7270
userSpecifiedSchema: Option[StructType] = None,
7371
options: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map.empty)): DataFrame = {
74-
val plan = builder(sparkSession, provider, paths, userSpecifiedSchema, options)
72+
val plan = builder(sparkSession, provider, userSpecifiedSchema, options)
7573
Dataset.ofRows(sparkSession, plan)
7674
}
7775
}
@@ -89,7 +87,6 @@ case class PythonDataSourceCreationResult(
8987
class UserDefinedPythonDataSourceRunner(
9088
dataSourceCls: PythonFunction,
9189
provider: String,
92-
paths: Seq[String],
9390
userSpecifiedSchema: Option[StructType],
9491
options: CaseInsensitiveMap[String])
9592
extends PythonPlannerRunner[PythonDataSourceCreationResult](dataSourceCls) {
@@ -103,10 +100,6 @@ class UserDefinedPythonDataSourceRunner(
103100
// Send the provider name
104101
PythonWorkerUtils.writeUTF(provider, dataOut)
105102

106-
// Send the paths
107-
dataOut.writeInt(paths.length)
108-
paths.foreach(PythonWorkerUtils.writeUTF(_, dataOut))
109-
110103
// Send the user-specified schema, if provided
111104
dataOut.writeBoolean(userSpecifiedSchema.isDefined)
112105
userSpecifiedSchema.map(_.json).foreach(PythonWorkerUtils.writeUTF(_, dataOut))

sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,13 +160,20 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
160160
val dataSourceScript =
161161
s"""
162162
|from pyspark.sql.datasource import DataSource, DataSourceReader
163+
|import json
164+
|
163165
|class SimpleDataSourceReader(DataSourceReader):
164-
| def __init__(self, paths, options):
165-
| self.paths = paths
166+
| def __init__(self, options):
166167
| self.options = options
167168
|
168169
| def partitions(self):
169-
| return iter(self.paths)
170+
| if "paths" in self.options:
171+
| paths = json.loads(self.options["paths"])
172+
| elif "path" in self.options:
173+
| paths = [self.options["path"]]
174+
| else:
175+
| paths = []
176+
| return paths
170177
|
171178
| def read(self, path):
172179
| yield (path, 1)
@@ -180,11 +187,10 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
180187
| return "id STRING, value INT"
181188
|
182189
| def reader(self, schema):
183-
| return SimpleDataSourceReader(self.paths, self.options)
190+
| return SimpleDataSourceReader(self.options)
184191
|""".stripMargin
185192
val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript)
186193
spark.dataSource.registerPython("test", dataSource)
187-
188194
checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1)))
189195
checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1)))
190196
checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1), Row("2", 1)))

0 commit comments

Comments
 (0)