From 71170e74df5c7ec657f61154212d1dc2ba7d0613 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 15 Feb 2019 14:57:23 +0800 Subject: [PATCH 01/19] [SPARK-26871][SQL] File Source V2: avoid creating unnecessary FileIndex in the write path ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/23383, the file source V2 framework is implemented. In the PR, `FileIndex` is created as a member of `FileTable`, so that we can implement partition pruning like https://github.com/apache/spark/commit/0f9fcabb4ac2e8afec14d010e86467372a85d334 in the future(As data source V2 catalog is under development, partition pruning is removed from the PR) However, after write path of file source V2 is implemented, I find that a simple write will create an unnecessary `FileIndex`, which is required by `FileTable`. This is a sort of regression. And we can see there is a warning message when writing to ORC files ``` WARN InMemoryFileIndex: The directory file:/tmp/foo was not found. Was it deleted very recently? ``` This PR is to make `FileIndex` as a lazy value in `FileTable`, so that we can avoid creating unnecessary `FileIndex` in the write path. ## How was this patch tested? Existing unit test Closes #23774 from gengliangwang/moveFileIndexInV2. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- .../datasources/FallbackOrcDataSourceV2.scala | 2 +- .../datasources/v2/FileDataSourceV2.scala | 20 ++----------------- .../execution/datasources/v2/FileTable.scala | 17 +++++++++++++--- .../datasources/v2/orc/OrcDataSourceV2.scala | 6 ++---- .../datasources/v2/orc/OrcTable.scala | 5 ++--- 5 files changed, 21 insertions(+), 29 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala index 254c09001f7e..e22d6a6d399a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala @@ -35,7 +35,7 @@ class FallbackOrcDataSourceV2(sparkSession: SparkSession) extends Rule[LogicalPl override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(d @DataSourceV2Relation(table: OrcTable, _, _), _, _, _, _) => val v1FileFormat = new OrcFileFormat - val relation = HadoopFsRelation(table.getFileIndex, table.getFileIndex.partitionSchema, + val relation = HadoopFsRelation(table.fileIndex, table.fileIndex.partitionSchema, table.schema(), None, v1FileFormat, d.options)(sparkSession) i.copy(table = LogicalRelation(relation)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala index a0c932cbb0e0..06c57066aa24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala @@ -16,13 +16,10 @@ */ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, SupportsBatchRead, TableProvider} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.sources.v2.TableProvider /** * A base interface for data source v2 implementations of the built-in file-based data sources. @@ -38,17 +35,4 @@ trait FileDataSourceV2 extends TableProvider with DataSourceRegister { def fallBackFileFormat: Class[_ <: FileFormat] lazy val sparkSession = SparkSession.active - - def getFileIndex( - options: DataSourceOptions, - userSpecifiedSchema: Option[StructType]): PartitioningAwareFileIndex = { - val filePaths = options.paths() - val hadoopConf = - sparkSession.sessionState.newHadoopConfWithOptions(options.asMap().asScala.toMap) - val rootPathsSpecified = DataSource.checkAndGlobPathIfNecessary(filePaths, hadoopConf, - checkEmptyGlobPath = true, checkFilesExist = options.checkFilesExist()) - val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) - new InMemoryFileIndex(sparkSession, rootPathsSpecified, - options.asMap().asScala.toMap, userSpecifiedSchema, fileStatusCache) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index 0dbef145f732..21d3e5e29cfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -16,19 +16,30 @@ */ package org.apache.spark.sql.execution.datasources.v2 +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.sources.v2.{SupportsBatchRead, SupportsBatchWrite, Table} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsBatchRead, SupportsBatchWrite, Table} import org.apache.spark.sql.types.StructType abstract class FileTable( sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, + options: DataSourceOptions, userSpecifiedSchema: Option[StructType]) extends Table with SupportsBatchRead with SupportsBatchWrite { - def getFileIndex: PartitioningAwareFileIndex = this.fileIndex + lazy val fileIndex: PartitioningAwareFileIndex = { + val filePaths = options.paths() + val hadoopConf = + sparkSession.sessionState.newHadoopConfWithOptions(options.asMap().asScala.toMap) + val rootPathsSpecified = DataSource.checkAndGlobPathIfNecessary(filePaths, hadoopConf, + checkEmptyGlobPath = true, checkFilesExist = options.checkFilesExist()) + val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) + new InMemoryFileIndex(sparkSession, rootPathsSpecified, + options.asMap().asScala.toMap, userSpecifiedSchema, fileStatusCache) + } lazy val dataSchema: StructType = userSpecifiedSchema.orElse { inferSchema(fileIndex.allFiles()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala index db1f2f793422..74739b4fe2d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala @@ -34,13 +34,11 @@ class OrcDataSourceV2 extends FileDataSourceV2 { override def getTable(options: DataSourceOptions): Table = { val tableName = getTableName(options) - val fileIndex = getFileIndex(options, None) - OrcTable(tableName, sparkSession, fileIndex, None) + OrcTable(tableName, sparkSession, options, None) } override def getTable(options: DataSourceOptions, schema: StructType): Table = { val tableName = getTableName(options) - val fileIndex = getFileIndex(options, Some(schema)) - OrcTable(tableName, sparkSession, fileIndex, Some(schema)) + OrcTable(tableName, sparkSession, options, Some(schema)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala index b467e505f1ba..249df8b8622f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.v2.orc import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.orc.OrcUtils import org.apache.spark.sql.execution.datasources.v2.FileTable import org.apache.spark.sql.sources.v2.DataSourceOptions @@ -29,9 +28,9 @@ import org.apache.spark.sql.types.StructType case class OrcTable( name: String, sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, + options: DataSourceOptions, userSpecifiedSchema: Option[StructType]) - extends FileTable(sparkSession, fileIndex, userSpecifiedSchema) { + extends FileTable(sparkSession, options, userSpecifiedSchema) { override def newScanBuilder(options: DataSourceOptions): OrcScanBuilder = new OrcScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) From b6c68755715e36f199c172443862ca4bfde3ced5 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Fri, 15 Feb 2019 12:44:14 -0800 Subject: [PATCH 02/19] [SPARK-26790][CORE] Change approach for retrieving executor logs and attributes: self-retrieve ## What changes were proposed in this pull request? This patch proposes to change the approach on extracting log urls as well as attributes from YARN executor: - AS-IS: extract information from `Container` API and include them to container launch context - TO-BE: let YARN executor self-extracting information This approach leads us to populate more attributes like nodemanager's IPC port which can let us configure custom log url to JHS log url directly. ## How was this patch tested? Existing unit tests. Closes #23706 from HeartSaVioR/SPARK-26790. Authored-by: Jungtaek Lim (HeartSaVioR) Signed-off-by: Marcelo Vanzin --- .../CoarseGrainedExecutorBackend.scala | 52 ++++++++----- docs/running-on-yarn.md | 18 +++++ .../spark/deploy/yarn/ExecutorRunnable.scala | 17 +--- .../YarnCoarseGrainedExecutorBackend.scala | 77 +++++++++++++++++++ .../spark/util/YarnContainerInfoHelper.scala | 22 +++++- .../spark/deploy/yarn/YarnClusterSuite.scala | 3 + 6 files changed, 151 insertions(+), 38 deletions(-) create mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/executor/YarnCoarseGrainedExecutorBackend.scala diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 2865c3bc86e4..645f58716de6 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -181,33 +181,47 @@ private[spark] class CoarseGrainedExecutorBackend( private[spark] object CoarseGrainedExecutorBackend extends Logging { - private def run( + case class Arguments( driverUrl: String, executorId: String, hostname: String, cores: Int, appId: String, workerUrl: Option[String], - userClassPath: Seq[URL]) { + userClassPath: mutable.ListBuffer[URL]) + + def main(args: Array[String]): Unit = { + val createFn: (RpcEnv, Arguments, SparkEnv) => + CoarseGrainedExecutorBackend = { case (rpcEnv, arguments, env) => + new CoarseGrainedExecutorBackend(rpcEnv, arguments.driverUrl, arguments.executorId, + arguments.hostname, arguments.cores, arguments.userClassPath, env) + } + run(parseArguments(args, this.getClass.getCanonicalName.stripSuffix("$")), createFn) + System.exit(0) + } + + def run( + arguments: Arguments, + backendCreateFn: (RpcEnv, Arguments, SparkEnv) => CoarseGrainedExecutorBackend): Unit = { Utils.initDaemon(log) SparkHadoopUtil.get.runAsSparkUser { () => // Debug code - Utils.checkHost(hostname) + Utils.checkHost(arguments.hostname) // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf val fetcher = RpcEnv.create( "driverPropsFetcher", - hostname, + arguments.hostname, -1, executorConf, new SecurityManager(executorConf), clientMode = true) - val driver = fetcher.setupEndpointRefByURI(driverUrl) + val driver = fetcher.setupEndpointRefByURI(arguments.driverUrl) val cfg = driver.askSync[SparkAppConfig](RetrieveSparkAppConfig) - val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", appId)) + val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", arguments.appId)) fetcher.shutdown() // Create SparkEnv using properties we fetched from the driver. @@ -225,19 +239,18 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf) } - val env = SparkEnv.createExecutorEnv( - driverConf, executorId, hostname, cores, cfg.ioEncryptionKey, isLocal = false) + val env = SparkEnv.createExecutorEnv(driverConf, arguments.executorId, arguments.hostname, + arguments.cores, cfg.ioEncryptionKey, isLocal = false) - env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( - env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env)) - workerUrl.foreach { url => + env.rpcEnv.setupEndpoint("Executor", backendCreateFn(env.rpcEnv, arguments, env)) + arguments.workerUrl.foreach { url => env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } env.rpcEnv.awaitTermination() } } - def main(args: Array[String]) { + def parseArguments(args: Array[String], classNameForEntry: String): Arguments = { var driverUrl: String = null var executorId: String = null var hostname: String = null @@ -276,24 +289,24 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // scalastyle:off println System.err.println(s"Unrecognized options: ${tail.mkString(" ")}") // scalastyle:on println - printUsageAndExit() + printUsageAndExit(classNameForEntry) } } if (driverUrl == null || executorId == null || hostname == null || cores <= 0 || appId == null) { - printUsageAndExit() + printUsageAndExit(classNameForEntry) } - run(driverUrl, executorId, hostname, cores, appId, workerUrl, userClassPath) - System.exit(0) + Arguments(driverUrl, executorId, hostname, cores, appId, workerUrl, + userClassPath) } - private def printUsageAndExit() = { + private def printUsageAndExit(classNameForEntry: String): Unit = { // scalastyle:off println System.err.println( - """ - |Usage: CoarseGrainedExecutorBackend [options] + s""" + |Usage: $classNameForEntry [options] | | Options are: | --driver-url @@ -307,5 +320,4 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // scalastyle:on println System.exit(1) } - } diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 51a13d342a67..8f1a12726b06 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -480,6 +480,18 @@ To use a custom metrics.properties for the application master and executors, upd {{HTTP_SCHEME}} `http://` or `https://` according to YARN HTTP policy. (Configured via `yarn.http.policy`) + + {{NM_HOST}} + The "host" of node where container was run. + + + {{NM_PORT}} + The "port" of node manager where container was run. + + + {{NM_HTTP_PORT}} + The "port" of node manager's http server where container was run. + {{NM_HTTP_ADDRESS}} Http URI of the node on which the container is allocated. @@ -502,6 +514,12 @@ To use a custom metrics.properties for the application master and executors, upd +For example, suppose you would like to point log url link to Job History Server directly instead of let NodeManager http server redirects it, you can configure `spark.history.custom.executor.log.url` as below: + + `{{HTTP_SCHEME}}:/jobhistory/logs/{{NM_HOST}}:{{NM_PORT}}/{{CONTAINER_ID}}/{{CONTAINER_ID}}/{{USER}}/{{FILE_NAME}}?start=-4096` + + NOTE: you need to replace `` and `` with actual value. + # Important notes - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 0b909d15c2fa..2f8f2a0a119c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -202,7 +202,7 @@ private[yarn] class ExecutorRunnable( val commands = prefixEnv ++ Seq(Environment.JAVA_HOME.$$() + "/bin/java", "-server") ++ javaOpts ++ - Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend", + Seq("org.apache.spark.executor.YarnCoarseGrainedExecutorBackend", "--driver-url", masterAddress, "--executor-id", executorId, "--hostname", hostname, @@ -235,21 +235,6 @@ private[yarn] class ExecutorRunnable( } } - // Add log urls, as well as executor attributes - container.foreach { c => - YarnContainerInfoHelper.getLogUrls(conf, Some(c)).foreach { m => - m.foreach { case (fileName, url) => - env("SPARK_LOG_URL_" + fileName.toUpperCase(Locale.ROOT)) = url - } - } - - YarnContainerInfoHelper.getAttributes(conf, Some(c)).foreach { m => - m.foreach { case (attr, value) => - env("SPARK_EXECUTOR_ATTRIBUTE_" + attr.toUpperCase(Locale.ROOT)) = value - } - } - } - env } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/executor/YarnCoarseGrainedExecutorBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/executor/YarnCoarseGrainedExecutorBackend.scala new file mode 100644 index 000000000000..53e99d992db8 --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/executor/YarnCoarseGrainedExecutorBackend.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import java.net.URL + +import org.apache.spark.SparkEnv +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.util.YarnContainerInfoHelper + +/** + * Custom implementation of CoarseGrainedExecutorBackend for YARN resource manager. + * This class extracts executor log URLs and executor attributes from system environment which + * properties are available for container being set via YARN. + */ +private[spark] class YarnCoarseGrainedExecutorBackend( + rpcEnv: RpcEnv, + driverUrl: String, + executorId: String, + hostname: String, + cores: Int, + userClassPath: Seq[URL], + env: SparkEnv) + extends CoarseGrainedExecutorBackend( + rpcEnv, + driverUrl, + executorId, + hostname, + cores, + userClassPath, + env) with Logging { + + private lazy val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(env.conf) + + override def extractLogUrls: Map[String, String] = { + YarnContainerInfoHelper.getLogUrls(hadoopConfiguration, container = None) + .getOrElse(Map()) + } + + override def extractAttributes: Map[String, String] = { + YarnContainerInfoHelper.getAttributes(hadoopConfiguration, container = None) + .getOrElse(Map()) + } +} + +private[spark] object YarnCoarseGrainedExecutorBackend extends Logging { + + def main(args: Array[String]): Unit = { + val createFn: (RpcEnv, CoarseGrainedExecutorBackend.Arguments, SparkEnv) => + CoarseGrainedExecutorBackend = { case (rpcEnv, arguments, env) => + new YarnCoarseGrainedExecutorBackend(rpcEnv, arguments.driverUrl, arguments.executorId, + arguments.hostname, arguments.cores, arguments.userClassPath, env) + } + val backendArgs = CoarseGrainedExecutorBackend.parseArguments(args, + this.getClass.getCanonicalName.stripSuffix("$")) + CoarseGrainedExecutorBackend.run(backendArgs, createFn) + System.exit(0) + } + +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/util/YarnContainerInfoHelper.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/util/YarnContainerInfoHelper.scala index 96350cdece55..5e39422e868b 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/util/YarnContainerInfoHelper.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/util/YarnContainerInfoHelper.scala @@ -59,6 +59,9 @@ private[spark] object YarnContainerInfoHelper extends Logging { val yarnConf = new YarnConfiguration(conf) Some(Map( "HTTP_SCHEME" -> getYarnHttpScheme(yarnConf), + "NM_HOST" -> getNodeManagerHost(container), + "NM_PORT" -> getNodeManagerPort(container), + "NM_HTTP_PORT" -> getNodeManagerHttpPort(container), "NM_HTTP_ADDRESS" -> getNodeManagerHttpAddress(container), "CLUSTER_ID" -> getClusterId(yarnConf).getOrElse(""), "CONTAINER_ID" -> ConverterUtils.toString(getContainerId(container)), @@ -97,7 +100,22 @@ private[spark] object YarnContainerInfoHelper extends Logging { def getNodeManagerHttpAddress(container: Option[Container]): String = container match { case Some(c) => c.getNodeHttpAddress - case None => System.getenv(Environment.NM_HOST.name()) + ":" + - System.getenv(Environment.NM_HTTP_PORT.name()) + case None => getNodeManagerHost(None) + ":" + getNodeManagerHttpPort(None) } + + def getNodeManagerHost(container: Option[Container]): String = container match { + case Some(c) => c.getNodeHttpAddress.split(":")(0) + case None => System.getenv(Environment.NM_HOST.name()) + } + + def getNodeManagerHttpPort(container: Option[Container]): String = container match { + case Some(c) => c.getNodeHttpAddress.split(":")(1) + case None => System.getenv(Environment.NM_HTTP_PORT.name()) + } + + def getNodeManagerPort(container: Option[Container]): String = container match { + case Some(_) => "-1" // Just return invalid port given we cannot retrieve the value + case None => System.getenv(Environment.NM_PORT.name()) + } + } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index b3c5bbd263ed..56b7dfc13699 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -477,6 +477,9 @@ private object YarnClusterDriver extends Logging with Matchers { val driverAttributes = listener.driverAttributes.get val expectationAttributes = Map( "HTTP_SCHEME" -> YarnContainerInfoHelper.getYarnHttpScheme(yarnConf), + "NM_HOST" -> YarnContainerInfoHelper.getNodeManagerHost(container = None), + "NM_PORT" -> YarnContainerInfoHelper.getNodeManagerPort(container = None), + "NM_HTTP_PORT" -> YarnContainerInfoHelper.getNodeManagerHttpPort(container = None), "NM_HTTP_ADDRESS" -> YarnContainerInfoHelper.getNodeManagerHttpAddress(container = None), "CLUSTER_ID" -> YarnContainerInfoHelper.getClusterId(yarnConf).getOrElse(""), "CONTAINER_ID" -> ConverterUtils.toString(containerId), From 28ced387b9ef0b4c9d3b72913b839786fa0bfa38 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Fri, 15 Feb 2019 14:43:13 -0800 Subject: [PATCH 03/19] [SPARK-26772][YARN] Delete ServiceCredentialProvider and make HadoopDelegationTokenProvider a developer API ## What changes were proposed in this pull request? `HadoopDelegationTokenProvider` has basically the same functionality just like `ServiceCredentialProvider` so the interfaces can be merged. `YARNHadoopDelegationTokenManager` now loads `ServiceCredentialProvider`s in one step. The drawback of this if one provider fails all others are not loaded. `HadoopDelegationTokenManager` loads `HadoopDelegationTokenProvider`s independently so it provides more robust behaviour. In this PR I've I've made the following changes: * Deleted `YARNHadoopDelegationTokenManager` and `ServiceCredentialProvider` * Made `HadoopDelegationTokenProvider` a `DeveloperApi` ## How was this patch tested? Existing unit tests. Closes #23686 from gaborgsomogyi/SPARK-26772. Authored-by: Gabor Somogyi Signed-off-by: Marcelo Vanzin --- ...rk.security.HadoopDelegationTokenProvider} | 0 .../HBaseDelegationTokenProvider.scala | 2 +- .../HadoopDelegationTokenManager.scala | 2 +- .../HadoopFSDelegationTokenProvider.scala | 1 + .../HadoopDelegationTokenProvider.scala | 8 +- ...rk.security.HadoopDelegationTokenProvider} | 0 .../HadoopDelegationTokenManagerSuite.scala | 1 + docs/running-on-yarn.md | 7 -- docs/security.md | 5 ++ ...rk.security.HadoopDelegationTokenProvider} | 0 .../KafkaDelegationTokenProvider.scala | 2 +- .../org/apache/spark/deploy/yarn/Client.scala | 9 +-- .../security/ServiceCredentialProvider.scala | 58 -------------- .../YARNHadoopDelegationTokenManager.scala | 75 ------------------- .../cluster/YarnSchedulerBackend.scala | 4 +- ...oy.yarn.security.ServiceCredentialProvider | 1 - ...ARNHadoopDelegationTokenManagerSuite.scala | 51 ------------- ...rk.security.HadoopDelegationTokenProvider} | 0 .../HiveDelegationTokenProvider.scala | 3 +- 19 files changed, 21 insertions(+), 208 deletions(-) rename core/src/main/resources/META-INF/services/{org.apache.spark.deploy.security.HadoopDelegationTokenProvider => org.apache.spark.security.HadoopDelegationTokenProvider} (100%) rename core/src/main/scala/org/apache/spark/{deploy => }/security/HadoopDelegationTokenProvider.scala (92%) rename core/src/test/resources/META-INF/services/{org.apache.spark.deploy.security.HadoopDelegationTokenProvider => org.apache.spark.security.HadoopDelegationTokenProvider} (100%) rename external/kafka-0-10-token-provider/src/main/resources/META-INF/services/{org.apache.spark.deploy.security.HadoopDelegationTokenProvider => org.apache.spark.security.HadoopDelegationTokenProvider} (100%) delete mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala delete mode 100644 resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala delete mode 100644 resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider delete mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala rename sql/hive/src/main/resources/META-INF/services/{org.apache.spark.deploy.security.HadoopDelegationTokenProvider => org.apache.spark.security.HadoopDelegationTokenProvider} (100%) diff --git a/core/src/main/resources/META-INF/services/org.apache.spark.deploy.security.HadoopDelegationTokenProvider b/core/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider similarity index 100% rename from core/src/main/resources/META-INF/services/org.apache.spark.deploy.security.HadoopDelegationTokenProvider rename to core/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala index e56d03401dca..2e21adac86a1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala @@ -23,12 +23,12 @@ import scala.reflect.runtime.universe import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.token.{Token, TokenIdentifier} import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.security.HadoopDelegationTokenProvider import org.apache.spark.util.Utils private[security] class HBaseDelegationTokenProvider diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 6a18a8dd33d1..4db86ba8f172 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -26,7 +26,6 @@ import java.util.concurrent.{ScheduledExecutorService, TimeUnit} import scala.collection.mutable import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.spark.SparkConf @@ -35,6 +34,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens +import org.apache.spark.security.HadoopDelegationTokenProvider import org.apache.spark.ui.UIUtils import org.apache.spark.util.{ThreadUtils, Utils} diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala index 725eefbda897..ac432e7581e9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdenti import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.security.HadoopDelegationTokenProvider private[deploy] class HadoopFSDelegationTokenProvider extends HadoopDelegationTokenProvider with Logging { diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/security/HadoopDelegationTokenProvider.scala similarity index 92% rename from core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala rename to core/src/main/scala/org/apache/spark/security/HadoopDelegationTokenProvider.scala index 3dc952d54e73..cff8d81443ef 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/security/HadoopDelegationTokenProvider.scala @@ -15,18 +15,20 @@ * limitations under the License. */ -package org.apache.spark.deploy.security +package org.apache.spark.security import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.security.Credentials import org.apache.spark.SparkConf +import org.apache.spark.annotation.DeveloperApi /** + * ::DeveloperApi:: * Hadoop delegation token provider. */ -private[spark] trait HadoopDelegationTokenProvider { +@DeveloperApi +trait HadoopDelegationTokenProvider { /** * Name of the service to provide delegation tokens. This name should be unique. Spark will diff --git a/core/src/test/resources/META-INF/services/org.apache.spark.deploy.security.HadoopDelegationTokenProvider b/core/src/test/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider similarity index 100% rename from core/src/test/resources/META-INF/services/org.apache.spark.deploy.security.HadoopDelegationTokenProvider rename to core/src/test/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider diff --git a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala index 2f36dba05c64..70174f7ff939 100644 --- a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.security.Credentials import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.security.HadoopDelegationTokenProvider private class ExceptionThrowingDelegationTokenProvider extends HadoopDelegationTokenProvider { ExceptionThrowingDelegationTokenProvider.constructed = true diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 8f1a12726b06..6ee4b3d4103b 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -538,13 +538,6 @@ for: filesystem if `spark.yarn.stagingDir` is not set); - if Hadoop federation is enabled, all the federated filesystems in the configuration. -The YARN integration also supports custom delegation token providers using the Java Services -mechanism (see `java.util.ServiceLoader`). Implementations of -`org.apache.spark.deploy.yarn.security.ServiceCredentialProvider` can be made available to Spark -by listing their names in the corresponding file in the jar's `META-INF/services` directory. These -providers can be disabled individually by setting `spark.security.credentials.{service}.enabled` to -`false`, where `{service}` is the name of the credential provider. - ## YARN-specific Kerberos Configuration diff --git a/docs/security.md b/docs/security.md index d2cff41eb0f7..20492d871b1a 100644 --- a/docs/security.md +++ b/docs/security.md @@ -756,6 +756,11 @@ If an application needs to interact with other secure Hadoop filesystems, their explicitly provided to Spark at launch time. This is done by listing them in the `spark.kerberos.access.hadoopFileSystems` property, described in the configuration section below. +Spark also supports custom delegation token providers using the Java Services +mechanism (see `java.util.ServiceLoader`). Implementations of +`org.apache.spark.security.HadoopDelegationTokenProvider` can be made available to Spark +by listing their names in the corresponding file in the jar's `META-INF/services` directory. + Delegation token support is currently only supported in YARN and Mesos modes. Consult the deployment-specific page for more information. diff --git a/external/kafka-0-10-token-provider/src/main/resources/META-INF/services/org.apache.spark.deploy.security.HadoopDelegationTokenProvider b/external/kafka-0-10-token-provider/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider similarity index 100% rename from external/kafka-0-10-token-provider/src/main/resources/META-INF/services/org.apache.spark.deploy.security.HadoopDelegationTokenProvider rename to external/kafka-0-10-token-provider/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider diff --git a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaDelegationTokenProvider.scala b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaDelegationTokenProvider.scala index c69e8a320059..cba4b40ca7f4 100644 --- a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaDelegationTokenProvider.scala +++ b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaDelegationTokenProvider.scala @@ -25,9 +25,9 @@ import org.apache.hadoop.security.Credentials import org.apache.kafka.common.security.auth.SecurityProtocol.{SASL_PLAINTEXT, SASL_SSL, SSL} import org.apache.spark.SparkConf -import org.apache.spark.deploy.security.HadoopDelegationTokenProvider import org.apache.spark.internal.Logging import org.apache.spark.internal.config.Kafka +import org.apache.spark.security.HadoopDelegationTokenProvider private[spark] class KafkaDelegationTokenProvider extends HadoopDelegationTokenProvider with Logging { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 7523e3c42c53..6ca81fb97c75 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -21,7 +21,6 @@ import java.io.{FileSystem => _, _} import java.net.{InetAddress, UnknownHostException, URI} import java.nio.ByteBuffer import java.nio.charset.StandardCharsets -import java.security.PrivilegedExceptionAction import java.util.{Locale, Properties, UUID} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -34,9 +33,9 @@ import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission -import org.apache.hadoop.io.{DataOutputBuffer, Text} +import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.MRJobConfig -import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -50,8 +49,8 @@ import org.apache.hadoop.yarn.util.Records import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{SparkApplication, SparkHadoopUtil} +import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Python._ @@ -315,7 +314,7 @@ private[spark] class Client( val credentials = currentUser.getCredentials() if (isClusterMode) { - val credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf, null) + val credentialManager = new HadoopDelegationTokenManager(sparkConf, hadoopConf, null) credentialManager.obtainDelegationTokens(credentials) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala deleted file mode 100644 index cc24ac4d9bcf..000000000000 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.security.{Credentials, UserGroupInformation} - -import org.apache.spark.SparkConf - -/** - * A credential provider for a service. User must implement this if they need to access a - * secure service from Spark. - */ -trait ServiceCredentialProvider { - - /** - * Name of the service to provide credentials. This name should unique, Spark internally will - * use this name to differentiate credential provider. - */ - def serviceName: String - - /** - * Returns true if credentials are required by this service. By default, it is based on whether - * Hadoop security is enabled. - */ - def credentialsRequired(hadoopConf: Configuration): Boolean = { - UserGroupInformation.isSecurityEnabled - } - - /** - * Obtain credentials for this service and get the time of the next renewal. - * - * @param hadoopConf Configuration of current Hadoop Compatible system. - * @param sparkConf Spark configuration. - * @param creds Credentials to add tokens and security keys to. - * @return If this Credential is renewable and can be renewed, return the time of the next - * renewal, otherwise None should be returned. - */ - def obtainCredentials( - hadoopConf: Configuration, - sparkConf: SparkConf, - creds: Credentials): Option[Long] -} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala deleted file mode 100644 index fc1f75254c57..000000000000 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import java.util.ServiceLoader - -import scala.collection.JavaConverters._ - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.security.Credentials - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.util.Utils - -/** - * This class loads delegation token providers registered under the YARN-specific - * [[ServiceCredentialProvider]] interface, as well as the builtin providers defined - * in [[HadoopDelegationTokenManager]]. - */ -private[spark] class YARNHadoopDelegationTokenManager( - _sparkConf: SparkConf, - _hadoopConf: Configuration, - _schedulerRef: RpcEndpointRef) - extends HadoopDelegationTokenManager(_sparkConf, _hadoopConf, _schedulerRef) { - - private val credentialProviders = { - ServiceLoader.load(classOf[ServiceCredentialProvider], Utils.getContextOrSparkClassLoader) - .asScala - .toList - .filter { p => isServiceEnabled(p.serviceName) } - .map { p => (p.serviceName, p) } - .toMap - } - if (credentialProviders.nonEmpty) { - logDebug("Using the following YARN-specific credential providers: " + - s"${credentialProviders.keys.mkString(", ")}.") - } - - override def obtainDelegationTokens(creds: Credentials): Long = { - val superInterval = super.obtainDelegationTokens(creds) - - credentialProviders.values.flatMap { provider => - if (provider.credentialsRequired(hadoopConf)) { - provider.obtainCredentials(hadoopConf, sparkConf, creds) - } else { - logDebug(s"Service ${provider.serviceName} does not require a token." + - s" Check your configuration to see if security is disabled or not.") - None - } - }.foldLeft(superInterval)(math.min) - } - - // For testing. - override def isProviderLoaded(serviceName: String): Boolean = { - credentialProviders.contains(serviceName) || super.isProviderLoaded(serviceName) - } - -} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 821fbcd956d5..78cd6a200ac2 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -31,14 +31,12 @@ import org.eclipse.jetty.servlet.{FilterHolder, FilterMapping} import org.apache.spark.SparkContext import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.internal.config.UI._ import org.apache.spark.rpc._ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.ui.JettyUtils import org.apache.spark.util.{RpcUtils, ThreadUtils} /** @@ -223,7 +221,7 @@ private[spark] abstract class YarnSchedulerBackend( } override protected def createTokenManager(): Option[HadoopDelegationTokenManager] = { - Some(new YARNHadoopDelegationTokenManager(sc.conf, sc.hadoopConfiguration, driverEndpoint)) + Some(new HadoopDelegationTokenManager(sc.conf, sc.hadoopConfiguration, driverEndpoint)) } /** diff --git a/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider b/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider deleted file mode 100644 index f31c23269313..000000000000 --- a/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider +++ /dev/null @@ -1 +0,0 @@ -org.apache.spark.deploy.yarn.security.YARNTestCredentialProvider diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala deleted file mode 100644 index f00453cb9c59..000000000000 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.security.Credentials - -import org.apache.spark.{SparkConf, SparkFunSuite} - -class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite { - private var credentialManager: YARNHadoopDelegationTokenManager = null - private var sparkConf: SparkConf = null - private var hadoopConf: Configuration = null - - override def beforeAll(): Unit = { - super.beforeAll() - sparkConf = new SparkConf() - hadoopConf = new Configuration() - } - - test("Correctly loads credential providers") { - credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf, null) - assert(credentialManager.isProviderLoaded("yarn-test")) - } -} - -class YARNTestCredentialProvider extends ServiceCredentialProvider { - override def serviceName: String = "yarn-test" - - override def credentialsRequired(conf: Configuration): Boolean = true - - override def obtainCredentials( - hadoopConf: Configuration, - sparkConf: SparkConf, - creds: Credentials): Option[Long] = None -} diff --git a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.deploy.security.HadoopDelegationTokenProvider b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider similarity index 100% rename from sql/hive/src/main/resources/META-INF/services/org.apache.spark.deploy.security.HadoopDelegationTokenProvider rename to sql/hive/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/security/HiveDelegationTokenProvider.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/security/HiveDelegationTokenProvider.scala index c0c46187b13a..faee405d70cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/security/HiveDelegationTokenProvider.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/security/HiveDelegationTokenProvider.scala @@ -23,7 +23,6 @@ import java.security.PrivilegedExceptionAction import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.metadata.Hive @@ -33,9 +32,9 @@ import org.apache.hadoop.security.token.Token import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.security.HadoopDelegationTokenProvider import org.apache.spark.internal.Logging import org.apache.spark.internal.config.KEYTAB +import org.apache.spark.security.HadoopDelegationTokenProvider import org.apache.spark.util.Utils private[spark] class HiveDelegationTokenProvider From 3d6066e9b6bcd24d1ece46f80689b1ff0fcddea3 Mon Sep 17 00:00:00 2001 From: Peter Parente Date: Fri, 15 Feb 2019 18:08:06 -0800 Subject: [PATCH 04/19] [SPARK-21094][PYTHON] Add popen_kwargs to launch_gateway ## What changes were proposed in this pull request? Allow the caller to customize the py4j JVM subprocess pipes and buffers for programmatic capturing of its output. https://issues.apache.org/jira/browse/SPARK-21094 has more detail about the use case. ## How was this patch tested? Tested by running the pyspark unit tests locally. Closes #18339 from parente/feature/SPARK-21094-popen-args. Lead-authored-by: Peter Parente Co-authored-by: Peter Parente Signed-off-by: Holden Karau --- python/pyspark/java_gateway.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index d8315c63a8fc..5a55401db53d 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -36,15 +36,21 @@ from pyspark.util import _exception_message -def launch_gateway(conf=None): +def launch_gateway(conf=None, popen_kwargs=None): """ launch jvm gateway :param conf: spark configuration passed to spark-submit + :param popen_kwargs: Dictionary of kwargs to pass to Popen when spawning + the py4j JVM. This is a developer feature intended for use in + customizing how pyspark interacts with the py4j JVM (e.g., capturing + stdout/stderr). :return: """ if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"] + # Process already exists + proc = None else: SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the @@ -75,15 +81,20 @@ def launch_gateway(conf=None): env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file # Launch the Java gateway. + popen_kwargs = {} if popen_kwargs is None else popen_kwargs # We open a pipe to stdin so that the Java gateway can die when the pipe is broken + popen_kwargs['stdin'] = PIPE + # We always set the necessary environment variables. + popen_kwargs['env'] = env if not on_windows: # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) + popen_kwargs['preexec_fn'] = preexec_func + proc = Popen(command, **popen_kwargs) else: # preexec_fn not supported on Windows - proc = Popen(command, stdin=PIPE, env=env) + proc = Popen(command, **popen_kwargs) # Wait for the file to appear, or for the process to exit, whichever happens first. while not proc.poll() and not os.path.isfile(conn_info_file): @@ -118,6 +129,8 @@ def killChild(): gateway = JavaGateway( gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret, auto_convert=True)) + # Store a reference to the Popen object for use by the caller (e.g., in reading stdout/stderr) + gateway.proc = proc # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") From 4cabab81716463c526fc24b385aa046951898151 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sat, 16 Feb 2019 14:44:37 +0800 Subject: [PATCH 05/19] [SPARK-26673][FOLLOWUP][SQL] File source V2: remove duplicated broadcast object in FileWriterFactory ## What changes were proposed in this pull request? This is a followup PR to fix two issues in #23601: 1. the class `FileWriterFactory` contains `conf: SerializableConfiguration` as a member, which is duplicated with `WriteJobDescription. serializableHadoopConf `. By removing it we can reduce the broadcast task binary size by around 70KB 2. The test suite `OrcV1QuerySuite`/`OrcV1QuerySuite`/`OrcV1PartitionDiscoverySuite` didn't change the configuration `SQLConf.USE_V1_SOURCE_WRITER_LIST` to `"orc"`. We should set the conf. ## How was this patch tested? Unit test Closes #23800 from gengliangwang/reduceWriteTaskSize. Authored-by: Gengliang Wang Signed-off-by: Hyukjin Kwon --- .../spark/sql/execution/datasources/v2/FileBatchWrite.scala | 3 +-- .../sql/execution/datasources/v2/FileWriterFactory.scala | 5 ++--- .../datasources/orc/OrcPartitionDiscoverySuite.scala | 5 ++++- .../spark/sql/execution/datasources/orc/OrcQuerySuite.scala | 5 ++++- .../sql/execution/datasources/orc/OrcV1FilterSuite.scala | 5 ++++- 5 files changed, 15 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala index 7b836111a59b..db31927fa73b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala @@ -46,8 +46,7 @@ class FileBatchWrite( } override def createBatchWriterFactory(): DataWriterFactory = { - val conf = new SerializableConfiguration(job.getConfiguration) - FileWriterFactory(description, committer, conf) + FileWriterFactory(description, committer) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala index be9c180a901b..eb573b317142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala @@ -29,8 +29,7 @@ import org.apache.spark.util.SerializableConfiguration case class FileWriterFactory ( description: WriteJobDescription, - committer: FileCommitProtocol, - conf: SerializableConfiguration) extends DataWriterFactory { + committer: FileCommitProtocol) extends DataWriterFactory { override def createWriter(partitionId: Int, realTaskId: Long): DataWriter[InternalRow] = { val taskAttemptContext = createTaskAttemptContext(partitionId) committer.setupTask(taskAttemptContext) @@ -46,7 +45,7 @@ case class FileWriterFactory ( val taskId = new TaskID(jobId, TaskType.MAP, partitionId) val taskAttemptId = new TaskAttemptID(taskId, 0) // Set up the configuration object - val hadoopConf = conf.value + val hadoopConf = description.serializableHadoopConf.value hadoopConf.set("mapreduce.job.id", jobId.toString) hadoopConf.set("mapreduce.task.id", taskId.toString) hadoopConf.set("mapreduce.task.attempt.id", taskAttemptId.toString) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala index 4a695ac74c47..bc5a30e97fae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala @@ -232,5 +232,8 @@ class OrcPartitionDiscoverySuite extends OrcPartitionDiscoveryTest with SharedSQ class OrcV1PartitionDiscoverySuite extends OrcPartitionDiscoveryTest with SharedSQLContext { override protected def sparkConf: SparkConf = - super.sparkConf.set(SQLConf.USE_V1_SOURCE_READER_LIST, "orc") + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_READER_LIST, "orc") + .set(SQLConf.USE_V1_SOURCE_WRITER_LIST, "orc") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index c0d26011e791..9bf122766687 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -690,5 +690,8 @@ class OrcQuerySuite extends OrcQueryTest with SharedSQLContext { class OrcV1QuerySuite extends OrcQuerySuite { override protected def sparkConf: SparkConf = - super.sparkConf.set(SQLConf.USE_V1_SOURCE_READER_LIST, "orc") + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_READER_LIST, "orc") + .set(SQLConf.USE_V1_SOURCE_WRITER_LIST, "orc") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala index cf5bbb3fff70..5a1bf9b43756 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala @@ -31,7 +31,10 @@ import org.apache.spark.sql.internal.SQLConf class OrcV1FilterSuite extends OrcFilterSuite { override protected def sparkConf: SparkConf = - super.sparkConf.set(SQLConf.USE_V1_SOURCE_READER_LIST, "orc") + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_READER_LIST, "orc") + .set(SQLConf.USE_V1_SOURCE_WRITER_LIST, "orc") override def checkFilterPredicate( df: DataFrame, From 4dce45a5992e6a89a26b5a0739b33cfeaf979208 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sat, 16 Feb 2019 17:11:36 +0800 Subject: [PATCH 06/19] [SPARK-26744][SQL] Support schema validation in FileDataSourceV2 framework ## What changes were proposed in this pull request? The file source has a schema validation feature, which validates 2 schemas: 1. the user-specified schema when reading. 2. the schema of input data when writing. If a file source doesn't support the schema, we can fail the query earlier. This PR is to implement the same feature in the `FileDataSourceV2` framework. Comparing to `FileFormat`, `FileDataSourceV2` has multiple layers. The API is added in two places: 1. Read path: the table schema is determined in `TableProvider.getTable`. The actual read schema can be a subset of the table schema. This PR proposes to validate the actual read schema in `FileScan`. 2. Write path: validate the actual output schema in `FileWriteBuilder`. ## How was this patch tested? Unit test Closes #23714 from gengliangwang/schemaValidationV2. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- .../execution/datasources/v2/FileScan.scala | 33 +++- .../datasources/v2/FileWriteBuilder.scala | 24 ++- .../datasources/v2/orc/OrcDataSourceV2.scala | 19 ++- .../datasources/v2/orc/OrcScan.scala | 10 +- .../datasources/v2/orc/OrcWriteBuilder.scala | 6 + .../spark/sql/FileBasedDataSourceSuite.scala | 152 ++++++++++-------- 6 files changed, 167 insertions(+), 77 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 3615b15be6fd..bdd6a48df20c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -18,15 +18,16 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.hadoop.fs.Path -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.execution.PartitionedFileUtil import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, Scan} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} abstract class FileScan( sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex) extends Scan with Batch { + fileIndex: PartitioningAwareFileIndex, + readSchema: StructType) extends Scan with Batch { /** * Returns whether a file with `path` could be split or not. */ @@ -34,6 +35,22 @@ abstract class FileScan( false } + /** + * Returns whether this format supports the given [[DataType]] in write path. + * By default all data types are supported. + */ + def supportsDataType(dataType: DataType): Boolean = true + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def formatName(): String = "ORC" + * }}} + */ + def formatName: String + protected def partitions: Seq[FilePartition] = { val selectedPartitions = fileIndex.listFiles(Seq.empty, Seq.empty) val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions) @@ -57,5 +74,13 @@ abstract class FileScan( partitions.toArray } - override def toBatch: Batch = this + override def toBatch: Batch = { + readSchema.foreach { field => + if (!supportsDataType(field.dataType)) { + throw new AnalysisException( + s"$formatName data source does not support ${field.dataType.catalogString} data type.") + } + } + this + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala index ce9b52f29d7b..6a94248a6f0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.writer.{BatchWrite, SupportsSaveMode, WriteBuilder} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration abstract class FileWriteBuilder(options: DataSourceOptions) @@ -104,12 +104,34 @@ abstract class FileWriteBuilder(options: DataSourceOptions) options: Map[String, String], dataSchema: StructType): OutputWriterFactory + /** + * Returns whether this format supports the given [[DataType]] in write path. + * By default all data types are supported. + */ + def supportsDataType(dataType: DataType): Boolean = true + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def formatName(): String = "ORC" + * }}} + */ + def formatName: String + private def validateInputs(): Unit = { assert(schema != null, "Missing input data schema") assert(queryId != null, "Missing query ID") assert(mode != null, "Missing save mode") assert(options.paths().length == 1) DataSource.validateSchema(schema) + schema.foreach { field => + if (!supportsDataType(field.dataType)) { + throw new AnalysisException( + s"$formatName data source does not support ${field.dataType.catalogString} data type.") + } + } } private def getJobInstance(hadoopConf: Configuration, path: Path): Job = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala index 74739b4fe2d4..f279af49ba9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala @@ -20,7 +20,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.sources.v2.{DataSourceOptions, Table} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ class OrcDataSourceV2 extends FileDataSourceV2 { @@ -42,3 +42,20 @@ class OrcDataSourceV2 extends FileDataSourceV2 { OrcTable(tableName, sparkSession, options, Some(schema)) } } + +object OrcDataSourceV2 { + def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportsDataType(f.dataType) } + + case ArrayType(elementType, _) => supportsDataType(elementType) + + case MapType(keyType, valueType, _) => + supportsDataType(keyType) && supportsDataType(valueType) + + case udt: UserDefinedType[_] => supportsDataType(udt.sqlType) + + case _ => false + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index a792ad318b39..3c5dc1f50d7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration case class OrcScan( @@ -31,7 +31,7 @@ case class OrcScan( hadoopConf: Configuration, fileIndex: PartitioningAwareFileIndex, dataSchema: StructType, - readSchema: StructType) extends FileScan(sparkSession, fileIndex) { + readSchema: StructType) extends FileScan(sparkSession, fileIndex, readSchema) { override def isSplitable(path: Path): Boolean = true override def createReaderFactory(): PartitionReaderFactory = { @@ -40,4 +40,10 @@ case class OrcScan( OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, dataSchema, fileIndex.partitionSchema, readSchema) } + + override def supportsDataType(dataType: DataType): Boolean = { + OrcDataSourceV2.supportsDataType(dataType) + } + + override def formatName: String = "ORC" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala index 80429d91d5e4..1aec4d872a64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala @@ -63,4 +63,10 @@ class OrcWriteBuilder(options: DataSourceOptions) extends FileWriteBuilder(optio } } } + + override def supportsDataType(dataType: DataType): Boolean = { + OrcDataSourceV2.supportsDataType(dataType) + } + + override def formatName: String = "ORC" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index fc87b0462bc2..e0c0484593d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -329,83 +329,97 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") { withTempDir { dir => val tempDir = new File(dir, "files").getCanonicalPath - // TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well. - withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") { - // write path - Seq("csv", "json", "parquet", "orc").foreach { format => - var msg = intercept[AnalysisException] { - sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.contains("Cannot save interval data type into external storage.")) - - msg = intercept[AnalysisException] { - spark.udf.register("testType", () => new IntervalData()) - sql("select testType()").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) + Seq(true, false).foreach { useV1 => + val useV1List = if (useV1) { + "orc" + } else { + "" } + def errorMessage(format: String, isWrite: Boolean): String = { + if (isWrite && (useV1 || format != "orc")) { + "cannot save interval data type into external storage." + } else { + s"$format data source does not support calendarinterval data type." + } + } + + withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) { + // write path + Seq("csv", "json", "parquet", "orc").foreach { format => + var msg = intercept[AnalysisException] { + sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, true))) + } - // read path - Seq("parquet", "csv").foreach { format => - var msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) - - msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) + // read path + Seq("parquet", "csv").foreach { format => + var msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false))) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false))) + } } } } } test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") { - // TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well. - withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "orc", - SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") { - withTempDir { dir => - val tempDir = new File(dir, "files").getCanonicalPath - - Seq("parquet", "csv", "orc").foreach { format => - // write path - var msg = intercept[AnalysisException] { - sql("select null").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) - - msg = intercept[AnalysisException] { - spark.udf.register("testType", () => new NullData()) - sql("select testType()").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) - - // read path - msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", NullType, true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) - - msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) + Seq(true, false).foreach { useV1 => + val useV1List = if (useV1) { + "orc" + } else { + "" + } + def errorMessage(format: String): String = { + s"$format data source does not support null data type." + } + withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> useV1List, + SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) { + withTempDir { dir => + val tempDir = new File(dir, "files").getCanonicalPath + + Seq("parquet", "csv", "orc").foreach { format => + // write path + var msg = intercept[AnalysisException] { + sql("select null").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(errorMessage(format))) + + msg = intercept[AnalysisException] { + spark.udf.register("testType", () => new NullData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(errorMessage(format))) + + // read path + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", NullType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(errorMessage(format))) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(errorMessage(format))) + } } } } From 5d8a934c13420dcce9d68cbf1f5f30381978d32e Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 16 Feb 2019 16:51:01 -0600 Subject: [PATCH 07/19] [SPARK-26721][ML] Avoid per-tree normalization in featureImportance for GBT ## What changes were proposed in this pull request? Our feature importance calculation is taken from sklearn's one, which has been recently fixed (in https://github.com/scikit-learn/scikit-learn/pull/11176). Citing the description of that PR: > Because the feature importances are (currently, by default) normalized and then averaged, feature importances from later stages are overweighted. The PR performs a fix similar to sklearn's one. The per-tree normalization of the feature importance is skipped and GBT. Credits for pointing out clearly the issue and the sklearn's PR to Daniel Jumper. ## How was this patch tested? modified UT, checked that the computed `featureImportance` in that test is similar to sklearn's one (ti can't be the same, because the trees may be slightly different) Closes #23773 from mgaido91/SPARK-26721. Authored-by: Marco Gaido Signed-off-by: Sean Owen --- .../ml/classification/GBTClassifier.scala | 5 ++-- .../spark/ml/regression/GBTRegressor.scala | 3 ++- .../org/apache/spark/ml/tree/treeModels.scala | 23 +++++++++++++++---- .../classification/GBTClassifierSuite.scala | 3 ++- .../ml/regression/GBTRegressorSuite.scala | 3 ++- 5 files changed, 28 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index abe2d1febfdf..a5ed4a38a886 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -341,11 +341,12 @@ class GBTClassificationModel private[ml]( * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) * and follows the implementation from scikit-learn. - + * * See `DecisionTreeClassificationModel.featureImportances` */ @Since("2.0.0") - lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) + lazy val featureImportances: Vector = + TreeEnsembleModel.featureImportances(trees, numFeatures, perTreeNormalization = false) /** Raw prediction for the positive class. */ private def margin(features: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 9a5b7d59e9ae..9f0f567a5b53 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -285,7 +285,8 @@ class GBTRegressionModel private[ml]( * @see `DecisionTreeRegressionModel.featureImportances` */ @Since("2.0.0") - lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) + lazy val featureImportances: Vector = + TreeEnsembleModel.featureImportances(trees, numFeatures, perTreeNormalization = false) /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldGBTModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 51d5d5c58c57..e95c55f6048f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -135,7 +135,7 @@ private[ml] object TreeEnsembleModel { * - Average over trees: * - importance(feature j) = sum (over nodes which split on feature j) of the gain, * where gain is scaled by the number of instances passing through node - * - Normalize importances for tree to sum to 1. + * - Normalize importances for tree to sum to 1 (only if `perTreeNormalization` is `true`). * - Normalize feature importance vector to sum to 1. * * References: @@ -145,9 +145,15 @@ private[ml] object TreeEnsembleModel { * @param numFeatures Number of features in model (even if not all are explicitly used by * the model). * If -1, then numFeatures is set based on the max feature index in all trees. + * @param perTreeNormalization By default this is set to `true` and it means that the importances + * of each tree are normalized before being summed. If set to `false`, + * the normalization is skipped. * @return Feature importance values, of length numFeatures. */ - def featureImportances[M <: DecisionTreeModel](trees: Array[M], numFeatures: Int): Vector = { + def featureImportances[M <: DecisionTreeModel]( + trees: Array[M], + numFeatures: Int, + perTreeNormalization: Boolean = true): Vector = { val totalImportances = new OpenHashMap[Int, Double]() trees.foreach { tree => // Aggregate feature importance vector for this tree @@ -155,10 +161,19 @@ private[ml] object TreeEnsembleModel { computeFeatureImportance(tree.rootNode, importances) // Normalize importance vector for this tree, and add it to total. // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count? - val treeNorm = importances.map(_._2).sum + val treeNorm = if (perTreeNormalization) { + importances.map(_._2).sum + } else { + // We won't use it + Double.NaN + } if (treeNorm != 0) { importances.foreach { case (idx, impt) => - val normImpt = impt / treeNorm + val normImpt = if (perTreeNormalization) { + impt / treeNorm + } else { + impt + } totalImportances.changeValue(idx, normImpt, _ + normImpt) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index cedbaf1858ef..cd59900c521c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -363,7 +363,8 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1") val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances val mostIF = importanceFeatures.argmax - assert(mostImportantFeature !== mostIF) + assert(mostIF === 1) + assert(importances(mostImportantFeature) !== importanceFeatures(mostIF)) } test("model evaluateEachIteration") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index b145c7a3dc95..46fa3767efdc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -200,7 +200,8 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1") val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances val mostIF = importanceFeatures.argmax - assert(mostImportantFeature !== mostIF) + assert(mostIF === 1) + assert(importances(mostImportantFeature) !== importanceFeatures(mostIF)) } test("model evaluateEachIteration") { From dcdbd06b687fafbf29df504949db0a5f77608c8e Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 18 Feb 2019 08:05:49 +0900 Subject: [PATCH 08/19] [SPARK-26897][SQL][TEST] Update Spark 2.3.x testing from HiveExternalCatalogVersionsSuite ## What changes were proposed in this pull request? The maintenance release of `branch-2.3` (v2.3.3) vote passed, so this issue updates PROCESS_TABLES.testingVersions in HiveExternalCatalogVersionsSuite ## How was this patch tested? Pass the Jenkins. Closes #23807 from maropu/SPARK-26897. Authored-by: Takeshi Yamamuro Signed-off-by: Takeshi Yamamuro --- .../spark/sql/hive/HiveExternalCatalogVersionsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index dd0e1bd0fe30..8086f75b4bae 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -206,7 +206,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.3.2", "2.4.0") + val testingVersions = Seq("2.3.3", "2.4.0") protected var spark: SparkSession = _ From 36902e10c6395cb378eb8743fe94ccd0aa33e616 Mon Sep 17 00:00:00 2001 From: Ala Luszczak Date: Mon, 18 Feb 2019 10:39:31 +0800 Subject: [PATCH 09/19] [SPARK-26878] QueryTest.compare() does not handle maps with array keys correctly ## What changes were proposed in this pull request? The previous strategy for comparing Maps leveraged sorting (key, value) tuples by their _.toString. However, the _.toString representation of an arrays has nothing to do with it's content. If a map has array keys, it's (key, value) pairs would be compared with other maps essentially at random. This could results in false negatives in tests. This changes first compares keys together to find the matching ones, and then compares associated values. ## How was this patch tested? New unit test added. Closes #23789 from ala/compare-map. Authored-by: Ala Luszczak Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/DatasetSuite.scala | 37 +++++++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 6 +-- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 8c34e47314db..64c4aabd4cdc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} +import org.scalatest.exceptions.TestFailedException + import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.sql.catalyst.ScroogeLikeExample import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} @@ -67,6 +69,41 @@ class DatasetSuite extends QueryTest with SharedSQLContext { data: _*) } + test("toDS should compare map with byte array keys correctly") { + // Choose the order of arrays in such way, that sorting keys of different maps by _.toString + // will not incidentally put equal keys together. + val arrays = (1 to 5).map(_ => Array[Byte](0.toByte, 0.toByte)).sortBy(_.toString).toArray + arrays(0)(1) = 1.toByte + arrays(1)(1) = 2.toByte + arrays(2)(1) = 2.toByte + arrays(3)(1) = 1.toByte + + val mapA = Map(arrays(0) -> "one", arrays(2) -> "two") + val subsetOfA = Map(arrays(0) -> "one") + val equalToA = Map(arrays(1) -> "two", arrays(3) -> "one") + val notEqualToA1 = Map(arrays(1) -> "two", arrays(3) -> "not one") + val notEqualToA2 = Map(arrays(1) -> "two", arrays(4) -> "one") + + // Comparing map with itself + checkDataset(Seq(mapA).toDS(), mapA) + + // Comparing map with equivalent map + checkDataset(Seq(equalToA).toDS(), mapA) + checkDataset(Seq(mapA).toDS(), equalToA) + + // Comparing map with it's subset + intercept[TestFailedException](checkDataset(Seq(subsetOfA).toDS(), mapA)) + intercept[TestFailedException](checkDataset(Seq(mapA).toDS(), subsetOfA)) + + // Comparing map with another map differing by single value + intercept[TestFailedException](checkDataset(Seq(notEqualToA1).toDS(), mapA)) + intercept[TestFailedException](checkDataset(Seq(mapA).toDS(), notEqualToA1)) + + // Comparing map with another map differing by single key + intercept[TestFailedException](checkDataset(Seq(notEqualToA2).toDS(), mapA)) + intercept[TestFailedException](checkDataset(Seq(mapA).toDS(), notEqualToA2)) + } + test("toDS with RDD") { val ds = sparkContext.makeRDD(Seq("a", "b", "c"), 3).toDS() checkDataset( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index d83deb17a090..f8298c9da97e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -341,9 +341,9 @@ object QueryTest { case (a: Array[_], b: Array[_]) => a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} case (a: Map[_, _], b: Map[_, _]) => - val entries1 = a.iterator.toSeq.sortBy(_.toString()) - val entries2 = b.iterator.toSeq.sortBy(_.toString()) - compare(entries1, entries2) + a.size == b.size && a.keys.forall { aKey => + b.keys.find(bKey => compare(aKey, bKey)).exists(bKey => compare(a(aKey), b(bKey))) + } case (a: Iterable[_], b: Iterable[_]) => a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} case (a: Product, b: Product) => From e2b8cc65cd579374ddbd70b93c9fcefe9b8873d9 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Mon, 18 Feb 2019 11:24:36 +0800 Subject: [PATCH 10/19] [SPARK-26897][SQL][TEST][FOLLOW-UP] Remove workaround for 2.2.0 and 2.1.x in HiveExternalCatalogVersionsSuite ## What changes were proposed in this pull request? This pr just removed workaround for 2.2.0 and 2.1.x in HiveExternalCatalogVersionsSuite. ## How was this patch tested? Pass the Jenkins. Closes #23817 from maropu/SPARK-26607-FOLLOWUP. Authored-by: Takeshi Yamamuro Signed-off-by: Hyukjin Kwon --- .../hive/HiveExternalCatalogVersionsSuite.scala | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 8086f75b4bae..1dd60c6757d4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -260,19 +260,10 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { // SPARK-22356: overlapped columns between data and partition schema in data source tables val tbl_with_col_overlap = s"tbl_with_col_overlap_$index" - // For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0, 2.2.1, 2.3+ - if (testingVersions(index).startsWith("2.1") || testingVersions(index) == "2.2.0") { - spark.sql("msck repair table " + tbl_with_col_overlap) - assert(spark.table(tbl_with_col_overlap).columns === Array("i", "j", "p")) - checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) - assert(sql("desc " + tbl_with_col_overlap).select("col_name") - .as[String].collect().mkString(",").contains("i,j,p")) - } else { - assert(spark.table(tbl_with_col_overlap).columns === Array("i", "p", "j")) - checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) - assert(sql("desc " + tbl_with_col_overlap).select("col_name") - .as[String].collect().mkString(",").contains("i,p,j")) - } + assert(spark.table(tbl_with_col_overlap).columns === Array("i", "p", "j")) + checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) + assert(sql("desc " + tbl_with_col_overlap).select("col_name") + .as[String].collect().mkString(",").contains("i,p,j")) } } } From 4a4e7aeca79738d5788628d67d97d704f067e8d7 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 18 Feb 2019 11:48:10 +0800 Subject: [PATCH 11/19] [SPARK-26887][SQL][PYTHON][NS] Create datetime.date directly instead of creating datetime64 as intermediate data. ## What changes were proposed in this pull request? Currently `DataFrame.toPandas()` with arrow enabled or `ArrowStreamPandasSerializer` for pandas UDF with pyarrow<0.12 creates `datetime64[ns]` type series as intermediate data and then convert to `datetime.date` series, but the intermediate `datetime64[ns]` might cause an overflow even if the date is valid. ``` >>> import datetime >>> >>> t = [datetime.date(2262, 4, 12), datetime.date(2263, 4, 12)] >>> >>> df = spark.createDataFrame(t, 'date') >>> df.show() +----------+ | value| +----------+ |2262-04-12| |2263-04-12| +----------+ >>> >>> spark.conf.set("spark.sql.execution.arrow.enabled", "true") >>> >>> df.toPandas() value 0 1677-09-21 1 1678-09-21 ``` We should avoid creating such intermediate data and create `datetime.date` series directly instead. ## How was this patch tested? Modified some tests to include the date which overflow caused by the intermediate conversion. Run tests with pyarrow 0.8, 0.10, 0.11, 0.12 in my local environment. Closes #23795 from ueshin/issues/SPARK-26887/date_as_object. Authored-by: Takuya UESHIN Signed-off-by: Hyukjin Kwon --- python/pyspark/serializers.py | 5 +- python/pyspark/sql/dataframe.py | 5 +- python/pyspark/sql/tests/test_arrow.py | 5 +- .../sql/tests/test_pandas_udf_scalar.py | 3 +- python/pyspark/sql/types.py | 54 ++++++++++++------- 5 files changed, 44 insertions(+), 28 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 3db259551fa8..a2c59fedfc8c 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -311,10 +311,9 @@ def __init__(self, timezone, safecheck): def arrow_to_pandas(self, arrow_column): from pyspark.sql.types import from_arrow_type, \ - _check_series_convert_date, _check_series_localize_timestamps + _arrow_column_to_pandas, _check_series_localize_timestamps - s = arrow_column.to_pandas() - s = _check_series_convert_date(s, from_arrow_type(arrow_column.type)) + s = _arrow_column_to_pandas(arrow_column, from_arrow_type(arrow_column.type)) s = _check_series_localize_timestamps(s, self._timezone) return s diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a1056d0b787e..472d2969b3e1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2107,14 +2107,13 @@ def toPandas(self): # of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled. if use_arrow: try: - from pyspark.sql.types import _check_dataframe_convert_date, \ + from pyspark.sql.types import _arrow_table_to_pandas, \ _check_dataframe_localize_timestamps import pyarrow batches = self._collectAsArrow() if len(batches) > 0: table = pyarrow.Table.from_batches(batches) - pdf = table.to_pandas() - pdf = _check_dataframe_convert_date(pdf, self.schema) + pdf = _arrow_table_to_pandas(table, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 8a62500b17f2..38a6402c0132 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -68,7 +68,9 @@ def setUpClass(cls): (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), - date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3)), + (u"d", 4, 40, 1.0, 8.0, Decimal("8.0"), + date(2262, 4, 12), datetime(2262, 3, 3, 3, 3, 3))] # TODO: remove version check once minimum pyarrow version is 0.10.0 if LooseVersion("0.10.0") <= LooseVersion(pa.__version__): @@ -76,6 +78,7 @@ def setUpClass(cls): cls.data[0] = cls.data[0] + (bytearray(b"a"),) cls.data[1] = cls.data[1] + (bytearray(b"bb"),) cls.data[2] = cls.data[2] + (bytearray(b"ccc"),) + cls.data[3] = cls.data[3] + (bytearray(b"dddd"),) @classmethod def tearDownClass(cls): diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 6a6865a9fb16..28ef98d7b3f1 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -349,7 +349,8 @@ def test_vectorized_udf_dates(self): data = [(0, date(1969, 1, 1),), (1, date(2012, 2, 2),), (2, None,), - (3, date(2100, 4, 4),)] + (3, date(2100, 4, 4),), + (4, date(2262, 4, 12),)] df = self.spark.createDataFrame(data, schema=schema) date_copy = pandas_udf(lambda t: t, returnType=DateType()) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 4b8f2efff4ac..348cb5b11859 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1681,38 +1681,52 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) -def _check_series_convert_date(series, data_type): - """ - Cast the series to datetime.date if it's a date type, otherwise returns the original series. +def _arrow_column_to_pandas(column, data_type): + """ Convert Arrow Column to pandas Series. - :param series: pandas.Series - :param data_type: a Spark data type for the series + :param series: pyarrow.lib.Column + :param data_type: a Spark data type for the column """ - import pyarrow + import pandas as pd + import pyarrow as pa from distutils.version import LooseVersion - # As of Arrow 0.12.0, date_as_objects is True by default, see ARROW-3910 - if LooseVersion(pyarrow.__version__) < LooseVersion("0.12.0") and type(data_type) == DateType: - return series.dt.date + # If the given column is a date type column, creates a series of datetime.date directly instead + # of creating datetime64[ns] as intermediate data to avoid overflow caused by datetime64[ns] + # type handling. + if LooseVersion(pa.__version__) < LooseVersion("0.11.0"): + if type(data_type) == DateType: + return pd.Series(column.to_pylist(), name=column.name) + else: + return column.to_pandas() else: - return series + # Since Arrow 0.11.0, support date_as_object to return datetime.date instead of + # np.datetime64. + return column.to_pandas(date_as_object=True) -def _check_dataframe_convert_date(pdf, schema): - """ Correct date type value to use datetime.date. +def _arrow_table_to_pandas(table, schema): + """ Convert Arrow Table to pandas DataFrame. Pandas DataFrame created from PyArrow uses datetime64[ns] for date type values, but we should use datetime.date to match the behavior with when Arrow optimization is disabled. - :param pdf: pandas.DataFrame - :param schema: a Spark schema of the pandas.DataFrame + :param table: pyarrow.lib.Table + :param schema: a Spark schema of the pyarrow.lib.Table """ - import pyarrow + import pandas as pd + import pyarrow as pa from distutils.version import LooseVersion - # As of Arrow 0.12.0, date_as_objects is True by default, see ARROW-3910 - if LooseVersion(pyarrow.__version__) < LooseVersion("0.12.0"): - for field in schema: - pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType) - return pdf + # If the given table contains a date type column, use `_arrow_column_to_pandas` for pyarrow<0.11 + # or use `date_as_object` option for pyarrow>=0.11 to avoid creating datetime64[ns] as + # intermediate data. + if LooseVersion(pa.__version__) < LooseVersion("0.11.0"): + if any(type(field.dataType) == DateType for field in schema): + return pd.concat([_arrow_column_to_pandas(column, field.dataType) + for column, field in zip(table.itercolumns(), schema)], axis=1) + else: + return table.to_pandas() + else: + return table.to_pandas(date_as_object=True) def _get_local_timezone(): From 60caa92deaf6941f58da82dcc0962ebf3a598ced Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Mon, 18 Feb 2019 13:16:28 +0800 Subject: [PATCH 12/19] [SPARK-26666][SQL] Support DSv2 overwrite and dynamic partition overwrite. ## What changes were proposed in this pull request? This adds two logical plans that implement the ReplaceData operation from the [logical plans SPIP](https://docs.google.com/document/d/1gYm5Ji2Mge3QBdOliFV5gSPTKlX4q1DCBXIkiyMv62A/edit?ts=5a987801#heading=h.m45webtwxf2d). These two plans will be used to implement Spark's `INSERT OVERWRITE` behavior for v2. Specific changes: * Add `SupportsTruncate`, `SupportsOverwrite`, and `SupportsDynamicOverwrite` to DSv2 write API * Add `OverwriteByExpression` and `OverwritePartitionsDynamic` plans (logical and physical) * Add new plans to DSv2 write validation rule `ResolveOutputRelation` * Refactor `WriteToDataSourceV2Exec` into trait used by all DSv2 write exec nodes ## How was this patch tested? * The v2 analysis suite has been updated to validate the new overwrite plans * The analysis suite for `OverwriteByExpression` checks that the delete expression is resolved using the table's columns * Existing tests validate that overwrite exec plan works * Updated existing v2 test because schema is used to validate overwrite Closes #23606 from rdblue/SPARK-26666-add-overwrite. Authored-by: Ryan Blue Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 27 ++- .../plans/logical/basicLogicalOperators.scala | 69 ++++++- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../analysis/DataSourceV2AnalysisSuite.scala | 191 +++++++++++++----- .../v2/reader/SupportsPushDownFilters.java | 3 + .../v2/writer/SupportsDynamicOverwrite.java | 37 ++++ .../sources/v2/writer/SupportsOverwrite.java | 45 +++++ .../sources/v2/writer/SupportsTruncate.java | 32 +++ .../apache/spark/sql/DataFrameWriter.scala | 54 +++-- .../datasources/DataSourceStrategy.scala | 6 + .../v2/DataSourceV2Implicits.scala | 49 +++++ .../datasources/v2/DataSourceV2Relation.scala | 24 +-- .../datasources/v2/DataSourceV2Strategy.scala | 35 ++-- .../v2/WriteToDataSourceV2Exec.scala | 135 ++++++++++++- .../apache/spark/sql/sources/filters.scala | 26 ++- .../sql/sources/v2/DataSourceV2Suite.scala | 8 +- 16 files changed, 613 insertions(+), 130 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 793c337ffcb1..42904c5c04c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -978,6 +978,11 @@ class Analyzer( case a @ Aggregate(groupingExprs, aggExprs, appendColumns: AppendColumns) => a.mapExpressions(resolveExpressionTopDown(_, appendColumns)) + case o: OverwriteByExpression if !o.outputResolved => + // do not resolve expression attributes until the query attributes are resolved against the + // table by ResolveOutputRelation. that rule will alias the attributes to the table's names. + o + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}") q.mapExpressions(resolveExpressionTopDown(_, q)) @@ -2246,7 +2251,7 @@ class Analyzer( object ResolveOutputRelation extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case append @ AppendData(table, query, isByName) - if table.resolved && query.resolved && !append.resolved => + if table.resolved && query.resolved && !append.outputResolved => val projection = resolveOutputColumns(table.name, table.output, query, isByName) if (projection != query) { @@ -2254,6 +2259,26 @@ class Analyzer( } else { append } + + case overwrite @ OverwriteByExpression(table, _, query, isByName) + if table.resolved && query.resolved && !overwrite.outputResolved => + val projection = resolveOutputColumns(table.name, table.output, query, isByName) + + if (projection != query) { + overwrite.copy(query = projection) + } else { + overwrite + } + + case overwrite @ OverwritePartitionsDynamic(table, query, isByName) + if table.resolved && query.resolved && !overwrite.outputResolved => + val projection = resolveOutputColumns(table.name, table.output, query, isByName) + + if (projection != query) { + overwrite.copy(query = projection) + } else { + overwrite + } } def resolveOutputColumns( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 639d68f4ecd7..f7f701cea51f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -365,16 +365,17 @@ case class Join( } /** - * Append data to an existing table. + * Base trait for DataSourceV2 write commands */ -case class AppendData( - table: NamedRelation, - query: LogicalPlan, - isByName: Boolean) extends LogicalPlan { +trait V2WriteCommand extends Command { + def table: NamedRelation + def query: LogicalPlan + override def children: Seq[LogicalPlan] = Seq(query) - override def output: Seq[Attribute] = Seq.empty - override lazy val resolved: Boolean = { + override lazy val resolved: Boolean = outputResolved + + def outputResolved: Boolean = { table.resolved && query.resolved && query.output.size == table.output.size && query.output.zip(table.output).forall { case (inAttr, outAttr) => @@ -386,16 +387,66 @@ case class AppendData( } } +/** + * Append data to an existing table. + */ +case class AppendData( + table: NamedRelation, + query: LogicalPlan, + isByName: Boolean) extends V2WriteCommand + object AppendData { def byName(table: NamedRelation, df: LogicalPlan): AppendData = { - new AppendData(table, df, true) + new AppendData(table, df, isByName = true) } def byPosition(table: NamedRelation, query: LogicalPlan): AppendData = { - new AppendData(table, query, false) + new AppendData(table, query, isByName = false) } } +/** + * Overwrite data matching a filter in an existing table. + */ +case class OverwriteByExpression( + table: NamedRelation, + deleteExpr: Expression, + query: LogicalPlan, + isByName: Boolean) extends V2WriteCommand { + override lazy val resolved: Boolean = outputResolved && deleteExpr.resolved +} + +object OverwriteByExpression { + def byName( + table: NamedRelation, df: LogicalPlan, deleteExpr: Expression): OverwriteByExpression = { + OverwriteByExpression(table, deleteExpr, df, isByName = true) + } + + def byPosition( + table: NamedRelation, query: LogicalPlan, deleteExpr: Expression): OverwriteByExpression = { + OverwriteByExpression(table, deleteExpr, query, isByName = false) + } +} + +/** + * Dynamically overwrite partitions in an existing table. + */ +case class OverwritePartitionsDynamic( + table: NamedRelation, + query: LogicalPlan, + isByName: Boolean) extends V2WriteCommand + +object OverwritePartitionsDynamic { + def byName(table: NamedRelation, df: LogicalPlan): OverwritePartitionsDynamic = { + OverwritePartitionsDynamic(table, df, isByName = true) + } + + def byPosition(table: NamedRelation, query: LogicalPlan): OverwritePartitionsDynamic = { + OverwritePartitionsDynamic(table, query, isByName = false) + } +} + + /** * Insert some data into a table. Note that this plan is unresolved and has to be replaced by the * concrete implementations during analysis. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d285e007dac1..0b7b67ed56d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1452,7 +1452,7 @@ object SQLConf { " register class names for which data source V2 write paths are disabled. Writes from these" + " sources will fall back to the V1 sources.") .stringConf - .createWithDefault("") + .createWithDefault("orc") val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") .doc("A comma-separated list of fully qualified data source register class names for which" + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala index 6c899b610ac5..0c4854861426 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala @@ -19,15 +19,92 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, UpCast} -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LeafNode, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, Expression, LessThanOrEqual, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LeafNode, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Project} import org.apache.spark.sql.types.{DoubleType, FloatType, StructField, StructType} +class V2AppendDataAnalysisSuite extends DataSourceV2AnalysisSuite { + override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + AppendData.byName(table, query) + } + + override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + AppendData.byPosition(table, query) + } +} + +class V2OverwritePartitionsDynamicAnalysisSuite extends DataSourceV2AnalysisSuite { + override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwritePartitionsDynamic.byName(table, query) + } + + override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwritePartitionsDynamic.byPosition(table, query) + } +} + +class V2OverwriteByExpressionAnalysisSuite extends DataSourceV2AnalysisSuite { + override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwriteByExpression.byName(table, query, Literal(true)) + } + + override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwriteByExpression.byPosition(table, query, Literal(true)) + } + + test("delete expression is resolved using table fields") { + val table = TestRelation(StructType(Seq( + StructField("x", DoubleType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("a", DoubleType, nullable = false), + StructField("b", DoubleType))).toAttributes) + + val a = query.output.head + val b = query.output.last + val x = table.output.head + + val parsedPlan = OverwriteByExpression.byPosition(table, query, + LessThanOrEqual(UnresolvedAttribute(Seq("x")), Literal(15.0d))) + + val expectedPlan = OverwriteByExpression.byPosition(table, + Project(Seq( + Alias(Cast(a, DoubleType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(b, DoubleType, Some(conf.sessionLocalTimeZone)), "y")()), + query), + LessThanOrEqual( + AttributeReference("x", DoubleType, nullable = false)(x.exprId), + Literal(15.0d))) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan) + assertResolved(expectedPlan) + } + + test("delete expression is not resolved using query fields") { + val xRequiredTable = TestRelation(StructType(Seq( + StructField("x", DoubleType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("a", DoubleType, nullable = false), + StructField("b", DoubleType))).toAttributes) + + // the write is resolved (checked above). this test plan is not because of the expression. + val parsedPlan = OverwriteByExpression.byPosition(xRequiredTable, query, + LessThanOrEqual(UnresolvedAttribute(Seq("a")), Literal(15.0d))) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq("cannot resolve", "`a`", "given input columns", "x, y")) + } +} + case class TestRelation(output: Seq[AttributeReference]) extends LeafNode with NamedRelation { override def name: String = "table-name" } -class DataSourceV2AnalysisSuite extends AnalysisTest { +abstract class DataSourceV2AnalysisSuite extends AnalysisTest { val table = TestRelation(StructType(Seq( StructField("x", FloatType), StructField("y", FloatType))).toAttributes) @@ -40,21 +117,25 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { StructField("x", DoubleType), StructField("y", DoubleType))).toAttributes) - test("Append.byName: basic behavior") { + def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan + + def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan + + test("byName: basic behavior") { val query = TestRelation(table.schema.toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) checkAnalysis(parsedPlan, parsedPlan) assertResolved(parsedPlan) } - test("Append.byName: does not match by position") { + test("byName: does not match by position") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), StructField("b", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -62,12 +143,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'x'", "'y'")) } - test("Append.byName: case sensitive column resolution") { + test("byName: case sensitive column resolution") { val query = TestRelation(StructType(Seq( StructField("X", FloatType), // doesn't match case! StructField("y", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -76,7 +157,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { caseSensitive = true) } - test("Append.byName: case insensitive column resolution") { + test("byName: case insensitive column resolution") { val query = TestRelation(StructType(Seq( StructField("X", FloatType), // doesn't match case! StructField("y", FloatType))).toAttributes) @@ -84,8 +165,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val X = query.output.head val y = query.output.last - val parsedPlan = AppendData.byName(table, query) - val expectedPlan = AppendData.byName(table, + val parsedPlan = byName(table, query) + val expectedPlan = byName(table, Project(Seq( Alias(Cast(toLower(X), FloatType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -96,7 +177,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byName: data columns are reordered by name") { + test("byName: data columns are reordered by name") { // out of order val query = TestRelation(StructType(Seq( StructField("y", FloatType), @@ -105,8 +186,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val y = query.output.head val x = query.output.last - val parsedPlan = AppendData.byName(table, query) - val expectedPlan = AppendData.byName(table, + val parsedPlan = byName(table, query) + val expectedPlan = byName(table, Project(Seq( Alias(Cast(x, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -117,26 +198,26 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byName: fail nullable data written to required columns") { - val parsedPlan = AppendData.byName(requiredTable, table) + test("byName: fail nullable data written to required columns") { + val parsedPlan = byName(requiredTable, table) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write incompatible data to table", "'table-name'", "Cannot write nullable values to non-null column", "'x'", "'y'")) } - test("Append.byName: allow required data written to nullable columns") { - val parsedPlan = AppendData.byName(table, requiredTable) + test("byName: allow required data written to nullable columns") { + val parsedPlan = byName(table, requiredTable) assertResolved(parsedPlan) checkAnalysis(parsedPlan, parsedPlan) } - test("Append.byName: missing required columns cause failure and are identified by name") { + test("byName: missing required columns cause failure and are identified by name") { // missing required field x val query = TestRelation(StructType(Seq( StructField("y", FloatType, nullable = false))).toAttributes) - val parsedPlan = AppendData.byName(requiredTable, query) + val parsedPlan = byName(requiredTable, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -144,12 +225,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'x'")) } - test("Append.byName: missing optional columns cause failure and are identified by name") { + test("byName: missing optional columns cause failure and are identified by name") { // missing optional field x val query = TestRelation(StructType(Seq( StructField("y", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -157,8 +238,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'x'")) } - test("Append.byName: fail canWrite check") { - val parsedPlan = AppendData.byName(table, widerTable) + test("byName: fail canWrite check") { + val parsedPlan = byName(table, widerTable) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -166,12 +247,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) } - test("Append.byName: insert safe cast") { + test("byName: insert safe cast") { val x = table.output.head val y = table.output.last - val parsedPlan = AppendData.byName(widerTable, table) - val expectedPlan = AppendData.byName(widerTable, + val parsedPlan = byName(widerTable, table) + val expectedPlan = byName(widerTable, Project(Seq( Alias(Cast(x, DoubleType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(y, DoubleType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -182,13 +263,13 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byName: fail extra data fields") { + test("byName: fail extra data fields") { val query = TestRelation(StructType(Seq( StructField("x", FloatType), StructField("y", FloatType), StructField("z", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -197,7 +278,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'x', 'y', 'z'")) } - test("Append.byName: multiple field errors are reported") { + test("byName: multiple field errors are reported") { val xRequiredTable = TestRelation(StructType(Seq( StructField("x", FloatType, nullable = false), StructField("y", DoubleType))).toAttributes) @@ -206,7 +287,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { StructField("x", DoubleType), StructField("b", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(xRequiredTable, query) + val parsedPlan = byName(xRequiredTable, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -216,7 +297,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'y'")) } - test("Append.byPosition: basic behavior") { + test("byPosition: basic behavior") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), StructField("b", FloatType))).toAttributes) @@ -224,8 +305,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val a = query.output.head val b = query.output.last - val parsedPlan = AppendData.byPosition(table, query) - val expectedPlan = AppendData.byPosition(table, + val parsedPlan = byPosition(table, query) + val expectedPlan = byPosition(table, Project(Seq( Alias(Cast(a, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(b, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -236,7 +317,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byPosition: data columns are not reordered") { + test("byPosition: data columns are not reordered") { // out of order val query = TestRelation(StructType(Seq( StructField("y", FloatType), @@ -245,8 +326,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val y = query.output.head val x = query.output.last - val parsedPlan = AppendData.byPosition(table, query) - val expectedPlan = AppendData.byPosition(table, + val parsedPlan = byPosition(table, query) + val expectedPlan = byPosition(table, Project(Seq( Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(x, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -257,26 +338,26 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byPosition: fail nullable data written to required columns") { - val parsedPlan = AppendData.byPosition(requiredTable, table) + test("byPosition: fail nullable data written to required columns") { + val parsedPlan = byPosition(requiredTable, table) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write incompatible data to table", "'table-name'", "Cannot write nullable values to non-null column", "'x'", "'y'")) } - test("Append.byPosition: allow required data written to nullable columns") { - val parsedPlan = AppendData.byPosition(table, requiredTable) + test("byPosition: allow required data written to nullable columns") { + val parsedPlan = byPosition(table, requiredTable) assertResolved(parsedPlan) checkAnalysis(parsedPlan, parsedPlan) } - test("Append.byPosition: missing required columns cause failure") { + test("byPosition: missing required columns cause failure") { // missing optional field x val query = TestRelation(StructType(Seq( StructField("y", FloatType, nullable = false))).toAttributes) - val parsedPlan = AppendData.byPosition(requiredTable, query) + val parsedPlan = byPosition(requiredTable, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -285,12 +366,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'y'")) } - test("Append.byPosition: missing optional columns cause failure") { + test("byPosition: missing optional columns cause failure") { // missing optional field x val query = TestRelation(StructType(Seq( StructField("y", FloatType))).toAttributes) - val parsedPlan = AppendData.byPosition(table, query) + val parsedPlan = byPosition(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -299,12 +380,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'y'")) } - test("Append.byPosition: fail canWrite check") { + test("byPosition: fail canWrite check") { val widerTable = TestRelation(StructType(Seq( StructField("a", DoubleType), StructField("b", DoubleType))).toAttributes) - val parsedPlan = AppendData.byPosition(table, widerTable) + val parsedPlan = byPosition(table, widerTable) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -312,7 +393,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) } - test("Append.byPosition: insert safe cast") { + test("byPosition: insert safe cast") { val widerTable = TestRelation(StructType(Seq( StructField("a", DoubleType), StructField("b", DoubleType))).toAttributes) @@ -320,8 +401,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val x = table.output.head val y = table.output.last - val parsedPlan = AppendData.byPosition(widerTable, table) - val expectedPlan = AppendData.byPosition(widerTable, + val parsedPlan = byPosition(widerTable, table) + val expectedPlan = byPosition(widerTable, Project(Seq( Alias(Cast(x, DoubleType, Some(conf.sessionLocalTimeZone)), "a")(), Alias(Cast(y, DoubleType, Some(conf.sessionLocalTimeZone)), "b")()), @@ -332,13 +413,13 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byPosition: fail extra data fields") { + test("byPosition: fail extra data fields") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), StructField("b", FloatType), StructField("c", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -347,7 +428,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'a', 'b', 'c'")) } - test("Append.byPosition: multiple field errors are reported") { + test("byPosition: multiple field errors are reported") { val xRequiredTable = TestRelation(StructType(Seq( StructField("x", FloatType, nullable = false), StructField("y", DoubleType))).toAttributes) @@ -356,7 +437,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { StructField("x", DoubleType), StructField("b", FloatType))).toAttributes) - val parsedPlan = AppendData.byPosition(xRequiredTable, query) + val parsedPlan = byPosition(xRequiredTable, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 296d3e47e732..f10fd884daab 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -29,6 +29,9 @@ public interface SupportsPushDownFilters extends ScanBuilder { /** * Pushes down filters, and returns filters that need to be evaluated after scanning. + *

+ * Rows should be returned from the data source if and only if all of the filters match. That is, + * filters must be interpreted as ANDed together. */ Filter[] pushFilters(Filter[] filters); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java new file mode 100644 index 000000000000..8058964b662b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +/** + * Write builder trait for tables that support dynamic partition overwrite. + *

+ * A write that dynamically overwrites partitions removes all existing data in each logical + * partition for which the write will commit new data. Any existing logical partition for which the + * write does not contain data will remain unchanged. + *

+ * This is provided to implement SQL compatible with Hive table operations but is not recommended. + * Instead, use the {@link SupportsOverwrite overwrite by filter API} to explicitly replace data. + */ +public interface SupportsDynamicOverwrite extends WriteBuilder { + /** + * Configures a write to dynamically replace partitions with data committed in the write. + * + * @return this write builder for method chaining + */ + WriteBuilder overwriteDynamicPartitions(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java new file mode 100644 index 000000000000..b443b3c3aeb4 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.sql.sources.AlwaysTrue$; +import org.apache.spark.sql.sources.Filter; + +/** + * Write builder trait for tables that support overwrite by filter. + *

+ * Overwriting data by filter will delete any data that matches the filter and replace it with data + * that is committed in the write. + */ +public interface SupportsOverwrite extends WriteBuilder, SupportsTruncate { + /** + * Configures a write to replace data matching the filters with data committed in the write. + *

+ * Rows must be deleted from the data source if and only if all of the filters match. That is, + * filters must be interpreted as ANDed together. + * + * @param filters filters used to match data to overwrite + * @return this write builder for method chaining + */ + WriteBuilder overwrite(Filter[] filters); + + @Override + default WriteBuilder truncate() { + return overwrite(new Filter[] { AlwaysTrue$.MODULE$ }); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java new file mode 100644 index 000000000000..69c2ba5e01a4 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +/** + * Write builder trait for tables that support truncation. + *

+ * Truncation removes all data in a table and replaces it with data that is committed in the write. + */ +public interface SupportsTruncate extends WriteBuilder { + /** + * Configures a write to replace all existing data with data committed in the write. + * + * @return this write builder for method chaining + */ + WriteBuilder truncate(); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index e5f947337c94..450828172b93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -25,7 +25,8 @@ import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan, OverwriteByExpression} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} @@ -264,29 +265,38 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val dsOptions = new DataSourceOptions(options.asJava) provider.getTable(dsOptions) match { case table: SupportsBatchWrite => - if (mode == SaveMode.Append) { - val relation = DataSourceV2Relation.create(table, options) - runCommand(df.sparkSession, "save") { - AppendData.byName(relation, df.logicalPlan) - } - } else { - val writeBuilder = table.newWriteBuilder(dsOptions) - .withQueryId(UUID.randomUUID().toString) - .withInputDataSchema(df.logicalPlan.schema) - writeBuilder match { - case s: SupportsSaveMode => - val write = s.mode(mode).buildForBatch() - // It can only return null with `SupportsSaveMode`. We can clean it up after - // removing `SupportsSaveMode`. - if (write != null) { - runCommand(df.sparkSession, "save") { - WriteToDataSourceV2(write, df.logicalPlan) + lazy val relation = DataSourceV2Relation.create(table, options) + mode match { + case SaveMode.Append => + runCommand(df.sparkSession, "save") { + AppendData.byName(relation, df.logicalPlan) + } + + case SaveMode.Overwrite => + // truncate the table + runCommand(df.sparkSession, "save") { + OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true)) + } + + case _ => + table.newWriteBuilder(dsOptions) match { + case writeBuilder: SupportsSaveMode => + val write = writeBuilder.mode(mode) + .withQueryId(UUID.randomUUID().toString) + .withInputDataSchema(df.logicalPlan.schema) + .buildForBatch() + // It can only return null with `SupportsSaveMode`. We can clean it up after + // removing `SupportsSaveMode`. + if (write != null) { + runCommand(df.sparkSession, "save") { + WriteToDataSourceV2(write, df.logicalPlan) + } } - } - case _ => throw new AnalysisException( - s"data source ${table.name} does not support SaveMode $mode") - } + case _ => + throw new AnalysisException( + s"data source ${table.name} does not support SaveMode $mode") + } } // Streaming also uses the data source V2 API. So it may be that the data source implements diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 273cc3b19302..b73dc30d6f23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -529,6 +529,12 @@ object DataSourceStrategy { case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) => Some(sources.StringContains(a.name, v.toString)) + case expressions.Literal(true, BooleanType) => + Some(sources.AlwaysTrue) + + case expressions.Literal(false, BooleanType) => + Some(sources.AlwaysFalse) + case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala new file mode 100644 index 000000000000..c8542bfe5e59 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsBatchRead, SupportsBatchWrite, Table} + +object DataSourceV2Implicits { + implicit class TableHelper(table: Table) { + def asBatchReadable: SupportsBatchRead = { + table match { + case support: SupportsBatchRead => + support + case _ => + throw new AnalysisException(s"Table does not support batch reads: ${table.name}") + } + } + + def asBatchWritable: SupportsBatchWrite = { + table match { + case support: SupportsBatchWrite => + support + case _ => + throw new AnalysisException(s"Table does not support batch writes: ${table.name}") + } + } + } + + implicit class OptionsHelper(options: Map[String, String]) { + def toDataSourceOptions: DataSourceOptions = new DataSourceOptions(options.asJava) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 47cf26dc9481..53677782c95f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,11 +17,6 @@ package org.apache.spark.sql.execution.datasources.v2 -import java.util.UUID - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} @@ -30,7 +25,6 @@ import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{Offset, SparkDataStream} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.types.StructType /** * A logical plan representing a data source v2 table. @@ -45,26 +39,16 @@ case class DataSourceV2Relation( options: Map[String, String]) extends LeafNode with MultiInstanceRelation with NamedRelation { + import DataSourceV2Implicits._ + override def name: String = table.name() override def simpleString(maxFields: Int): String = { s"RelationV2${truncatedString(output, "[", ", ", "]", maxFields)} $name" } - def newScanBuilder(): ScanBuilder = table match { - case s: SupportsBatchRead => - val dsOptions = new DataSourceOptions(options.asJava) - s.newScanBuilder(dsOptions) - case _ => throw new AnalysisException(s"Table is not readable: ${table.name()}") - } - - def newWriteBuilder(schema: StructType): WriteBuilder = table match { - case s: SupportsBatchWrite => - val dsOptions = new DataSourceOptions(options.asJava) - s.newWriteBuilder(dsOptions) - .withQueryId(UUID.randomUUID().toString) - .withInputDataSchema(schema) - case _ => throw new AnalysisException(s"Table is not writable: ${table.name()}") + def newScanBuilder(): ScanBuilder = { + table.asBatchReadable.newScanBuilder(options.toDataSourceOptions) } override def computeStats(): Statistics = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index d6d17d6df7b1..55d7b0a18cbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -19,18 +19,18 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.{sources, AnalysisException, SaveMode, Strategy} -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, SubqueryExpression} +import org.apache.spark.sql.{AnalysisException, Strategy} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, PredicateHelper} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.sources import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} -import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode -object DataSourceV2Strategy extends Strategy { +object DataSourceV2Strategy extends Strategy with PredicateHelper { /** * Pushes down filters to the data source reader @@ -100,6 +100,7 @@ object DataSourceV2Strategy extends Strategy { } } + import DataSourceV2Implicits._ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => @@ -146,14 +147,22 @@ object DataSourceV2Strategy extends Strategy { WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil case AppendData(r: DataSourceV2Relation, query, _) => - val writeBuilder = r.newWriteBuilder(query.schema) - writeBuilder match { - case s: SupportsSaveMode => - val write = s.mode(SaveMode.Append).buildForBatch() - assert(write != null) - WriteToDataSourceV2Exec(write, planLater(query)) :: Nil - case _ => throw new AnalysisException(s"data source ${r.name} does not support SaveMode") - } + AppendDataExec( + r.table.asBatchWritable, r.options.toDataSourceOptions, planLater(query)) :: Nil + + case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, _) => + // fail if any filter cannot be converted. correctness depends on removing all matching data. + val filters = splitConjunctivePredicates(deleteExpr).map { + filter => DataSourceStrategy.translateFilter(deleteExpr).getOrElse( + throw new AnalysisException(s"Cannot translate expression to source filter: $filter")) + }.toArray + + OverwriteByExpressionExec( + r.table.asBatchWritable, filters, r.options.toDataSourceOptions, planLater(query)) :: Nil + + case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _) => + OverwritePartitionsDynamicExec(r.table.asBatchWritable, + r.options.toDataSourceOptions, planLater(query)) :: Nil case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 50c5e4f2ad7d..d7cb2457433b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -17,17 +17,22 @@ package org.apache.spark.sql.execution.datasources.v2 +import java.util.UUID + import scala.util.control.NonFatal import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} -import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.{AlwaysTrue, Filter} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsBatchWrite} +import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsSaveMode, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.util.{LongAccumulator, Utils} /** @@ -42,17 +47,137 @@ case class WriteToDataSourceV2(batchWrite: BatchWrite, query: LogicalPlan) } /** - * The physical plan for writing data into data source v2. + * Physical plan node for append into a v2 table. + * + * Rows in the output data set are appended. + */ +case class AppendDataExec( + table: SupportsBatchWrite, + writeOptions: DataSourceOptions, + query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { + + override protected def doExecute(): RDD[InternalRow] = { + val batchWrite = newWriteBuilder() match { + case builder: SupportsSaveMode => + builder.mode(SaveMode.Append).buildForBatch() + + case builder => + builder.buildForBatch() + } + doWrite(batchWrite) + } +} + +/** + * Physical plan node for overwrite into a v2 table. + * + * Overwrites data in a table matched by a set of filters. Rows matching all of the filters will be + * deleted and rows in the output data set are appended. + * + * This plan is used to implement SaveMode.Overwrite. The behavior of SaveMode.Overwrite is to + * truncate the table -- delete all rows -- and append the output data set. This uses the filter + * AlwaysTrue to delete all rows. */ -case class WriteToDataSourceV2Exec(batchWrite: BatchWrite, query: SparkPlan) - extends UnaryExecNode { +case class OverwriteByExpressionExec( + table: SupportsBatchWrite, + deleteWhere: Array[Filter], + writeOptions: DataSourceOptions, + query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { + + private def isTruncate(filters: Array[Filter]): Boolean = { + filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] + } + + override protected def doExecute(): RDD[InternalRow] = { + val batchWrite = newWriteBuilder() match { + case builder: SupportsTruncate if isTruncate(deleteWhere) => + builder.truncate().buildForBatch() + + case builder: SupportsSaveMode if isTruncate(deleteWhere) => + builder.mode(SaveMode.Overwrite).buildForBatch() + + case builder: SupportsOverwrite => + builder.overwrite(deleteWhere).buildForBatch() + + case _ => + throw new SparkException(s"Table does not support dynamic partition overwrite: $table") + } + + doWrite(batchWrite) + } +} + +/** + * Physical plan node for dynamic partition overwrite into a v2 table. + * + * Dynamic partition overwrite is the behavior of Hive INSERT OVERWRITE ... PARTITION queries, and + * Spark INSERT OVERWRITE queries when spark.sql.sources.partitionOverwriteMode=dynamic. Each + * partition in the output data set replaces the corresponding existing partition in the table or + * creates a new partition. Existing partitions for which there is no data in the output data set + * are not modified. + */ +case class OverwritePartitionsDynamicExec( + table: SupportsBatchWrite, + writeOptions: DataSourceOptions, + query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { + + override protected def doExecute(): RDD[InternalRow] = { + val batchWrite = newWriteBuilder() match { + case builder: SupportsDynamicOverwrite => + builder.overwriteDynamicPartitions().buildForBatch() + + case builder: SupportsSaveMode => + builder.mode(SaveMode.Overwrite).buildForBatch() + + case _ => + throw new SparkException(s"Table does not support dynamic partition overwrite: $table") + } + + doWrite(batchWrite) + } +} + +case class WriteToDataSourceV2Exec( + batchWrite: BatchWrite, + query: SparkPlan + ) extends V2TableWriteExec { + + import DataSourceV2Implicits._ + + def writeOptions: DataSourceOptions = Map.empty[String, String].toDataSourceOptions + + override protected def doExecute(): RDD[InternalRow] = { + doWrite(batchWrite) + } +} + +/** + * Helper for physical plans that build batch writes. + */ +trait BatchWriteHelper { + def table: SupportsBatchWrite + def query: SparkPlan + def writeOptions: DataSourceOptions + + def newWriteBuilder(): WriteBuilder = { + table.newWriteBuilder(writeOptions) + .withInputDataSchema(query.schema) + .withQueryId(UUID.randomUUID().toString) + } +} + +/** + * The base physical plan for writing data into data source v2. + */ +trait V2TableWriteExec extends UnaryExecNode { + def query: SparkPlan var commitProgress: Option[StreamWriterCommitProgress] = None override def child: SparkPlan = query override def output: Seq[Attribute] = Nil - override protected def doExecute(): RDD[InternalRow] = { + protected def doWrite(batchWrite: BatchWrite): RDD[InternalRow] = { val writerFactory = batchWrite.createBatchWriterFactory() val useCommitCoordinator = batchWrite.useCommitCoordinator val rdd = query.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 3f941cc6e107..a1ab55a7185c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import org.apache.spark.annotation.Stable +import org.apache.spark.annotation.{Evolving, Stable} //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines all the filters that we can push down to the data sources. @@ -218,3 +218,27 @@ case class StringEndsWith(attribute: String, value: String) extends Filter { case class StringContains(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } + +/** + * A filter that always evaluates to `true`. + */ +@Evolving +case class AlwaysTrue() extends Filter { + override def references: Array[String] = Array.empty +} + +@Evolving +object AlwaysTrue extends AlwaysTrue { +} + +/** + * A filter that always evaluates to `false`. + */ +@Evolving +case class AlwaysFalse() extends Filter { + override def references: Array[String] = Array.empty +} + +@Evolving +object AlwaysFalse extends AlwaysFalse { +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 511fdfe5c23a..6b5c45e40ab0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -351,19 +351,21 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } - test("SPARK-25700: do not read schema when writing in other modes except append mode") { + test("SPARK-25700: do not read schema when writing in other modes except append and overwrite") { withTempPath { file => val cls = classOf[SimpleWriteOnlyDataSource] val path = file.getCanonicalPath val df = spark.range(5).select('id as 'i, -'id as 'j) // non-append mode should not throw exception, as they don't access schema. df.write.format(cls.getName).option("path", path).mode("error").save() - df.write.format(cls.getName).option("path", path).mode("overwrite").save() df.write.format(cls.getName).option("path", path).mode("ignore").save() - // append mode will access schema and should throw exception. + // append and overwrite modes will access the schema and should throw exception. intercept[SchemaReadAttemptException] { df.write.format(cls.getName).option("path", path).mode("append").save() } + intercept[SchemaReadAttemptException] { + df.write.format(cls.getName).option("path", path).mode("overwrite").save() + } } } } From 7f53116f77bac6302bb727769b4a4c684b6b0b5b Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sun, 17 Feb 2019 23:35:45 -0800 Subject: [PATCH 13/19] [SPARK-24570][SQL] Implement Spark own GetTablesOperation to fix SQL client tools cannot show tables ## What changes were proposed in this pull request? For SQL client tools([DBeaver](https://dbeaver.io/))'s Navigator use [`GetTablesOperation`](https://github.com/apache/spark/blob/a7444570764b0a08b7e908dc7931744f9dbdf3c6/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTablesOperation.java) to obtain table names. We should use [`metadataHive`](https://github.com/apache/spark/blob/95d172da2b370ff6257bfd6fcd102ac553f6f6af/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala#L52-L53), but it use [`executionHive`](https://github.com/apache/spark/blob/24f5bbd770033dacdea62555488bfffb61665279/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala#L93-L95). This PR implement Spark own `GetTablesOperation` to use `metadataHive`. ## How was this patch tested? unit test and manual tests ![image](https://user-images.githubusercontent.com/5399861/47430696-acf77980-d7cc-11e8-824d-f28d78f60a00.png) ![image](https://user-images.githubusercontent.com/5399861/47440576-09649400-d7e1-11e8-97a8-a96f73f70361.png) Closes #22794 from wangyum/SPARK-24570. Authored-by: Yuming Wang Signed-off-by: gatorsmile --- .../cli/operation/GetTablesOperation.java | 2 +- .../SparkGetTablesOperation.scala | 99 +++++++++++++++++++ .../server/SparkSQLOperationManager.scala | 22 ++++- .../HiveThriftServer2Suites.scala | 2 +- .../SparkMetadataOperationSuite.scala | 87 +++++++++++++++- 5 files changed, 206 insertions(+), 6 deletions(-) create mode 100644 sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTablesOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTablesOperation.java index 1a7ca79163d7..2af17a662a29 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTablesOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTablesOperation.java @@ -46,7 +46,7 @@ public class GetTablesOperation extends MetadataOperation { private final String schemaName; private final String tableName; private final List tableTypes = new ArrayList(); - private final RowSet rowSet; + protected final RowSet rowSet; private final TableTypeMapping tableTypeMapping; diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala new file mode 100644 index 000000000000..369650047b10 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.{List => JList} + +import scala.collection.JavaConverters.seqAsJavaListConverter + +import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType +import org.apache.hadoop.hive.ql.security.authorization.plugin.HivePrivilegeObjectUtils +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.GetTablesOperation +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.catalog.CatalogTableType +import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ + +/** + * Spark's own GetTablesOperation + * + * @param sqlContext SQLContext to use + * @param parentSession a HiveSession from SessionManager + * @param catalogName catalog name. null if not applicable + * @param schemaName database name, null or a concrete database name + * @param tableName table name pattern + * @param tableTypes list of allowed table types, e.g. "TABLE", "VIEW" + */ +private[hive] class SparkGetTablesOperation( + sqlContext: SQLContext, + parentSession: HiveSession, + catalogName: String, + schemaName: String, + tableName: String, + tableTypes: JList[String]) + extends GetTablesOperation(parentSession, catalogName, schemaName, tableName, tableTypes) { + + if (tableTypes != null) { + this.tableTypes.addAll(tableTypes) + } + + override def runInternal(): Unit = { + setState(OperationState.RUNNING) + // Always use the latest class loader provided by executionHive's state. + val executionHiveClassLoader = sqlContext.sharedState.jarClassLoader + Thread.currentThread().setContextClassLoader(executionHiveClassLoader) + + val catalog = sqlContext.sessionState.catalog + val schemaPattern = convertSchemaPattern(schemaName) + val matchingDbs = catalog.listDatabases(schemaPattern) + + if (isAuthV2Enabled) { + val privObjs = + HivePrivilegeObjectUtils.getHivePrivDbObjects(seqAsJavaListConverter(matchingDbs).asJava) + val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName" + authorizeMetaGets(HiveOperationType.GET_TABLES, privObjs, cmdStr) + } + + val tablePattern = convertIdentifierPattern(tableName, true) + matchingDbs.foreach { dbName => + catalog.listTables(dbName, tablePattern).foreach { tableIdentifier => + val catalogTable = catalog.getTableMetadata(tableIdentifier) + val tableType = tableTypeString(catalogTable.tableType) + if (tableTypes == null || tableTypes.isEmpty || tableTypes.contains(tableType)) { + val rowData = Array[AnyRef]( + "", + catalogTable.database, + catalogTable.identifier.table, + tableType, + catalogTable.comment.getOrElse("")) + rowSet.addRow(rowData) + } + } + } + setState(OperationState.FINISHED) + } + + private def tableTypeString(tableType: CatalogTableType): String = tableType match { + case EXTERNAL | MANAGED => "TABLE" + case VIEW => "VIEW" + case t => + throw new IllegalArgumentException(s"Unknown table type is found at showCreateHiveTable: $t") + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 85b6c7134755..7947d1785a8f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql.hive.thriftserver.server -import java.util.{Map => JMap} +import java.util.{List => JList, Map => JMap} import java.util.concurrent.ConcurrentHashMap import org.apache.hive.service.cli._ -import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, GetSchemasOperation, Operation, OperationManager} +import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, GetSchemasOperation, MetadataOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.HiveUtils -import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation, SparkGetSchemasOperation} +import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation, SparkGetSchemasOperation, SparkGetTablesOperation} import org.apache.spark.sql.internal.SQLConf /** @@ -76,6 +76,22 @@ private[thriftserver] class SparkSQLOperationManager() operation } + override def newGetTablesOperation( + parentSession: HiveSession, + catalogName: String, + schemaName: String, + tableName: String, + tableTypes: JList[String]): MetadataOperation = synchronized { + val sqlContext = sessionToContexts.get(parentSession.getSessionHandle) + require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" + + " initialized or had already closed.") + val operation = new SparkGetTablesOperation(sqlContext, parentSession, + catalogName, schemaName, tableName, tableTypes) + handleToOperation.put(operation.getHandle, operation) + logDebug(s"Created GetTablesOperation with session=$parentSession.") + operation + } + def setConfMap(conf: SQLConf, confMap: java.util.Map[String, String]): Unit = { val iterator = confMap.entrySet().iterator() while (iterator.hasNext) { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index f9509aed4aaa..0f53fcd327f1 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -280,7 +280,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { var defaultV2: String = null var data: ArrayBuffer[Int] = null - withMultipleConnectionJdbcStatement("test_map")( + withMultipleConnectionJdbcStatement("test_map", "db1.test_map2")( // create table { statement => diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala index 9a997ae01df9..bf9982388d6b 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.thriftserver -import java.util.Properties +import java.util.{Arrays => JArrays, List => JList, Properties} import org.apache.hive.jdbc.{HiveConnection, HiveQueryResultSet, Utils => JdbcUtils} import org.apache.hive.service.auth.PlainSaslHelper @@ -100,4 +100,89 @@ class SparkMetadataOperationSuite extends HiveThriftJdbcTest { } } } + + test("Spark's own GetTablesOperation(SparkGetTablesOperation)") { + def testGetTablesOperation( + schema: String, + tableNamePattern: String, + tableTypes: JList[String])(f: HiveQueryResultSet => Unit): Unit = { + val rawTransport = new TSocket("localhost", serverPort) + val connection = new HiveConnection(s"jdbc:hive2://localhost:$serverPort", new Properties) + val user = System.getProperty("user.name") + val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) + val client = new TCLIService.Client(new TBinaryProtocol(transport)) + transport.open() + + var rs: HiveQueryResultSet = null + + try { + val openResp = client.OpenSession(new TOpenSessionReq) + val sessHandle = openResp.getSessionHandle + + val getTableReq = new TGetTablesReq(sessHandle) + getTableReq.setSchemaName(schema) + getTableReq.setTableName(tableNamePattern) + getTableReq.setTableTypes(tableTypes) + + val getTableResp = client.GetTables(getTableReq) + + JdbcUtils.verifySuccess(getTableResp.getStatus) + + rs = new HiveQueryResultSet.Builder(connection) + .setClient(client) + .setSessionHandle(sessHandle) + .setStmtHandle(getTableResp.getOperationHandle) + .build() + + f(rs) + } finally { + rs.close() + connection.close() + transport.close() + rawTransport.close() + } + } + + def checkResult(tableNames: Seq[String], rs: HiveQueryResultSet): Unit = { + if (tableNames.nonEmpty) { + for (i <- tableNames.indices) { + assert(rs.next()) + assert(rs.getString("TABLE_NAME") === tableNames(i)) + } + } else { + assert(!rs.next()) + } + } + + withJdbcStatement("table1", "table2") { statement => + Seq( + "CREATE TABLE table1(key INT, val STRING)", + "CREATE TABLE table2(key INT, val STRING)", + "CREATE VIEW view1 AS SELECT * FROM table2").foreach(statement.execute) + + testGetTablesOperation("%", "%", null) { rs => + checkResult(Seq("table1", "table2", "view1"), rs) + } + + testGetTablesOperation("%", "table1", null) { rs => + checkResult(Seq("table1"), rs) + } + + testGetTablesOperation("%", "table_not_exist", null) { rs => + checkResult(Seq.empty, rs) + } + + testGetTablesOperation("%", "%", JArrays.asList("TABLE")) { rs => + checkResult(Seq("table1", "table2"), rs) + } + + testGetTablesOperation("%", "%", JArrays.asList("VIEW")) { rs => + checkResult(Seq("view1"), rs) + } + + testGetTablesOperation("%", "%", JArrays.asList("TABLE", "VIEW")) { rs => + checkResult(Seq("table1", "table2", "view1"), rs) + } + } + } } From 8290e5eccb22e6b865075a0fe62772ce072f0a08 Mon Sep 17 00:00:00 2001 From: liuxian Date: Mon, 18 Feb 2019 17:20:58 +0800 Subject: [PATCH 14/19] [SPARK-26353][SQL] Add typed aggregate functions(max/min) to the example module. ## What changes were proposed in this pull request? Add typed aggregate functions(max/min) to the example module. ## How was this patch tested? Manual testing: running typed minimum: ``` +-----+----------------------+ |value|TypedMin(scala.Tuple2)| +-----+----------------------+ | 0| [0.0]| | 2| [2.0]| | 1| [1.0]| +-----+----------------------+ ``` running typed maximum: ``` +-----+----------------------+ |value|TypedMax(scala.Tuple2)| +-----+----------------------+ | 0| [18]| | 2| [17]| | 1| [19]| +-----+----------------------+ ``` Closes #23304 from 10110346/typedminmax. Authored-by: liuxian Signed-off-by: Hyukjin Kwon --- .../examples/sql/SimpleTypedAggregator.scala | 74 +++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala index f8af91980bde..5510f0019353 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala @@ -44,6 +44,12 @@ object SimpleTypedAggregator { println("running typed average:") ds.groupByKey(_._1).agg(new TypedAverage[(Long, Long)](_._2.toDouble).toColumn).show() + println("running typed minimum:") + ds.groupByKey(_._1).agg(new TypedMin[(Long, Long)](_._2.toDouble).toColumn).show() + + println("running typed maximum:") + ds.groupByKey(_._1).agg(new TypedMax[(Long, Long)](_._2).toColumn).show() + spark.stop() } } @@ -84,3 +90,71 @@ class TypedAverage[IN](val f: IN => Double) extends Aggregator[IN, (Double, Long } override def outputEncoder: Encoder[Double] = Encoders.scalaDouble } + +class TypedMin[IN](val f: IN => Double) extends Aggregator[IN, MutableDouble, Option[Double]] { + override def zero: MutableDouble = null + override def reduce(b: MutableDouble, a: IN): MutableDouble = { + if (b == null) { + new MutableDouble(f(a)) + } else { + b.value = math.min(b.value, f(a)) + b + } + } + override def merge(b1: MutableDouble, b2: MutableDouble): MutableDouble = { + if (b1 == null) { + b2 + } else if (b2 == null) { + b1 + } else { + b1.value = math.min(b1.value, b2.value) + b1 + } + } + override def finish(reduction: MutableDouble): Option[Double] = { + if (reduction != null) { + Some(reduction.value) + } else { + None + } + } + + override def bufferEncoder: Encoder[MutableDouble] = Encoders.kryo[MutableDouble] + override def outputEncoder: Encoder[Option[Double]] = Encoders.product[Option[Double]] +} + +class TypedMax[IN](val f: IN => Long) extends Aggregator[IN, MutableLong, Option[Long]] { + override def zero: MutableLong = null + override def reduce(b: MutableLong, a: IN): MutableLong = { + if (b == null) { + new MutableLong(f(a)) + } else { + b.value = math.max(b.value, f(a)) + b + } + } + override def merge(b1: MutableLong, b2: MutableLong): MutableLong = { + if (b1 == null) { + b2 + } else if (b2 == null) { + b1 + } else { + b1.value = math.max(b1.value, b2.value) + b1 + } + } + override def finish(reduction: MutableLong): Option[Long] = { + if (reduction != null) { + Some(reduction.value) + } else { + None + } + } + + override def bufferEncoder: Encoder[MutableLong] = Encoders.kryo[MutableLong] + override def outputEncoder: Encoder[Option[Long]] = Encoders.product[Option[Long]] +} + +class MutableLong(var value: Long) extends Serializable + +class MutableDouble(var value: Double) extends Serializable From 59eb34b82c023ac56dcd08a4ceccdf612bfa7f29 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 18 Feb 2019 17:22:06 +0800 Subject: [PATCH 15/19] [SPARK-26889][SS][DOCS] Fix timestamp type in Structured Streaming + Kafka Integration Guide ## What changes were proposed in this pull request? ``` $ spark-shell --packages org.apache.spark:spark-sql-kafka-0-10_2.11:3.0.0-SNAPSHOT ... scala> val df = spark.read.format("kafka").option("kafka.bootstrap.servers", "foo").option("subscribe", "bar").load().printSchema() root |-- key: binary (nullable = true) |-- value: binary (nullable = true) |-- topic: string (nullable = true) |-- partition: integer (nullable = true) |-- offset: long (nullable = true) |-- timestamp: timestamp (nullable = true) |-- timestampType: integer (nullable = true) df: Unit = () ``` In the doc timestamp type is `long` and in this PR I've changed it to `timestamp`. ## How was this patch tested? cd docs/ SKIP_API=1 jekyll build Manual webpage check. Closes #23796 from gaborgsomogyi/SPARK-26889. Authored-by: Gabor Somogyi Signed-off-by: Hyukjin Kwon --- docs/structured-streaming-kafka-integration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index c19aa5c504b0..425110a8bfa5 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -265,7 +265,7 @@ Each row in the source has the following schema:

- + From a0e81fcfe8a6dbb246f8b170b6f5e203ab194d7e Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 18 Feb 2019 21:13:00 +0800 Subject: [PATCH 16/19] [SPARK-26744][SPARK-26744][SQL][HOTFOX] Disable schema validation tests for FileDataSourceV2 (partially revert ) ## What changes were proposed in this pull request? This PR partially revert SPARK-26744. https://github.com/apache/spark/commit/60caa92deaf6941f58da82dcc0962ebf3a598ced and https://github.com/apache/spark/commit/4dce45a5992e6a89a26b5a0739b33cfeaf979208 were merged at similar time range independently. So the test failures were not caught. - https://github.com/apache/spark/commit/60caa92deaf6941f58da82dcc0962ebf3a598ced happened to add a schema reading logic in writing path for overwrite mode as well. - https://github.com/apache/spark/commit/4dce45a5992e6a89a26b5a0739b33cfeaf979208 added some tests with overwrite modes with migrated ORC v2. And the tests looks starting to fail. I guess the discussion won't be short (see https://github.com/apache/spark/pull/23606#discussion_r257675083) and this PR proposes to disable the tests added at https://github.com/apache/spark/commit/4dce45a5992e6a89a26b5a0739b33cfeaf979208 to unblock other PRs for now. ## How was this patch tested? Existing tests. Closes #23828 from HyukjinKwon/SPARK-26744. Authored-by: Hyukjin Kwon Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index e0c0484593d9..58522f7b1376 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -329,7 +329,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") { withTempDir { dir => val tempDir = new File(dir, "files").getCanonicalPath - Seq(true, false).foreach { useV1 => + Seq(true).foreach { useV1 => val useV1List = if (useV1) { "orc" } else { @@ -374,7 +374,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") { - Seq(true, false).foreach { useV1 => + Seq(true).foreach { useV1 => val useV1List = if (useV1) { "orc" } else { From f85ed9a3e55083b0de0e20a37775efa92d248a4f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 18 Feb 2019 16:17:24 -0800 Subject: [PATCH 17/19] [SPARK-26785][SQL] data source v2 API refactor: streaming write ## What changes were proposed in this pull request? Continue the API refactor for streaming write, according to the [doc](https://docs.google.com/document/d/1vI26UEuDpVuOjWw4WPoH2T6y8WAekwtI7qoowhOFnI4/edit?usp=sharing). The major changes: 1. rename `StreamingWriteSupport` to `StreamingWrite` 2. add `WriteBuilder.buildForStreaming` 3. update existing sinks, to move the creation of `StreamingWrite` to `Table` ## How was this patch tested? existing tests Closes #23702 from cloud-fan/stream-write. Authored-by: Wenchen Fan Signed-off-by: gatorsmile --- .../sql/kafka010/KafkaSourceProvider.scala | 42 +++++---- ...upport.scala => KafkaStreamingWrite.scala} | 8 +- .../sql/sources/v2/SessionConfigSupport.java | 4 +- .../v2/StreamingWriteSupportProvider.java | 54 ------------ .../sql/sources/v2/SupportsBatchWrite.java | 2 +- .../sources/v2/SupportsStreamingWrite.java | 33 +++++++ .../spark/sql/sources/v2/TableProvider.java | 3 +- .../sql/sources/v2/writer/WriteBuilder.java | 9 +- .../v2/writer/WriterCommitMessage.java | 4 +- .../streaming/StreamingDataWriterFactory.java | 2 +- ...gWriteSupport.java => StreamingWrite.java} | 21 ++++- .../streaming/SupportsOutputMode.java} | 17 ++-- .../apache/spark/sql/DataFrameReader.scala | 2 +- .../datasources/noop/NoopDataSource.scala | 26 ++---- .../v2/DataSourceV2StringFormat.scala | 88 ------------------- .../datasources/v2/DataSourceV2Utils.scala | 43 ++++----- .../streaming/MicroBatchExecution.scala | 20 +++-- .../streaming/StreamingRelation.scala | 6 +- .../sql/execution/streaming/console.scala | 43 ++++++--- .../continuous/ContinuousExecution.scala | 25 +++--- .../continuous/EpochCoordinator.scala | 6 +- .../WriteToContinuousDataSource.scala | 6 +- .../WriteToContinuousDataSourceExec.scala | 13 +-- ...eWriteSupport.scala => ConsoleWrite.scala} | 6 +- ...rovider.scala => ForeachWriterTable.scala} | 76 +++++++++------- .../streaming/sources/MicroBatchWrite.scala | 4 +- .../sources/RateStreamProvider.scala | 3 +- .../sources/TextSocketSourceProvider.scala | 3 +- .../streaming/sources/memoryV2.scala | 42 ++++++--- .../sql/streaming/DataStreamReader.scala | 2 +- .../sql/streaming/DataStreamWriter.scala | 50 ++++++----- .../sql/streaming/StreamingQueryManager.scala | 4 +- ...pache.spark.sql.sources.DataSourceRegister | 2 +- .../streaming/MemorySinkV2Suite.scala | 6 +- .../sources/v2/DataSourceV2UtilsSuite.scala | 4 +- .../sources/v2/SimpleWritableDataSource.scala | 3 +- .../ContinuousQueuedDataReaderSuite.scala | 4 +- .../continuous/EpochCoordinatorSuite.scala | 6 +- .../sources/StreamingDataSourceV2Suite.scala | 70 +++++++++------ 39 files changed, 373 insertions(+), 389 deletions(-) rename external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/{KafkaStreamingWriteSupport.scala => KafkaStreamingWrite.scala} (95%) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/{StreamingWriteSupport.java => StreamingWrite.java} (73%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{DataSourceV2.java => writer/streaming/SupportsOutputMode.java} (67%) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/{ConsoleWriteSupport.scala => ConsoleWrite.scala} (94%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/{ForeachWriteSupportProvider.scala => ForeachWriterTable.scala} (66%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 9238899b0c00..6994517b27d6 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -33,7 +33,8 @@ import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.WriteBuilder +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingWrite, SupportsOutputMode} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -47,7 +48,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider - with StreamingWriteSupportProvider with TableProvider with Logging { import KafkaSourceProvider._ @@ -180,20 +180,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - import scala.collection.JavaConverters._ - - val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) - // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. - val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) - - new KafkaStreamingWriteSupport(topic, producerParams, schema) - } - private def strategy(caseInsensitiveParams: Map[String, String]) = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { case ("assign", value) => @@ -365,7 +351,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } class KafkaTable(strategy: => ConsumerStrategy) extends Table - with SupportsMicroBatchRead with SupportsContinuousRead { + with SupportsMicroBatchRead with SupportsContinuousRead with SupportsStreamingWrite { override def name(): String = s"Kafka $strategy" @@ -374,6 +360,28 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new ScanBuilder { override def build(): Scan = new KafkaScan(options) } + + override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + new WriteBuilder with SupportsOutputMode { + private var inputSchema: StructType = _ + + override def withInputDataSchema(schema: StructType): WriteBuilder = { + this.inputSchema = schema + this + } + + override def outputMode(mode: OutputMode): WriteBuilder = this + + override def buildForStreaming(): StreamingWrite = { + import scala.collection.JavaConverters._ + + assert(inputSchema != null) + val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) + val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) + new KafkaStreamingWrite(topic, producerParams, inputSchema) + } + } + } } class KafkaScan(options: DataSourceOptions) extends Scan { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala similarity index 95% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala index 0d831c388460..e3101e157208 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType /** @@ -33,18 +33,18 @@ import org.apache.spark.sql.types.StructType case object KafkaWriterCommitMessage extends WriterCommitMessage /** - * A [[StreamingWriteSupport]] for Kafka writing. Responsible for generating the writer factory. + * A [[StreamingWrite]] for Kafka writing. Responsible for generating the writer factory. * * @param topic The topic this writer is responsible for. If None, topic will be inferred from * a `topic` field in the incoming data. * @param producerParams Parameters for Kafka producers in each task. * @param schema The schema of the input data. */ -class KafkaStreamingWriteSupport( +class KafkaStreamingWrite( topic: Option[String], producerParams: ju.Map[String, Object], schema: StructType) - extends StreamingWriteSupport { + extends StreamingWrite { validateQuery(schema.toAttributes, producerParams, topic) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index c00abd9b685b..d27fbfdd1461 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -20,12 +20,12 @@ import org.apache.spark.annotation.Evolving; /** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * A mix-in interface for {@link TableProvider}. Data sources can implement this interface to * propagate session configs with the specified key-prefix to all data source operations in this * session. */ @Evolving -public interface SessionConfigSupport extends DataSourceV2 { +public interface SessionConfigSupport extends TableProvider { /** * Key prefix of the session configs to propagate, which is usually the data source name. Spark diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java deleted file mode 100644 index 8ac9c5175086..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; -import org.apache.spark.sql.streaming.OutputMode; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability for structured streaming. - * - * This interface is used to create {@link StreamingWriteSupport} instances when end users run - * {@code Dataset.writeStream.format(...).option(...).start()}. - */ -@Evolving -public interface StreamingWriteSupportProvider extends DataSourceV2, BaseStreamingSink { - - /** - * Creates a {@link StreamingWriteSupport} instance to save the data to this data source, which is - * called by Spark at the beginning of each streaming query. - * - * @param queryId A unique string for the writing query. It's possible that there are many - * writing queries running at the same time, and the returned - * {@link StreamingWriteSupport} can use this id to distinguish itself from others. - * @param schema the schema of the data to be written. - * @param mode the output mode which determines what successive epoch output means to this - * sink, please refer to {@link OutputMode} for more details. - * @param options the options for the returned data source writer, which is an immutable - * case-insensitive string-to-string map. - */ - StreamingWriteSupport createStreamingWriteSupport( - String queryId, - StructType schema, - OutputMode mode, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java index 08caadd5308e..b2cd97a2f533 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java @@ -24,7 +24,7 @@ * An empty mix-in interface for {@link Table}, to indicate this table supports batch write. *

* If a {@link Table} implements this interface, the - * {@link SupportsWrite#newWriteBuilder(DataSourceOptions)} must return a {@link WriteBuilder} + * {@link SupportsWrite#newWriteBuilder(DataSourceOptions)} must return a {@link WriteBuilder} * with {@link WriteBuilder#buildForBatch()} implemented. *

*/ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java new file mode 100644 index 000000000000..1050d35250c1 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.execution.streaming.BaseStreamingSink; +import org.apache.spark.sql.sources.v2.writer.WriteBuilder; + +/** + * An empty mix-in interface for {@link Table}, to indicate this table supports streaming write. + *

+ * If a {@link Table} implements this interface, the + * {@link SupportsWrite#newWriteBuilder(DataSourceOptions)} must return a {@link WriteBuilder} + * with {@link WriteBuilder#buildForStreaming()} implemented. + *

+ */ +@Evolving +public interface SupportsStreamingWrite extends SupportsWrite, BaseStreamingSink { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java index 855d5efe0c69..a9b83b6de995 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java @@ -29,8 +29,7 @@ *

*/ @Evolving -// TODO: do not extend `DataSourceV2`, after we finish the API refactor completely. -public interface TableProvider extends DataSourceV2 { +public interface TableProvider { /** * Return a {@link Table} instance to do read/write with user-specified options. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java index e861c72af9e6..07529fe1dee9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java @@ -20,6 +20,7 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.SupportsBatchWrite; import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite; import org.apache.spark.sql.types.StructType; /** @@ -64,6 +65,12 @@ default WriteBuilder withInputDataSchema(StructType schema) { * {@link SupportsSaveMode}. */ default BatchWrite buildForBatch() { - throw new UnsupportedOperationException("Batch scans are not supported"); + throw new UnsupportedOperationException(getClass().getName() + + " does not support batch write"); + } + + default StreamingWrite buildForStreaming() { + throw new UnsupportedOperationException(getClass().getName() + + " does not support streaming write"); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java index 6334c8f64309..23e8580c404d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -20,12 +20,12 @@ import java.io.Serializable; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite; /** * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side * as the input parameter of {@link BatchWrite#commit(WriterCommitMessage[])} or - * {@link StreamingWriteSupport#commit(long, WriterCommitMessage[])}. + * {@link StreamingWrite#commit(long, WriterCommitMessage[])}. * * This is an empty interface, data sources should define their own message class and use it when * generating messages at executor side and handling the messages at driver side. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java index 7d3d21cb2b63..af2f03c9d419 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java @@ -26,7 +26,7 @@ /** * A factory of {@link DataWriter} returned by - * {@link StreamingWriteSupport#createStreamingWriterFactory()}, which is responsible for creating + * {@link StreamingWrite#createStreamingWriterFactory()}, which is responsible for creating * and initializing the actual data writer at executor side. * * Note that, the writer factory will be serialized and sent to executors, then the data writer diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java similarity index 73% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java index 84cfbf2dda48..5617f1cdc0ef 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java @@ -22,13 +22,26 @@ import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * An interface that defines how to write the data to data source for streaming processing. + * An interface that defines how to write the data to data source in streaming queries. * - * Streaming queries are divided into intervals of data called epochs, with a monotonically - * increasing numeric ID. This writer handles commits and aborts for each successive epoch. + * The writing procedure is: + * 1. Create a writer factory by {@link #createStreamingWriterFactory()}, serialize and send it to + * all the partitions of the input data(RDD). + * 2. For each epoch in each partition, create the data writer, and write the data of the epoch in + * the partition with this writer. If all the data are written successfully, call + * {@link DataWriter#commit()}. If exception happens during the writing, call + * {@link DataWriter#abort()}. + * 3. If writers in all partitions of one epoch are successfully committed, call + * {@link #commit(long, WriterCommitMessage[])}. If some writers are aborted, or the job failed + * with an unknown reason, call {@link #abort(long, WriterCommitMessage[])}. + * + * While Spark will retry failed writing tasks, Spark won't retry failed writing jobs. Users should + * do it manually in their Spark applications if they want to retry. + * + * Please refer to the documentation of commit/abort methods for detailed specifications. */ @Evolving -public interface StreamingWriteSupport { +public interface StreamingWrite { /** * Creates a writer factory which will be serialized and sent to executors. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsOutputMode.java similarity index 67% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsOutputMode.java index 43bdcca70cb0..832dcfa145d1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsOutputMode.java @@ -15,12 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.sources.v2.writer.streaming; -import org.apache.spark.annotation.Evolving; +import org.apache.spark.annotation.Unstable; +import org.apache.spark.sql.sources.v2.writer.WriteBuilder; +import org.apache.spark.sql.streaming.OutputMode; -/** - * TODO: remove it when we finish the API refactor for streaming write side. - */ -@Evolving -public interface DataSourceV2 {} +// TODO: remove it when we have `SupportsTruncate` +@Unstable +public interface SupportsOutputMode extends WriteBuilder { + + WriteBuilder outputMode(OutputMode mode); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 713c9a9faa74..e757785e5664 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -205,7 +205,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { if (classOf[TableProvider].isAssignableFrom(cls)) { val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - ds = provider, conf = sparkSession.sessionState.conf) + source = provider, conf = sparkSession.sessionState.conf) val pathsOption = { val objectMapper = new ObjectMapper() DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 452ebbbeb99c..8f2072c586a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite, SupportsOutputMode} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -30,30 +30,23 @@ import org.apache.spark.sql.types.StructType * This is no-op datasource. It does not do anything besides consuming its input. * This can be useful for benchmarking or to cache data without any additional overhead. */ -class NoopDataSource - extends DataSourceV2 - with TableProvider - with DataSourceRegister - with StreamingWriteSupportProvider { - +class NoopDataSource extends TableProvider with DataSourceRegister { override def shortName(): String = "noop" override def getTable(options: DataSourceOptions): Table = NoopTable - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = NoopStreamingWriteSupport } -private[noop] object NoopTable extends Table with SupportsBatchWrite { +private[noop] object NoopTable extends Table with SupportsBatchWrite with SupportsStreamingWrite { override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = NoopWriteBuilder override def name(): String = "noop-table" override def schema(): StructType = new StructType() } -private[noop] object NoopWriteBuilder extends WriteBuilder with SupportsSaveMode { - override def buildForBatch(): BatchWrite = NoopBatchWrite +private[noop] object NoopWriteBuilder extends WriteBuilder + with SupportsSaveMode with SupportsOutputMode { override def mode(mode: SaveMode): WriteBuilder = this + override def outputMode(mode: OutputMode): WriteBuilder = this + override def buildForBatch(): BatchWrite = NoopBatchWrite + override def buildForStreaming(): StreamingWrite = NoopStreamingWrite } private[noop] object NoopBatchWrite extends BatchWrite { @@ -72,7 +65,7 @@ private[noop] object NoopWriter extends DataWriter[InternalRow] { override def abort(): Unit = {} } -private[noop] object NoopStreamingWriteSupport extends StreamingWriteSupport { +private[noop] object NoopStreamingWrite extends StreamingWrite { override def createStreamingWriterFactory(): StreamingDataWriterFactory = NoopStreamingDataWriterFactory override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} @@ -85,4 +78,3 @@ private[noop] object NoopStreamingDataWriterFactory extends StreamingDataWriterF taskId: Long, epochId: Long): DataWriter[InternalRow] = NoopWriter } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala deleted file mode 100644 index f11703c8a277..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import org.apache.commons.lang3.StringUtils - -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.util.Utils - -/** - * A trait that can be used by data source v2 related query plans(both logical and physical), to - * provide a string format of the data source information for explain. - */ -trait DataSourceV2StringFormat { - - /** - * The instance of this data source implementation. Note that we only consider its class in - * equals/hashCode, not the instance itself. - */ - def source: DataSourceV2 - - /** - * The output of the data source reader, w.r.t. column pruning. - */ - def output: Seq[Attribute] - - /** - * The options for this data source reader. - */ - def options: Map[String, String] - - /** - * The filters which have been pushed to the data source. - */ - def pushedFilters: Seq[Expression] - - private def sourceName: String = source match { - case registered: DataSourceRegister => registered.shortName() - // source.getClass.getSimpleName can cause Malformed class name error, - // call safer `Utils.getSimpleName` instead - case _ => Utils.getSimpleName(source.getClass) - } - - def metadataString(maxFields: Int): String = { - val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] - - if (pushedFilters.nonEmpty) { - entries += "Filters" -> pushedFilters.mkString("[", ", ", "]") - } - - // TODO: we should only display some standard options like path, table, etc. - if (options.nonEmpty) { - entries += "Options" -> Utils.redact(options).map { - case (k, v) => s"$k=$v" - }.mkString("[", ",", "]") - } - - val outputStr = truncatedString(output, "[", ", ", "]", maxFields) - - val entriesStr = if (entries.nonEmpty) { - truncatedString(entries.map { - case (key, value) => key + ": " + StringUtils.abbreviate(value, 100) - }, " (", ", ", ")", maxFields) - } else { - "" - } - - s"$sourceName$outputStr$entriesStr" - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index e9cc3991155c..30897d86f817 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -21,8 +21,7 @@ import java.util.regex.Pattern import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceV2, SessionConfigSupport} +import org.apache.spark.sql.sources.v2.{SessionConfigSupport, TableProvider} private[sql] object DataSourceV2Utils extends Logging { @@ -34,34 +33,28 @@ private[sql] object DataSourceV2Utils extends Logging { * `spark.datasource.$keyPrefix`. A session config `spark.datasource.$keyPrefix.xxx -> yyy` will * be transformed into `xxx -> yyy`. * - * @param ds a [[DataSourceV2]] object + * @param source a [[TableProvider]] object * @param conf the session conf * @return an immutable map that contains all the extracted and transformed k/v pairs. */ - def extractSessionConfigs(ds: DataSourceV2, conf: SQLConf): Map[String, String] = ds match { - case cs: SessionConfigSupport => - val keyPrefix = cs.keyPrefix() - require(keyPrefix != null, "The data source config key prefix can't be null.") - - val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.+)") - - conf.getAllConfs.flatMap { case (key, value) => - val m = pattern.matcher(key) - if (m.matches() && m.groupCount() > 0) { - Seq((m.group(1), value)) - } else { - Seq.empty + def extractSessionConfigs(source: TableProvider, conf: SQLConf): Map[String, String] = { + source match { + case cs: SessionConfigSupport => + val keyPrefix = cs.keyPrefix() + require(keyPrefix != null, "The data source config key prefix can't be null.") + + val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.+)") + + conf.getAllConfs.flatMap { case (key, value) => + val m = pattern.matcher(key) + if (m.matches() && m.groupCount() > 0) { + Seq((m.group(1), value)) + } else { + Seq.empty + } } - } - - case _ => Map.empty - } - def failForUserSpecifiedSchema[T](ds: DataSourceV2): T = { - val name = ds match { - case register: DataSourceRegister => register.shortName() - case _ => ds.getClass.getName + case _ => Map.empty } - throw new UnsupportedOperationException(name + " source does not support user-specified schema") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 2c339759f95b..cca279030dfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWrite, RateCo import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.writer.streaming.SupportsOutputMode import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.Clock @@ -513,13 +514,16 @@ class MicroBatchExecution( val triggerLogicalPlan = sink match { case _: Sink => newAttributePlan - case s: StreamingWriteSupportProvider => - val writer = s.createStreamingWriteSupport( - s"$runId", - newAttributePlan.schema, - outputMode, - new DataSourceOptions(extraOptions.asJava)) - WriteToDataSourceV2(new MicroBatchWrite(currentBatchId, writer), newAttributePlan) + case s: SupportsStreamingWrite => + // TODO: we should translate OutputMode to concrete write actions like truncate, but + // the truncate action is being developed in SPARK-26666. + val writeBuilder = s.newWriteBuilder(new DataSourceOptions(extraOptions.asJava)) + .withQueryId(runId.toString) + .withInputDataSchema(newAttributePlan.schema) + val streamingWrite = writeBuilder.asInstanceOf[SupportsOutputMode] + .outputMode(outputMode) + .buildForStreaming() + WriteToDataSourceV2(new MicroBatchWrite(currentBatchId, streamingWrite), newAttributePlan) case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } @@ -549,7 +553,7 @@ class MicroBatchExecution( SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) - case _: StreamingWriteSupportProvider => + case _: SupportsStreamingWrite => // This doesn't accumulate any data - it just forces execution of the microbatch writer. nextBatch.collect() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 83d38dcade7e..1b7aa548e6d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.{DataSourceV2, Table} +import org.apache.spark.sql.sources.v2.{Table, TableProvider} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -86,13 +86,13 @@ case class StreamingExecutionRelation( // know at read time whether the query is continuous or not, so we need to be able to // swap a V1 relation back in. /** - * Used to link a [[DataSourceV2]] into a streaming + * Used to link a [[TableProvider]] into a streaming * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. This is only used for creating * a streaming [[org.apache.spark.sql.DataFrame]] from [[org.apache.spark.sql.DataFrameReader]], * and should be converted before passing to [[StreamExecution]]. */ case class StreamingRelationV2( - dataSource: DataSourceV2, + source: TableProvider, sourceName: String, table: Table, extraOptions: Map[String, String], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 9c5c16f4f5d1..348bc767b2c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.ConsoleWriteSupport +import org.apache.spark.sql.execution.streaming.sources.ConsoleWrite import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.writer.WriteBuilder +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingWrite, SupportsOutputMode} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -30,17 +31,12 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) override def schema: StructType = data.schema } -class ConsoleSinkProvider extends DataSourceV2 - with StreamingWriteSupportProvider +class ConsoleSinkProvider extends TableProvider with DataSourceRegister with CreatableRelationProvider { - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - new ConsoleWriteSupport(schema, options) + override def getTable(options: DataSourceOptions): Table = { + ConsoleTable } def createRelation( @@ -60,3 +56,28 @@ class ConsoleSinkProvider extends DataSourceV2 def shortName(): String = "console" } + +object ConsoleTable extends Table with SupportsStreamingWrite { + + override def name(): String = "console" + + override def schema(): StructType = StructType(Nil) + + override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + new WriteBuilder with SupportsOutputMode { + private var inputSchema: StructType = _ + + override def withInputDataSchema(schema: StructType): WriteBuilder = { + this.inputSchema = schema + this + } + + override def outputMode(mode: OutputMode): WriteBuilder = this + + override def buildForStreaming(): StreamingWrite = { + assert(inputSchema != null) + new ConsoleWrite(inputSchema, options) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index b22795d20776..20101c7fda32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -32,8 +32,9 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming.{StreamingRelationV2, _} import org.apache.spark.sql.sources.v2 -import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider, SupportsContinuousRead} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsContinuousRead, SupportsStreamingWrite} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.streaming.SupportsOutputMode import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.Clock @@ -42,7 +43,7 @@ class ContinuousExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: StreamingWriteSupportProvider, + sink: SupportsStreamingWrite, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, @@ -174,12 +175,15 @@ class ContinuousExecution( "CurrentTimestamp and CurrentDate not yet supported for continuous processing") } - val writer = sink.createStreamingWriteSupport( - s"$runId", - withNewSources.schema, - outputMode, - new DataSourceOptions(extraOptions.asJava)) - val planWithSink = WriteToContinuousDataSource(writer, withNewSources) + // TODO: we should translate OutputMode to concrete write actions like truncate, but + // the truncate action is being developed in SPARK-26666. + val writeBuilder = sink.newWriteBuilder(new DataSourceOptions(extraOptions.asJava)) + .withQueryId(runId.toString) + .withInputDataSchema(withNewSources.schema) + val streamingWrite = writeBuilder.asInstanceOf[SupportsOutputMode] + .outputMode(outputMode) + .buildForStreaming() + val planWithSink = WriteToContinuousDataSource(streamingWrite, withNewSources) reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( @@ -214,9 +218,8 @@ class ContinuousExecution( trigger.asInstanceOf[ContinuousTrigger].intervalMs.toString) // Use the parent Spark session for the endpoint since it's where this query ID is registered. - val epochEndpoint = - EpochCoordinatorRef.create( - writer, stream, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) + val epochEndpoint = EpochCoordinatorRef.create( + streamingWrite, stream, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index d1bda79f4b6e..a99842220424 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -25,7 +25,7 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeR import org.apache.spark.sql.SparkSession import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable @@ -82,7 +82,7 @@ private[sql] object EpochCoordinatorRef extends Logging { * Create a reference to a new [[EpochCoordinator]]. */ def create( - writeSupport: StreamingWriteSupport, + writeSupport: StreamingWrite, stream: ContinuousStream, query: ContinuousExecution, epochCoordinatorId: String, @@ -115,7 +115,7 @@ private[sql] object EpochCoordinatorRef extends Logging { * have both committed and reported an end offset for a given epoch. */ private[continuous] class EpochCoordinator( - writeSupport: StreamingWriteSupport, + writeSupport: StreamingWrite, stream: ContinuousStream, query: ContinuousExecution, startEpoch: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala index 7ad21cc304e7..54f484c4adae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite /** * The logical plan for writing data in a continuous stream. */ -case class WriteToContinuousDataSource( - writeSupport: StreamingWriteSupport, query: LogicalPlan) extends LogicalPlan { +case class WriteToContinuousDataSource(write: StreamingWrite, query: LogicalPlan) + extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index 2178466d6314..2f3af6a6544c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -26,21 +26,22 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite /** - * The physical plan for writing data into a continuous processing [[StreamingWriteSupport]]. + * The physical plan for writing data into a continuous processing [[StreamingWrite]]. */ -case class WriteToContinuousDataSourceExec(writeSupport: StreamingWriteSupport, query: SparkPlan) - extends UnaryExecNode with Logging { +case class WriteToContinuousDataSourceExec(write: StreamingWrite, query: SparkPlan) + extends UnaryExecNode with Logging { + override def child: SparkPlan = query override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writerFactory = writeSupport.createStreamingWriterFactory() + val writerFactory = write.createStreamingWriterFactory() val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) - logInfo(s"Start processing data source write support: $writeSupport. " + + logInfo(s"Start processing data source write support: $write. " + s"The input RDD has ${rdd.partitions.length} partitions.") EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala index 833e62f35ede..f2ff30bcf1be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala @@ -22,12 +22,12 @@ import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ -class ConsoleWriteSupport(schema: StructType, options: DataSourceOptions) - extends StreamingWriteSupport with Logging { +class ConsoleWrite(schema: StructType, options: DataSourceOptions) + extends StreamingWrite with Logging { // Number of rows to display, by default 20 rows protected val numRowsToShow = options.getInt("numRows", 20) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala similarity index 66% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index 4218fd51ad20..6fbb59c43625 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -22,63 +22,73 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.python.PythonForeachWriter -import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsStreamingWrite, Table} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite, SupportsOutputMode} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType /** - * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified - * [[ForeachWriter]]. + * A write-only table for forwarding data into the specified [[ForeachWriter]]. * * @param writer The [[ForeachWriter]] to process all data. * @param converter An object to convert internal rows to target type T. Either it can be * a [[ExpressionEncoder]] or a direct converter function. * @tparam T The expected type of the sink. */ -case class ForeachWriteSupportProvider[T]( +case class ForeachWriterTable[T]( writer: ForeachWriter[T], converter: Either[ExpressionEncoder[T], InternalRow => T]) - extends StreamingWriteSupportProvider { - - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - new StreamingWriteSupport { - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - - override def createStreamingWriterFactory(): StreamingDataWriterFactory = { - val rowConverter: InternalRow => T = converter match { - case Left(enc) => - val boundEnc = enc.resolveAndBind( - schema.toAttributes, - SparkSession.getActiveSession.get.sessionState.analyzer) - boundEnc.fromRow - case Right(func) => - func - } - ForeachWriterFactory(writer, rowConverter) + extends Table with SupportsStreamingWrite { + + override def name(): String = "ForeachSink" + + override def schema(): StructType = StructType(Nil) + + override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + new WriteBuilder with SupportsOutputMode { + private var inputSchema: StructType = _ + + override def withInputDataSchema(schema: StructType): WriteBuilder = { + this.inputSchema = schema + this } - override def toString: String = "ForeachSink" + override def outputMode(mode: OutputMode): WriteBuilder = this + + override def buildForStreaming(): StreamingWrite = { + new StreamingWrite { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + override def createStreamingWriterFactory(): StreamingDataWriterFactory = { + val rowConverter: InternalRow => T = converter match { + case Left(enc) => + val boundEnc = enc.resolveAndBind( + inputSchema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + boundEnc.fromRow + case Right(func) => + func + } + ForeachWriterFactory(writer, rowConverter) + } + } + } } } } -object ForeachWriteSupportProvider { +object ForeachWriterTable { def apply[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T]): ForeachWriteSupportProvider[_] = { + encoder: ExpressionEncoder[T]): ForeachWriterTable[_] = { writer match { case pythonWriter: PythonForeachWriter => - new ForeachWriteSupportProvider[UnsafeRow]( + new ForeachWriterTable[UnsafeRow]( pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) case _ => - new ForeachWriteSupportProvider[T](writer, Left(encoder)) + new ForeachWriterTable[T](writer, Left(encoder)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala index 143235efee81..f3951897ea74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriter, DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} /** * A [[BatchWrite]] used to hook V2 stream writers into a microbatch plan. It implements * the non-streaming interface, forwarding the epoch ID determined at construction to a wrapped * streaming write support. */ -class MicroBatchWrite(eppchId: Long, val writeSupport: StreamingWriteSupport) extends BatchWrite { +class MicroBatchWrite(eppchId: Long, val writeSupport: StreamingWrite) extends BatchWrite { override def commit(messages: Array[WriterCommitMessage]): Unit = { writeSupport.commit(eppchId, messages) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 075c6b9362ba..3a0082536512 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -40,8 +40,7 @@ import org.apache.spark.sql.types._ * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ -class RateStreamProvider extends DataSourceV2 - with TableProvider with DataSourceRegister { +class RateStreamProvider extends TableProvider with DataSourceRegister { import RateStreamProvider._ override def getTable(options: DataSourceOptions): Table = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala index c3b24a8f65dd..8ac5bfc307aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala @@ -31,8 +31,7 @@ import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} -class TextSocketSourceProvider extends DataSourceV2 - with TableProvider with DataSourceRegister with Logging { +class TextSocketSourceProvider extends TableProvider with DataSourceRegister with Logging { private def checkParameters(params: DataSourceOptions): Unit = { logWarning("The socket source should not be used for production applications! " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index c50dc7bcb8da..3fc2cbe0fde5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -32,9 +32,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsStreamingWrite} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite, SupportsOutputMode} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -42,15 +42,31 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider - with MemorySinkBase with Logging { - - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - new MemoryStreamingWriteSupport(this, mode, schema) +class MemorySinkV2 extends SupportsStreamingWrite with MemorySinkBase with Logging { + + override def name(): String = "MemorySinkV2" + + override def schema(): StructType = StructType(Nil) + + override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + new WriteBuilder with SupportsOutputMode { + private var mode: OutputMode = _ + private var inputSchema: StructType = _ + + override def outputMode(mode: OutputMode): WriteBuilder = { + this.mode = mode + this + } + + override def withInputDataSchema(schema: StructType): WriteBuilder = { + this.inputSchema = schema + this + } + + override def buildForStreaming(): StreamingWrite = { + new MemoryStreamingWrite(MemorySinkV2.this, mode, inputSchema) + } + } } private case class AddedData(batchId: Long, data: Array[Row]) @@ -122,9 +138,9 @@ class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} -class MemoryStreamingWriteSupport( +class MemoryStreamingWrite( val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) - extends StreamingWriteSupport { + extends StreamingWrite { override def createStreamingWriterFactory: MemoryWriterFactory = { MemoryWriterFactory(outputMode, schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 866681838af8..ef21caa3ac29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -173,7 +173,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo ds match { case provider: TableProvider => val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - ds = provider, conf = sparkSession.sessionState.conf) + source = provider, conf = sparkSession.sessionState.conf) val options = sessionOptions ++ extraOptions val dsOptions = new DataSourceOptions(options.asJava) val table = userSpecifiedSchema match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index ea596ba728c1..984199488fa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsStreamingWrite, TableProvider} /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -278,7 +278,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = ForeachWriteSupportProvider[T](foreachWriter, ds.exprEnc) + val sink = ForeachWriterTable[T](foreachWriter, ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), @@ -304,30 +304,29 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") - var options = extraOptions.toMap - val sink = ds.getConstructor().newInstance() match { - case w: StreamingWriteSupportProvider - if !disabledSources.contains(w.getClass.getCanonicalName) => - val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - w, df.sparkSession.sessionState.conf) - options = sessionOptions ++ extraOptions - w - case _ => - val ds = DataSource( - df.sparkSession, - className = source, - options = options, - partitionColumns = normalizedParCols.getOrElse(Nil)) - ds.createSink(outputMode) + val useV1Source = disabledSources.contains(cls.getCanonicalName) + + val sink = if (classOf[TableProvider].isAssignableFrom(cls) && !useV1Source) { + val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + source = provider, conf = df.sparkSession.sessionState.conf) + val options = sessionOptions ++ extraOptions + val dsOptions = new DataSourceOptions(options.asJava) + provider.getTable(dsOptions) match { + case s: SupportsStreamingWrite => s + case _ => createV1Sink() + } + } else { + createV1Sink() } df.sparkSession.sessionState.streamingQueryManager.startQuery( - options.get("queryName"), - options.get("checkpointLocation"), + extraOptions.get("queryName"), + extraOptions.get("checkpointLocation"), df, - options, + extraOptions.toMap, sink, outputMode, useTempCheckpointLocation = source == "console" || source == "noop", @@ -336,6 +335,15 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } } + private def createV1Sink(): BaseStreamingSink = { + val ds = DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + ds.createSink(outputMode) + } + /** * Sets the output of the streaming query to be processed using the provided writer object. * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 0bd8a9299ef4..a7fa800a49d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS -import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider +import org.apache.spark.sql.sources.v2.SupportsStreamingWrite import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -261,7 +261,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } (sink, trigger) match { - case (v2Sink: StreamingWriteSupportProvider, trigger: ContinuousTrigger) => + case (v2Sink: SupportsStreamingWrite, trigger: ContinuousTrigger) => if (operationCheckEnabled) { UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) } diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index a36b0cfa6ff1..914af589384d 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -9,6 +9,6 @@ org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly org.apache.spark.sql.streaming.sources.FakeReadBothModes org.apache.spark.sql.streaming.sources.FakeReadNeitherMode -org.apache.spark.sql.streaming.sources.FakeWriteSupportProvider +org.apache.spark.sql.streaming.sources.FakeWriteOnly org.apache.spark.sql.streaming.sources.FakeNoWrite org.apache.spark.sql.streaming.sources.FakeWriteSupportProviderV1Fallback diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 61857365ac98..e80437754051 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -43,9 +43,9 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("streaming writer") { val sink = new MemorySinkV2 - val writeSupport = new MemoryStreamingWriteSupport( + val write = new MemoryStreamingWrite( sink, OutputMode.Append(), new StructType().add("i", "int")) - writeSupport.commit(0, + write.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), @@ -53,7 +53,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - writeSupport.commit(19, + write.commit(19, Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala index f903c17923d0..0b1e3b5fb076 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala @@ -33,8 +33,8 @@ class DataSourceV2UtilsSuite extends SparkFunSuite { conf.setConfString(s"spark.sql.$keyPrefix.config.name", "false") conf.setConfString("spark.datasource.another.config.name", "123") conf.setConfString(s"spark.datasource.$keyPrefix.", "123") - val cs = classOf[DataSourceV2WithSessionConfig].getConstructor().newInstance() - val confs = DataSourceV2Utils.extractSessionConfigs(cs.asInstanceOf[DataSourceV2], conf) + val source = new DataSourceV2WithSessionConfig + val confs = DataSourceV2Utils.extractSessionConfigs(source, conf) assert(confs.size == 2) assert(confs.keySet.filter(_.startsWith("spark.datasource")).size == 0) assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index daca65fd1ad2..c56a54598cd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -38,8 +38,7 @@ import org.apache.spark.util.SerializableConfiguration * Each task writes data to `target/_temporary/uniqueId/$jobId-$partitionId-$attemptNumber`. * Each job moves files from `target/_temporary/uniqueId/` to `target`. */ -class SimpleWritableDataSource extends DataSourceV2 - with TableProvider with SessionConfigSupport { +class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { private val tableSchema = new StructType().add("i", "long").add("j", "long") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index d3d210c02e90..bad22590807a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, ContinuousStream, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{DataType, IntegerType, StructType} @@ -43,7 +43,7 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { override def beforeEach(): Unit = { super.beforeEach() epochEndpoint = EpochCoordinatorRef.create( - mock[StreamingWriteSupport], + mock[StreamingWrite], mock[ContinuousStream], mock[ContinuousExecution], coordinatorId, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index a0b56ec17f0b..f74285f4b0fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.LocalSparkSession import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.test.TestSparkSession class EpochCoordinatorSuite @@ -40,13 +40,13 @@ class EpochCoordinatorSuite private var epochCoordinator: RpcEndpointRef = _ - private var writeSupport: StreamingWriteSupport = _ + private var writeSupport: StreamingWrite = _ private var query: ContinuousExecution = _ private var orderVerifier: InOrder = _ override def beforeEach(): Unit = { val stream = mock[ContinuousStream] - writeSupport = mock[StreamingWriteSupport] + writeSupport = mock[StreamingWrite] query = mock[ContinuousExecution] orderVerifier = inOrder(writeSupport, query) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 62f166602941..c841793fdd4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.WriteBuilder import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -71,13 +71,10 @@ trait FakeContinuousReadTable extends Table with SupportsContinuousRead { override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new FakeScanBuilder } -trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider { - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - LastWriteOptions.options = options +trait FakeStreamingWriteTable extends Table with SupportsStreamingWrite { + override def name(): String = "fake" + override def schema(): StructType = StructType(Seq()) + override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { throw new IllegalStateException("fake sink - cannot actually write") } } @@ -129,20 +126,33 @@ class FakeReadNeitherMode extends DataSourceRegister with TableProvider { } } -class FakeWriteSupportProvider +class FakeWriteOnly extends DataSourceRegister - with FakeStreamingWriteSupportProvider + with TableProvider with SessionConfigSupport { override def shortName(): String = "fake-write-microbatch-continuous" override def keyPrefix: String = shortName() + + override def getTable(options: DataSourceOptions): Table = { + LastWriteOptions.options = options + new Table with FakeStreamingWriteTable { + override def name(): String = "fake" + override def schema(): StructType = StructType(Nil) + } + } } -class FakeNoWrite extends DataSourceRegister { +class FakeNoWrite extends DataSourceRegister with TableProvider { override def shortName(): String = "fake-write-neither-mode" + override def getTable(options: DataSourceOptions): Table = { + new Table { + override def name(): String = "fake" + override def schema(): StructType = StructType(Nil) + } + } } - case class FakeWriteV1FallbackException() extends Exception class FakeSink extends Sink { @@ -150,17 +160,24 @@ class FakeSink extends Sink { } class FakeWriteSupportProviderV1Fallback extends DataSourceRegister - with FakeStreamingWriteSupportProvider with StreamSinkProvider { + with TableProvider with StreamSinkProvider { override def createSink( - sqlContext: SQLContext, - parameters: Map[String, String], - partitionColumns: Seq[String], - outputMode: OutputMode): Sink = { + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { new FakeSink() } override def shortName(): String = "fake-write-v1-fallback" + + override def getTable(options: DataSourceOptions): Table = { + new Table with FakeStreamingWriteTable { + override def name(): String = "fake" + override def schema(): StructType = StructType(Nil) + } + } } object LastReadOptions { @@ -260,7 +277,7 @@ class StreamingDataSourceV2Suite extends StreamTest { testPositiveCaseWithQuery( "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) { v2Query => assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink - .isInstanceOf[FakeWriteSupportProviderV1Fallback]) + .isInstanceOf[Table]) } // Ensure we create a V1 sink with the config. Note the config is a comma separated @@ -319,19 +336,20 @@ class StreamingDataSourceV2Suite extends StreamTest { for ((read, write, trigger) <- cases) { testQuietly(s"stream with read format $read, write format $write, trigger $trigger") { - val table = DataSource.lookupDataSource(read, spark.sqlContext.conf).getConstructor() + val sourceTable = DataSource.lookupDataSource(read, spark.sqlContext.conf).getConstructor() + .newInstance().asInstanceOf[TableProvider].getTable(DataSourceOptions.empty()) + + val sinkTable = DataSource.lookupDataSource(write, spark.sqlContext.conf).getConstructor() .newInstance().asInstanceOf[TableProvider].getTable(DataSourceOptions.empty()) - val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf). - getConstructor().newInstance() - (table, writeSource, trigger) match { + (sourceTable, sinkTable, trigger) match { // Valid microbatch queries. - case (_: SupportsMicroBatchRead, _: StreamingWriteSupportProvider, t) + case (_: SupportsMicroBatchRead, _: SupportsStreamingWrite, t) if !t.isInstanceOf[ContinuousTrigger] => testPositiveCase(read, write, trigger) // Valid continuous queries. - case (_: SupportsContinuousRead, _: StreamingWriteSupportProvider, + case (_: SupportsContinuousRead, _: SupportsStreamingWrite, _: ContinuousTrigger) => testPositiveCase(read, write, trigger) @@ -342,12 +360,12 @@ class StreamingDataSourceV2Suite extends StreamTest { s"Data source $read does not support streamed reading") // Invalid - can't write - case (_, w, _) if !w.isInstanceOf[StreamingWriteSupportProvider] => + case (_, w, _) if !w.isInstanceOf[SupportsStreamingWrite] => testNegativeCase(read, write, trigger, s"Data source $write does not support streamed writing") // Invalid - trigger is continuous but reader is not - case (r, _: StreamingWriteSupportProvider, _: ContinuousTrigger) + case (r, _: SupportsStreamingWrite, _: ContinuousTrigger) if !r.isInstanceOf[SupportsContinuousRead] => testNegativeCase(read, write, trigger, s"Data source $read does not support continuous processing") From 865c88f9c735b15dd1a0d275533f086665e8abd8 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Tue, 19 Feb 2019 09:42:21 +0800 Subject: [PATCH 18/19] [MINOR][DOC] Add note regarding proper usage of QueryExecution.toRdd ## What changes were proposed in this pull request? This proposes adding a note on `QueryExecution.toRdd` regarding Spark's internal optimization callers would need to indicate. ## How was this patch tested? This patch is a documentation change. Closes #23822 from HeartSaVioR/MINOR-doc-add-note-query-execution-to-rdd. Authored-by: Jungtaek Lim (HeartSaVioR) Signed-off-by: Hyukjin Kwon --- .../apache/spark/sql/execution/QueryExecution.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 72499aa936a5..49d6acf65dd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -85,7 +85,16 @@ class QueryExecution( prepareForExecution(sparkPlan) } - /** Internal version of the RDD. Avoids copies and has no schema */ + /** + * Internal version of the RDD. Avoids copies and has no schema. + * Note for callers: Spark may apply various optimization including reusing object: this means + * the row is valid only for the iteration it is retrieved. You should avoid storing row and + * accessing after iteration. (Calling `collect()` is one of known bad usage.) + * If you want to store these rows into collection, please apply some converter or copy row + * which produces new object per iteration. + * Given QueryExecution is not a public class, end users are discouraged to use this: please + * use `Dataset.rdd` instead where conversion will be applied. + */ lazy val toRdd: RDD[InternalRow] = executedPlan.execute() /** From 743b73daf7fbbb6cd0f763955ed331ac3889ba6f Mon Sep 17 00:00:00 2001 From: yucai Date: Tue, 19 Feb 2019 13:01:10 +0800 Subject: [PATCH 19/19] [SPARK-26909][FOLLOWUP][SQL] use unsafeRow.hashCode() as hash value in HashAggregate ## What changes were proposed in this pull request? This is a followup PR for #21149. New way uses unsafeRow.hashCode() as hash value in HashAggregate. The unsafe row has [null bit set] etc., so the hash should be different from shuffle hash, and then we don't need a special seed. ## How was this patch tested? UTs. Closes #23821 from yucai/unsafe_hash. Authored-by: yucai Signed-off-by: Wenchen Fan --- .../execution/aggregate/HashAggregateExec.scala | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 17cc7fde42bb..23ae1f0e2590 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -742,6 +742,7 @@ case class HashAggregateExec( val fastRowKeys = ctx.generateExpressions( bindReferences[Expression](groupingExpressions, child.output)) val unsafeRowKeys = unsafeRowKeyCode.value + val unsafeRowKeyHash = ctx.freshName("unsafeRowKeyHash") val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer") val fastRowBuffer = ctx.freshName("fastAggBuffer") @@ -755,13 +756,6 @@ case class HashAggregateExec( } } - // generate hash code for key - // SPARK-24076: HashAggregate uses the same hash algorithm on the same expressions - // as ShuffleExchange, it may lead to bad hash conflict when shuffle.partitions=8192*n, - // pick a different seed to avoid this conflict - val hashExpr = Murmur3Hash(groupingExpressions, 48) - val hashEval = BindReferences.bindReference(hashExpr, child.output).genCode(ctx) - val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") @@ -777,11 +771,11 @@ case class HashAggregateExec( s""" |// generate grouping key |${unsafeRowKeyCode.code} - |${hashEval.code} + |int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode(); |if ($checkFallbackForBytesToBytesMap) { | // try to get the buffer from hash map | $unsafeRowBuffer = - | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value}); + | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash); |} |// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based |// aggregation after processing all input rows. @@ -795,7 +789,7 @@ case class HashAggregateExec( | // the hash map had be spilled, it should have enough memory now, | // try to allocate buffer again. | $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( - | $unsafeRowKeys, ${hashEval.value}); + | $unsafeRowKeys, $unsafeRowKeyHash); | if ($unsafeRowBuffer == null) { | // failed to allocate the first page | throw new $oomeClassName("No enough memory for aggregation");
timestamplongtimestamp
timestampType