Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1513,7 +1513,7 @@ object CodeGenerator extends Logging {
(evaluator.getClazz().getConstructor().newInstance().asInstanceOf[GeneratedClass], codeStats)
}

private def logGeneratedCode(code: CodeAndComment): Unit = {
def logGeneratedCode(code: CodeAndComment): Unit = {
val maxLines = SQLConf.get.loggingMaxLinesForCodegen
if (Utils.isTesting) {
logError(s"\n${CodeFormatter.format(code, maxLines)}")
Expand All @@ -1527,7 +1527,7 @@ object CodeGenerator extends Logging {
* # of inner classes) of generated classes by inspecting Janino classes.
* Also, this method updates the metrics information.
*/
private def updateAndGetCompilationStats(evaluator: ClassBodyEvaluator): ByteCodeStats = {
def updateAndGetCompilationStats(evaluator: ClassBodyEvaluator): ByteCodeStats = {
// First retrieve the generated classes.
val classes = evaluator.getBytecodes.asScala

Expand Down
48 changes: 6 additions & 42 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@

package org.apache.spark.sql

import java.util.{Locale, Properties, ServiceConfigurationError}
import java.util.{Locale, Properties}

import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}

import org.apache.spark.{Partition, SparkClassNotFoundException, SparkThrowable}
import org.apache.spark.Partition
import org.apache.spark.annotation.Stable
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -209,45 +208,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
throw QueryCompilationErrors.pathOptionNotSetCorrectlyWhenReadingError()
}

val isUserDefinedDataSource =
sparkSession.sessionState.dataSourceManager.dataSourceExists(source)

Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf)) match {
case Success(providerOpt) =>
// The source can be successfully loaded as either a V1 or a V2 data source.
// Check if it is also a user-defined data source.
if (isUserDefinedDataSource) {
throw QueryCompilationErrors.foundMultipleDataSources(source)
}
providerOpt.flatMap { provider =>
DataSourceV2Utils.loadV2Source(
sparkSession, provider, userSpecifiedSchema, extraOptions, source, paths: _*)
}.getOrElse(loadV1Source(paths: _*))
case Failure(exception) =>
// Exceptions are thrown while trying to load the data source as a V1 or V2 data source.
// For the following not found exceptions, if the user-defined data source is defined,
// we can instead return the user-defined data source.
val isNotFoundError = exception match {
case _: NoClassDefFoundError | _: SparkClassNotFoundException => true
case e: SparkThrowable => e.getErrorClass == "DATA_SOURCE_NOT_FOUND"
case e: ServiceConfigurationError => e.getCause.isInstanceOf[NoClassDefFoundError]
case _ => false
}
if (isNotFoundError && isUserDefinedDataSource) {
loadUserDefinedDataSource(paths)
} else {
// Throw the original exception.
throw exception
}
}
}

private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
val builder = sparkSession.sessionState.dataSourceManager.lookupDataSource(source)
// Add `path` and `paths` options to the extra options if specified.
val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions, paths: _*)
val plan = builder(sparkSession, source, userSpecifiedSchema, optionsWithPath)
Dataset.ofRows(sparkSession, plan)
DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf).flatMap { provider =>
DataSourceV2Utils.loadV2Source(sparkSession, provider, userSpecifiedSchema, extraOptions,
source, paths: _*)
}.getOrElse(loadV1Source(paths: _*))
}

private def loadV1Source(paths: String*) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,9 @@ object DataSource extends Logging {
// Found the data source using fully qualified path
dataSource
case Failure(error) =>
// TODO(SPARK-45600): should be session-based.
val isUserDefinedDataSource = SparkSession.getActiveSession.exists(
_.sessionState.dataSourceManager.dataSourceExists(provider))
if (provider1.startsWith("org.apache.spark.sql.hive.orc")) {
throw QueryCompilationErrors.orcNotUsedWithHiveEnabledError()
} else if (provider1.toLowerCase(Locale.ROOT) == "avro" ||
Expand All @@ -657,6 +660,12 @@ object DataSource extends Logging {
throw QueryCompilationErrors.failedToFindAvroDataSourceError(provider1)
} else if (provider1.toLowerCase(Locale.ROOT) == "kafka") {
throw QueryCompilationErrors.failedToFindKafkaDataSourceError(provider1)
} else if (isUserDefinedDataSource) {
// TODO(SPARK-45916): Try Python Data Source. Should probably cache
// to avoid regenerating every time (?), but if we want to allow
// runtime update of the Python datasource, we should regenerate
// everytime.
PythonDataSourceCodeGenerator.generateClass(provider)
} else {
throw QueryExecutionErrors.dataSourceNotFoundError(provider1, error)
}
Expand All @@ -673,6 +682,14 @@ object DataSource extends Logging {
}
case head :: Nil =>
// there is exactly one registered alias
// TODO(SPARK-45600): should be session-based.
val isUserDefinedDataSource = SparkSession.getActiveSession.exists(
_.sessionState.dataSourceManager.dataSourceExists(provider))
// The source can be successfully loaded as either a V1 or a V2 data source.
// Check if it is also a user-defined data source.
if (isUserDefinedDataSource) {
throw QueryCompilationErrors.foundMultipleDataSources(provider)
}
head.getClass
case sources =>
// There are multiple registered aliases for the input. If there is single datasource
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,29 @@ package org.apache.spark.sql.execution.datasources
import java.util.Locale
import java.util.concurrent.ConcurrentHashMap

import scala.jdk.CollectionConverters._

import org.apache.commons.text.StringEscapeUtils
import org.codehaus.commons.compiler.{CompileException, InternalCompilerException}
import org.codehaus.janino.ClassBodyEvaluator

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeFormatter, CodeGenerator}
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate.newCodeGenContext
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider}
import org.apache.spark.sql.connector.catalog.TableCapability.BATCH_READ
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, TableScan}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.{ParentClassLoader, Utils}


/**
* A manager for user-defined data sources. It is used to register and lookup data sources by
Expand All @@ -40,6 +57,8 @@ class DataSourceManager extends Logging {
CaseInsensitiveMap[String] // options
) => LogicalPlan

// TODO(SPARK-45917): Statically load Python Data Source so idempotently Python
// Data Sources can be loaded even when the Driver is restarted.
private val dataSourceBuilders = new ConcurrentHashMap[String, DataSourceBuilder]()

private def normalize(name: String): String = name.toLowerCase(Locale.ROOT)
Expand Down Expand Up @@ -81,3 +100,125 @@ class DataSourceManager extends Logging {
manager
}
}

/**
* Data Source V2 default source wrapper for Python Data Source.
*/
abstract class PythonDefaultSource
extends TableProvider
with DataSourceRegister {

private var sourceDataFrame: DataFrame = _

private def getOrCreateSourceDataFrame(
options: CaseInsensitiveStringMap, maybeSchema: Option[StructType]): DataFrame = {
if (sourceDataFrame != null) return sourceDataFrame
// TODO(SPARK-45600): should be session-based.
val builder = SparkSession.active.sessionState.dataSourceManager.lookupDataSource(shortName())
val plan = builder(
SparkSession.active,
shortName(),
maybeSchema,
CaseInsensitiveMap(options.asCaseSensitiveMap().asScala.toMap))
sourceDataFrame = Dataset.ofRows(SparkSession.active, plan)
sourceDataFrame
}

override def inferSchema(options: CaseInsensitiveStringMap): StructType =
getOrCreateSourceDataFrame(options, None).schema

override def getTable(
schema: StructType,
partitioning: Array[Transform],
properties: java.util.Map[String, String]): Table = {
val givenSchema = schema
new Table with SupportsRead {
override def name(): String = shortName()

override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of(BATCH_READ)

override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new ScanBuilder with V1Scan {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan and @allisonwang-db, Here yet I use V1Scan interface.

In order to fully leverage DSv2, we should actually refactor the whole PlanPythonDataSourceScan and UserDefinedPythonDataSource.

  1. First we should remove PlanPythonDataSourceScan rule so DataSourceV2Strategy can resolve the DSv2.
  2. Second, we should fix/port the partitioning/reading logics from UserDefinedPythonDataSource to this Scan and ScanBuilder implementation.

While I don't think this is a problem now, but we should do it in the end for write path, etc I believe (?). I would like it to be done separately if you don't mind (and I would like to focus on static/runtime registration part).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe it's good enough for read since we can mix-in to implement write, etc. separately(?)

override def build(): Scan = this
override def toV1TableScan[T <: BaseRelation with TableScan](
context: SQLContext): T = {
new BaseRelation with TableScan {
// Avoid Row <> InternalRow conversion
override val needConversion: Boolean = false
override def buildScan(): RDD[Row] =
getOrCreateSourceDataFrame(options, Some(givenSchema))
.queryExecution.toRdd.asInstanceOf[RDD[Row]]
override def schema: StructType = givenSchema
override def sqlContext: SQLContext = context
}.asInstanceOf[T]
}
override def readSchema(): StructType = givenSchema
}
}

override def schema(): StructType = givenSchema
}
}
}


/**
* Responsible for generating a class for Python Data Source
* that inherits Scala Data Source interface so other features work together
* with Python Data Source.
*/
object PythonDataSourceCodeGenerator extends Logging {
/**
* When you invoke `generateClass`, it generates a class that inherits [[PythonDefaultSource]]
* that has a different short name. The generated class name as follows:
* "org.apache.spark.sql.execution.datasources.$shortName.DefaultSource".
*
* The `shortName` should be registered via `spark.dataSource.register` first, then
* this method can generate corresponding Scala Data Source wrapper for the Python
* Data Source.
*
* @param shortName The short name registered for Python Data Source.
* @return
*/
def generateClass(shortName: String): Class[_] = {
val ctx = newCodeGenContext()

val codeBody = s"""
@Override
public String shortName() {
return "${StringEscapeUtils.escapeJava(shortName)}";
}"""

val evaluator = new ClassBodyEvaluator()
val parentClassLoader = new ParentClassLoader(Utils.getContextOrSparkClassLoader)
evaluator.setParentClassLoader(parentClassLoader)
evaluator.setClassName(
s"org.apache.spark.sql.execution.python.datasources.$shortName.DefaultSource")
evaluator.setExtendedClass(classOf[PythonDefaultSource])

val code = CodeFormatter.stripOverlappingComments(
new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))

// Note that the default `CodeGenerator.compile` wraps everything into a `GeneratedClass`
// class, and the defined DataSource becomes a nested class that cannot properly define
// getConstructors, etc. Therefore, we cannot simply reuse this.
try {
evaluator.cook("generated.java", code.body)
CodeGenerator.updateAndGetCompilationStats(evaluator)
} catch {
case e: InternalCompilerException =>
val msg = QueryExecutionErrors.failedToCompileMsg(e)
logError(msg, e)
CodeGenerator.logGeneratedCode(code)
throw QueryExecutionErrors.internalCompilerError(e)
case e: CompileException =>
val msg = QueryExecutionErrors.failedToCompileMsg(e)
logError(msg, e)
CodeGenerator.logGeneratedCode(code)
throw QueryExecutionErrors.compilerError(e)
}

logDebug(s"Generated Python Data Source':\n${CodeFormatter.format(code)}")
evaluator.getClazz()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,13 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession {
checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1)))
checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1)))
checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1), Row("2", 1)))

// Test SQL
withTable("tblA") {
sql("CREATE TABLE tblA USING test")
// The path will be the actual temp path.
checkAnswer(spark.table("tblA").selectExpr("value"), Seq(Row(1)))
}
}

test("reader not implemented") {
Expand Down