From 8be32065cc60b45f7675a478446b65e94e7ace31 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 13 Nov 2023 19:34:58 +0900 Subject: [PATCH 1/8] Reusing existing codegeneration logic --- .../expressions/codegen/CodeGenerator.scala | 4 +- .../apache/spark/sql/DataFrameReader.scala | 5 +- .../datasources/DataSourceManager.scala | 121 +++++++++++++++++- .../python/PythonDataSourceSuite.scala | 8 ++ 4 files changed, 131 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d10e4a1ced1b..518c467c584a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1513,7 +1513,7 @@ object CodeGenerator extends Logging { (evaluator.getClazz().getConstructor().newInstance().asInstanceOf[GeneratedClass], codeStats) } - private def logGeneratedCode(code: CodeAndComment): Unit = { + def logGeneratedCode(code: CodeAndComment): Unit = { val maxLines = SQLConf.get.loggingMaxLinesForCodegen if (Utils.isTesting) { logError(s"\n${CodeFormatter.format(code, maxLines)}") @@ -1527,7 +1527,7 @@ object CodeGenerator extends Logging { * # of inner classes) of generated classes by inspecting Janino classes. * Also, this method updates the metrics information. */ - private def updateAndGetCompilationStats(evaluator: ClassBodyEvaluator): ByteCodeStats = { + def updateAndGetCompilationStats(evaluator: ClassBodyEvaluator): ByteCodeStats = { // First retrieve the generated classes. val classes = evaluator.getBytecodes.asScala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index c29ffb329072..ea930e38156b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql -import java.util.{Locale, Properties, ServiceConfigurationError} +import java.util.{Locale, Properties} import scala.jdk.CollectionConverters._ -import scala.util.{Failure, Success, Try} -import org.apache.spark.{Partition, SparkClassNotFoundException, SparkThrowable} +import org.apache.spark.Partition import org.apache.spark.annotation.Stable import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala index 1cdc3d9cb69e..ac68e1b87154 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala @@ -20,12 +20,23 @@ package org.apache.spark.sql.execution.datasources import java.util.Locale import java.util.concurrent.ConcurrentHashMap +import scala.collection.immutable.Map + +import org.apache.commons.text.StringEscapeUtils +import org.codehaus.commons.compiler.{CompileException, InternalCompilerException} +import org.codehaus.janino.ClassBodyEvaluator + import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession, SQLContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeFormatter, CodeGenerator} +import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate.newCodeGenContext import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider, SchemaRelationProvider, TableScan} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{ParentClassLoader, Utils} /** * A manager for user-defined data sources. It is used to register and lookup data sources by @@ -40,6 +51,8 @@ class DataSourceManager extends Logging { CaseInsensitiveMap[String] // options ) => LogicalPlan + // TODO(SPARK-45917): Statically load Python Data Source so idempotently Python + // Data Sources can be loaded even when the Driver is restarted. private val dataSourceBuilders = new ConcurrentHashMap[String, DataSourceBuilder]() private def normalize(name: String): String = name.toLowerCase(Locale.ROOT) @@ -81,3 +94,107 @@ class DataSourceManager extends Logging { manager } } + +/** + * Data Source V1 default source wrapper for Python Data Source. + */ +abstract class PythonDefaultSource + extends RelationProvider + with SchemaRelationProvider + with DataSourceRegister { + + override def createRelation( + sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = + new PythonRelation(shortName(), sqlContext, parameters, None) + + override def createRelation( + sqlContext: SQLContext, + parameters: Map[String, String], + schema: StructType): BaseRelation = + new PythonRelation(shortName(), sqlContext, parameters, Some(schema)) +} + +/** + * Data Source V1 relation wrapper for Python Data Source. + */ +class PythonRelation( + source: String, + override val sqlContext: SQLContext, + parameters: Map[String, String], + maybeSchema: Option[StructType]) extends BaseRelation with TableScan { + + private lazy val sourceDf: DataFrame = { + val caseInsensitiveMap = CaseInsensitiveMap(parameters) + // TODO(SPARK-45600): should be session-based. + val builder = sqlContext.sparkSession.sharedState.dataSourceManager.lookupDataSource(source) + val plan = builder( + sqlContext.sparkSession, source, caseInsensitiveMap.get("path").toSeq, + maybeSchema, caseInsensitiveMap) + Dataset.ofRows(sqlContext.sparkSession, plan) + } + + override def schema: StructType = sourceDf.schema + + override def buildScan(): RDD[Row] = sourceDf.rdd +} + +/** + * Responsible for generating a class for Python Data Source + * that inherits Scala Data Source interface so other features work together + * with Python Data Source. + */ +object PythonDataSourceCodeGenerator extends Logging { + /** + * When you invoke `generateClass`, it generates a class that inherits [[PythonDefaultSource]] + * that has a different short name. The generated class name as follows: + * "org.apache.spark.sql.execution.datasources.$shortName.DefaultSource". + * + * The `shortName` should be registered via `spark.dataSource.register` first, then + * this method can generate corresponding Scala Data Source wrapper for the Python + * Data Source. + * + * @param shortName The short name registered for Python Data Source. + * @return + */ + def generateClass(shortName: String): Class[_] = { + val ctx = newCodeGenContext() + + val codeBody = s""" + @Override + public String shortName() { + return "${StringEscapeUtils.escapeJava(shortName)}"; + }""" + + val evaluator = new ClassBodyEvaluator() + val parentClassLoader = new ParentClassLoader(Utils.getContextOrSparkClassLoader) + evaluator.setParentClassLoader(parentClassLoader) + evaluator.setClassName( + s"org.apache.spark.sql.execution.python.datasources.$shortName.DefaultSource") + evaluator.setExtendedClass(classOf[PythonDefaultSource]) + + val code = CodeFormatter.stripOverlappingComments( + new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) + + // Note that the default `CodeGenerator.compile` wraps everything into a `GeneratedClass` + // class, and the defined DataSource becomes a nested class that cannot properly define + // getConstructors, etc. Therefore, we cannot simply reuse this. + try { + evaluator.cook("generated.java", code.body) + CodeGenerator.updateAndGetCompilationStats(evaluator) + } catch { + case e: InternalCompilerException => + val msg = QueryExecutionErrors.failedToCompileMsg(e) + logError(msg, e) + CodeGenerator.logGeneratedCode(code) + throw QueryExecutionErrors.internalCompilerError(e) + case e: CompileException => + val msg = QueryExecutionErrors.failedToCompileMsg(e) + logError(msg, e) + CodeGenerator.logGeneratedCode(code) + throw QueryExecutionErrors.compilerError(e) + } + + logDebug(s"Generated Python Data Source':\n${CodeFormatter.format(code)}") + evaluator.getClazz() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index 6bc9166117f2..dcf06b3311e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -215,9 +215,17 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { |""".stripMargin val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) spark.dataSource.registerPython("test", dataSource) + checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1))) checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1))) checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1), Row("2", 1))) + + // Test SQL + withTable("tblA") { + sql("CREATE TABLE tblA USING test") + // The path will be the actual temp path. + checkAnswer(spark.table("tblA").selectExpr("value"), Seq(Row(1))) + } } test("reader not implemented") { From 82e8ec73f2e51b8888673581d58702f23d4f9f62 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 6 Dec 2023 16:40:19 +0900 Subject: [PATCH 2/8] DSv2 exec version --- .../expressions/codegen/CodeGenerator.scala | 4 +- .../apache/spark/sql/DataFrameReader.scala | 43 +---- .../spark/sql/DataSourceRegistration.scala | 2 +- .../org/apache/spark/sql/SparkSession.scala | 2 +- .../spark/sql/execution/SparkOptimizer.scala | 8 +- .../spark/sql/execution/command/ddl.scala | 6 +- .../spark/sql/execution/command/tables.scala | 7 +- .../execution/datasources/DataSource.scala | 35 +++- .../datasources/DataSourceManager.scala | 139 +------------- .../PlanPythonDataSourceScan.scala | 89 --------- .../ApplyInPandasWithStatePythonRunner.scala | 6 +- .../python/ArrowEvalPythonUDTFExec.scala | 2 +- .../execution/python/ArrowPythonRunner.scala | 6 +- .../python/ArrowPythonUDTFRunner.scala | 2 +- .../python/CoGroupedArrowPythonRunner.scala | 6 +- .../python/FlatMapGroupsInPythonExec.scala | 2 +- .../python/MapInBatchEvaluatorFactory.scala | 2 +- .../sql/execution/python/MapInBatchExec.scala | 2 +- .../execution/python/PythonArrowInput.scala | 4 +- .../execution/python/PythonArrowOutput.scala | 6 +- .../python/UserDefinedPythonDataSource.scala | 176 +++++++++++++++--- .../sql/streaming/DataStreamReader.scala | 5 +- .../sql/streaming/DataStreamWriter.scala | 2 +- .../python/PythonDataSourceSuite.scala | 69 +++---- 24 files changed, 254 insertions(+), 371 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 518c467c584a..d10e4a1ced1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1513,7 +1513,7 @@ object CodeGenerator extends Logging { (evaluator.getClazz().getConstructor().newInstance().asInstanceOf[GeneratedClass], codeStats) } - def logGeneratedCode(code: CodeAndComment): Unit = { + private def logGeneratedCode(code: CodeAndComment): Unit = { val maxLines = SQLConf.get.loggingMaxLinesForCodegen if (Utils.isTesting) { logError(s"\n${CodeFormatter.format(code, maxLines)}") @@ -1527,7 +1527,7 @@ object CodeGenerator extends Logging { * # of inner classes) of generated classes by inspecting Janino classes. * Also, this method updates the metrics information. */ - def updateAndGetCompilationStats(evaluator: ClassBodyEvaluator): ByteCodeStats = { + private def updateAndGetCompilationStats(evaluator: ClassBodyEvaluator): ByteCodeStats = { // First retrieve the generated classes. val classes = evaluator.getBytecodes.asScala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ea930e38156b..9992d8cbba07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -208,45 +208,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { throw QueryCompilationErrors.pathOptionNotSetCorrectlyWhenReadingError() } - val isUserDefinedDataSource = - sparkSession.sessionState.dataSourceManager.dataSourceExists(source) - - Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf)) match { - case Success(providerOpt) => - // The source can be successfully loaded as either a V1 or a V2 data source. - // Check if it is also a user-defined data source. - if (isUserDefinedDataSource) { - throw QueryCompilationErrors.foundMultipleDataSources(source) - } - providerOpt.flatMap { provider => - DataSourceV2Utils.loadV2Source( - sparkSession, provider, userSpecifiedSchema, extraOptions, source, paths: _*) - }.getOrElse(loadV1Source(paths: _*)) - case Failure(exception) => - // Exceptions are thrown while trying to load the data source as a V1 or V2 data source. - // For the following not found exceptions, if the user-defined data source is defined, - // we can instead return the user-defined data source. - val isNotFoundError = exception match { - case _: NoClassDefFoundError | _: SparkClassNotFoundException => true - case e: SparkThrowable => e.getErrorClass == "DATA_SOURCE_NOT_FOUND" - case e: ServiceConfigurationError => e.getCause.isInstanceOf[NoClassDefFoundError] - case _ => false - } - if (isNotFoundError && isUserDefinedDataSource) { - loadUserDefinedDataSource(paths) - } else { - // Throw the original exception. - throw exception - } - } - } - - private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = { - val builder = sparkSession.sessionState.dataSourceManager.lookupDataSource(source) - // Add `path` and `paths` options to the extra options if specified. - val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions, paths: _*) - val plan = builder(sparkSession, source, userSpecifiedSchema, optionsWithPath) - Dataset.ofRows(sparkSession, plan) + DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf).flatMap { provider => + DataSourceV2Utils.loadV2Source(sparkSession, provider, userSpecifiedSchema, extraOptions, + source, paths: _*) + }.getOrElse(loadV1Source(paths: _*)) } private def loadV1Source(paths: String*) = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala index 15d26418984b..936286eb0da5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala @@ -43,6 +43,6 @@ private[sql] class DataSourceRegistration private[sql] (dataSourceManager: DataS | pythonExec: ${dataSource.dataSourceCls.pythonExec} """.stripMargin) - dataSourceManager.registerDataSource(name, dataSource.builder) + dataSourceManager.registerDataSource(name, dataSource) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 15eeca87dcf6..44a4d82c1dac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -780,7 +780,7 @@ class SparkSession private( DataSource.lookupDataSource(runner, sessionState.conf) match { case source if classOf[ExternalCommandRunner].isAssignableFrom(source) => Dataset.ofRows(self, ExternalCommandExecutor( - source.getDeclaredConstructor().newInstance() + DataSource.newDataSourceInstance(runner, source) .asInstanceOf[ExternalCommandRunner], command, options)) case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 00328910f5b6..70a35ea91153 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager -import org.apache.spark.sql.execution.datasources.{PlanPythonDataSourceScan, PruneFileSourcePartitions, SchemaPruning, V1Writes} +import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, SchemaPruning, V1Writes} import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, V2ScanRelationPushDown, V2Writes} import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning, RowLevelOperationRuntimeGroupFiltering} import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs, ExtractPythonUDTFs} @@ -42,8 +42,7 @@ class SparkOptimizer( V2ScanRelationPushDown :+ V2ScanPartitioningAndOrdering :+ V2Writes :+ - PruneFileSourcePartitions :+ - PlanPythonDataSourceScan + PruneFileSourcePartitions override def preCBORules: Seq[Rule[LogicalPlan]] = OptimizeMetadataOnlyDeleteFromTable :: Nil @@ -102,8 +101,7 @@ class SparkOptimizer( V2ScanRelationPushDown.ruleName :+ V2ScanPartitioningAndOrdering.ruleName :+ V2Writes.ruleName :+ - ReplaceCTERefWithRepartition.ruleName :+ - PlanPythonDataSourceScan.ruleName + ReplaceCTERefWithRepartition.ruleName /** * Optimization batches that are executed before the regular optimization batches (also before diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index dc1c5b3fd580..199c8728a5c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM import org.apache.spark.sql.connector.catalog.SupportsNamespaces._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.errors.QueryExecutionErrors.hiveTableWithAnsiIntervalsError -import org.apache.spark.sql.execution.datasources.{DataSource, DataSourceUtils, FileFormat, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types._ @@ -1025,7 +1025,9 @@ object DDLUtils extends Logging { def checkDataColNames(provider: String, schema: StructType): Unit = { val source = try { - DataSource.lookupDataSource(provider, SQLConf.get).getConstructor().newInstance() + DataSource.newDataSourceInstance( + provider, + DataSource.lookupDataSource(provider, SQLConf.get)) } catch { case e: Throwable => logError(s"Failed to find data source: $provider when check data column names.", e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 2f8fca7cfd73..9771ee08b258 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIfNeeded, CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, ResolveDefaultColumns} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TableIdentifierHelper import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -264,8 +264,9 @@ case class AlterTableAddColumnsCommand( } if (DDLUtils.isDatasourceTable(catalogTable)) { - DataSource.lookupDataSource(catalogTable.provider.get, conf). - getConstructor().newInstance() match { + DataSource.newDataSourceInstance( + catalogTable.provider.get, + DataSource.lookupDataSource(catalogTable.provider.get, conf)) match { // For datasource table, this command can only support the following File format. // TextFileFormat only default to one column "value" // Hive type is already considered as hive serde table, so the logic will not diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 71b6d4b886b4..9612d8ff24f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 import org.apache.spark.sql.execution.datasources.xml.XmlFileFormat +import org.apache.spark.sql.execution.python.PythonTableProvider import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider, TextSocketSourceProvider} import org.apache.spark.sql.internal.SQLConf @@ -105,13 +106,14 @@ case class DataSource( // [[FileDataSourceV2]] will still be used if we call the load()/save() method in // [[DataFrameReader]]/[[DataFrameWriter]], since they use method `lookupDataSource` // instead of `providingClass`. - cls.getDeclaredConstructor().newInstance() match { + DataSource.newDataSourceInstance(className, cls) match { case f: FileDataSourceV2 => f.fallbackFileFormat case _ => cls } } - private[sql] def providingInstance(): Any = providingClass.getConstructor().newInstance() + private[sql] def providingInstance(): Any = + DataSource.newDataSourceInstance(className, providingClass) private def newHadoopConfiguration(): Configuration = sparkSession.sessionState.newHadoopConfWithOptions(options) @@ -622,6 +624,15 @@ object DataSource extends Logging { "org.apache.spark.sql.sources.HadoopFsRelationProvider", "org.apache.spark.Logging") + /** Create the instance of the datasource */ + def newDataSourceInstance(provider: String, providingClass: Class[_]): Any = { + providingClass match { + case cls if classOf[PythonTableProvider].isAssignableFrom(cls) => + cls.getDeclaredConstructor(classOf[String]).newInstance(provider) + case cls => cls.getDeclaredConstructor().newInstance() + } + } + /** Given a provider name, look up the data source class definition. */ def lookupDataSource(provider: String, conf: SQLConf): Class[_] = { val provider1 = backwardCompatibilityMap.getOrElse(provider, provider) match { @@ -649,6 +660,9 @@ object DataSource extends Logging { // Found the data source using fully qualified path dataSource case Failure(error) => + // TODO(SPARK-45600): should be session-based. + val isUserDefinedDataSource = SparkSession.getActiveSession.exists( + _.sessionState.dataSourceManager.dataSourceExists(provider)) if (provider1.startsWith("org.apache.spark.sql.hive.orc")) { throw QueryCompilationErrors.orcNotUsedWithHiveEnabledError() } else if (provider1.toLowerCase(Locale.ROOT) == "avro" || @@ -657,6 +671,8 @@ object DataSource extends Logging { throw QueryCompilationErrors.failedToFindAvroDataSourceError(provider1) } else if (provider1.toLowerCase(Locale.ROOT) == "kafka") { throw QueryCompilationErrors.failedToFindKafkaDataSourceError(provider1) + } else if (isUserDefinedDataSource) { + classOf[PythonTableProvider] } else { throw QueryExecutionErrors.dataSourceNotFoundError(provider1, error) } @@ -673,6 +689,14 @@ object DataSource extends Logging { } case head :: Nil => // there is exactly one registered alias + // TODO(SPARK-45600): should be session-based. + val isUserDefinedDataSource = SparkSession.getActiveSession.exists( + _.sessionState.dataSourceManager.dataSourceExists(provider)) + // The source can be successfully loaded as either a V1 or a V2 data source. + // Check if it is also a user-defined data source. + if (isUserDefinedDataSource) { + throw QueryCompilationErrors.foundMultipleDataSources(provider) + } head.getClass case sources => // There are multiple registered aliases for the input. If there is single datasource @@ -708,9 +732,9 @@ object DataSource extends Logging { def lookupDataSourceV2(provider: String, conf: SQLConf): Option[TableProvider] = { val useV1Sources = conf.getConf(SQLConf.USE_V1_SOURCE_LIST).toLowerCase(Locale.ROOT) .split(",").map(_.trim) - val cls = lookupDataSource(provider, conf) + val providingClass = lookupDataSource(provider, conf) val instance = try { - cls.getDeclaredConstructor().newInstance() + newDataSourceInstance(provider, providingClass) } catch { // Throw the original error from the data source implementation. case e: java.lang.reflect.InvocationTargetException => throw e.getCause @@ -718,7 +742,8 @@ object DataSource extends Logging { instance match { case d: DataSourceRegister if useV1Sources.contains(d.shortName()) => None case t: TableProvider - if !useV1Sources.contains(cls.getCanonicalName.toLowerCase(Locale.ROOT)) => + if !useV1Sources.contains( + providingClass.getCanonicalName.toLowerCase(Locale.ROOT)) => Some(t) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala index ac68e1b87154..e6c4749df60a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala @@ -20,40 +20,19 @@ package org.apache.spark.sql.execution.datasources import java.util.Locale import java.util.concurrent.ConcurrentHashMap -import scala.collection.immutable.Map - -import org.apache.commons.text.StringEscapeUtils -import org.codehaus.commons.compiler.{CompileException, InternalCompilerException} -import org.codehaus.janino.ClassBodyEvaluator - import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession, SQLContext} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeFormatter, CodeGenerator} -import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate.newCodeGenContext -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider, SchemaRelationProvider, TableScan} -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.{ParentClassLoader, Utils} +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource + /** * A manager for user-defined data sources. It is used to register and lookup data sources by * their short names or fully qualified names. */ class DataSourceManager extends Logging { - - private type DataSourceBuilder = ( - SparkSession, // Spark session - String, // provider name - Option[StructType], // user specified schema - CaseInsensitiveMap[String] // options - ) => LogicalPlan - // TODO(SPARK-45917): Statically load Python Data Source so idempotently Python // Data Sources can be loaded even when the Driver is restarted. - private val dataSourceBuilders = new ConcurrentHashMap[String, DataSourceBuilder]() + private val dataSourceBuilders = new ConcurrentHashMap[String, UserDefinedPythonDataSource]() private def normalize(name: String): String = name.toLowerCase(Locale.ROOT) @@ -61,9 +40,9 @@ class DataSourceManager extends Logging { * Register a data source builder for the given provider. * Note that the provider name is case-insensitive. */ - def registerDataSource(name: String, builder: DataSourceBuilder): Unit = { + def registerDataSource(name: String, source: UserDefinedPythonDataSource): Unit = { val normalizedName = normalize(name) - val previousValue = dataSourceBuilders.put(normalizedName, builder) + val previousValue = dataSourceBuilders.put(normalizedName, source) if (previousValue != null) { logWarning(f"The data source $name replaced a previously registered data source.") } @@ -73,7 +52,7 @@ class DataSourceManager extends Logging { * Returns a data source builder for the given provider and throw an exception if * it does not exist. */ - def lookupDataSource(name: String): DataSourceBuilder = { + def lookupDataSource(name: String): UserDefinedPythonDataSource = { if (dataSourceExists(name)) { dataSourceBuilders.get(normalize(name)) } else { @@ -94,107 +73,3 @@ class DataSourceManager extends Logging { manager } } - -/** - * Data Source V1 default source wrapper for Python Data Source. - */ -abstract class PythonDefaultSource - extends RelationProvider - with SchemaRelationProvider - with DataSourceRegister { - - override def createRelation( - sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = - new PythonRelation(shortName(), sqlContext, parameters, None) - - override def createRelation( - sqlContext: SQLContext, - parameters: Map[String, String], - schema: StructType): BaseRelation = - new PythonRelation(shortName(), sqlContext, parameters, Some(schema)) -} - -/** - * Data Source V1 relation wrapper for Python Data Source. - */ -class PythonRelation( - source: String, - override val sqlContext: SQLContext, - parameters: Map[String, String], - maybeSchema: Option[StructType]) extends BaseRelation with TableScan { - - private lazy val sourceDf: DataFrame = { - val caseInsensitiveMap = CaseInsensitiveMap(parameters) - // TODO(SPARK-45600): should be session-based. - val builder = sqlContext.sparkSession.sharedState.dataSourceManager.lookupDataSource(source) - val plan = builder( - sqlContext.sparkSession, source, caseInsensitiveMap.get("path").toSeq, - maybeSchema, caseInsensitiveMap) - Dataset.ofRows(sqlContext.sparkSession, plan) - } - - override def schema: StructType = sourceDf.schema - - override def buildScan(): RDD[Row] = sourceDf.rdd -} - -/** - * Responsible for generating a class for Python Data Source - * that inherits Scala Data Source interface so other features work together - * with Python Data Source. - */ -object PythonDataSourceCodeGenerator extends Logging { - /** - * When you invoke `generateClass`, it generates a class that inherits [[PythonDefaultSource]] - * that has a different short name. The generated class name as follows: - * "org.apache.spark.sql.execution.datasources.$shortName.DefaultSource". - * - * The `shortName` should be registered via `spark.dataSource.register` first, then - * this method can generate corresponding Scala Data Source wrapper for the Python - * Data Source. - * - * @param shortName The short name registered for Python Data Source. - * @return - */ - def generateClass(shortName: String): Class[_] = { - val ctx = newCodeGenContext() - - val codeBody = s""" - @Override - public String shortName() { - return "${StringEscapeUtils.escapeJava(shortName)}"; - }""" - - val evaluator = new ClassBodyEvaluator() - val parentClassLoader = new ParentClassLoader(Utils.getContextOrSparkClassLoader) - evaluator.setParentClassLoader(parentClassLoader) - evaluator.setClassName( - s"org.apache.spark.sql.execution.python.datasources.$shortName.DefaultSource") - evaluator.setExtendedClass(classOf[PythonDefaultSource]) - - val code = CodeFormatter.stripOverlappingComments( - new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) - - // Note that the default `CodeGenerator.compile` wraps everything into a `GeneratedClass` - // class, and the defined DataSource becomes a nested class that cannot properly define - // getConstructors, etc. Therefore, we cannot simply reuse this. - try { - evaluator.cook("generated.java", code.body) - CodeGenerator.updateAndGetCompilationStats(evaluator) - } catch { - case e: InternalCompilerException => - val msg = QueryExecutionErrors.failedToCompileMsg(e) - logError(msg, e) - CodeGenerator.logGeneratedCode(code) - throw QueryExecutionErrors.internalCompilerError(e) - case e: CompileException => - val msg = QueryExecutionErrors.failedToCompileMsg(e) - logError(msg, e) - CodeGenerator.logGeneratedCode(code) - throw QueryExecutionErrors.compilerError(e) - } - - logDebug(s"Generated Python Data Source':\n${CodeFormatter.format(code)}") - evaluator.getClazz() - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala deleted file mode 100644 index 7ffd61a4a266..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources - -import org.apache.spark.api.python.{PythonEvalType, PythonFunction, SimplePythonFunction} -import org.apache.spark.sql.catalyst.expressions.PythonUDF -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, PythonDataSource, PythonDataSourcePartitions, PythonMapInArrow} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.PYTHON_DATA_SOURCE -import org.apache.spark.sql.execution.python.UserDefinedPythonDataSourceReadRunner -import org.apache.spark.util.ArrayImplicits._ - -/** - * A logical rule to plan reads from a Python data source. - * - * This rule creates a Python process and invokes the `DataSource.reader` method to create an - * instance of the user-defined data source reader, generates partitions if any, and returns - * the information back to JVM (this rule) to construct the logical plan for Python data source. - * - * For example, prior to applying this rule, the plan might look like: - * - * PythonDataSource(dataSource, schema, output) - * - * Here, `dataSource` is a serialized Python function that contains an instance of the DataSource - * class. Post this rule, the plan is transformed into: - * - * Project [output] - * +- PythonMapInArrow [read_from_data_source, ...] - * +- PythonDataSourcePartitions [partition_bytes] - * - * The PythonDataSourcePartitions contains a list of serialized partition values for the data - * source. The `DataSourceReader.read` method will be planned as a MapInArrow operator that - * accepts a partition value and yields the scanning output. - */ -object PlanPythonDataSourceScan extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning( - _.containsPattern(PYTHON_DATA_SOURCE)) { - case ds @ PythonDataSource(dataSource: PythonFunction, schema, _) => - val inputSchema = PythonDataSourcePartitions.schema - - val info = new UserDefinedPythonDataSourceReadRunner( - dataSource, inputSchema, schema).runInPython() - - val readerFunc = SimplePythonFunction( - command = info.func.toImmutableArraySeq, - envVars = dataSource.envVars, - pythonIncludes = dataSource.pythonIncludes, - pythonExec = dataSource.pythonExec, - pythonVer = dataSource.pythonVer, - broadcastVars = dataSource.broadcastVars, - accumulator = dataSource.accumulator) - - val partitionPlan = PythonDataSourcePartitions( - PythonDataSourcePartitions.getOutputAttrs, info.partitions) - - val pythonUDF = PythonUDF( - name = "read_from_data_source", - func = readerFunc, - dataType = schema, - children = partitionPlan.output, - evalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF, - udfDeterministic = false) - - // Construct the plan. - val plan = PythonMapInArrow( - pythonUDF, - ds.output, - partitionPlan, - isBarrier = false) - - // Project out partition values. - Project(ds.output, plan) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index cfe01f85cbe7..936ab866f5bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -61,12 +61,14 @@ class ApplyInPandasWithStatePythonRunner( keySchema: StructType, outputSchema: StructType, stateValueSchema: StructType, - val pythonMetrics: Map[String, SQLMetric], + pyMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets, jobArtifactUUID) with PythonArrowInput[InType] with PythonArrowOutput[OutType] { + override val pythonMetrics: Option[Map[String, SQLMetric]] = Some(pyMetrics) + override val pythonExec: String = SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( funcs.head.funcs.head.pythonExec) @@ -149,7 +151,7 @@ class ApplyInPandasWithStatePythonRunner( pandasWriter.finalizeGroup() val deltaData = dataOut.size() - startData - pythonMetrics("pythonDataSent") += deltaData + pythonMetrics.foreach(_("pythonDataSent") += deltaData) true } else { pandasWriter.finalizeData() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala index 9e210bf5241b..2503deae7d5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala @@ -70,7 +70,7 @@ case class ArrowEvalPythonUDTFExec( sessionLocalTimeZone, largeVarTypes, pythonRunnerConf, - pythonMetrics, + Some(pythonMetrics), jobArtifactUUID).compute(batchIter, context.partitionId(), context) columnarBatchIter.map { batch => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index a9eaf79c9db0..5dcb79cc2b91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -35,7 +35,7 @@ abstract class BaseArrowPythonRunner( _timeZoneId: String, protected override val largeVarTypes: Boolean, protected override val workerConf: Map[String, String], - val pythonMetrics: Map[String, SQLMetric], + override val pythonMetrics: Option[Map[String, SQLMetric]], jobArtifactUUID: Option[String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, evalType, argOffsets, jobArtifactUUID) @@ -74,7 +74,7 @@ class ArrowPythonRunner( _timeZoneId: String, largeVarTypes: Boolean, workerConf: Map[String, String], - pythonMetrics: Map[String, SQLMetric], + pythonMetrics: Option[Map[String, SQLMetric]], jobArtifactUUID: Option[String]) extends BaseArrowPythonRunner( funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes, workerConf, @@ -100,7 +100,7 @@ class ArrowPythonWithNamedArgumentRunner( jobArtifactUUID: Option[String]) extends BaseArrowPythonRunner( funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId, largeVarTypes, workerConf, - pythonMetrics, jobArtifactUUID) { + Some(pythonMetrics), jobArtifactUUID) { override protected def writeUDF(dataOut: DataOutputStream): Unit = PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala index 87d1ccb25776..df2e89128124 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala @@ -39,7 +39,7 @@ class ArrowPythonUDTFRunner( protected override val timeZoneId: String, protected override val largeVarTypes: Boolean, protected override val workerConf: Map[String, String], - val pythonMetrics: Map[String, SQLMetric], + override val pythonMetrics: Option[Map[String, SQLMetric]], jobArtifactUUID: Option[String]) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( Seq(ChainedPythonFunctions(Seq(udtf.func))), evalType, Array(argMetas.map(_.offset)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala index eb56298bfbee..70bd1ce82e2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala @@ -46,13 +46,15 @@ class CoGroupedArrowPythonRunner( rightSchema: StructType, timeZoneId: String, conf: Map[String, String], - val pythonMetrics: Map[String, SQLMetric], + pyMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) extends BasePythonRunner[ (Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch]( funcs, evalType, argOffsets, jobArtifactUUID) with BasicPythonArrowOutput { + override val pythonMetrics: Option[Map[String, SQLMetric]] = Some(pyMetrics) + override val pythonExec: String = SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( funcs.head.funcs.head.pythonExec) @@ -93,7 +95,7 @@ class CoGroupedArrowPythonRunner( writeGroup(nextRight, rightSchema, dataOut, "right") val deltaData = dataOut.size() - startData - pythonMetrics("pythonDataSent") += deltaData + pythonMetrics.foreach(_("pythonDataSent") += deltaData) true } else { dataOut.writeInt(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPythonExec.scala index 0c18206a825a..e5a00e2cc8ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPythonExec.scala @@ -88,7 +88,7 @@ trait FlatMapGroupsInPythonExec extends SparkPlan with UnaryExecNode with Python sessionLocalTimeZone, largeVarTypes, pythonRunnerConf, - pythonMetrics, + Some(pythonMetrics), jobArtifactUUID) executePython(data, output, runner) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala index 316c543ea807..00990ee46ea5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala @@ -36,7 +36,7 @@ class MapInBatchEvaluatorFactory( sessionLocalTimeZone: String, largeVarTypes: Boolean, pythonRunnerConf: Map[String, String], - pythonMetrics: Map[String, SQLMetric], + pythonMetrics: Option[Map[String, SQLMetric]], jobArtifactUUID: Option[String]) extends PartitionEvaluatorFactory[InternalRow, InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala index 8db389f02667..6db6c96b426a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala @@ -57,7 +57,7 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics { conf.sessionLocalTimeZone, conf.arrowUseLargeVarTypes, pythonRunnerConf, - pythonMetrics, + Some(pythonMetrics), jobArtifactUUID) if (isBarrier) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index 1e075cab9224..6d0f31f35ff7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -46,7 +46,7 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => protected val largeVarTypes: Boolean - protected def pythonMetrics: Map[String, SQLMetric] + protected def pythonMetrics: Option[Map[String, SQLMetric]] protected def writeNextInputToArrowStream( root: VectorSchemaRoot, @@ -132,7 +132,7 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In writer.writeBatch() arrowWriter.reset() val deltaData = dataOut.size() - startData - pythonMetrics("pythonDataSent") += deltaData + pythonMetrics.foreach(_("pythonDataSent") += deltaData) true } else { super[PythonArrowInput].close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala index 90922d89ad10..82e8e7aa4f64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, Column */ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[_, OUT] => - protected def pythonMetrics: Map[String, SQLMetric] + protected def pythonMetrics: Option[Map[String, SQLMetric]] protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { } @@ -91,8 +91,8 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[ val rowCount = root.getRowCount batch.setNumRows(root.getRowCount) val bytesReadEnd = reader.bytesRead() - pythonMetrics("pythonNumRowsReceived") += rowCount - pythonMetrics("pythonDataReceived") += bytesReadEnd - bytesReadStart + pythonMetrics.foreach(_("pythonNumRowsReceived") += rowCount) + pythonMetrics.foreach(_("pythonDataReceived") += bytesReadEnd - bytesReadStart) deserializeColumnarBatch(batch, schema) } else { reader.close(false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index 2c8e1b942727..fcad92f815fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -20,19 +20,142 @@ package org.apache.spark.sql.execution.python import java.io.{DataInputStream, DataOutputStream} import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ import net.razorvine.pickle.Pickler -import org.apache.spark.api.python.{PythonFunction, PythonWorkerUtils, SimplePythonFunction, SpecialLengths} -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PythonDataSource} +import org.apache.spark.JobArtifactSet +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonFunction, PythonWorkerUtils, SimplePythonFunction, SpecialLengths} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.PythonUDF +import org.apache.spark.sql.catalyst.plans.logical.PythonDataSourcePartitions import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, BATCH_WRITE} +import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ArrayImplicits._ + +/** + * Data Source V2 wrapper for Python Data Source. + */ +class PythonTableProvider(shortName: String) extends TableProvider { + private lazy val source: UserDefinedPythonDataSource = + SparkSession.active.sessionState.dataSourceManager.lookupDataSource(shortName) + override def inferSchema(options: CaseInsensitiveStringMap): StructType = + source.inferSchema(shortName, options) + + override def getTable( + schema: StructType, + partitioning: Array[Transform], + properties: java.util.Map[String, String]): Table = { + new PythonTable(shortName, source, schema) + } +} + +class PythonTable(shortName: String, source: UserDefinedPythonDataSource, givenSchema: StructType) + extends Table with SupportsRead { + override def name(): String = shortName + + override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of( + BATCH_READ, BATCH_WRITE) + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new ScanBuilder with Batch with Scan { + + private lazy val pythonFunc: PythonFunction = source.createPythonFunction( + shortName, options, Some(givenSchema)) + + private lazy val info: PythonDataSourceReadInfo = + new UserDefinedPythonDataSourceReadRunner( + pythonFunc, PythonDataSourcePartitions.schema, givenSchema).runInPython() + + override def build(): Scan = this + + override def toBatch: Batch = this + + override def readSchema(): StructType = givenSchema + + override def planInputPartitions(): Array[InputPartition] = + info.partitions.zipWithIndex.map(p => PythonInputPartition(p._2, p._1)).toArray + + override def createReaderFactory(): PartitionReaderFactory = + new PythonPartitionReaderFactory(info, pythonFunc, givenSchema) + } + } + + override def schema(): StructType = givenSchema +} + +case class PythonInputPartition(index: Int, pickedPartition: Array[Byte]) extends InputPartition + +class PythonPartitionReaderFactory( + info: PythonDataSourceReadInfo, dataSource: PythonFunction, schema: StructType) + extends PartitionReaderFactory { + + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val partitionInfo = partition.asInstanceOf[PythonInputPartition] + + val readerFunc = SimplePythonFunction( + command = info.func.toImmutableArraySeq, + envVars = dataSource.envVars, + pythonIncludes = dataSource.pythonIncludes, + pythonExec = dataSource.pythonExec, + pythonVer = dataSource.pythonVer, + broadcastVars = dataSource.broadcastVars, + accumulator = dataSource.accumulator) + + val partitionPlan = PythonDataSourcePartitions( + PythonDataSourcePartitions.getOutputAttrs, info.partitions) + + val pythonEvalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF + + val pythonUDF = PythonUDF( + name = "read_from_data_source", + func = readerFunc, + dataType = schema, + children = partitionPlan.output, + evalType = pythonEvalType, + udfDeterministic = false) + + val conf = SQLConf.get + + val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) + val evaluatorFactory = new MapInBatchEvaluatorFactory( + toAttributes(schema), + Seq(ChainedPythonFunctions(Seq(pythonUDF.func))), + PythonDataSourcePartitions.schema, + conf.arrowMaxRecordsPerBatch, + pythonEvalType, + conf.sessionLocalTimeZone, + conf.arrowUseLargeVarTypes, + pythonRunnerConf, + None, + jobArtifactUUID) + + new PartitionReader[InternalRow] { + + private val outputIter = evaluatorFactory.createEvaluator().eval( + partitionInfo.index, Iterator.single(InternalRow(partitionInfo.pickedPartition))) + + override def next(): Boolean = outputIter.hasNext + + override def get(): InternalRow = outputIter.next() + + override def close(): Unit = {} + } + } +} + /** * A user-defined Python data source. This is used by the Python API. * @@ -40,19 +163,36 @@ import org.apache.spark.util.ArrayImplicits._ */ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { - def builder( - sparkSession: SparkSession, - provider: String, - userSpecifiedSchema: Option[StructType], - options: CaseInsensitiveMap[String]): LogicalPlan = { + private var pythonResult: PythonDataSourceCreationResult = _ + private def getOrCreatePythonResult( + shortName: String, + options: CaseInsensitiveStringMap, + userSpecifiedSchema: Option[StructType]): PythonDataSourceCreationResult = { + if (pythonResult != null) return pythonResult val runner = new UserDefinedPythonDataSourceRunner( - dataSourceCls, provider, userSpecifiedSchema, options) + dataSourceCls, + shortName, + userSpecifiedSchema, + CaseInsensitiveMap(options.asCaseSensitiveMap().asScala.toMap)) + pythonResult = runner.runInPython() + pythonResult + } - val result = runner.runInPython() - val pickledDataSourceInstance = result.dataSource + def inferSchema( + shortName: String, + options: CaseInsensitiveStringMap): StructType = { + getOrCreatePythonResult(shortName, options, None).schema + } + + def createPythonFunction( + shortName: String, + options: CaseInsensitiveStringMap, + userSpecifiedSchema: Option[StructType]): PythonFunction = { + val pickledDataSourceInstance = getOrCreatePythonResult( + shortName, options, userSpecifiedSchema).dataSource - val dataSource = SimplePythonFunction( + SimplePythonFunction( command = pickledDataSourceInstance.toImmutableArraySeq, envVars = dataSourceCls.envVars, pythonIncludes = dataSourceCls.pythonIncludes, @@ -60,18 +200,6 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { pythonVer = dataSourceCls.pythonVer, broadcastVars = dataSourceCls.broadcastVars, accumulator = dataSourceCls.accumulator) - val schema = result.schema - - PythonDataSource(dataSource, schema, output = toAttributes(schema)) - } - - def apply( - sparkSession: SparkSession, - provider: String, - userSpecifiedSchema: Option[StructType] = None, - options: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map.empty)): DataFrame = { - val plan = builder(sparkSession, provider, userSpecifiedSchema, options) - Dataset.ofRows(sparkSession, plan) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 1a69678c2f54..c93ca632d3c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -156,8 +156,9 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo extraOptions + ("path" -> path.get) } - val ds = DataSource.lookupDataSource(source, sparkSession.sessionState.conf). - getConstructor().newInstance() + val ds = DataSource.newDataSourceInstance( + source, + DataSource.lookupDataSource(source, sparkSession.sessionState.conf)) // We need to generate the V1 data source so we can pass it to the V2 relation as a shim. // We can't be sure at this point whether we'll actually want to use V2, since we don't know the // writer or whether the query is continuous. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 95aa2f8c7a4e..7202f69ab1bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -382,7 +382,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } val sink = if (classOf[TableProvider].isAssignableFrom(cls) && !useV1Source) { - val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] + val provider = DataSource.newDataSourceInstance(source, cls).asInstanceOf[TableProvider] val sessionOptions = DataSourceV2Utils.extractSessionConfigs( source = provider, conf = df.sparkSession.sessionState.conf) val finalOptions = sessionOptions.filter { case (k, _) => !optionsWithPath.contains(k) } ++ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index dcf06b3311e2..3bb7ebd90a5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -18,8 +18,6 @@ package org.apache.spark.sql.execution.python import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row} -import org.apache.spark.sql.catalyst.plans.logical.{PythonDataSourcePartitions, PythonMapInArrow} -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType @@ -53,14 +51,10 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { val schema = StructType.fromDDL("id INT, partition INT") val dataSource = createUserDefinedPythonDataSource( name = dataSourceName, pythonScript = dataSourceScript) - val df = dataSource.apply( - spark, provider = dataSourceName, userSpecifiedSchema = Some(schema)) + spark.dataSource.registerPython(dataSourceName, dataSource) + val df = spark.read.format(dataSourceName).schema(schema).load() assert(df.rdd.getNumPartitions == 2) val plan = df.queryExecution.optimizedPlan - plan match { - case PythonMapInArrow(_, _, _: PythonDataSourcePartitions, _) => - case _ => fail(s"Plan did not match the expected pattern. Actual plan:\n$plan") - } checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1))) } @@ -79,7 +73,8 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { | return SimpleDataSourceReader() |""".stripMargin val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) - val df = dataSource(spark, provider = dataSourceName) + spark.dataSource.registerPython(dataSourceName, dataSource) + val df = spark.read.format(dataSourceName).load() checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1))) } @@ -102,7 +97,8 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { | return SimpleDataSourceReader() |""".stripMargin val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) - val df = dataSource(spark, provider = dataSourceName) + spark.dataSource.registerPython(dataSourceName, dataSource) + val df = spark.read.format(dataSourceName).load() checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1))) } @@ -121,13 +117,14 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { | return SimpleDataSourceReader() |""".stripMargin val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) + spark.dataSource.registerPython(dataSourceName, dataSource) checkError( - exception = intercept[AnalysisException](dataSource(spark, provider = dataSourceName)), + exception = intercept[AnalysisException](spark.read.format(dataSourceName).load()), errorClass = "INVALID_SCHEMA.NON_STRUCT_TYPE", parameters = Map("inputSchema" -> "INT", "dataType" -> "\"INT\"")) } - test("register data source") { + test("test dataSourceExists") { assume(shouldTestPandasUDFs) val dataSourceScript = s""" @@ -145,34 +142,6 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) spark.dataSource.registerPython(dataSourceName, dataSource) assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName)) - val ds1 = spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName) - checkAnswer( - ds1(spark, dataSourceName, None, CaseInsensitiveMap(Map.empty)), - Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1))) - - // Should be able to override an already registered data source. - val newScript = - s""" - |from pyspark.sql.datasource import DataSource, DataSourceReader - |class SimpleDataSourceReader(DataSourceReader): - | def read(self, partition): - | yield (0, ) - | - |class $dataSourceName(DataSource): - | def schema(self) -> str: - | return "id INT" - | - | def reader(self, schema): - | return SimpleDataSourceReader() - |""".stripMargin - val newDataSource = createUserDefinedPythonDataSource(dataSourceName, newScript) - spark.dataSource.registerPython(dataSourceName, newDataSource) - assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName)) - - val ds2 = spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName) - checkAnswer( - ds2(spark, dataSourceName, None, CaseInsensitiveMap(Map.empty)), - Seq(Row(0))) } test("load data source") { @@ -215,12 +184,10 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { |""".stripMargin val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) spark.dataSource.registerPython("test", dataSource) - checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1))) checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1))) checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1), Row("2", 1))) - // Test SQL withTable("tblA") { sql("CREATE TABLE tblA USING test") // The path will be the actual temp path. @@ -239,8 +206,9 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { val schema = StructType.fromDDL("id INT, partition INT") val dataSource = createUserDefinedPythonDataSource( name = dataSourceName, pythonScript = dataSourceScript) + spark.dataSource.registerPython(dataSourceName, dataSource) val err = intercept[AnalysisException] { - dataSource(spark, dataSourceName, userSpecifiedSchema = Some(schema)).collect() + spark.read.format(dataSourceName).schema(schema).load().collect() } assert(err.getErrorClass == "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON") assert(err.getMessage.contains("PYTHON_DATA_SOURCE_METHOD_NOT_IMPLEMENTED")) @@ -258,8 +226,9 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { val schema = StructType.fromDDL("id INT, partition INT") val dataSource = createUserDefinedPythonDataSource( name = dataSourceName, pythonScript = dataSourceScript) + spark.dataSource.registerPython(dataSourceName, dataSource) val err = intercept[AnalysisException] { - dataSource(spark, dataSourceName, userSpecifiedSchema = Some(schema)).collect() + spark.read.format(dataSourceName).schema(schema).load().collect() } assert(err.getErrorClass == "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON") assert(err.getMessage.contains("PYTHON_DATA_SOURCE_CREATE_ERROR")) @@ -277,8 +246,9 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { val schema = StructType.fromDDL("id INT, partition INT") val dataSource = createUserDefinedPythonDataSource( name = dataSourceName, pythonScript = dataSourceScript) + spark.dataSource.registerPython(dataSourceName, dataSource) val err = intercept[AnalysisException] { - dataSource(spark, dataSourceName, userSpecifiedSchema = Some(schema)).collect() + spark.read.format(dataSourceName).schema(schema).load().collect() } assert(err.getErrorClass == "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON") assert(err.getMessage.contains("PYTHON_DATA_SOURCE_TYPE_MISMATCH")) @@ -312,7 +282,8 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { | return SimpleDataSourceReader() |""".stripMargin val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) - val df = dataSource(spark, provider = dataSourceName) + spark.dataSource.registerPython(dataSourceName, dataSource) + val df = spark.read.format(dataSourceName).load() checkAnswer(df, Seq(Row(1), Row(3))) } @@ -339,7 +310,8 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { | return SimpleDataSourceReader() |""".stripMargin val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) - val df = dataSource(spark, provider = dataSourceName) + spark.dataSource.registerPython(dataSourceName, dataSource) + val df = spark.read.format(dataSourceName).load() checkAnswer(df, Row("success")) } @@ -386,8 +358,9 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { | return SimpleDataSourceReader() |""".stripMargin val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) + spark.dataSource.registerPython(dataSourceName, dataSource) val err = intercept[AnalysisException]( - dataSource(spark, provider = dataSourceName).collect()) + spark.read.format(dataSourceName).load().collect()) assert(err.getErrorClass == "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON") assert(err.getMessage.contains("PYTHON_DATA_SOURCE_CREATE_ERROR")) } From 5497b9f13998299d16e3818bf1c409e6e1b40c52 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 12 Dec 2023 22:36:31 -0800 Subject: [PATCH 3/8] Fix --- .../python/UserDefinedPythonDataSource.scala | 28 +++++++++---------- .../python/PythonDataSourceSuite.scala | 21 ++++++++------ 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index fcad92f815fa..caa3d5ac95b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -56,8 +56,11 @@ class PythonTableProvider(shortName: String) extends TableProvider { schema: StructType, partitioning: Array[Transform], properties: java.util.Map[String, String]): Table = { + assert(partitioning.isEmpty) new PythonTable(shortName, source, schema) } + + override def supportsExternalMetadata(): Boolean = true } class PythonTable(shortName: String, source: UserDefinedPythonDataSource, givenSchema: StructType) @@ -70,8 +73,8 @@ class PythonTable(shortName: String, source: UserDefinedPythonDataSource, givenS override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new ScanBuilder with Batch with Scan { - private lazy val pythonFunc: PythonFunction = source.createPythonFunction( - shortName, options, Some(givenSchema)) + private lazy val pythonFunc: PythonFunction = + source.createPythonFunction(shortName, options, givenSchema) private lazy val info: PythonDataSourceReadInfo = new UserDefinedPythonDataSourceReadRunner( @@ -163,37 +166,32 @@ class PythonPartitionReaderFactory( */ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { - private var pythonResult: PythonDataSourceCreationResult = _ - - private def getOrCreatePythonResult( + private def createPythonResult( shortName: String, options: CaseInsensitiveStringMap, userSpecifiedSchema: Option[StructType]): PythonDataSourceCreationResult = { - if (pythonResult != null) return pythonResult - val runner = new UserDefinedPythonDataSourceRunner( + new UserDefinedPythonDataSourceRunner( dataSourceCls, shortName, userSpecifiedSchema, - CaseInsensitiveMap(options.asCaseSensitiveMap().asScala.toMap)) - pythonResult = runner.runInPython() - pythonResult + CaseInsensitiveMap(options.asCaseSensitiveMap().asScala.toMap)).runInPython() } def inferSchema( shortName: String, options: CaseInsensitiveStringMap): StructType = { - getOrCreatePythonResult(shortName, options, None).schema + createPythonResult(shortName, options, None).schema } def createPythonFunction( shortName: String, options: CaseInsensitiveStringMap, - userSpecifiedSchema: Option[StructType]): PythonFunction = { - val pickledDataSourceInstance = getOrCreatePythonResult( - shortName, options, userSpecifiedSchema).dataSource + givenSchema: StructType): PythonFunction = { + val dataSource = createPythonResult( + shortName, options, Some(givenSchema)).dataSource SimplePythonFunction( - command = pickledDataSourceInstance.toImmutableArraySeq, + command = dataSource.toImmutableArraySeq, envVars = dataSourceCls.envVars, pythonIncludes = dataSourceCls.pythonIncludes, pythonExec = dataSourceCls.pythonExec, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index 3bb7ebd90a5a..734a4ef315c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.python import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType @@ -55,6 +56,10 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { val df = spark.read.format(dataSourceName).schema(schema).load() assert(df.rdd.getNumPartitions == 2) val plan = df.queryExecution.optimizedPlan + plan match { + case s: DataSourceV2ScanRelation if s.relation.table.isInstanceOf[PythonTable] => + case _ => fail(s"Plan did not match the expected pattern. Actual plan:\n$plan") + } checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1))) } @@ -164,12 +169,12 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { | paths = [] | return [InputPartition(p) for p in paths] | - | def read(self, path): - | if path is not None: - | assert isinstance(path, InputPartition) - | yield (path.value, 1) + | def read(self, part): + | if part is not None: + | assert isinstance(part, InputPartition) + | yield (part.value, 1) | else: - | yield (path, 1) + | yield (part, 1) | |class $dataSourceName(DataSource): | @classmethod @@ -256,7 +261,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { } test("data source read with custom partitions") { - assume(shouldTestPythonUDFs) + assume(shouldTestPandasUDFs) val dataSourceScript = s""" |from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition @@ -288,7 +293,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { } test("data source read with empty partitions") { - assume(shouldTestPythonUDFs) + assume(shouldTestPandasUDFs) val dataSourceScript = s""" |from pyspark.sql.datasource import DataSource, DataSourceReader @@ -316,7 +321,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { } test("data source read with invalid partitions") { - assume(shouldTestPythonUDFs) + assume(shouldTestPandasUDFs) val reader1 = s""" |class SimpleDataSourceReader(DataSourceReader): From 6e1a9f1f34816cfeeb6b62fee5e631de33ef2597 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 12 Dec 2023 22:58:10 -0800 Subject: [PATCH 4/8] fix --- .../python/UserDefinedPythonDataSource.scala | 94 +++++++++---------- .../python/PythonDataSourceSuite.scala | 3 +- 2 files changed, 44 insertions(+), 53 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index caa3d5ac95b6..855cb63103b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -47,54 +47,68 @@ import org.apache.spark.util.ArrayImplicits._ * Data Source V2 wrapper for Python Data Source. */ class PythonTableProvider(shortName: String) extends TableProvider { + private var pythonResult: PythonDataSourceCreationResult = _ private lazy val source: UserDefinedPythonDataSource = SparkSession.active.sessionState.dataSourceManager.lookupDataSource(shortName) - override def inferSchema(options: CaseInsensitiveStringMap): StructType = - source.inferSchema(shortName, options) + override def inferSchema(options: CaseInsensitiveStringMap): StructType = { + if (pythonResult == null) { + pythonResult = source.createPythonResult(shortName, options, None) + } + pythonResult.schema + } override def getTable( schema: StructType, partitioning: Array[Transform], properties: java.util.Map[String, String]): Table = { assert(partitioning.isEmpty) - new PythonTable(shortName, source, schema) - } - - override def supportsExternalMetadata(): Boolean = true -} + val givenSchema = schema + new Table with SupportsRead { + override def name(): String = shortName -class PythonTable(shortName: String, source: UserDefinedPythonDataSource, givenSchema: StructType) - extends Table with SupportsRead { - override def name(): String = shortName + override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of( + BATCH_READ, BATCH_WRITE) - override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of( - BATCH_READ, BATCH_WRITE) + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new ScanBuilder with Batch with Scan { - override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - new ScanBuilder with Batch with Scan { + private lazy val pythonFunc: PythonFunction = { + if (pythonResult == null) { + pythonResult = source.createPythonResult(shortName, options, Some(givenSchema)) + } + SimplePythonFunction( + command = pythonResult.dataSource.toImmutableArraySeq, + envVars = source.dataSourceCls.envVars, + pythonIncludes = source.dataSourceCls.pythonIncludes, + pythonExec = source.dataSourceCls.pythonExec, + pythonVer = source.dataSourceCls.pythonVer, + broadcastVars = source.dataSourceCls.broadcastVars, + accumulator = source.dataSourceCls.accumulator) + } - private lazy val pythonFunc: PythonFunction = - source.createPythonFunction(shortName, options, givenSchema) + private lazy val info: PythonDataSourceReadInfo = + new UserDefinedPythonDataSourceReadRunner( + pythonFunc, PythonDataSourcePartitions.schema, givenSchema).runInPython() - private lazy val info: PythonDataSourceReadInfo = - new UserDefinedPythonDataSourceReadRunner( - pythonFunc, PythonDataSourcePartitions.schema, givenSchema).runInPython() + override def build(): Scan = this - override def build(): Scan = this + override def toBatch: Batch = this - override def toBatch: Batch = this + override def readSchema(): StructType = givenSchema - override def readSchema(): StructType = givenSchema + override def planInputPartitions(): Array[InputPartition] = + info.partitions.zipWithIndex.map(p => PythonInputPartition(p._2, p._1)).toArray - override def planInputPartitions(): Array[InputPartition] = - info.partitions.zipWithIndex.map(p => PythonInputPartition(p._2, p._1)).toArray + override def createReaderFactory(): PartitionReaderFactory = + new PythonPartitionReaderFactory(info, pythonFunc, givenSchema) + } + } - override def createReaderFactory(): PartitionReaderFactory = - new PythonPartitionReaderFactory(info, pythonFunc, givenSchema) + override def schema(): StructType = givenSchema } } - override def schema(): StructType = givenSchema + override def supportsExternalMetadata(): Boolean = true } case class PythonInputPartition(index: Int, pickedPartition: Array[Byte]) extends InputPartition @@ -165,8 +179,7 @@ class PythonPartitionReaderFactory( * @param dataSourceCls The Python data source class. */ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { - - private def createPythonResult( + def createPythonResult( shortName: String, options: CaseInsensitiveStringMap, userSpecifiedSchema: Option[StructType]): PythonDataSourceCreationResult = { @@ -176,29 +189,6 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { userSpecifiedSchema, CaseInsensitiveMap(options.asCaseSensitiveMap().asScala.toMap)).runInPython() } - - def inferSchema( - shortName: String, - options: CaseInsensitiveStringMap): StructType = { - createPythonResult(shortName, options, None).schema - } - - def createPythonFunction( - shortName: String, - options: CaseInsensitiveStringMap, - givenSchema: StructType): PythonFunction = { - val dataSource = createPythonResult( - shortName, options, Some(givenSchema)).dataSource - - SimplePythonFunction( - command = dataSource.toImmutableArraySeq, - envVars = dataSourceCls.envVars, - pythonIncludes = dataSourceCls.pythonIncludes, - pythonExec = dataSourceCls.pythonExec, - pythonVer = dataSourceCls.pythonVer, - broadcastVars = dataSourceCls.broadcastVars, - accumulator = dataSourceCls.accumulator) - } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index 734a4ef315c9..16ac9b3b0edb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -57,7 +57,8 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { assert(df.rdd.getNumPartitions == 2) val plan = df.queryExecution.optimizedPlan plan match { - case s: DataSourceV2ScanRelation if s.relation.table.isInstanceOf[PythonTable] => + case s: DataSourceV2ScanRelation + if s.relation.table.getClass.toString.contains("PythonTable") => case _ => fail(s"Plan did not match the expected pattern. Actual plan:\n$plan") } checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1))) From 02513140142bc46963fd519bb4eacbcbf9d22fb2 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 13 Dec 2023 11:51:56 -0800 Subject: [PATCH 5/8] Refactoring --- .../python/UserDefinedPythonDataSource.scala | 167 ++++++++++-------- 1 file changed, 96 insertions(+), 71 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index 855cb63103b4..0776fa99783b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -42,19 +42,19 @@ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ArrayImplicits._ - /** * Data Source V2 wrapper for Python Data Source. */ class PythonTableProvider(shortName: String) extends TableProvider { - private var pythonResult: PythonDataSourceCreationResult = _ + private var dataSourceInPython: PythonDataSourceCreationResult = _ + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) private lazy val source: UserDefinedPythonDataSource = SparkSession.active.sessionState.dataSourceManager.lookupDataSource(shortName) override def inferSchema(options: CaseInsensitiveStringMap): StructType = { - if (pythonResult == null) { - pythonResult = source.createPythonResult(shortName, options, None) + if (dataSourceInPython == null) { + dataSourceInPython = source.createDataSourceInPython(shortName, options, None) } - pythonResult.schema + dataSourceInPython.schema } override def getTable( @@ -62,7 +62,7 @@ class PythonTableProvider(shortName: String) extends TableProvider { partitioning: Array[Transform], properties: java.util.Map[String, String]): Table = { assert(partitioning.isEmpty) - val givenSchema = schema + val outputSchema = schema new Table with SupportsRead { override def name(): String = shortName @@ -72,39 +72,32 @@ class PythonTableProvider(shortName: String) extends TableProvider { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new ScanBuilder with Batch with Scan { - private lazy val pythonFunc: PythonFunction = { - if (pythonResult == null) { - pythonResult = source.createPythonResult(shortName, options, Some(givenSchema)) + private lazy val infoInPython: PythonDataSourceReadInfo = { + if (dataSourceInPython == null) { + dataSourceInPython = source + .createDataSourceInPython(shortName, options, Some(outputSchema)) } - SimplePythonFunction( - command = pythonResult.dataSource.toImmutableArraySeq, - envVars = source.dataSourceCls.envVars, - pythonIncludes = source.dataSourceCls.pythonIncludes, - pythonExec = source.dataSourceCls.pythonExec, - pythonVer = source.dataSourceCls.pythonVer, - broadcastVars = source.dataSourceCls.broadcastVars, - accumulator = source.dataSourceCls.accumulator) + source.createReadInfoInPython(dataSourceInPython, outputSchema) } - private lazy val info: PythonDataSourceReadInfo = - new UserDefinedPythonDataSourceReadRunner( - pythonFunc, PythonDataSourcePartitions.schema, givenSchema).runInPython() - override def build(): Scan = this override def toBatch: Batch = this - override def readSchema(): StructType = givenSchema + override def readSchema(): StructType = outputSchema override def planInputPartitions(): Array[InputPartition] = - info.partitions.zipWithIndex.map(p => PythonInputPartition(p._2, p._1)).toArray + infoInPython.partitions.zipWithIndex.map(p => PythonInputPartition(p._2, p._1)).toArray - override def createReaderFactory(): PartitionReaderFactory = - new PythonPartitionReaderFactory(info, pythonFunc, givenSchema) + override def createReaderFactory(): PartitionReaderFactory = { + val readerFunc = infoInPython.func + new PythonPartitionReaderFactory( + source, readerFunc, outputSchema, jobArtifactUUID) + } } } - override def schema(): StructType = givenSchema + override def schema(): StructType = outputSchema } } @@ -114,33 +107,81 @@ class PythonTableProvider(shortName: String) extends TableProvider { case class PythonInputPartition(index: Int, pickedPartition: Array[Byte]) extends InputPartition class PythonPartitionReaderFactory( - info: PythonDataSourceReadInfo, dataSource: PythonFunction, schema: StructType) - extends PartitionReaderFactory { - - private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + source: UserDefinedPythonDataSource, + pickledReadFunc: Array[Byte], + outputSchema: StructType, + jobArtifactUUID: Option[String]) + extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - val partitionInfo = partition.asInstanceOf[PythonInputPartition] + new PartitionReader[InternalRow] { + private val outputIter = source.createPartitionReadIteratorInPython( + partition.asInstanceOf[PythonInputPartition], + pickledReadFunc, + outputSchema, + jobArtifactUUID) + + override def next(): Boolean = outputIter.hasNext + + override def get(): InternalRow = outputIter.next() + + override def close(): Unit = {} + } + } +} + +/** + * A user-defined Python data source. This is used by the Python API. + * Defines the interation between Python and JVM. + * + * @param dataSourceCls The Python data source class. + */ +case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { + + /** + * (Driver-side) Run Python process, and get the pickled Python Data Source + * instance and its schema. + */ + def createDataSourceInPython( + shortName: String, + options: CaseInsensitiveStringMap, + userSpecifiedSchema: Option[StructType]): PythonDataSourceCreationResult = { + new UserDefinedPythonDataSourceRunner( + dataSourceCls, + shortName, + userSpecifiedSchema, + CaseInsensitiveMap(options.asCaseSensitiveMap().asScala.toMap)).runInPython() + } - val readerFunc = SimplePythonFunction( - command = info.func.toImmutableArraySeq, - envVars = dataSource.envVars, - pythonIncludes = dataSource.pythonIncludes, - pythonExec = dataSource.pythonExec, - pythonVer = dataSource.pythonVer, - broadcastVars = dataSource.broadcastVars, - accumulator = dataSource.accumulator) + /** + * (Driver-side) Run Python process, and get the partition read functions, and + * partition information. + */ + def createReadInfoInPython( + pythonResult: PythonDataSourceCreationResult, + outputSchema: StructType): PythonDataSourceReadInfo = { + new UserDefinedPythonDataSourceReadRunner( + createPythonFunction( + pythonResult.dataSource), PythonDataSourcePartitions.schema, outputSchema).runInPython() + } - val partitionPlan = PythonDataSourcePartitions( - PythonDataSourcePartitions.getOutputAttrs, info.partitions) + /** + * (Executor-side) Create an iterator that reads the input partitions. + */ + def createPartitionReadIteratorInPython( + partition: PythonInputPartition, + pickledReadFunc: Array[Byte], + outputSchema: StructType, + jobArtifactUUID: Option[String]): Iterator[InternalRow] = { + val readerFunc = createPythonFunction(pickledReadFunc) val pythonEvalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF val pythonUDF = PythonUDF( name = "read_from_data_source", func = readerFunc, - dataType = schema, - children = partitionPlan.output, + dataType = outputSchema, + children = PythonDataSourcePartitions.getOutputAttrs, evalType = pythonEvalType, udfDeterministic = false) @@ -148,7 +189,7 @@ class PythonPartitionReaderFactory( val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) val evaluatorFactory = new MapInBatchEvaluatorFactory( - toAttributes(schema), + toAttributes(outputSchema), Seq(ChainedPythonFunctions(Seq(pythonUDF.func))), PythonDataSourcePartitions.schema, conf.arrowMaxRecordsPerBatch, @@ -159,35 +200,19 @@ class PythonPartitionReaderFactory( None, jobArtifactUUID) - new PartitionReader[InternalRow] { - - private val outputIter = evaluatorFactory.createEvaluator().eval( - partitionInfo.index, Iterator.single(InternalRow(partitionInfo.pickedPartition))) - - override def next(): Boolean = outputIter.hasNext - - override def get(): InternalRow = outputIter.next() - - override def close(): Unit = {} - } + evaluatorFactory.createEvaluator().eval( + partition.index, Iterator.single(InternalRow(partition.pickedPartition))) } -} -/** - * A user-defined Python data source. This is used by the Python API. - * - * @param dataSourceCls The Python data source class. - */ -case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { - def createPythonResult( - shortName: String, - options: CaseInsensitiveStringMap, - userSpecifiedSchema: Option[StructType]): PythonDataSourceCreationResult = { - new UserDefinedPythonDataSourceRunner( - dataSourceCls, - shortName, - userSpecifiedSchema, - CaseInsensitiveMap(options.asCaseSensitiveMap().asScala.toMap)).runInPython() + private def createPythonFunction(pickledFunc: Array[Byte]): PythonFunction = { + SimplePythonFunction( + command = pickledFunc.toImmutableArraySeq, + envVars = dataSourceCls.envVars, + pythonIncludes = dataSourceCls.pythonIncludes, + pythonExec = dataSourceCls.pythonExec, + pythonVer = dataSourceCls.pythonVer, + broadcastVars = dataSourceCls.broadcastVars, + accumulator = dataSourceCls.accumulator) } } From a9d1e1cd17aa5b3bc0eb329fa6f108f1f6fdfd19 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 13 Dec 2023 12:50:30 -0800 Subject: [PATCH 6/8] Recover test case --- .../python/PythonDataSourceSuite.scala | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index 16ac9b3b0edb..53a54abf8392 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -130,7 +130,7 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { parameters = Map("inputSchema" -> "INT", "dataType" -> "\"INT\"")) } - test("test dataSourceExists") { + test("register data source") { assume(shouldTestPandasUDFs) val dataSourceScript = s""" @@ -148,6 +148,31 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) spark.dataSource.registerPython(dataSourceName, dataSource) assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName)) + checkAnswer( + spark.read.format(dataSourceName).load(), + Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1))) + + // Should be able to override an already registered data source. + val newScript = + s""" + |from pyspark.sql.datasource import DataSource, DataSourceReader + |class SimpleDataSourceReader(DataSourceReader): + | def read(self, partition): + | yield (0, ) + | + |class $dataSourceName(DataSource): + | def schema(self) -> str: + | return "id INT" + | + | def reader(self, schema): + | return SimpleDataSourceReader() + |""".stripMargin + val newDataSource = createUserDefinedPythonDataSource(dataSourceName, newScript) + spark.dataSource.registerPython(dataSourceName, newDataSource) + assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName)) + checkAnswer( + spark.read.format(dataSourceName).load(), + Seq(Row(0))) } test("load data source") { From 8846bf5b0568ab05dbb2f40041e053ef0fce3d4a Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 13 Dec 2023 12:54:57 -0800 Subject: [PATCH 7/8] Further refactoring and cleanup --- .../logical/pythonLogicalOperators.scala | 40 +--------- .../spark/sql/execution/SparkStrategies.scala | 2 - .../PythonDataSourcePartitionsExec.scala | 80 ------------------- .../python/UserDefinedPythonDataSource.scala | 11 +-- 4 files changed, 7 insertions(+), 126 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonDataSourcePartitionsExec.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index fb8b06eb41bc..f5930c5272a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.api.python.PythonFunction import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF, PythonUDTF} import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} -import org.apache.spark.sql.types.{BinaryType, StructType} +import org.apache.spark.sql.types.StructType /** * FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame. @@ -103,42 +101,6 @@ case class PythonMapInArrow( copy(child = newChild) } -/** - * Represents a Python data source. - */ -case class PythonDataSource( - dataSource: PythonFunction, - outputSchema: StructType, - override val output: Seq[Attribute]) extends LeafNode { - require(output.forall(_.resolved), - "Unresolved attributes found when constructing PythonDataSource.") - override protected def stringArgs: Iterator[Any] = { - Iterator(output) - } - final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_DATA_SOURCE) -} - -/** - * Represents a list of Python data source partitions. - */ -case class PythonDataSourcePartitions( - output: Seq[Attribute], - partitions: Seq[Array[Byte]]) extends LeafNode { - override protected def stringArgs: Iterator[Any] = { - if (partitions.isEmpty) { - Iterator("", output) - } else { - Iterator(output) - } - } -} - -object PythonDataSourcePartitions { - def schema: StructType = new StructType().add("partition", BinaryType) - - def getOutputAttrs: Seq[Attribute] = toAttributes(schema) -} - /** * Flatmap cogroups using a udf: pandas.Dataframe, pandas.Dataframe -> pandas.Dataframe * This is used by DataFrame.groupby().cogroup().apply(). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2d24f997d105..35070ac1d562 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -753,8 +753,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ArrowEvalPythonUDTF(udtf, requiredChildOutput, resultAttrs, child, evalType) => ArrowEvalPythonUDTFExec( udtf, requiredChildOutput, resultAttrs, planLater(child), evalType) :: Nil - case PythonDataSourcePartitions(output, partitions) => - PythonDataSourcePartitionsExec(output, partitions) :: Nil case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonDataSourcePartitionsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonDataSourcePartitionsExec.scala deleted file mode 100644 index 8f1595cfdd71..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonDataSourcePartitionsExec.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.python - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.{InputRDDCodegen, LeafExecNode, SQLExecution} -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.ArrayImplicits._ - -/** - * A physical plan node for scanning data from a list of data source partition values. - * - * It creates a RDD with number of partitions equal to size of the partition value list and - * each partition contains a single row with a serialized partition value. - */ -case class PythonDataSourcePartitionsExec( - output: Seq[Attribute], - partitions: Seq[Array[Byte]]) extends LeafExecNode with InputRDDCodegen { - - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - - @transient private lazy val unsafeRows: Array[InternalRow] = { - if (partitions.isEmpty) { - Array.empty - } else { - val proj = UnsafeProjection.create(output, output) - partitions.map(p => proj(InternalRow(p)).copy()).toArray - } - } - - @transient private lazy val rdd: RDD[InternalRow] = { - val numPartitions = partitions.size - if (numPartitions == 0) { - sparkContext.emptyRDD - } else { - sparkContext.parallelize(unsafeRows.toImmutableArraySeq, numPartitions) - } - } - - override def inputRDD: RDD[InternalRow] = rdd - - override protected val createUnsafeProjection: Boolean = false - - protected override def doExecute(): RDD[InternalRow] = { - longMetric("numOutputRows").add(partitions.size) - sendDriverMetrics() - rdd - } - - override protected def stringArgs: Iterator[Any] = { - if (partitions.isEmpty) { - Iterator("", output) - } else { - Iterator(output) - } - } - - private def sendDriverMetrics(): Unit = { - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index 0776fa99783b..4977b59dab40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -29,7 +29,6 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, Pyth import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.PythonUDF -import org.apache.spark.sql.catalyst.plans.logical.PythonDataSourcePartitions import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} @@ -38,7 +37,7 @@ import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{BinaryType, DataType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ArrayImplicits._ @@ -138,6 +137,8 @@ class PythonPartitionReaderFactory( */ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { + private val inputSchema: StructType = new StructType().add("partition", BinaryType) + /** * (Driver-side) Run Python process, and get the pickled Python Data Source * instance and its schema. @@ -162,7 +163,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { outputSchema: StructType): PythonDataSourceReadInfo = { new UserDefinedPythonDataSourceReadRunner( createPythonFunction( - pythonResult.dataSource), PythonDataSourcePartitions.schema, outputSchema).runInPython() + pythonResult.dataSource), inputSchema, outputSchema).runInPython() } /** @@ -181,7 +182,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { name = "read_from_data_source", func = readerFunc, dataType = outputSchema, - children = PythonDataSourcePartitions.getOutputAttrs, + children = toAttributes(inputSchema), evalType = pythonEvalType, udfDeterministic = false) @@ -191,7 +192,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { val evaluatorFactory = new MapInBatchEvaluatorFactory( toAttributes(outputSchema), Seq(ChainedPythonFunctions(Seq(pythonUDF.func))), - PythonDataSourcePartitions.schema, + inputSchema, conf.arrowMaxRecordsPerBatch, pythonEvalType, conf.sessionLocalTimeZone, From 7a23cb4a9a3694411abcabe90c05da22ee692906 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 14 Dec 2023 08:52:02 -0800 Subject: [PATCH 8/8] Address comments --- .../sql/execution/python/UserDefinedPythonDataSource.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index 4977b59dab40..7c850d1e2890 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.PythonUDF import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} -import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, BATCH_WRITE} +import org.apache.spark.sql.connector.catalog.TableCapability.BATCH_READ import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} import org.apache.spark.sql.errors.QueryCompilationErrors @@ -60,13 +60,12 @@ class PythonTableProvider(shortName: String) extends TableProvider { schema: StructType, partitioning: Array[Transform], properties: java.util.Map[String, String]): Table = { - assert(partitioning.isEmpty) val outputSchema = schema new Table with SupportsRead { override def name(): String = shortName override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of( - BATCH_READ, BATCH_WRITE) + BATCH_READ) override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new ScanBuilder with Batch with Scan {