Skip to content

Commit 5497b9f

Browse files
committed
Fix
1 parent 82e8ec7 commit 5497b9f

File tree

2 files changed

+26
-23
lines changed

2 files changed

+26
-23
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,11 @@ class PythonTableProvider(shortName: String) extends TableProvider {
5656
schema: StructType,
5757
partitioning: Array[Transform],
5858
properties: java.util.Map[String, String]): Table = {
59+
assert(partitioning.isEmpty)
5960
new PythonTable(shortName, source, schema)
6061
}
62+
63+
override def supportsExternalMetadata(): Boolean = true
6164
}
6265

6366
class PythonTable(shortName: String, source: UserDefinedPythonDataSource, givenSchema: StructType)
@@ -70,8 +73,8 @@ class PythonTable(shortName: String, source: UserDefinedPythonDataSource, givenS
7073
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
7174
new ScanBuilder with Batch with Scan {
7275

73-
private lazy val pythonFunc: PythonFunction = source.createPythonFunction(
74-
shortName, options, Some(givenSchema))
76+
private lazy val pythonFunc: PythonFunction =
77+
source.createPythonFunction(shortName, options, givenSchema)
7578

7679
private lazy val info: PythonDataSourceReadInfo =
7780
new UserDefinedPythonDataSourceReadRunner(
@@ -163,37 +166,32 @@ class PythonPartitionReaderFactory(
163166
*/
164167
case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
165168

166-
private var pythonResult: PythonDataSourceCreationResult = _
167-
168-
private def getOrCreatePythonResult(
169+
private def createPythonResult(
169170
shortName: String,
170171
options: CaseInsensitiveStringMap,
171172
userSpecifiedSchema: Option[StructType]): PythonDataSourceCreationResult = {
172-
if (pythonResult != null) return pythonResult
173-
val runner = new UserDefinedPythonDataSourceRunner(
173+
new UserDefinedPythonDataSourceRunner(
174174
dataSourceCls,
175175
shortName,
176176
userSpecifiedSchema,
177-
CaseInsensitiveMap(options.asCaseSensitiveMap().asScala.toMap))
178-
pythonResult = runner.runInPython()
179-
pythonResult
177+
CaseInsensitiveMap(options.asCaseSensitiveMap().asScala.toMap)).runInPython()
180178
}
181179

182180
def inferSchema(
183181
shortName: String,
184182
options: CaseInsensitiveStringMap): StructType = {
185-
getOrCreatePythonResult(shortName, options, None).schema
183+
createPythonResult(shortName, options, None).schema
186184
}
187185

188186
def createPythonFunction(
189187
shortName: String,
190188
options: CaseInsensitiveStringMap,
191-
userSpecifiedSchema: Option[StructType]): PythonFunction = {
192-
val pickledDataSourceInstance = getOrCreatePythonResult(
193-
shortName, options, userSpecifiedSchema).dataSource
189+
givenSchema: StructType): PythonFunction = {
190+
val dataSource = createPythonResult(
191+
shortName, options, Some(givenSchema)).dataSource
194192

195193
SimplePythonFunction(
196-
command = pickledDataSourceInstance.toImmutableArraySeq,
194+
command = dataSource.toImmutableArraySeq,
197195
envVars = dataSourceCls.envVars,
198196
pythonIncludes = dataSourceCls.pythonIncludes,
199197
pythonExec = dataSourceCls.pythonExec,

sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.execution.python
1919

2020
import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row}
21+
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
2122
import org.apache.spark.sql.test.SharedSparkSession
2223
import org.apache.spark.sql.types.StructType
2324

@@ -55,6 +56,10 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
5556
val df = spark.read.format(dataSourceName).schema(schema).load()
5657
assert(df.rdd.getNumPartitions == 2)
5758
val plan = df.queryExecution.optimizedPlan
59+
plan match {
60+
case s: DataSourceV2ScanRelation if s.relation.table.isInstanceOf[PythonTable] =>
61+
case _ => fail(s"Plan did not match the expected pattern. Actual plan:\n$plan")
62+
}
5863
checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1)))
5964
}
6065

@@ -164,12 +169,12 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
164169
| paths = []
165170
| return [InputPartition(p) for p in paths]
166171
|
167-
| def read(self, path):
168-
| if path is not None:
169-
| assert isinstance(path, InputPartition)
170-
| yield (path.value, 1)
172+
| def read(self, part):
173+
| if part is not None:
174+
| assert isinstance(part, InputPartition)
175+
| yield (part.value, 1)
171176
| else:
172-
| yield (path, 1)
177+
| yield (part, 1)
173178
|
174179
|class $dataSourceName(DataSource):
175180
| @classmethod
@@ -256,7 +261,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
256261
}
257262

258263
test("data source read with custom partitions") {
259-
assume(shouldTestPythonUDFs)
264+
assume(shouldTestPandasUDFs)
260265
val dataSourceScript =
261266
s"""
262267
|from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
@@ -288,7 +293,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
288293
}
289294

290295
test("data source read with empty partitions") {
291-
assume(shouldTestPythonUDFs)
296+
assume(shouldTestPandasUDFs)
292297
val dataSourceScript =
293298
s"""
294299
|from pyspark.sql.datasource import DataSource, DataSourceReader
@@ -316,7 +321,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
316321
}
317322

318323
test("data source read with invalid partitions") {
319-
assume(shouldTestPythonUDFs)
324+
assume(shouldTestPandasUDFs)
320325
val reader1 =
321326
s"""
322327
|class SimpleDataSourceReader(DataSourceReader):

0 commit comments

Comments
 (0)