Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.api.python.PythonFunction
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF, PythonUDTF}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
import org.apache.spark.sql.types.{BinaryType, StructType}
import org.apache.spark.sql.types.StructType

/**
* FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame.
Expand Down Expand Up @@ -103,42 +101,6 @@ case class PythonMapInArrow(
copy(child = newChild)
}

/**
* Represents a Python data source.
*/
case class PythonDataSource(
dataSource: PythonFunction,
outputSchema: StructType,
override val output: Seq[Attribute]) extends LeafNode {
require(output.forall(_.resolved),
"Unresolved attributes found when constructing PythonDataSource.")
override protected def stringArgs: Iterator[Any] = {
Iterator(output)
}
final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_DATA_SOURCE)
}

/**
* Represents a list of Python data source partitions.
*/
case class PythonDataSourcePartitions(
output: Seq[Attribute],
partitions: Seq[Array[Byte]]) extends LeafNode {
override protected def stringArgs: Iterator[Any] = {
if (partitions.isEmpty) {
Iterator("<empty>", output)
} else {
Iterator(output)
}
}
}

object PythonDataSourcePartitions {
def schema: StructType = new StructType().add("partition", BinaryType)

def getOutputAttrs: Seq[Attribute] = toAttributes(schema)
}

/**
* Flatmap cogroups using a udf: pandas.Dataframe, pandas.Dataframe -> pandas.Dataframe
* This is used by DataFrame.groupby().cogroup().apply().
Expand Down
48 changes: 6 additions & 42 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@

package org.apache.spark.sql

import java.util.{Locale, Properties, ServiceConfigurationError}
import java.util.{Locale, Properties}

import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}

import org.apache.spark.{Partition, SparkClassNotFoundException, SparkThrowable}
import org.apache.spark.Partition
import org.apache.spark.annotation.Stable
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -209,45 +208,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
throw QueryCompilationErrors.pathOptionNotSetCorrectlyWhenReadingError()
}

val isUserDefinedDataSource =
sparkSession.sessionState.dataSourceManager.dataSourceExists(source)

Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf)) match {
case Success(providerOpt) =>
// The source can be successfully loaded as either a V1 or a V2 data source.
// Check if it is also a user-defined data source.
if (isUserDefinedDataSource) {
throw QueryCompilationErrors.foundMultipleDataSources(source)
}
providerOpt.flatMap { provider =>
DataSourceV2Utils.loadV2Source(
sparkSession, provider, userSpecifiedSchema, extraOptions, source, paths: _*)
}.getOrElse(loadV1Source(paths: _*))
case Failure(exception) =>
// Exceptions are thrown while trying to load the data source as a V1 or V2 data source.
// For the following not found exceptions, if the user-defined data source is defined,
// we can instead return the user-defined data source.
val isNotFoundError = exception match {
case _: NoClassDefFoundError | _: SparkClassNotFoundException => true
case e: SparkThrowable => e.getErrorClass == "DATA_SOURCE_NOT_FOUND"
case e: ServiceConfigurationError => e.getCause.isInstanceOf[NoClassDefFoundError]
case _ => false
}
if (isNotFoundError && isUserDefinedDataSource) {
loadUserDefinedDataSource(paths)
} else {
// Throw the original exception.
throw exception
}
}
}

private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
val builder = sparkSession.sessionState.dataSourceManager.lookupDataSource(source)
// Add `path` and `paths` options to the extra options if specified.
val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions, paths: _*)
val plan = builder(sparkSession, source, userSpecifiedSchema, optionsWithPath)
Dataset.ofRows(sparkSession, plan)
DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf).flatMap { provider =>
DataSourceV2Utils.loadV2Source(sparkSession, provider, userSpecifiedSchema, extraOptions,
source, paths: _*)
}.getOrElse(loadV1Source(paths: _*))
}

private def loadV1Source(paths: String*) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@ private[sql] class DataSourceRegistration private[sql] (dataSourceManager: DataS
| pythonExec: ${dataSource.dataSourceCls.pythonExec}
""".stripMargin)

dataSourceManager.registerDataSource(name, dataSource.builder)
dataSourceManager.registerDataSource(name, dataSource)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ class SparkSession private(
DataSource.lookupDataSource(runner, sessionState.conf) match {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually @cloud-fan that would not work .. E.g., if PythonDataSource implements ExternalCommandRunner, we should load it here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lemme fix it separately. Reading the code path, I think it won't more and less affect.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's worry about it when we actually adding this ability to the python data source. We may never add it for simplicity.

case source if classOf[ExternalCommandRunner].isAssignableFrom(source) =>
Dataset.ofRows(self, ExternalCommandExecutor(
source.getDeclaredConstructor().newInstance()
DataSource.newDataSourceInstance(runner, source)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be arguable that if this is a breaking change. Now people need to worry about python data source in the code that is to deal with DS v1 only.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExternalCommandRunner is DSv2 API..

.asInstanceOf[ExternalCommandRunner], command, options))

case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.datasources.{PlanPythonDataSourceScan, PruneFileSourcePartitions, SchemaPruning, V1Writes}
import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, SchemaPruning, V1Writes}
import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, V2ScanRelationPushDown, V2Writes}
import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning, RowLevelOperationRuntimeGroupFiltering}
import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs, ExtractPythonUDTFs}
Expand All @@ -42,8 +42,7 @@ class SparkOptimizer(
V2ScanRelationPushDown :+
V2ScanPartitioningAndOrdering :+
V2Writes :+
PruneFileSourcePartitions :+
PlanPythonDataSourceScan
PruneFileSourcePartitions

override def preCBORules: Seq[Rule[LogicalPlan]] =
OptimizeMetadataOnlyDeleteFromTable :: Nil
Expand Down Expand Up @@ -102,8 +101,7 @@ class SparkOptimizer(
V2ScanRelationPushDown.ruleName :+
V2ScanPartitioningAndOrdering.ruleName :+
V2Writes.ruleName :+
ReplaceCTERefWithRepartition.ruleName :+
PlanPythonDataSourceScan.ruleName
ReplaceCTERefWithRepartition.ruleName

/**
* Optimization batches that are executed before the regular optimization batches (also before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -753,8 +753,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ArrowEvalPythonUDTF(udtf, requiredChildOutput, resultAttrs, child, evalType) =>
ArrowEvalPythonUDTFExec(
udtf, requiredChildOutput, resultAttrs, planLater(child), evalType) :: Nil
case PythonDataSourcePartitions(output, partitions) =>
PythonDataSourcePartitionsExec(output, partitions) :: Nil
case _ =>
Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM
import org.apache.spark.sql.connector.catalog.SupportsNamespaces._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.errors.QueryExecutionErrors.hiveTableWithAnsiIntervalsError
import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils, FileFormat, HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -1025,7 +1025,9 @@ object DDLUtils extends Logging {

def checkDataColNames(provider: String, schema: StructType): Unit = {
val source = try {
DataSource.lookupDataSource(provider, SQLConf.get).getConstructor().newInstance()
DataSource.newDataSourceInstance(
provider,
DataSource.lookupDataSource(provider, SQLConf.get))
} catch {
case e: Throwable =>
logError(s"Failed to find data source: $provider when check data column names.", e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIfNeeded, CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, ResolveDefaultColumns}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TableIdentifierHelper
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
Expand Down Expand Up @@ -264,8 +264,9 @@ case class AlterTableAddColumnsCommand(
}

if (DDLUtils.isDatasourceTable(catalogTable)) {
DataSource.lookupDataSource(catalogTable.provider.get, conf).
getConstructor().newInstance() match {
DataSource.newDataSourceInstance(
catalogTable.provider.get,
DataSource.lookupDataSource(catalogTable.provider.get, conf)) match {
// For datasource table, this command can only support the following File format.
// TextFileFormat only default to one column "value"
// Hive type is already considered as hive serde table, so the logic will not
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2
import org.apache.spark.sql.execution.datasources.xml.XmlFileFormat
import org.apache.spark.sql.execution.python.PythonTableProvider
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -105,13 +106,14 @@ case class DataSource(
// [[FileDataSourceV2]] will still be used if we call the load()/save() method in
// [[DataFrameReader]]/[[DataFrameWriter]], since they use method `lookupDataSource`
// instead of `providingClass`.
cls.getDeclaredConstructor().newInstance() match {
DataSource.newDataSourceInstance(className, cls) match {
case f: FileDataSourceV2 => f.fallbackFileFormat
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here too

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and tables.scala as well:

    if (DDLUtils.isDatasourceTable(catalogTable)) {
      DataSource.newDataSourceInstance(
          catalogTable.provider.get,
          DataSource.lookupDataSource(catalogTable.provider.get, conf)) match {
        // For datasource table, this command can only support the following File format.
        // TextFileFormat only default to one column "value"
        // Hive type is already considered as hive serde table, so the logic will not
        // come in here.
        case _: CSVFileFormat | _: JsonFileFormat | _: ParquetFileFormat =>
        case _: JsonDataSourceV2 | _: CSVDataSourceV2 |
             _: OrcDataSourceV2 | _: ParquetDataSourceV2 =>
        case s if s.getClass.getCanonicalName.endsWith("OrcFileFormat") =>
        case s =>
          throw QueryCompilationErrors.alterAddColNotSupportDatasourceTableError(s, table)
      }
    }
    catalogTable

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and DataStreamReader:

    val v1DataSource = DataSource(
      sparkSession,
      userSpecifiedSchema = userSpecifiedSchema,
      className = source,
      options = optionsWithPath.originalMap)
    val v1Relation = ds match {
      case _: StreamSourceProvider => Some(StreamingRelation(v1DataSource))
      case _ => None
    }
    ds match {
      // file source v2 does not support streaming yet.
      case provider: TableProvider if !provider.isInstanceOf[FileDataSourceV2] =>

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and DataStreamWriter:

      val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
      val disabledSources =
        Utils.stringToSeq(df.sparkSession.sessionState.conf.disabledV2StreamingWriters)
      val useV1Source = disabledSources.contains(cls.getCanonicalName) ||
        // file source v2 does not support streaming yet.
        classOf[FileDataSourceV2].isAssignableFrom(cls)

      val optionsWithPath = if (path.isEmpty) {
        extraOptions
      } else {
        extraOptions + ("path" -> path.get)
      }

      val sink = if (classOf[TableProvider].isAssignableFrom(cls) && !useV1Source) {

case _ => cls
}
}

private[sql] def providingInstance(): Any = providingClass.getConstructor().newInstance()
private[sql] def providingInstance(): Any =
DataSource.newDataSourceInstance(className, providingClass)

private def newHadoopConfiguration(): Configuration =
sparkSession.sessionState.newHadoopConfWithOptions(options)
Expand Down Expand Up @@ -622,6 +624,15 @@ object DataSource extends Logging {
"org.apache.spark.sql.sources.HadoopFsRelationProvider",
"org.apache.spark.Logging")

/** Create the instance of the datasource */
def newDataSourceInstance(provider: String, providingClass: Class[_]): Any = {
providingClass match {
case cls if classOf[PythonTableProvider].isAssignableFrom(cls) =>
cls.getDeclaredConstructor(classOf[String]).newInstance(provider)
case cls => cls.getDeclaredConstructor().newInstance()
}
}

/** Given a provider name, look up the data source class definition. */
def lookupDataSource(provider: String, conf: SQLConf): Class[_] = {
val provider1 = backwardCompatibilityMap.getOrElse(provider, provider) match {
Expand Down Expand Up @@ -649,6 +660,9 @@ object DataSource extends Logging {
// Found the data source using fully qualified path
dataSource
case Failure(error) =>
// TODO(SPARK-45600): should be session-based.
val isUserDefinedDataSource = SparkSession.getActiveSession.exists(
_.sessionState.dataSourceManager.dataSourceExists(provider))
if (provider1.startsWith("org.apache.spark.sql.hive.orc")) {
throw QueryCompilationErrors.orcNotUsedWithHiveEnabledError()
} else if (provider1.toLowerCase(Locale.ROOT) == "avro" ||
Expand All @@ -657,6 +671,8 @@ object DataSource extends Logging {
throw QueryCompilationErrors.failedToFindAvroDataSourceError(provider1)
} else if (provider1.toLowerCase(Locale.ROOT) == "kafka") {
throw QueryCompilationErrors.failedToFindKafkaDataSourceError(provider1)
} else if (isUserDefinedDataSource) {
classOf[PythonTableProvider]
} else {
throw QueryExecutionErrors.dataSourceNotFoundError(provider1, error)
}
Expand All @@ -673,6 +689,14 @@ object DataSource extends Logging {
}
case head :: Nil =>
// there is exactly one registered alias
// TODO(SPARK-45600): should be session-based.
val isUserDefinedDataSource = SparkSession.getActiveSession.exists(
_.sessionState.dataSourceManager.dataSourceExists(provider))
// The source can be successfully loaded as either a V1 or a V2 data source.
// Check if it is also a user-defined data source.
if (isUserDefinedDataSource) {
throw QueryCompilationErrors.foundMultipleDataSources(provider)
}
head.getClass
case sources =>
// There are multiple registered aliases for the input. If there is single datasource
Expand Down Expand Up @@ -708,17 +732,18 @@ object DataSource extends Logging {
def lookupDataSourceV2(provider: String, conf: SQLConf): Option[TableProvider] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I like that idea. Can I do it in a followup though? I would like to extract some changes from your PR, and make another PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a followup... I have a concern about changing lookupDataSource which is only used for the DS v1 path. Let's avoid the risk of breaking anything. It's also less code change if we only instantiate the PythonTableProvider here, so that the existing caller of lookupDataSource can still instantiate the objects directly instead of calling the new newDataSourceInstance function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh okie dokie. I was actually thinking about porting more changes in your PR. I will fix that one alone here for now.

val useV1Sources = conf.getConf(SQLConf.USE_V1_SOURCE_LIST).toLowerCase(Locale.ROOT)
.split(",").map(_.trim)
val cls = lookupDataSource(provider, conf)
val providingClass = lookupDataSource(provider, conf)
val instance = try {
cls.getDeclaredConstructor().newInstance()
newDataSourceInstance(provider, providingClass)
} catch {
// Throw the original error from the data source implementation.
case e: java.lang.reflect.InvocationTargetException => throw e.getCause
}
instance match {
case d: DataSourceRegister if useV1Sources.contains(d.shortName()) => None
case t: TableProvider
if !useV1Sources.contains(cls.getCanonicalName.toLowerCase(Locale.ROOT)) =>
if !useV1Sources.contains(
providingClass.getCanonicalName.toLowerCase(Locale.ROOT)) =>
Some(t)
case _ => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,28 @@ import java.util.Locale
import java.util.concurrent.ConcurrentHashMap

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource


/**
* A manager for user-defined data sources. It is used to register and lookup data sources by
* their short names or fully qualified names.
*/
class DataSourceManager extends Logging {

private type DataSourceBuilder = (
SparkSession, // Spark session
String, // provider name
Option[StructType], // user specified schema
CaseInsensitiveMap[String] // options
) => LogicalPlan

private val dataSourceBuilders = new ConcurrentHashMap[String, DataSourceBuilder]()
// TODO(SPARK-45917): Statically load Python Data Source so idempotently Python
// Data Sources can be loaded even when the Driver is restarted.
private val dataSourceBuilders = new ConcurrentHashMap[String, UserDefinedPythonDataSource]()

private def normalize(name: String): String = name.toLowerCase(Locale.ROOT)

/**
* Register a data source builder for the given provider.
* Note that the provider name is case-insensitive.
*/
def registerDataSource(name: String, builder: DataSourceBuilder): Unit = {
def registerDataSource(name: String, source: UserDefinedPythonDataSource): Unit = {
val normalizedName = normalize(name)
val previousValue = dataSourceBuilders.put(normalizedName, builder)
val previousValue = dataSourceBuilders.put(normalizedName, source)
if (previousValue != null) {
logWarning(f"The data source $name replaced a previously registered data source.")
}
Expand All @@ -60,7 +52,7 @@ class DataSourceManager extends Logging {
* Returns a data source builder for the given provider and throw an exception if
* it does not exist.
*/
def lookupDataSource(name: String): DataSourceBuilder = {
def lookupDataSource(name: String): UserDefinedPythonDataSource = {
if (dataSourceExists(name)) {
dataSourceBuilders.get(normalize(name))
} else {
Expand Down
Loading