|
14 | 14 | # See the License for the specific language governing permissions and |
15 | 15 | # limitations under the License. |
16 | 16 | # |
| 17 | +import os |
17 | 18 | import unittest |
18 | 19 |
|
19 | 20 | from pyspark.sql.datasource import DataSource, DataSourceReader |
| 21 | +from pyspark.sql.types import Row |
| 22 | +from pyspark.testing import assertDataFrameEqual |
20 | 23 | from pyspark.testing.sqlutils import ReusedSQLTestCase |
| 24 | +from pyspark.testing.utils import SPARK_HOME |
21 | 25 |
|
22 | 26 |
|
23 | 27 | class BasePythonDataSourceTestsMixin: |
@@ -45,16 +49,93 @@ def read(self, partition): |
45 | 49 | self.assertEqual(list(reader.partitions()), [None]) |
46 | 50 | self.assertEqual(list(reader.read(None)), [(None,)]) |
47 | 51 |
|
48 | | - def test_register_data_source(self): |
49 | | - class MyDataSource(DataSource): |
50 | | - ... |
| 52 | + def test_in_memory_data_source(self): |
| 53 | + class InMemDataSourceReader(DataSourceReader): |
| 54 | + DEFAULT_NUM_PARTITIONS: int = 3 |
| 55 | + |
| 56 | + def __init__(self, paths, options): |
| 57 | + self.paths = paths |
| 58 | + self.options = options |
| 59 | + |
| 60 | + def partitions(self): |
| 61 | + if "num_partitions" in self.options: |
| 62 | + num_partitions = int(self.options["num_partitions"]) |
| 63 | + else: |
| 64 | + num_partitions = self.DEFAULT_NUM_PARTITIONS |
| 65 | + return range(num_partitions) |
| 66 | + |
| 67 | + def read(self, partition): |
| 68 | + yield partition, str(partition) |
| 69 | + |
| 70 | + class InMemoryDataSource(DataSource): |
| 71 | + @classmethod |
| 72 | + def name(cls): |
| 73 | + return "memory" |
| 74 | + |
| 75 | + def schema(self): |
| 76 | + return "x INT, y STRING" |
| 77 | + |
| 78 | + def reader(self, schema) -> "DataSourceReader": |
| 79 | + return InMemDataSourceReader(self.paths, self.options) |
| 80 | + |
| 81 | + self.spark.dataSource.register(InMemoryDataSource) |
| 82 | + df = self.spark.read.format("memory").load() |
| 83 | + self.assertEqual(df.rdd.getNumPartitions(), 3) |
| 84 | + assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1"), Row(x=2, y="2")]) |
51 | 85 |
|
52 | | - self.spark.dataSource.register(MyDataSource) |
| 86 | + df = self.spark.read.format("memory").option("num_partitions", 2).load() |
| 87 | + assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")]) |
| 88 | + self.assertEqual(df.rdd.getNumPartitions(), 2) |
| 89 | + |
| 90 | + def test_custom_json_data_source(self): |
| 91 | + import json |
| 92 | + |
| 93 | + class JsonDataSourceReader(DataSourceReader): |
| 94 | + def __init__(self, paths, options): |
| 95 | + self.paths = paths |
| 96 | + self.options = options |
| 97 | + |
| 98 | + def partitions(self): |
| 99 | + return iter(self.paths) |
| 100 | + |
| 101 | + def read(self, path): |
| 102 | + with open(path, "r") as file: |
| 103 | + for line in file.readlines(): |
| 104 | + if line.strip(): |
| 105 | + data = json.loads(line) |
| 106 | + yield data.get("name"), data.get("age") |
| 107 | + |
| 108 | + class JsonDataSource(DataSource): |
| 109 | + @classmethod |
| 110 | + def name(cls): |
| 111 | + return "my-json" |
| 112 | + |
| 113 | + def schema(self): |
| 114 | + return "name STRING, age INT" |
| 115 | + |
| 116 | + def reader(self, schema) -> "DataSourceReader": |
| 117 | + return JsonDataSourceReader(self.paths, self.options) |
| 118 | + |
| 119 | + self.spark.dataSource.register(JsonDataSource) |
| 120 | + path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json") |
| 121 | + path2 = os.path.join(SPARK_HOME, "python/test_support/sql/people1.json") |
| 122 | + df1 = self.spark.read.format("my-json").load(path1) |
| 123 | + self.assertEqual(df1.rdd.getNumPartitions(), 1) |
| 124 | + assertDataFrameEqual( |
| 125 | + df1, |
| 126 | + [Row(name="Michael", age=None), Row(name="Andy", age=30), Row(name="Justin", age=19)], |
| 127 | + ) |
53 | 128 |
|
54 | | - self.assertTrue( |
55 | | - self.spark._jsparkSession.sharedState() |
56 | | - .dataSourceRegistry() |
57 | | - .dataSourceExists("MyDataSource") |
| 129 | + df2 = self.spark.read.format("my-json").load([path1, path2]) |
| 130 | + self.assertEqual(df2.rdd.getNumPartitions(), 2) |
| 131 | + assertDataFrameEqual( |
| 132 | + df2, |
| 133 | + [ |
| 134 | + Row(name="Michael", age=None), |
| 135 | + Row(name="Andy", age=30), |
| 136 | + Row(name="Justin", age=19), |
| 137 | + Row(name="Jonathan", age=None), |
| 138 | + ], |
58 | 139 | ) |
59 | 140 |
|
60 | 141 |
|
|
0 commit comments