Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions python/pyspark/sql/tests/test_python_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,36 @@ def read(self, partition):
self.assertEqual(list(reader.partitions()), [None])
self.assertEqual(list(reader.read(None)), [(None,)])

def test_data_source_register(self):
class TestReader(DataSourceReader):
def read(self, partition):
yield (0, 1)

class TestDataSource(DataSource):
def schema(self):
return "a INT, b INT"

def reader(self, schema):
return TestReader()

self.spark.dataSource.register(TestDataSource)
df = self.spark.read.format("TestDataSource").load()
assertDataFrameEqual(df, [Row(a=0, b=1)])

class MyDataSource(TestDataSource):
@classmethod
def name(cls):
return "TestDataSource"

def schema(self):
return "c INT, d INT"

# Should be able to register the data source with the same name.
self.spark.dataSource.register(MyDataSource)

df = self.spark.read.format("TestDataSource").load()
assertDataFrameEqual(df, [Row(c=0, d=1)])

def test_in_memory_data_source(self):
class InMemDataSourceReader(DataSourceReader):
DEFAULT_NUM_PARTITIONS: int = 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}

val isUserDefinedDataSource =
sparkSession.sharedState.dataSourceManager.dataSourceExists(source)
sparkSession.sessionState.dataSourceManager.dataSourceExists(source)

Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf)) match {
case Success(providerOpt) =>
Expand Down Expand Up @@ -243,7 +243,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}

private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
val builder = sparkSession.sharedState.dataSourceManager.lookupDataSource(source)
val builder = sparkSession.sessionState.dataSourceManager.lookupDataSource(source)
// Add `path` and `paths` options to the extra options if specified.
val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions, paths: _*)
val plan = builder(sparkSession, source, userSpecifiedSchema, optionsWithPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ class SparkSession private(
/**
* A collection of methods for registering user-defined data sources.
*/
private[sql] def dataSource: DataSourceRegistration = sharedState.dataSourceRegistration
private[sql] def dataSource: DataSourceRegistration = sessionState.dataSourceRegistration

/**
* Returns a `StreamingQueryManager` that allows managing all the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources
import java.util.Locale
import java.util.concurrent.ConcurrentHashMap

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
Expand All @@ -30,7 +31,7 @@ import org.apache.spark.sql.types.StructType
* A manager for user-defined data sources. It is used to register and lookup data sources by
* their short names or fully qualified names.
*/
class DataSourceManager {
class DataSourceManager extends Logging {

private type DataSourceBuilder = (
SparkSession, // Spark session
Expand All @@ -49,10 +50,10 @@ class DataSourceManager {
*/
def registerDataSource(name: String, builder: DataSourceBuilder): Unit = {
val normalizedName = normalize(name)
if (dataSourceBuilders.containsKey(normalizedName)) {
throw QueryCompilationErrors.dataSourceAlreadyExists(name)
val previousValue = dataSourceBuilders.put(normalizedName, builder)
if (previousValue != null) {
logWarning(f"The data source $name replaced a previously registered data source.")
}
dataSourceBuilders.put(normalizedName, builder)
}

/**
Expand All @@ -73,4 +74,10 @@ class DataSourceManager {
def dataSourceExists(name: String): Boolean = {
dataSourceBuilders.containsKey(normalize(name))
}

override def clone(): DataSourceManager = {
val manager = new DataSourceManager
dataSourceBuilders.forEach((k, v) => manager.registerDataSource(k, v))
manager
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ abstract class BaseSessionStateBuilder(
.getOrElse(extensions.registerTableFunctions(TableFunctionRegistry.builtin.clone()))
}

/**
* Manages the registration of data sources
*/
protected lazy val dataSourceManager: DataSourceManager = {
parentState.map(_.dataSourceManager.clone()).getOrElse(new DataSourceManager)
}

/**
* Experimental methods that can be used to define custom optimization rules and custom planning
* strategies.
Expand Down Expand Up @@ -178,6 +185,12 @@ abstract class BaseSessionStateBuilder(

protected def udtfRegistration: UDTFRegistration = new UDTFRegistration(tableFunctionRegistry)

/**
* A collection of method used for registering user-defined data sources.
*/
protected def dataSourceRegistration: DataSourceRegistration =
new DataSourceRegistration(dataSourceManager)

/**
* Logical query plan analyzer for resolving unresolved attributes and relations.
*
Expand Down Expand Up @@ -376,6 +389,8 @@ abstract class BaseSessionStateBuilder(
tableFunctionRegistry,
udfRegistration,
udtfRegistration,
dataSourceManager,
dataSourceRegistration,
() => catalog,
sqlParser,
() => analyzer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.AdaptiveRulesHolder
import org.apache.spark.sql.execution.datasources.DataSourceManager
import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.util.ExecutionListenerManager
import org.apache.spark.util.{DependencyUtils, Utils}
Expand All @@ -49,6 +50,8 @@ import org.apache.spark.util.{DependencyUtils, Utils}
* @param udfRegistration Interface exposed to the user for registering user-defined functions.
* @param udtfRegistration Interface exposed to the user for registering user-defined
* table functions.
* @param dataSourceManager Internal catalog for managing data sources registered by users.
* @param dataSourceRegistration Interface exposed to users for registering data sources.
* @param catalogBuilder a function to create an internal catalog for managing table and database
* states.
* @param sqlParser Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
Expand All @@ -73,6 +76,8 @@ private[sql] class SessionState(
val tableFunctionRegistry: TableFunctionRegistry,
val udfRegistration: UDFRegistration,
val udtfRegistration: UDTFRegistration,
val dataSourceManager: DataSourceManager,
val dataSourceRegistration: DataSourceRegistration,
catalogBuilder: () => SessionCatalog,
val sqlParser: ParserInterface,
analyzerBuilder: () => Analyzer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@ import org.apache.hadoop.fs.{FsUrlStreamHandlerFactory, Path}

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataSourceRegistration
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.CacheManager
import org.apache.spark.sql.execution.datasources.DataSourceManager
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.execution.ui.{SQLAppStatusListener, SQLAppStatusStore, SQLTab, StreamingQueryStatusStore}
import org.apache.spark.sql.internal.StaticSQLConf._
Expand Down Expand Up @@ -107,16 +105,6 @@ private[sql] class SharedState(
@GuardedBy("activeQueriesLock")
private[sql] val activeStreamingQueries = new ConcurrentHashMap[UUID, StreamExecution]()

/**
* A data source manager shared by all sessions.
*/
lazy val dataSourceManager = new DataSourceManager()

/**
* A collection of method used for registering user-defined data sources.
*/
lazy val dataSourceRegistration = new DataSourceRegistration(dataSourceManager)

/**
* A status store to query SQL status/metrics of this Spark application, based on SQL-specific
* [[org.apache.spark.scheduler.SparkListenerEvent]]s.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.python

import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row}
import org.apache.spark.sql.catalyst.plans.logical.{BatchEvalPythonUDTF, PythonDataSourcePartitions}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -143,16 +144,35 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {

val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript)
spark.dataSource.registerPython(dataSourceName, dataSource)
assert(spark.sharedState.dataSourceManager.dataSourceExists(dataSourceName))
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
val ds1 = spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName)
checkAnswer(
ds1(spark, dataSourceName, None, CaseInsensitiveMap(Map.empty)),
Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1)))

// Check error when registering a data source with the same name.
val err = intercept[AnalysisException] {
spark.dataSource.registerPython(dataSourceName, dataSource)
}
checkError(
exception = err,
errorClass = "DATA_SOURCE_ALREADY_EXISTS",
parameters = Map("provider" -> dataSourceName))
// Should be able to override an already registered data source.
val newScript =
s"""
|from pyspark.sql.datasource import DataSource, DataSourceReader
|class SimpleDataSourceReader(DataSourceReader):
| def read(self, partition):
| yield (0, )
|
|class $dataSourceName(DataSource):
| def schema(self) -> str:
| return "id INT"
|
| def reader(self, schema):
| return SimpleDataSourceReader()
|""".stripMargin
val newDataSource = createUserDefinedPythonDataSource(dataSourceName, newScript)
spark.dataSource.registerPython(dataSourceName, newDataSource)
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))

val ds2 = spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName)
checkAnswer(
ds2(spark, dataSourceName, None, CaseInsensitiveMap(Map.empty)),
Seq(Row(0)))
}

test("load data source") {
Expand Down