diff --git a/core/src/main/resources/META-INF/services/org.apache.spark.deploy.security.HadoopDelegationTokenProvider b/core/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider similarity index 100% rename from core/src/main/resources/META-INF/services/org.apache.spark.deploy.security.HadoopDelegationTokenProvider rename to core/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala index e56d03401dca..2e21adac86a1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala @@ -23,12 +23,12 @@ import scala.reflect.runtime.universe import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.token.{Token, TokenIdentifier} import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.security.HadoopDelegationTokenProvider import org.apache.spark.util.Utils private[security] class HBaseDelegationTokenProvider diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index 6a18a8dd33d1..4db86ba8f172 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -26,7 +26,6 @@ import java.util.concurrent.{ScheduledExecutorService, TimeUnit} import scala.collection.mutable import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.spark.SparkConf @@ -35,6 +34,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.UpdateDelegationTokens +import org.apache.spark.security.HadoopDelegationTokenProvider import org.apache.spark.ui.UIUtils import org.apache.spark.util.{ThreadUtils, Utils} diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala index 725eefbda897..ac432e7581e9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala @@ -30,6 +30,7 @@ import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdenti import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.security.HadoopDelegationTokenProvider private[deploy] class HadoopFSDelegationTokenProvider extends HadoopDelegationTokenProvider with Logging { diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 2865c3bc86e4..645f58716de6 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -181,33 +181,47 @@ private[spark] class CoarseGrainedExecutorBackend( private[spark] object CoarseGrainedExecutorBackend extends Logging { - private def run( + case class Arguments( driverUrl: String, executorId: String, hostname: String, cores: Int, appId: String, workerUrl: Option[String], - userClassPath: Seq[URL]) { + userClassPath: mutable.ListBuffer[URL]) + + def main(args: Array[String]): Unit = { + val createFn: (RpcEnv, Arguments, SparkEnv) => + CoarseGrainedExecutorBackend = { case (rpcEnv, arguments, env) => + new CoarseGrainedExecutorBackend(rpcEnv, arguments.driverUrl, arguments.executorId, + arguments.hostname, arguments.cores, arguments.userClassPath, env) + } + run(parseArguments(args, this.getClass.getCanonicalName.stripSuffix("$")), createFn) + System.exit(0) + } + + def run( + arguments: Arguments, + backendCreateFn: (RpcEnv, Arguments, SparkEnv) => CoarseGrainedExecutorBackend): Unit = { Utils.initDaemon(log) SparkHadoopUtil.get.runAsSparkUser { () => // Debug code - Utils.checkHost(hostname) + Utils.checkHost(arguments.hostname) // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf val fetcher = RpcEnv.create( "driverPropsFetcher", - hostname, + arguments.hostname, -1, executorConf, new SecurityManager(executorConf), clientMode = true) - val driver = fetcher.setupEndpointRefByURI(driverUrl) + val driver = fetcher.setupEndpointRefByURI(arguments.driverUrl) val cfg = driver.askSync[SparkAppConfig](RetrieveSparkAppConfig) - val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", appId)) + val props = cfg.sparkProperties ++ Seq[(String, String)](("spark.app.id", arguments.appId)) fetcher.shutdown() // Create SparkEnv using properties we fetched from the driver. @@ -225,19 +239,18 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf) } - val env = SparkEnv.createExecutorEnv( - driverConf, executorId, hostname, cores, cfg.ioEncryptionKey, isLocal = false) + val env = SparkEnv.createExecutorEnv(driverConf, arguments.executorId, arguments.hostname, + arguments.cores, cfg.ioEncryptionKey, isLocal = false) - env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( - env.rpcEnv, driverUrl, executorId, hostname, cores, userClassPath, env)) - workerUrl.foreach { url => + env.rpcEnv.setupEndpoint("Executor", backendCreateFn(env.rpcEnv, arguments, env)) + arguments.workerUrl.foreach { url => env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } env.rpcEnv.awaitTermination() } } - def main(args: Array[String]) { + def parseArguments(args: Array[String], classNameForEntry: String): Arguments = { var driverUrl: String = null var executorId: String = null var hostname: String = null @@ -276,24 +289,24 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // scalastyle:off println System.err.println(s"Unrecognized options: ${tail.mkString(" ")}") // scalastyle:on println - printUsageAndExit() + printUsageAndExit(classNameForEntry) } } if (driverUrl == null || executorId == null || hostname == null || cores <= 0 || appId == null) { - printUsageAndExit() + printUsageAndExit(classNameForEntry) } - run(driverUrl, executorId, hostname, cores, appId, workerUrl, userClassPath) - System.exit(0) + Arguments(driverUrl, executorId, hostname, cores, appId, workerUrl, + userClassPath) } - private def printUsageAndExit() = { + private def printUsageAndExit(classNameForEntry: String): Unit = { // scalastyle:off println System.err.println( - """ - |Usage: CoarseGrainedExecutorBackend [options] + s""" + |Usage: $classNameForEntry [options] | | Options are: | --driver-url @@ -307,5 +320,4 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // scalastyle:on println System.exit(1) } - } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/security/HadoopDelegationTokenProvider.scala similarity index 92% rename from core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala rename to core/src/main/scala/org/apache/spark/security/HadoopDelegationTokenProvider.scala index 3dc952d54e73..cff8d81443ef 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/security/HadoopDelegationTokenProvider.scala @@ -15,18 +15,20 @@ * limitations under the License. */ -package org.apache.spark.deploy.security +package org.apache.spark.security import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.security.Credentials import org.apache.spark.SparkConf +import org.apache.spark.annotation.DeveloperApi /** + * ::DeveloperApi:: * Hadoop delegation token provider. */ -private[spark] trait HadoopDelegationTokenProvider { +@DeveloperApi +trait HadoopDelegationTokenProvider { /** * Name of the service to provide delegation tokens. This name should be unique. Spark will diff --git a/core/src/test/resources/META-INF/services/org.apache.spark.deploy.security.HadoopDelegationTokenProvider b/core/src/test/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider similarity index 100% rename from core/src/test/resources/META-INF/services/org.apache.spark.deploy.security.HadoopDelegationTokenProvider rename to core/src/test/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider diff --git a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala index 2f36dba05c64..70174f7ff939 100644 --- a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala @@ -21,6 +21,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.security.Credentials import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.security.HadoopDelegationTokenProvider private class ExceptionThrowingDelegationTokenProvider extends HadoopDelegationTokenProvider { ExceptionThrowingDelegationTokenProvider.constructed = true diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 51a13d342a67..6ee4b3d4103b 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -480,6 +480,18 @@ To use a custom metrics.properties for the application master and executors, upd {{HTTP_SCHEME}} `http://` or `https://` according to YARN HTTP policy. (Configured via `yarn.http.policy`) + + {{NM_HOST}} + The "host" of node where container was run. + + + {{NM_PORT}} + The "port" of node manager where container was run. + + + {{NM_HTTP_PORT}} + The "port" of node manager's http server where container was run. + {{NM_HTTP_ADDRESS}} Http URI of the node on which the container is allocated. @@ -502,6 +514,12 @@ To use a custom metrics.properties for the application master and executors, upd +For example, suppose you would like to point log url link to Job History Server directly instead of let NodeManager http server redirects it, you can configure `spark.history.custom.executor.log.url` as below: + + `{{HTTP_SCHEME}}:/jobhistory/logs/{{NM_HOST}}:{{NM_PORT}}/{{CONTAINER_ID}}/{{CONTAINER_ID}}/{{USER}}/{{FILE_NAME}}?start=-4096` + + NOTE: you need to replace `` and `` with actual value. + # Important notes - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. @@ -520,13 +538,6 @@ for: filesystem if `spark.yarn.stagingDir` is not set); - if Hadoop federation is enabled, all the federated filesystems in the configuration. -The YARN integration also supports custom delegation token providers using the Java Services -mechanism (see `java.util.ServiceLoader`). Implementations of -`org.apache.spark.deploy.yarn.security.ServiceCredentialProvider` can be made available to Spark -by listing their names in the corresponding file in the jar's `META-INF/services` directory. These -providers can be disabled individually by setting `spark.security.credentials.{service}.enabled` to -`false`, where `{service}` is the name of the credential provider. - ## YARN-specific Kerberos Configuration diff --git a/docs/security.md b/docs/security.md index d2cff41eb0f7..20492d871b1a 100644 --- a/docs/security.md +++ b/docs/security.md @@ -756,6 +756,11 @@ If an application needs to interact with other secure Hadoop filesystems, their explicitly provided to Spark at launch time. This is done by listing them in the `spark.kerberos.access.hadoopFileSystems` property, described in the configuration section below. +Spark also supports custom delegation token providers using the Java Services +mechanism (see `java.util.ServiceLoader`). Implementations of +`org.apache.spark.security.HadoopDelegationTokenProvider` can be made available to Spark +by listing their names in the corresponding file in the jar's `META-INF/services` directory. + Delegation token support is currently only supported in YARN and Mesos modes. Consult the deployment-specific page for more information. diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index c19aa5c504b0..425110a8bfa5 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -265,7 +265,7 @@ Each row in the source has the following schema: - + diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala index f8af91980bde..5510f0019353 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala @@ -44,6 +44,12 @@ object SimpleTypedAggregator { println("running typed average:") ds.groupByKey(_._1).agg(new TypedAverage[(Long, Long)](_._2.toDouble).toColumn).show() + println("running typed minimum:") + ds.groupByKey(_._1).agg(new TypedMin[(Long, Long)](_._2.toDouble).toColumn).show() + + println("running typed maximum:") + ds.groupByKey(_._1).agg(new TypedMax[(Long, Long)](_._2).toColumn).show() + spark.stop() } } @@ -84,3 +90,71 @@ class TypedAverage[IN](val f: IN => Double) extends Aggregator[IN, (Double, Long } override def outputEncoder: Encoder[Double] = Encoders.scalaDouble } + +class TypedMin[IN](val f: IN => Double) extends Aggregator[IN, MutableDouble, Option[Double]] { + override def zero: MutableDouble = null + override def reduce(b: MutableDouble, a: IN): MutableDouble = { + if (b == null) { + new MutableDouble(f(a)) + } else { + b.value = math.min(b.value, f(a)) + b + } + } + override def merge(b1: MutableDouble, b2: MutableDouble): MutableDouble = { + if (b1 == null) { + b2 + } else if (b2 == null) { + b1 + } else { + b1.value = math.min(b1.value, b2.value) + b1 + } + } + override def finish(reduction: MutableDouble): Option[Double] = { + if (reduction != null) { + Some(reduction.value) + } else { + None + } + } + + override def bufferEncoder: Encoder[MutableDouble] = Encoders.kryo[MutableDouble] + override def outputEncoder: Encoder[Option[Double]] = Encoders.product[Option[Double]] +} + +class TypedMax[IN](val f: IN => Long) extends Aggregator[IN, MutableLong, Option[Long]] { + override def zero: MutableLong = null + override def reduce(b: MutableLong, a: IN): MutableLong = { + if (b == null) { + new MutableLong(f(a)) + } else { + b.value = math.max(b.value, f(a)) + b + } + } + override def merge(b1: MutableLong, b2: MutableLong): MutableLong = { + if (b1 == null) { + b2 + } else if (b2 == null) { + b1 + } else { + b1.value = math.max(b1.value, b2.value) + b1 + } + } + override def finish(reduction: MutableLong): Option[Long] = { + if (reduction != null) { + Some(reduction.value) + } else { + None + } + } + + override def bufferEncoder: Encoder[MutableLong] = Encoders.kryo[MutableLong] + override def outputEncoder: Encoder[Option[Long]] = Encoders.product[Option[Long]] +} + +class MutableLong(var value: Long) extends Serializable + +class MutableDouble(var value: Double) extends Serializable diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 9238899b0c00..6994517b27d6 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -33,7 +33,8 @@ import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.WriteBuilder +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingWrite, SupportsOutputMode} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -47,7 +48,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider - with StreamingWriteSupportProvider with TableProvider with Logging { import KafkaSourceProvider._ @@ -180,20 +180,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - import scala.collection.JavaConverters._ - - val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) - // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. - val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) - - new KafkaStreamingWriteSupport(topic, producerParams, schema) - } - private def strategy(caseInsensitiveParams: Map[String, String]) = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { case ("assign", value) => @@ -365,7 +351,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } class KafkaTable(strategy: => ConsumerStrategy) extends Table - with SupportsMicroBatchRead with SupportsContinuousRead { + with SupportsMicroBatchRead with SupportsContinuousRead with SupportsStreamingWrite { override def name(): String = s"Kafka $strategy" @@ -374,6 +360,28 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new ScanBuilder { override def build(): Scan = new KafkaScan(options) } + + override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + new WriteBuilder with SupportsOutputMode { + private var inputSchema: StructType = _ + + override def withInputDataSchema(schema: StructType): WriteBuilder = { + this.inputSchema = schema + this + } + + override def outputMode(mode: OutputMode): WriteBuilder = this + + override def buildForStreaming(): StreamingWrite = { + import scala.collection.JavaConverters._ + + assert(inputSchema != null) + val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) + val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) + new KafkaStreamingWrite(topic, producerParams, inputSchema) + } + } + } } class KafkaScan(options: DataSourceOptions) extends Scan { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala similarity index 95% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala index 0d831c388460..e3101e157208 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType /** @@ -33,18 +33,18 @@ import org.apache.spark.sql.types.StructType case object KafkaWriterCommitMessage extends WriterCommitMessage /** - * A [[StreamingWriteSupport]] for Kafka writing. Responsible for generating the writer factory. + * A [[StreamingWrite]] for Kafka writing. Responsible for generating the writer factory. * * @param topic The topic this writer is responsible for. If None, topic will be inferred from * a `topic` field in the incoming data. * @param producerParams Parameters for Kafka producers in each task. * @param schema The schema of the input data. */ -class KafkaStreamingWriteSupport( +class KafkaStreamingWrite( topic: Option[String], producerParams: ju.Map[String, Object], schema: StructType) - extends StreamingWriteSupport { + extends StreamingWrite { validateQuery(schema.toAttributes, producerParams, topic) diff --git a/external/kafka-0-10-token-provider/src/main/resources/META-INF/services/org.apache.spark.deploy.security.HadoopDelegationTokenProvider b/external/kafka-0-10-token-provider/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider similarity index 100% rename from external/kafka-0-10-token-provider/src/main/resources/META-INF/services/org.apache.spark.deploy.security.HadoopDelegationTokenProvider rename to external/kafka-0-10-token-provider/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider diff --git a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaDelegationTokenProvider.scala b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaDelegationTokenProvider.scala index c69e8a320059..cba4b40ca7f4 100644 --- a/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaDelegationTokenProvider.scala +++ b/external/kafka-0-10-token-provider/src/main/scala/org/apache/spark/kafka010/KafkaDelegationTokenProvider.scala @@ -25,9 +25,9 @@ import org.apache.hadoop.security.Credentials import org.apache.kafka.common.security.auth.SecurityProtocol.{SASL_PLAINTEXT, SASL_SSL, SSL} import org.apache.spark.SparkConf -import org.apache.spark.deploy.security.HadoopDelegationTokenProvider import org.apache.spark.internal.Logging import org.apache.spark.internal.config.Kafka +import org.apache.spark.security.HadoopDelegationTokenProvider private[spark] class KafkaDelegationTokenProvider extends HadoopDelegationTokenProvider with Logging { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index abe2d1febfdf..a5ed4a38a886 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -341,11 +341,12 @@ class GBTClassificationModel private[ml]( * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) * and follows the implementation from scikit-learn. - + * * See `DecisionTreeClassificationModel.featureImportances` */ @Since("2.0.0") - lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) + lazy val featureImportances: Vector = + TreeEnsembleModel.featureImportances(trees, numFeatures, perTreeNormalization = false) /** Raw prediction for the positive class. */ private def margin(features: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 9a5b7d59e9ae..9f0f567a5b53 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -285,7 +285,8 @@ class GBTRegressionModel private[ml]( * @see `DecisionTreeRegressionModel.featureImportances` */ @Since("2.0.0") - lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures) + lazy val featureImportances: Vector = + TreeEnsembleModel.featureImportances(trees, numFeatures, perTreeNormalization = false) /** (private[ml]) Convert to a model in the old API */ private[ml] def toOld: OldGBTModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 51d5d5c58c57..e95c55f6048f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -135,7 +135,7 @@ private[ml] object TreeEnsembleModel { * - Average over trees: * - importance(feature j) = sum (over nodes which split on feature j) of the gain, * where gain is scaled by the number of instances passing through node - * - Normalize importances for tree to sum to 1. + * - Normalize importances for tree to sum to 1 (only if `perTreeNormalization` is `true`). * - Normalize feature importance vector to sum to 1. * * References: @@ -145,9 +145,15 @@ private[ml] object TreeEnsembleModel { * @param numFeatures Number of features in model (even if not all are explicitly used by * the model). * If -1, then numFeatures is set based on the max feature index in all trees. + * @param perTreeNormalization By default this is set to `true` and it means that the importances + * of each tree are normalized before being summed. If set to `false`, + * the normalization is skipped. * @return Feature importance values, of length numFeatures. */ - def featureImportances[M <: DecisionTreeModel](trees: Array[M], numFeatures: Int): Vector = { + def featureImportances[M <: DecisionTreeModel]( + trees: Array[M], + numFeatures: Int, + perTreeNormalization: Boolean = true): Vector = { val totalImportances = new OpenHashMap[Int, Double]() trees.foreach { tree => // Aggregate feature importance vector for this tree @@ -155,10 +161,19 @@ private[ml] object TreeEnsembleModel { computeFeatureImportance(tree.rootNode, importances) // Normalize importance vector for this tree, and add it to total. // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count? - val treeNorm = importances.map(_._2).sum + val treeNorm = if (perTreeNormalization) { + importances.map(_._2).sum + } else { + // We won't use it + Double.NaN + } if (treeNorm != 0) { importances.foreach { case (idx, impt) => - val normImpt = impt / treeNorm + val normImpt = if (perTreeNormalization) { + impt / treeNorm + } else { + impt + } totalImportances.changeValue(idx, normImpt, _ + normImpt) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index cedbaf1858ef..cd59900c521c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -363,7 +363,8 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest { val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1") val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances val mostIF = importanceFeatures.argmax - assert(mostImportantFeature !== mostIF) + assert(mostIF === 1) + assert(importances(mostImportantFeature) !== importanceFeatures(mostIF)) } test("model evaluateEachIteration") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index b145c7a3dc95..46fa3767efdc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -200,7 +200,8 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest { val gbtWithFeatureSubset = gbt.setFeatureSubsetStrategy("1") val importanceFeatures = gbtWithFeatureSubset.fit(df).featureImportances val mostIF = importanceFeatures.argmax - assert(mostImportantFeature !== mostIF) + assert(mostIF === 1) + assert(importances(mostImportantFeature) !== importanceFeatures(mostIF)) } test("model evaluateEachIteration") { diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index d8315c63a8fc..5a55401db53d 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -36,15 +36,21 @@ from pyspark.util import _exception_message -def launch_gateway(conf=None): +def launch_gateway(conf=None, popen_kwargs=None): """ launch jvm gateway :param conf: spark configuration passed to spark-submit + :param popen_kwargs: Dictionary of kwargs to pass to Popen when spawning + the py4j JVM. This is a developer feature intended for use in + customizing how pyspark interacts with the py4j JVM (e.g., capturing + stdout/stderr). :return: """ if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"] + # Process already exists + proc = None else: SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the @@ -75,15 +81,20 @@ def launch_gateway(conf=None): env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file # Launch the Java gateway. + popen_kwargs = {} if popen_kwargs is None else popen_kwargs # We open a pipe to stdin so that the Java gateway can die when the pipe is broken + popen_kwargs['stdin'] = PIPE + # We always set the necessary environment variables. + popen_kwargs['env'] = env if not on_windows: # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) + popen_kwargs['preexec_fn'] = preexec_func + proc = Popen(command, **popen_kwargs) else: # preexec_fn not supported on Windows - proc = Popen(command, stdin=PIPE, env=env) + proc = Popen(command, **popen_kwargs) # Wait for the file to appear, or for the process to exit, whichever happens first. while not proc.poll() and not os.path.isfile(conn_info_file): @@ -118,6 +129,8 @@ def killChild(): gateway = JavaGateway( gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret, auto_convert=True)) + # Store a reference to the Popen object for use by the caller (e.g., in reading stdout/stderr) + gateway.proc = proc # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 3db259551fa8..a2c59fedfc8c 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -311,10 +311,9 @@ def __init__(self, timezone, safecheck): def arrow_to_pandas(self, arrow_column): from pyspark.sql.types import from_arrow_type, \ - _check_series_convert_date, _check_series_localize_timestamps + _arrow_column_to_pandas, _check_series_localize_timestamps - s = arrow_column.to_pandas() - s = _check_series_convert_date(s, from_arrow_type(arrow_column.type)) + s = _arrow_column_to_pandas(arrow_column, from_arrow_type(arrow_column.type)) s = _check_series_localize_timestamps(s, self._timezone) return s diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a1056d0b787e..472d2969b3e1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2107,14 +2107,13 @@ def toPandas(self): # of PyArrow is found, if 'spark.sql.execution.arrow.enabled' is enabled. if use_arrow: try: - from pyspark.sql.types import _check_dataframe_convert_date, \ + from pyspark.sql.types import _arrow_table_to_pandas, \ _check_dataframe_localize_timestamps import pyarrow batches = self._collectAsArrow() if len(batches) > 0: table = pyarrow.Table.from_batches(batches) - pdf = table.to_pandas() - pdf = _check_dataframe_convert_date(pdf, self.schema) + pdf = _arrow_table_to_pandas(table, self.schema) return _check_dataframe_localize_timestamps(pdf, timezone) else: return pd.DataFrame.from_records([], columns=self.columns) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 8a62500b17f2..38a6402c0132 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -68,7 +68,9 @@ def setUpClass(cls): (u"b", 2, 20, 0.4, 4.0, Decimal("4.0"), date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), (u"c", 3, 30, 0.8, 6.0, Decimal("6.0"), - date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3)), + (u"d", 4, 40, 1.0, 8.0, Decimal("8.0"), + date(2262, 4, 12), datetime(2262, 3, 3, 3, 3, 3))] # TODO: remove version check once minimum pyarrow version is 0.10.0 if LooseVersion("0.10.0") <= LooseVersion(pa.__version__): @@ -76,6 +78,7 @@ def setUpClass(cls): cls.data[0] = cls.data[0] + (bytearray(b"a"),) cls.data[1] = cls.data[1] + (bytearray(b"bb"),) cls.data[2] = cls.data[2] + (bytearray(b"ccc"),) + cls.data[3] = cls.data[3] + (bytearray(b"dddd"),) @classmethod def tearDownClass(cls): diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 6a6865a9fb16..28ef98d7b3f1 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -349,7 +349,8 @@ def test_vectorized_udf_dates(self): data = [(0, date(1969, 1, 1),), (1, date(2012, 2, 2),), (2, None,), - (3, date(2100, 4, 4),)] + (3, date(2100, 4, 4),), + (4, date(2262, 4, 12),)] df = self.spark.createDataFrame(data, schema=schema) date_copy = pandas_udf(lambda t: t, returnType=DateType()) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 4b8f2efff4ac..348cb5b11859 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1681,38 +1681,52 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) -def _check_series_convert_date(series, data_type): - """ - Cast the series to datetime.date if it's a date type, otherwise returns the original series. +def _arrow_column_to_pandas(column, data_type): + """ Convert Arrow Column to pandas Series. - :param series: pandas.Series - :param data_type: a Spark data type for the series + :param series: pyarrow.lib.Column + :param data_type: a Spark data type for the column """ - import pyarrow + import pandas as pd + import pyarrow as pa from distutils.version import LooseVersion - # As of Arrow 0.12.0, date_as_objects is True by default, see ARROW-3910 - if LooseVersion(pyarrow.__version__) < LooseVersion("0.12.0") and type(data_type) == DateType: - return series.dt.date + # If the given column is a date type column, creates a series of datetime.date directly instead + # of creating datetime64[ns] as intermediate data to avoid overflow caused by datetime64[ns] + # type handling. + if LooseVersion(pa.__version__) < LooseVersion("0.11.0"): + if type(data_type) == DateType: + return pd.Series(column.to_pylist(), name=column.name) + else: + return column.to_pandas() else: - return series + # Since Arrow 0.11.0, support date_as_object to return datetime.date instead of + # np.datetime64. + return column.to_pandas(date_as_object=True) -def _check_dataframe_convert_date(pdf, schema): - """ Correct date type value to use datetime.date. +def _arrow_table_to_pandas(table, schema): + """ Convert Arrow Table to pandas DataFrame. Pandas DataFrame created from PyArrow uses datetime64[ns] for date type values, but we should use datetime.date to match the behavior with when Arrow optimization is disabled. - :param pdf: pandas.DataFrame - :param schema: a Spark schema of the pandas.DataFrame + :param table: pyarrow.lib.Table + :param schema: a Spark schema of the pyarrow.lib.Table """ - import pyarrow + import pandas as pd + import pyarrow as pa from distutils.version import LooseVersion - # As of Arrow 0.12.0, date_as_objects is True by default, see ARROW-3910 - if LooseVersion(pyarrow.__version__) < LooseVersion("0.12.0"): - for field in schema: - pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType) - return pdf + # If the given table contains a date type column, use `_arrow_column_to_pandas` for pyarrow<0.11 + # or use `date_as_object` option for pyarrow>=0.11 to avoid creating datetime64[ns] as + # intermediate data. + if LooseVersion(pa.__version__) < LooseVersion("0.11.0"): + if any(type(field.dataType) == DateType for field in schema): + return pd.concat([_arrow_column_to_pandas(column, field.dataType) + for column, field in zip(table.itercolumns(), schema)], axis=1) + else: + return table.to_pandas() + else: + return table.to_pandas(date_as_object=True) def _get_local_timezone(): diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 7523e3c42c53..6ca81fb97c75 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -21,7 +21,6 @@ import java.io.{FileSystem => _, _} import java.net.{InetAddress, UnknownHostException, URI} import java.nio.ByteBuffer import java.nio.charset.StandardCharsets -import java.security.PrivilegedExceptionAction import java.util.{Locale, Properties, UUID} import java.util.zip.{ZipEntry, ZipOutputStream} @@ -34,9 +33,9 @@ import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission -import org.apache.hadoop.io.{DataOutputBuffer, Text} +import org.apache.hadoop.io.Text import org.apache.hadoop.mapreduce.MRJobConfig -import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -50,8 +49,8 @@ import org.apache.hadoop.yarn.util.Records import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{SparkApplication, SparkHadoopUtil} +import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Python._ @@ -315,7 +314,7 @@ private[spark] class Client( val credentials = currentUser.getCredentials() if (isClusterMode) { - val credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf, null) + val credentialManager = new HadoopDelegationTokenManager(sparkConf, hadoopConf, null) credentialManager.obtainDelegationTokens(credentials) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 0b909d15c2fa..2f8f2a0a119c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -202,7 +202,7 @@ private[yarn] class ExecutorRunnable( val commands = prefixEnv ++ Seq(Environment.JAVA_HOME.$$() + "/bin/java", "-server") ++ javaOpts ++ - Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend", + Seq("org.apache.spark.executor.YarnCoarseGrainedExecutorBackend", "--driver-url", masterAddress, "--executor-id", executorId, "--hostname", hostname, @@ -235,21 +235,6 @@ private[yarn] class ExecutorRunnable( } } - // Add log urls, as well as executor attributes - container.foreach { c => - YarnContainerInfoHelper.getLogUrls(conf, Some(c)).foreach { m => - m.foreach { case (fileName, url) => - env("SPARK_LOG_URL_" + fileName.toUpperCase(Locale.ROOT)) = url - } - } - - YarnContainerInfoHelper.getAttributes(conf, Some(c)).foreach { m => - m.foreach { case (attr, value) => - env("SPARK_EXECUTOR_ATTRIBUTE_" + attr.toUpperCase(Locale.ROOT)) = value - } - } - } - env } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala deleted file mode 100644 index cc24ac4d9bcf..000000000000 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/ServiceCredentialProvider.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.security.{Credentials, UserGroupInformation} - -import org.apache.spark.SparkConf - -/** - * A credential provider for a service. User must implement this if they need to access a - * secure service from Spark. - */ -trait ServiceCredentialProvider { - - /** - * Name of the service to provide credentials. This name should unique, Spark internally will - * use this name to differentiate credential provider. - */ - def serviceName: String - - /** - * Returns true if credentials are required by this service. By default, it is based on whether - * Hadoop security is enabled. - */ - def credentialsRequired(hadoopConf: Configuration): Boolean = { - UserGroupInformation.isSecurityEnabled - } - - /** - * Obtain credentials for this service and get the time of the next renewal. - * - * @param hadoopConf Configuration of current Hadoop Compatible system. - * @param sparkConf Spark configuration. - * @param creds Credentials to add tokens and security keys to. - * @return If this Credential is renewable and can be renewed, return the time of the next - * renewal, otherwise None should be returned. - */ - def obtainCredentials( - hadoopConf: Configuration, - sparkConf: SparkConf, - creds: Credentials): Option[Long] -} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala deleted file mode 100644 index fc1f75254c57..000000000000 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManager.scala +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import java.util.ServiceLoader - -import scala.collection.JavaConverters._ - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.security.Credentials - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.util.Utils - -/** - * This class loads delegation token providers registered under the YARN-specific - * [[ServiceCredentialProvider]] interface, as well as the builtin providers defined - * in [[HadoopDelegationTokenManager]]. - */ -private[spark] class YARNHadoopDelegationTokenManager( - _sparkConf: SparkConf, - _hadoopConf: Configuration, - _schedulerRef: RpcEndpointRef) - extends HadoopDelegationTokenManager(_sparkConf, _hadoopConf, _schedulerRef) { - - private val credentialProviders = { - ServiceLoader.load(classOf[ServiceCredentialProvider], Utils.getContextOrSparkClassLoader) - .asScala - .toList - .filter { p => isServiceEnabled(p.serviceName) } - .map { p => (p.serviceName, p) } - .toMap - } - if (credentialProviders.nonEmpty) { - logDebug("Using the following YARN-specific credential providers: " + - s"${credentialProviders.keys.mkString(", ")}.") - } - - override def obtainDelegationTokens(creds: Credentials): Long = { - val superInterval = super.obtainDelegationTokens(creds) - - credentialProviders.values.flatMap { provider => - if (provider.credentialsRequired(hadoopConf)) { - provider.obtainCredentials(hadoopConf, sparkConf, creds) - } else { - logDebug(s"Service ${provider.serviceName} does not require a token." + - s" Check your configuration to see if security is disabled or not.") - None - } - }.foldLeft(superInterval)(math.min) - } - - // For testing. - override def isProviderLoaded(serviceName: String): Boolean = { - credentialProviders.contains(serviceName) || super.isProviderLoaded(serviceName) - } - -} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/executor/YarnCoarseGrainedExecutorBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/executor/YarnCoarseGrainedExecutorBackend.scala new file mode 100644 index 000000000000..53e99d992db8 --- /dev/null +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/executor/YarnCoarseGrainedExecutorBackend.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import java.net.URL + +import org.apache.spark.SparkEnv +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.util.YarnContainerInfoHelper + +/** + * Custom implementation of CoarseGrainedExecutorBackend for YARN resource manager. + * This class extracts executor log URLs and executor attributes from system environment which + * properties are available for container being set via YARN. + */ +private[spark] class YarnCoarseGrainedExecutorBackend( + rpcEnv: RpcEnv, + driverUrl: String, + executorId: String, + hostname: String, + cores: Int, + userClassPath: Seq[URL], + env: SparkEnv) + extends CoarseGrainedExecutorBackend( + rpcEnv, + driverUrl, + executorId, + hostname, + cores, + userClassPath, + env) with Logging { + + private lazy val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(env.conf) + + override def extractLogUrls: Map[String, String] = { + YarnContainerInfoHelper.getLogUrls(hadoopConfiguration, container = None) + .getOrElse(Map()) + } + + override def extractAttributes: Map[String, String] = { + YarnContainerInfoHelper.getAttributes(hadoopConfiguration, container = None) + .getOrElse(Map()) + } +} + +private[spark] object YarnCoarseGrainedExecutorBackend extends Logging { + + def main(args: Array[String]): Unit = { + val createFn: (RpcEnv, CoarseGrainedExecutorBackend.Arguments, SparkEnv) => + CoarseGrainedExecutorBackend = { case (rpcEnv, arguments, env) => + new YarnCoarseGrainedExecutorBackend(rpcEnv, arguments.driverUrl, arguments.executorId, + arguments.hostname, arguments.cores, arguments.userClassPath, env) + } + val backendArgs = CoarseGrainedExecutorBackend.parseArguments(args, + this.getClass.getCanonicalName.stripSuffix("$")) + CoarseGrainedExecutorBackend.run(backendArgs, createFn) + System.exit(0) + } + +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 821fbcd956d5..78cd6a200ac2 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -31,14 +31,12 @@ import org.eclipse.jetty.servlet.{FilterHolder, FilterMapping} import org.apache.spark.SparkContext import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.internal.config.UI._ import org.apache.spark.rpc._ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.ui.JettyUtils import org.apache.spark.util.{RpcUtils, ThreadUtils} /** @@ -223,7 +221,7 @@ private[spark] abstract class YarnSchedulerBackend( } override protected def createTokenManager(): Option[HadoopDelegationTokenManager] = { - Some(new YARNHadoopDelegationTokenManager(sc.conf, sc.hadoopConfiguration, driverEndpoint)) + Some(new HadoopDelegationTokenManager(sc.conf, sc.hadoopConfiguration, driverEndpoint)) } /** diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/util/YarnContainerInfoHelper.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/util/YarnContainerInfoHelper.scala index 96350cdece55..5e39422e868b 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/util/YarnContainerInfoHelper.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/util/YarnContainerInfoHelper.scala @@ -59,6 +59,9 @@ private[spark] object YarnContainerInfoHelper extends Logging { val yarnConf = new YarnConfiguration(conf) Some(Map( "HTTP_SCHEME" -> getYarnHttpScheme(yarnConf), + "NM_HOST" -> getNodeManagerHost(container), + "NM_PORT" -> getNodeManagerPort(container), + "NM_HTTP_PORT" -> getNodeManagerHttpPort(container), "NM_HTTP_ADDRESS" -> getNodeManagerHttpAddress(container), "CLUSTER_ID" -> getClusterId(yarnConf).getOrElse(""), "CONTAINER_ID" -> ConverterUtils.toString(getContainerId(container)), @@ -97,7 +100,22 @@ private[spark] object YarnContainerInfoHelper extends Logging { def getNodeManagerHttpAddress(container: Option[Container]): String = container match { case Some(c) => c.getNodeHttpAddress - case None => System.getenv(Environment.NM_HOST.name()) + ":" + - System.getenv(Environment.NM_HTTP_PORT.name()) + case None => getNodeManagerHost(None) + ":" + getNodeManagerHttpPort(None) } + + def getNodeManagerHost(container: Option[Container]): String = container match { + case Some(c) => c.getNodeHttpAddress.split(":")(0) + case None => System.getenv(Environment.NM_HOST.name()) + } + + def getNodeManagerHttpPort(container: Option[Container]): String = container match { + case Some(c) => c.getNodeHttpAddress.split(":")(1) + case None => System.getenv(Environment.NM_HTTP_PORT.name()) + } + + def getNodeManagerPort(container: Option[Container]): String = container match { + case Some(_) => "-1" // Just return invalid port given we cannot retrieve the value + case None => System.getenv(Environment.NM_PORT.name()) + } + } diff --git a/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider b/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider deleted file mode 100644 index f31c23269313..000000000000 --- a/resource-managers/yarn/src/test/resources/META-INF/services/org.apache.spark.deploy.yarn.security.ServiceCredentialProvider +++ /dev/null @@ -1 +0,0 @@ -org.apache.spark.deploy.yarn.security.YARNTestCredentialProvider diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index b3c5bbd263ed..56b7dfc13699 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -477,6 +477,9 @@ private object YarnClusterDriver extends Logging with Matchers { val driverAttributes = listener.driverAttributes.get val expectationAttributes = Map( "HTTP_SCHEME" -> YarnContainerInfoHelper.getYarnHttpScheme(yarnConf), + "NM_HOST" -> YarnContainerInfoHelper.getNodeManagerHost(container = None), + "NM_PORT" -> YarnContainerInfoHelper.getNodeManagerPort(container = None), + "NM_HTTP_PORT" -> YarnContainerInfoHelper.getNodeManagerHttpPort(container = None), "NM_HTTP_ADDRESS" -> YarnContainerInfoHelper.getNodeManagerHttpAddress(container = None), "CLUSTER_ID" -> YarnContainerInfoHelper.getClusterId(yarnConf).getOrElse(""), "CONTAINER_ID" -> ConverterUtils.toString(containerId), diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala deleted file mode 100644 index f00453cb9c59..000000000000 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/security/YARNHadoopDelegationTokenManagerSuite.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.yarn.security - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.security.Credentials - -import org.apache.spark.{SparkConf, SparkFunSuite} - -class YARNHadoopDelegationTokenManagerSuite extends SparkFunSuite { - private var credentialManager: YARNHadoopDelegationTokenManager = null - private var sparkConf: SparkConf = null - private var hadoopConf: Configuration = null - - override def beforeAll(): Unit = { - super.beforeAll() - sparkConf = new SparkConf() - hadoopConf = new Configuration() - } - - test("Correctly loads credential providers") { - credentialManager = new YARNHadoopDelegationTokenManager(sparkConf, hadoopConf, null) - assert(credentialManager.isProviderLoaded("yarn-test")) - } -} - -class YARNTestCredentialProvider extends ServiceCredentialProvider { - override def serviceName: String = "yarn-test" - - override def credentialsRequired(conf: Configuration): Boolean = true - - override def obtainCredentials( - hadoopConf: Configuration, - sparkConf: SparkConf, - creds: Credentials): Option[Long] = None -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 793c337ffcb1..42904c5c04c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -978,6 +978,11 @@ class Analyzer( case a @ Aggregate(groupingExprs, aggExprs, appendColumns: AppendColumns) => a.mapExpressions(resolveExpressionTopDown(_, appendColumns)) + case o: OverwriteByExpression if !o.outputResolved => + // do not resolve expression attributes until the query attributes are resolved against the + // table by ResolveOutputRelation. that rule will alias the attributes to the table's names. + o + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}") q.mapExpressions(resolveExpressionTopDown(_, q)) @@ -2246,7 +2251,7 @@ class Analyzer( object ResolveOutputRelation extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case append @ AppendData(table, query, isByName) - if table.resolved && query.resolved && !append.resolved => + if table.resolved && query.resolved && !append.outputResolved => val projection = resolveOutputColumns(table.name, table.output, query, isByName) if (projection != query) { @@ -2254,6 +2259,26 @@ class Analyzer( } else { append } + + case overwrite @ OverwriteByExpression(table, _, query, isByName) + if table.resolved && query.resolved && !overwrite.outputResolved => + val projection = resolveOutputColumns(table.name, table.output, query, isByName) + + if (projection != query) { + overwrite.copy(query = projection) + } else { + overwrite + } + + case overwrite @ OverwritePartitionsDynamic(table, query, isByName) + if table.resolved && query.resolved && !overwrite.outputResolved => + val projection = resolveOutputColumns(table.name, table.output, query, isByName) + + if (projection != query) { + overwrite.copy(query = projection) + } else { + overwrite + } } def resolveOutputColumns( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 639d68f4ecd7..f7f701cea51f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -365,16 +365,17 @@ case class Join( } /** - * Append data to an existing table. + * Base trait for DataSourceV2 write commands */ -case class AppendData( - table: NamedRelation, - query: LogicalPlan, - isByName: Boolean) extends LogicalPlan { +trait V2WriteCommand extends Command { + def table: NamedRelation + def query: LogicalPlan + override def children: Seq[LogicalPlan] = Seq(query) - override def output: Seq[Attribute] = Seq.empty - override lazy val resolved: Boolean = { + override lazy val resolved: Boolean = outputResolved + + def outputResolved: Boolean = { table.resolved && query.resolved && query.output.size == table.output.size && query.output.zip(table.output).forall { case (inAttr, outAttr) => @@ -386,16 +387,66 @@ case class AppendData( } } +/** + * Append data to an existing table. + */ +case class AppendData( + table: NamedRelation, + query: LogicalPlan, + isByName: Boolean) extends V2WriteCommand + object AppendData { def byName(table: NamedRelation, df: LogicalPlan): AppendData = { - new AppendData(table, df, true) + new AppendData(table, df, isByName = true) } def byPosition(table: NamedRelation, query: LogicalPlan): AppendData = { - new AppendData(table, query, false) + new AppendData(table, query, isByName = false) } } +/** + * Overwrite data matching a filter in an existing table. + */ +case class OverwriteByExpression( + table: NamedRelation, + deleteExpr: Expression, + query: LogicalPlan, + isByName: Boolean) extends V2WriteCommand { + override lazy val resolved: Boolean = outputResolved && deleteExpr.resolved +} + +object OverwriteByExpression { + def byName( + table: NamedRelation, df: LogicalPlan, deleteExpr: Expression): OverwriteByExpression = { + OverwriteByExpression(table, deleteExpr, df, isByName = true) + } + + def byPosition( + table: NamedRelation, query: LogicalPlan, deleteExpr: Expression): OverwriteByExpression = { + OverwriteByExpression(table, deleteExpr, query, isByName = false) + } +} + +/** + * Dynamically overwrite partitions in an existing table. + */ +case class OverwritePartitionsDynamic( + table: NamedRelation, + query: LogicalPlan, + isByName: Boolean) extends V2WriteCommand + +object OverwritePartitionsDynamic { + def byName(table: NamedRelation, df: LogicalPlan): OverwritePartitionsDynamic = { + OverwritePartitionsDynamic(table, df, isByName = true) + } + + def byPosition(table: NamedRelation, query: LogicalPlan): OverwritePartitionsDynamic = { + OverwritePartitionsDynamic(table, query, isByName = false) + } +} + + /** * Insert some data into a table. Note that this plan is unresolved and has to be replaced by the * concrete implementations during analysis. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d285e007dac1..0b7b67ed56d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1452,7 +1452,7 @@ object SQLConf { " register class names for which data source V2 write paths are disabled. Writes from these" + " sources will fall back to the V1 sources.") .stringConf - .createWithDefault("") + .createWithDefault("orc") val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") .doc("A comma-separated list of fully qualified data source register class names for which" + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala index 6c899b610ac5..0c4854861426 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala @@ -19,15 +19,92 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, UpCast} -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LeafNode, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, Expression, LessThanOrEqual, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LeafNode, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Project} import org.apache.spark.sql.types.{DoubleType, FloatType, StructField, StructType} +class V2AppendDataAnalysisSuite extends DataSourceV2AnalysisSuite { + override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + AppendData.byName(table, query) + } + + override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + AppendData.byPosition(table, query) + } +} + +class V2OverwritePartitionsDynamicAnalysisSuite extends DataSourceV2AnalysisSuite { + override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwritePartitionsDynamic.byName(table, query) + } + + override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwritePartitionsDynamic.byPosition(table, query) + } +} + +class V2OverwriteByExpressionAnalysisSuite extends DataSourceV2AnalysisSuite { + override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwriteByExpression.byName(table, query, Literal(true)) + } + + override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwriteByExpression.byPosition(table, query, Literal(true)) + } + + test("delete expression is resolved using table fields") { + val table = TestRelation(StructType(Seq( + StructField("x", DoubleType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("a", DoubleType, nullable = false), + StructField("b", DoubleType))).toAttributes) + + val a = query.output.head + val b = query.output.last + val x = table.output.head + + val parsedPlan = OverwriteByExpression.byPosition(table, query, + LessThanOrEqual(UnresolvedAttribute(Seq("x")), Literal(15.0d))) + + val expectedPlan = OverwriteByExpression.byPosition(table, + Project(Seq( + Alias(Cast(a, DoubleType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(b, DoubleType, Some(conf.sessionLocalTimeZone)), "y")()), + query), + LessThanOrEqual( + AttributeReference("x", DoubleType, nullable = false)(x.exprId), + Literal(15.0d))) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan) + assertResolved(expectedPlan) + } + + test("delete expression is not resolved using query fields") { + val xRequiredTable = TestRelation(StructType(Seq( + StructField("x", DoubleType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("a", DoubleType, nullable = false), + StructField("b", DoubleType))).toAttributes) + + // the write is resolved (checked above). this test plan is not because of the expression. + val parsedPlan = OverwriteByExpression.byPosition(xRequiredTable, query, + LessThanOrEqual(UnresolvedAttribute(Seq("a")), Literal(15.0d))) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq("cannot resolve", "`a`", "given input columns", "x, y")) + } +} + case class TestRelation(output: Seq[AttributeReference]) extends LeafNode with NamedRelation { override def name: String = "table-name" } -class DataSourceV2AnalysisSuite extends AnalysisTest { +abstract class DataSourceV2AnalysisSuite extends AnalysisTest { val table = TestRelation(StructType(Seq( StructField("x", FloatType), StructField("y", FloatType))).toAttributes) @@ -40,21 +117,25 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { StructField("x", DoubleType), StructField("y", DoubleType))).toAttributes) - test("Append.byName: basic behavior") { + def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan + + def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan + + test("byName: basic behavior") { val query = TestRelation(table.schema.toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) checkAnalysis(parsedPlan, parsedPlan) assertResolved(parsedPlan) } - test("Append.byName: does not match by position") { + test("byName: does not match by position") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), StructField("b", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -62,12 +143,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'x'", "'y'")) } - test("Append.byName: case sensitive column resolution") { + test("byName: case sensitive column resolution") { val query = TestRelation(StructType(Seq( StructField("X", FloatType), // doesn't match case! StructField("y", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -76,7 +157,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { caseSensitive = true) } - test("Append.byName: case insensitive column resolution") { + test("byName: case insensitive column resolution") { val query = TestRelation(StructType(Seq( StructField("X", FloatType), // doesn't match case! StructField("y", FloatType))).toAttributes) @@ -84,8 +165,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val X = query.output.head val y = query.output.last - val parsedPlan = AppendData.byName(table, query) - val expectedPlan = AppendData.byName(table, + val parsedPlan = byName(table, query) + val expectedPlan = byName(table, Project(Seq( Alias(Cast(toLower(X), FloatType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -96,7 +177,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byName: data columns are reordered by name") { + test("byName: data columns are reordered by name") { // out of order val query = TestRelation(StructType(Seq( StructField("y", FloatType), @@ -105,8 +186,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val y = query.output.head val x = query.output.last - val parsedPlan = AppendData.byName(table, query) - val expectedPlan = AppendData.byName(table, + val parsedPlan = byName(table, query) + val expectedPlan = byName(table, Project(Seq( Alias(Cast(x, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -117,26 +198,26 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byName: fail nullable data written to required columns") { - val parsedPlan = AppendData.byName(requiredTable, table) + test("byName: fail nullable data written to required columns") { + val parsedPlan = byName(requiredTable, table) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write incompatible data to table", "'table-name'", "Cannot write nullable values to non-null column", "'x'", "'y'")) } - test("Append.byName: allow required data written to nullable columns") { - val parsedPlan = AppendData.byName(table, requiredTable) + test("byName: allow required data written to nullable columns") { + val parsedPlan = byName(table, requiredTable) assertResolved(parsedPlan) checkAnalysis(parsedPlan, parsedPlan) } - test("Append.byName: missing required columns cause failure and are identified by name") { + test("byName: missing required columns cause failure and are identified by name") { // missing required field x val query = TestRelation(StructType(Seq( StructField("y", FloatType, nullable = false))).toAttributes) - val parsedPlan = AppendData.byName(requiredTable, query) + val parsedPlan = byName(requiredTable, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -144,12 +225,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'x'")) } - test("Append.byName: missing optional columns cause failure and are identified by name") { + test("byName: missing optional columns cause failure and are identified by name") { // missing optional field x val query = TestRelation(StructType(Seq( StructField("y", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -157,8 +238,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'x'")) } - test("Append.byName: fail canWrite check") { - val parsedPlan = AppendData.byName(table, widerTable) + test("byName: fail canWrite check") { + val parsedPlan = byName(table, widerTable) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -166,12 +247,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) } - test("Append.byName: insert safe cast") { + test("byName: insert safe cast") { val x = table.output.head val y = table.output.last - val parsedPlan = AppendData.byName(widerTable, table) - val expectedPlan = AppendData.byName(widerTable, + val parsedPlan = byName(widerTable, table) + val expectedPlan = byName(widerTable, Project(Seq( Alias(Cast(x, DoubleType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(y, DoubleType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -182,13 +263,13 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byName: fail extra data fields") { + test("byName: fail extra data fields") { val query = TestRelation(StructType(Seq( StructField("x", FloatType), StructField("y", FloatType), StructField("z", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -197,7 +278,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'x', 'y', 'z'")) } - test("Append.byName: multiple field errors are reported") { + test("byName: multiple field errors are reported") { val xRequiredTable = TestRelation(StructType(Seq( StructField("x", FloatType, nullable = false), StructField("y", DoubleType))).toAttributes) @@ -206,7 +287,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { StructField("x", DoubleType), StructField("b", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(xRequiredTable, query) + val parsedPlan = byName(xRequiredTable, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -216,7 +297,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'y'")) } - test("Append.byPosition: basic behavior") { + test("byPosition: basic behavior") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), StructField("b", FloatType))).toAttributes) @@ -224,8 +305,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val a = query.output.head val b = query.output.last - val parsedPlan = AppendData.byPosition(table, query) - val expectedPlan = AppendData.byPosition(table, + val parsedPlan = byPosition(table, query) + val expectedPlan = byPosition(table, Project(Seq( Alias(Cast(a, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(b, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -236,7 +317,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byPosition: data columns are not reordered") { + test("byPosition: data columns are not reordered") { // out of order val query = TestRelation(StructType(Seq( StructField("y", FloatType), @@ -245,8 +326,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val y = query.output.head val x = query.output.last - val parsedPlan = AppendData.byPosition(table, query) - val expectedPlan = AppendData.byPosition(table, + val parsedPlan = byPosition(table, query) + val expectedPlan = byPosition(table, Project(Seq( Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(x, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -257,26 +338,26 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byPosition: fail nullable data written to required columns") { - val parsedPlan = AppendData.byPosition(requiredTable, table) + test("byPosition: fail nullable data written to required columns") { + val parsedPlan = byPosition(requiredTable, table) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write incompatible data to table", "'table-name'", "Cannot write nullable values to non-null column", "'x'", "'y'")) } - test("Append.byPosition: allow required data written to nullable columns") { - val parsedPlan = AppendData.byPosition(table, requiredTable) + test("byPosition: allow required data written to nullable columns") { + val parsedPlan = byPosition(table, requiredTable) assertResolved(parsedPlan) checkAnalysis(parsedPlan, parsedPlan) } - test("Append.byPosition: missing required columns cause failure") { + test("byPosition: missing required columns cause failure") { // missing optional field x val query = TestRelation(StructType(Seq( StructField("y", FloatType, nullable = false))).toAttributes) - val parsedPlan = AppendData.byPosition(requiredTable, query) + val parsedPlan = byPosition(requiredTable, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -285,12 +366,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'y'")) } - test("Append.byPosition: missing optional columns cause failure") { + test("byPosition: missing optional columns cause failure") { // missing optional field x val query = TestRelation(StructType(Seq( StructField("y", FloatType))).toAttributes) - val parsedPlan = AppendData.byPosition(table, query) + val parsedPlan = byPosition(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -299,12 +380,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'y'")) } - test("Append.byPosition: fail canWrite check") { + test("byPosition: fail canWrite check") { val widerTable = TestRelation(StructType(Seq( StructField("a", DoubleType), StructField("b", DoubleType))).toAttributes) - val parsedPlan = AppendData.byPosition(table, widerTable) + val parsedPlan = byPosition(table, widerTable) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -312,7 +393,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) } - test("Append.byPosition: insert safe cast") { + test("byPosition: insert safe cast") { val widerTable = TestRelation(StructType(Seq( StructField("a", DoubleType), StructField("b", DoubleType))).toAttributes) @@ -320,8 +401,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val x = table.output.head val y = table.output.last - val parsedPlan = AppendData.byPosition(widerTable, table) - val expectedPlan = AppendData.byPosition(widerTable, + val parsedPlan = byPosition(widerTable, table) + val expectedPlan = byPosition(widerTable, Project(Seq( Alias(Cast(x, DoubleType, Some(conf.sessionLocalTimeZone)), "a")(), Alias(Cast(y, DoubleType, Some(conf.sessionLocalTimeZone)), "b")()), @@ -332,13 +413,13 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byPosition: fail extra data fields") { + test("byPosition: fail extra data fields") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), StructField("b", FloatType), StructField("c", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -347,7 +428,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'a', 'b', 'c'")) } - test("Append.byPosition: multiple field errors are reported") { + test("byPosition: multiple field errors are reported") { val xRequiredTable = TestRelation(StructType(Seq( StructField("x", FloatType, nullable = false), StructField("y", DoubleType))).toAttributes) @@ -356,7 +437,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { StructField("x", DoubleType), StructField("b", FloatType))).toAttributes) - val parsedPlan = AppendData.byPosition(xRequiredTable, query) + val parsedPlan = byPosition(xRequiredTable, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index c00abd9b685b..d27fbfdd1461 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -20,12 +20,12 @@ import org.apache.spark.annotation.Evolving; /** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * A mix-in interface for {@link TableProvider}. Data sources can implement this interface to * propagate session configs with the specified key-prefix to all data source operations in this * session. */ @Evolving -public interface SessionConfigSupport extends DataSourceV2 { +public interface SessionConfigSupport extends TableProvider { /** * Key prefix of the session configs to propagate, which is usually the data source name. Spark diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java deleted file mode 100644 index 8ac9c5175086..000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; -import org.apache.spark.sql.streaming.OutputMode; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability for structured streaming. - * - * This interface is used to create {@link StreamingWriteSupport} instances when end users run - * {@code Dataset.writeStream.format(...).option(...).start()}. - */ -@Evolving -public interface StreamingWriteSupportProvider extends DataSourceV2, BaseStreamingSink { - - /** - * Creates a {@link StreamingWriteSupport} instance to save the data to this data source, which is - * called by Spark at the beginning of each streaming query. - * - * @param queryId A unique string for the writing query. It's possible that there are many - * writing queries running at the same time, and the returned - * {@link StreamingWriteSupport} can use this id to distinguish itself from others. - * @param schema the schema of the data to be written. - * @param mode the output mode which determines what successive epoch output means to this - * sink, please refer to {@link OutputMode} for more details. - * @param options the options for the returned data source writer, which is an immutable - * case-insensitive string-to-string map. - */ - StreamingWriteSupport createStreamingWriteSupport( - String queryId, - StructType schema, - OutputMode mode, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java index 08caadd5308e..b2cd97a2f533 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java @@ -24,7 +24,7 @@ * An empty mix-in interface for {@link Table}, to indicate this table supports batch write. *

* If a {@link Table} implements this interface, the - * {@link SupportsWrite#newWriteBuilder(DataSourceOptions)} must return a {@link WriteBuilder} + * {@link SupportsWrite#newWriteBuilder(DataSourceOptions)} must return a {@link WriteBuilder} * with {@link WriteBuilder#buildForBatch()} implemented. *

*/ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java new file mode 100644 index 000000000000..1050d35250c1 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.execution.streaming.BaseStreamingSink; +import org.apache.spark.sql.sources.v2.writer.WriteBuilder; + +/** + * An empty mix-in interface for {@link Table}, to indicate this table supports streaming write. + *

+ * If a {@link Table} implements this interface, the + * {@link SupportsWrite#newWriteBuilder(DataSourceOptions)} must return a {@link WriteBuilder} + * with {@link WriteBuilder#buildForStreaming()} implemented. + *

+ */ +@Evolving +public interface SupportsStreamingWrite extends SupportsWrite, BaseStreamingSink { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java index 855d5efe0c69..a9b83b6de995 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java @@ -29,8 +29,7 @@ *

*/ @Evolving -// TODO: do not extend `DataSourceV2`, after we finish the API refactor completely. -public interface TableProvider extends DataSourceV2 { +public interface TableProvider { /** * Return a {@link Table} instance to do read/write with user-specified options. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 296d3e47e732..f10fd884daab 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -29,6 +29,9 @@ public interface SupportsPushDownFilters extends ScanBuilder { /** * Pushes down filters, and returns filters that need to be evaluated after scanning. + *

+ * Rows should be returned from the data source if and only if all of the filters match. That is, + * filters must be interpreted as ANDed together. */ Filter[] pushFilters(Filter[] filters); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java new file mode 100644 index 000000000000..8058964b662b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +/** + * Write builder trait for tables that support dynamic partition overwrite. + *

+ * A write that dynamically overwrites partitions removes all existing data in each logical + * partition for which the write will commit new data. Any existing logical partition for which the + * write does not contain data will remain unchanged. + *

+ * This is provided to implement SQL compatible with Hive table operations but is not recommended. + * Instead, use the {@link SupportsOverwrite overwrite by filter API} to explicitly replace data. + */ +public interface SupportsDynamicOverwrite extends WriteBuilder { + /** + * Configures a write to dynamically replace partitions with data committed in the write. + * + * @return this write builder for method chaining + */ + WriteBuilder overwriteDynamicPartitions(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java new file mode 100644 index 000000000000..b443b3c3aeb4 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.sql.sources.AlwaysTrue$; +import org.apache.spark.sql.sources.Filter; + +/** + * Write builder trait for tables that support overwrite by filter. + *

+ * Overwriting data by filter will delete any data that matches the filter and replace it with data + * that is committed in the write. + */ +public interface SupportsOverwrite extends WriteBuilder, SupportsTruncate { + /** + * Configures a write to replace data matching the filters with data committed in the write. + *

+ * Rows must be deleted from the data source if and only if all of the filters match. That is, + * filters must be interpreted as ANDed together. + * + * @param filters filters used to match data to overwrite + * @return this write builder for method chaining + */ + WriteBuilder overwrite(Filter[] filters); + + @Override + default WriteBuilder truncate() { + return overwrite(new Filter[] { AlwaysTrue$.MODULE$ }); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java new file mode 100644 index 000000000000..69c2ba5e01a4 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +/** + * Write builder trait for tables that support truncation. + *

+ * Truncation removes all data in a table and replaces it with data that is committed in the write. + */ +public interface SupportsTruncate extends WriteBuilder { + /** + * Configures a write to replace all existing data with data committed in the write. + * + * @return this write builder for method chaining + */ + WriteBuilder truncate(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java index e861c72af9e6..07529fe1dee9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java @@ -20,6 +20,7 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.SupportsBatchWrite; import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite; import org.apache.spark.sql.types.StructType; /** @@ -64,6 +65,12 @@ default WriteBuilder withInputDataSchema(StructType schema) { * {@link SupportsSaveMode}. */ default BatchWrite buildForBatch() { - throw new UnsupportedOperationException("Batch scans are not supported"); + throw new UnsupportedOperationException(getClass().getName() + + " does not support batch write"); + } + + default StreamingWrite buildForStreaming() { + throw new UnsupportedOperationException(getClass().getName() + + " does not support streaming write"); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java index 6334c8f64309..23e8580c404d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -20,12 +20,12 @@ import java.io.Serializable; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite; /** * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side * as the input parameter of {@link BatchWrite#commit(WriterCommitMessage[])} or - * {@link StreamingWriteSupport#commit(long, WriterCommitMessage[])}. + * {@link StreamingWrite#commit(long, WriterCommitMessage[])}. * * This is an empty interface, data sources should define their own message class and use it when * generating messages at executor side and handling the messages at driver side. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java index 7d3d21cb2b63..af2f03c9d419 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java @@ -26,7 +26,7 @@ /** * A factory of {@link DataWriter} returned by - * {@link StreamingWriteSupport#createStreamingWriterFactory()}, which is responsible for creating + * {@link StreamingWrite#createStreamingWriterFactory()}, which is responsible for creating * and initializing the actual data writer at executor side. * * Note that, the writer factory will be serialized and sent to executors, then the data writer diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java similarity index 73% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java index 84cfbf2dda48..5617f1cdc0ef 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java @@ -22,13 +22,26 @@ import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * An interface that defines how to write the data to data source for streaming processing. + * An interface that defines how to write the data to data source in streaming queries. * - * Streaming queries are divided into intervals of data called epochs, with a monotonically - * increasing numeric ID. This writer handles commits and aborts for each successive epoch. + * The writing procedure is: + * 1. Create a writer factory by {@link #createStreamingWriterFactory()}, serialize and send it to + * all the partitions of the input data(RDD). + * 2. For each epoch in each partition, create the data writer, and write the data of the epoch in + * the partition with this writer. If all the data are written successfully, call + * {@link DataWriter#commit()}. If exception happens during the writing, call + * {@link DataWriter#abort()}. + * 3. If writers in all partitions of one epoch are successfully committed, call + * {@link #commit(long, WriterCommitMessage[])}. If some writers are aborted, or the job failed + * with an unknown reason, call {@link #abort(long, WriterCommitMessage[])}. + * + * While Spark will retry failed writing tasks, Spark won't retry failed writing jobs. Users should + * do it manually in their Spark applications if they want to retry. + * + * Please refer to the documentation of commit/abort methods for detailed specifications. */ @Evolving -public interface StreamingWriteSupport { +public interface StreamingWrite { /** * Creates a writer factory which will be serialized and sent to executors. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsOutputMode.java similarity index 67% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsOutputMode.java index 43bdcca70cb0..832dcfa145d1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsOutputMode.java @@ -15,12 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.sources.v2.writer.streaming; -import org.apache.spark.annotation.Evolving; +import org.apache.spark.annotation.Unstable; +import org.apache.spark.sql.sources.v2.writer.WriteBuilder; +import org.apache.spark.sql.streaming.OutputMode; -/** - * TODO: remove it when we finish the API refactor for streaming write side. - */ -@Evolving -public interface DataSourceV2 {} +// TODO: remove it when we have `SupportsTruncate` +@Unstable +public interface SupportsOutputMode extends WriteBuilder { + + WriteBuilder outputMode(OutputMode mode); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 713c9a9faa74..e757785e5664 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -205,7 +205,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { if (classOf[TableProvider].isAssignableFrom(cls)) { val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - ds = provider, conf = sparkSession.sessionState.conf) + source = provider, conf = sparkSession.sessionState.conf) val pathsOption = { val objectMapper = new ObjectMapper() DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index e5f947337c94..450828172b93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -25,7 +25,8 @@ import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan, OverwriteByExpression} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} @@ -264,29 +265,38 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val dsOptions = new DataSourceOptions(options.asJava) provider.getTable(dsOptions) match { case table: SupportsBatchWrite => - if (mode == SaveMode.Append) { - val relation = DataSourceV2Relation.create(table, options) - runCommand(df.sparkSession, "save") { - AppendData.byName(relation, df.logicalPlan) - } - } else { - val writeBuilder = table.newWriteBuilder(dsOptions) - .withQueryId(UUID.randomUUID().toString) - .withInputDataSchema(df.logicalPlan.schema) - writeBuilder match { - case s: SupportsSaveMode => - val write = s.mode(mode).buildForBatch() - // It can only return null with `SupportsSaveMode`. We can clean it up after - // removing `SupportsSaveMode`. - if (write != null) { - runCommand(df.sparkSession, "save") { - WriteToDataSourceV2(write, df.logicalPlan) + lazy val relation = DataSourceV2Relation.create(table, options) + mode match { + case SaveMode.Append => + runCommand(df.sparkSession, "save") { + AppendData.byName(relation, df.logicalPlan) + } + + case SaveMode.Overwrite => + // truncate the table + runCommand(df.sparkSession, "save") { + OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true)) + } + + case _ => + table.newWriteBuilder(dsOptions) match { + case writeBuilder: SupportsSaveMode => + val write = writeBuilder.mode(mode) + .withQueryId(UUID.randomUUID().toString) + .withInputDataSchema(df.logicalPlan.schema) + .buildForBatch() + // It can only return null with `SupportsSaveMode`. We can clean it up after + // removing `SupportsSaveMode`. + if (write != null) { + runCommand(df.sparkSession, "save") { + WriteToDataSourceV2(write, df.logicalPlan) + } } - } - case _ => throw new AnalysisException( - s"data source ${table.name} does not support SaveMode $mode") - } + case _ => + throw new AnalysisException( + s"data source ${table.name} does not support SaveMode $mode") + } } // Streaming also uses the data source V2 API. So it may be that the data source implements diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 72499aa936a5..49d6acf65dd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -85,7 +85,16 @@ class QueryExecution( prepareForExecution(sparkPlan) } - /** Internal version of the RDD. Avoids copies and has no schema */ + /** + * Internal version of the RDD. Avoids copies and has no schema. + * Note for callers: Spark may apply various optimization including reusing object: this means + * the row is valid only for the iteration it is retrieved. You should avoid storing row and + * accessing after iteration. (Calling `collect()` is one of known bad usage.) + * If you want to store these rows into collection, please apply some converter or copy row + * which produces new object per iteration. + * Given QueryExecution is not a public class, end users are discouraged to use this: please + * use `Dataset.rdd` instead where conversion will be applied. + */ lazy val toRdd: RDD[InternalRow] = executedPlan.execute() /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 17cc7fde42bb..23ae1f0e2590 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -742,6 +742,7 @@ case class HashAggregateExec( val fastRowKeys = ctx.generateExpressions( bindReferences[Expression](groupingExpressions, child.output)) val unsafeRowKeys = unsafeRowKeyCode.value + val unsafeRowKeyHash = ctx.freshName("unsafeRowKeyHash") val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer") val fastRowBuffer = ctx.freshName("fastAggBuffer") @@ -755,13 +756,6 @@ case class HashAggregateExec( } } - // generate hash code for key - // SPARK-24076: HashAggregate uses the same hash algorithm on the same expressions - // as ShuffleExchange, it may lead to bad hash conflict when shuffle.partitions=8192*n, - // pick a different seed to avoid this conflict - val hashExpr = Murmur3Hash(groupingExpressions, 48) - val hashEval = BindReferences.bindReference(hashExpr, child.output).genCode(ctx) - val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, incCounter) = if (testFallbackStartsAt.isDefined) { val countTerm = ctx.addMutableState(CodeGenerator.JAVA_INT, "fallbackCounter") @@ -777,11 +771,11 @@ case class HashAggregateExec( s""" |// generate grouping key |${unsafeRowKeyCode.code} - |${hashEval.code} + |int $unsafeRowKeyHash = ${unsafeRowKeyCode.value}.hashCode(); |if ($checkFallbackForBytesToBytesMap) { | // try to get the buffer from hash map | $unsafeRowBuffer = - | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value}); + | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, $unsafeRowKeyHash); |} |// Can't allocate buffer from the hash map. Spill the map and fallback to sort-based |// aggregation after processing all input rows. @@ -795,7 +789,7 @@ case class HashAggregateExec( | // the hash map had be spilled, it should have enough memory now, | // try to allocate buffer again. | $unsafeRowBuffer = $hashMapTerm.getAggregationBufferFromUnsafeRow( - | $unsafeRowKeys, ${hashEval.value}); + | $unsafeRowKeys, $unsafeRowKeyHash); | if ($unsafeRowBuffer == null) { | // failed to allocate the first page | throw new $oomeClassName("No enough memory for aggregation"); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 273cc3b19302..b73dc30d6f23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -529,6 +529,12 @@ object DataSourceStrategy { case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) => Some(sources.StringContains(a.name, v.toString)) + case expressions.Literal(true, BooleanType) => + Some(sources.AlwaysTrue) + + case expressions.Literal(false, BooleanType) => + Some(sources.AlwaysFalse) + case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala index 254c09001f7e..e22d6a6d399a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala @@ -35,7 +35,7 @@ class FallbackOrcDataSourceV2(sparkSession: SparkSession) extends Rule[LogicalPl override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(d @DataSourceV2Relation(table: OrcTable, _, _), _, _, _, _) => val v1FileFormat = new OrcFileFormat - val relation = HadoopFsRelation(table.getFileIndex, table.getFileIndex.partitionSchema, + val relation = HadoopFsRelation(table.fileIndex, table.fileIndex.partitionSchema, table.schema(), None, v1FileFormat, d.options)(sparkSession) i.copy(table = LogicalRelation(relation)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 452ebbbeb99c..8f2072c586a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite, SupportsOutputMode} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -30,30 +30,23 @@ import org.apache.spark.sql.types.StructType * This is no-op datasource. It does not do anything besides consuming its input. * This can be useful for benchmarking or to cache data without any additional overhead. */ -class NoopDataSource - extends DataSourceV2 - with TableProvider - with DataSourceRegister - with StreamingWriteSupportProvider { - +class NoopDataSource extends TableProvider with DataSourceRegister { override def shortName(): String = "noop" override def getTable(options: DataSourceOptions): Table = NoopTable - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = NoopStreamingWriteSupport } -private[noop] object NoopTable extends Table with SupportsBatchWrite { +private[noop] object NoopTable extends Table with SupportsBatchWrite with SupportsStreamingWrite { override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = NoopWriteBuilder override def name(): String = "noop-table" override def schema(): StructType = new StructType() } -private[noop] object NoopWriteBuilder extends WriteBuilder with SupportsSaveMode { - override def buildForBatch(): BatchWrite = NoopBatchWrite +private[noop] object NoopWriteBuilder extends WriteBuilder + with SupportsSaveMode with SupportsOutputMode { override def mode(mode: SaveMode): WriteBuilder = this + override def outputMode(mode: OutputMode): WriteBuilder = this + override def buildForBatch(): BatchWrite = NoopBatchWrite + override def buildForStreaming(): StreamingWrite = NoopStreamingWrite } private[noop] object NoopBatchWrite extends BatchWrite { @@ -72,7 +65,7 @@ private[noop] object NoopWriter extends DataWriter[InternalRow] { override def abort(): Unit = {} } -private[noop] object NoopStreamingWriteSupport extends StreamingWriteSupport { +private[noop] object NoopStreamingWrite extends StreamingWrite { override def createStreamingWriterFactory(): StreamingDataWriterFactory = NoopStreamingDataWriterFactory override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} @@ -85,4 +78,3 @@ private[noop] object NoopStreamingDataWriterFactory extends StreamingDataWriterF taskId: Long, epochId: Long): DataWriter[InternalRow] = NoopWriter } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala new file mode 100644 index 000000000000..c8542bfe5e59 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsBatchRead, SupportsBatchWrite, Table} + +object DataSourceV2Implicits { + implicit class TableHelper(table: Table) { + def asBatchReadable: SupportsBatchRead = { + table match { + case support: SupportsBatchRead => + support + case _ => + throw new AnalysisException(s"Table does not support batch reads: ${table.name}") + } + } + + def asBatchWritable: SupportsBatchWrite = { + table match { + case support: SupportsBatchWrite => + support + case _ => + throw new AnalysisException(s"Table does not support batch writes: ${table.name}") + } + } + } + + implicit class OptionsHelper(options: Map[String, String]) { + def toDataSourceOptions: DataSourceOptions = new DataSourceOptions(options.asJava) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 47cf26dc9481..53677782c95f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,11 +17,6 @@ package org.apache.spark.sql.execution.datasources.v2 -import java.util.UUID - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} @@ -30,7 +25,6 @@ import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{Offset, SparkDataStream} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.types.StructType /** * A logical plan representing a data source v2 table. @@ -45,26 +39,16 @@ case class DataSourceV2Relation( options: Map[String, String]) extends LeafNode with MultiInstanceRelation with NamedRelation { + import DataSourceV2Implicits._ + override def name: String = table.name() override def simpleString(maxFields: Int): String = { s"RelationV2${truncatedString(output, "[", ", ", "]", maxFields)} $name" } - def newScanBuilder(): ScanBuilder = table match { - case s: SupportsBatchRead => - val dsOptions = new DataSourceOptions(options.asJava) - s.newScanBuilder(dsOptions) - case _ => throw new AnalysisException(s"Table is not readable: ${table.name()}") - } - - def newWriteBuilder(schema: StructType): WriteBuilder = table match { - case s: SupportsBatchWrite => - val dsOptions = new DataSourceOptions(options.asJava) - s.newWriteBuilder(dsOptions) - .withQueryId(UUID.randomUUID().toString) - .withInputDataSchema(schema) - case _ => throw new AnalysisException(s"Table is not writable: ${table.name()}") + def newScanBuilder(): ScanBuilder = { + table.asBatchReadable.newScanBuilder(options.toDataSourceOptions) } override def computeStats(): Statistics = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index d6d17d6df7b1..55d7b0a18cbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -19,18 +19,18 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.{sources, AnalysisException, SaveMode, Strategy} -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, SubqueryExpression} +import org.apache.spark.sql.{AnalysisException, Strategy} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, PredicateHelper} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.sources import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} -import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode -object DataSourceV2Strategy extends Strategy { +object DataSourceV2Strategy extends Strategy with PredicateHelper { /** * Pushes down filters to the data source reader @@ -100,6 +100,7 @@ object DataSourceV2Strategy extends Strategy { } } + import DataSourceV2Implicits._ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => @@ -146,14 +147,22 @@ object DataSourceV2Strategy extends Strategy { WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil case AppendData(r: DataSourceV2Relation, query, _) => - val writeBuilder = r.newWriteBuilder(query.schema) - writeBuilder match { - case s: SupportsSaveMode => - val write = s.mode(SaveMode.Append).buildForBatch() - assert(write != null) - WriteToDataSourceV2Exec(write, planLater(query)) :: Nil - case _ => throw new AnalysisException(s"data source ${r.name} does not support SaveMode") - } + AppendDataExec( + r.table.asBatchWritable, r.options.toDataSourceOptions, planLater(query)) :: Nil + + case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, _) => + // fail if any filter cannot be converted. correctness depends on removing all matching data. + val filters = splitConjunctivePredicates(deleteExpr).map { + filter => DataSourceStrategy.translateFilter(deleteExpr).getOrElse( + throw new AnalysisException(s"Cannot translate expression to source filter: $filter")) + }.toArray + + OverwriteByExpressionExec( + r.table.asBatchWritable, filters, r.options.toDataSourceOptions, planLater(query)) :: Nil + + case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _) => + OverwritePartitionsDynamicExec(r.table.asBatchWritable, + r.options.toDataSourceOptions, planLater(query)) :: Nil case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala deleted file mode 100644 index f11703c8a277..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import org.apache.commons.lang3.StringUtils - -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.util.Utils - -/** - * A trait that can be used by data source v2 related query plans(both logical and physical), to - * provide a string format of the data source information for explain. - */ -trait DataSourceV2StringFormat { - - /** - * The instance of this data source implementation. Note that we only consider its class in - * equals/hashCode, not the instance itself. - */ - def source: DataSourceV2 - - /** - * The output of the data source reader, w.r.t. column pruning. - */ - def output: Seq[Attribute] - - /** - * The options for this data source reader. - */ - def options: Map[String, String] - - /** - * The filters which have been pushed to the data source. - */ - def pushedFilters: Seq[Expression] - - private def sourceName: String = source match { - case registered: DataSourceRegister => registered.shortName() - // source.getClass.getSimpleName can cause Malformed class name error, - // call safer `Utils.getSimpleName` instead - case _ => Utils.getSimpleName(source.getClass) - } - - def metadataString(maxFields: Int): String = { - val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] - - if (pushedFilters.nonEmpty) { - entries += "Filters" -> pushedFilters.mkString("[", ", ", "]") - } - - // TODO: we should only display some standard options like path, table, etc. - if (options.nonEmpty) { - entries += "Options" -> Utils.redact(options).map { - case (k, v) => s"$k=$v" - }.mkString("[", ",", "]") - } - - val outputStr = truncatedString(output, "[", ", ", "]", maxFields) - - val entriesStr = if (entries.nonEmpty) { - truncatedString(entries.map { - case (key, value) => key + ": " + StringUtils.abbreviate(value, 100) - }, " (", ", ", ")", maxFields) - } else { - "" - } - - s"$sourceName$outputStr$entriesStr" - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index e9cc3991155c..30897d86f817 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -21,8 +21,7 @@ import java.util.regex.Pattern import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceV2, SessionConfigSupport} +import org.apache.spark.sql.sources.v2.{SessionConfigSupport, TableProvider} private[sql] object DataSourceV2Utils extends Logging { @@ -34,34 +33,28 @@ private[sql] object DataSourceV2Utils extends Logging { * `spark.datasource.$keyPrefix`. A session config `spark.datasource.$keyPrefix.xxx -> yyy` will * be transformed into `xxx -> yyy`. * - * @param ds a [[DataSourceV2]] object + * @param source a [[TableProvider]] object * @param conf the session conf * @return an immutable map that contains all the extracted and transformed k/v pairs. */ - def extractSessionConfigs(ds: DataSourceV2, conf: SQLConf): Map[String, String] = ds match { - case cs: SessionConfigSupport => - val keyPrefix = cs.keyPrefix() - require(keyPrefix != null, "The data source config key prefix can't be null.") - - val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.+)") - - conf.getAllConfs.flatMap { case (key, value) => - val m = pattern.matcher(key) - if (m.matches() && m.groupCount() > 0) { - Seq((m.group(1), value)) - } else { - Seq.empty + def extractSessionConfigs(source: TableProvider, conf: SQLConf): Map[String, String] = { + source match { + case cs: SessionConfigSupport => + val keyPrefix = cs.keyPrefix() + require(keyPrefix != null, "The data source config key prefix can't be null.") + + val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.+)") + + conf.getAllConfs.flatMap { case (key, value) => + val m = pattern.matcher(key) + if (m.matches() && m.groupCount() > 0) { + Seq((m.group(1), value)) + } else { + Seq.empty + } } - } - - case _ => Map.empty - } - def failForUserSpecifiedSchema[T](ds: DataSourceV2): T = { - val name = ds match { - case register: DataSourceRegister => register.shortName() - case _ => ds.getClass.getName + case _ => Map.empty } - throw new UnsupportedOperationException(name + " source does not support user-specified schema") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala index 7b836111a59b..db31927fa73b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala @@ -46,8 +46,7 @@ class FileBatchWrite( } override def createBatchWriterFactory(): DataWriterFactory = { - val conf = new SerializableConfiguration(job.getConfiguration) - FileWriterFactory(description, committer, conf) + FileWriterFactory(description, committer) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala index a0c932cbb0e0..06c57066aa24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala @@ -16,13 +16,10 @@ */ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, SupportsBatchRead, TableProvider} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.sources.v2.TableProvider /** * A base interface for data source v2 implementations of the built-in file-based data sources. @@ -38,17 +35,4 @@ trait FileDataSourceV2 extends TableProvider with DataSourceRegister { def fallBackFileFormat: Class[_ <: FileFormat] lazy val sparkSession = SparkSession.active - - def getFileIndex( - options: DataSourceOptions, - userSpecifiedSchema: Option[StructType]): PartitioningAwareFileIndex = { - val filePaths = options.paths() - val hadoopConf = - sparkSession.sessionState.newHadoopConfWithOptions(options.asMap().asScala.toMap) - val rootPathsSpecified = DataSource.checkAndGlobPathIfNecessary(filePaths, hadoopConf, - checkEmptyGlobPath = true, checkFilesExist = options.checkFilesExist()) - val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) - new InMemoryFileIndex(sparkSession, rootPathsSpecified, - options.asMap().asScala.toMap, userSpecifiedSchema, fileStatusCache) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 3615b15be6fd..bdd6a48df20c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -18,15 +18,16 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.hadoop.fs.Path -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.execution.PartitionedFileUtil import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, Scan} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} abstract class FileScan( sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex) extends Scan with Batch { + fileIndex: PartitioningAwareFileIndex, + readSchema: StructType) extends Scan with Batch { /** * Returns whether a file with `path` could be split or not. */ @@ -34,6 +35,22 @@ abstract class FileScan( false } + /** + * Returns whether this format supports the given [[DataType]] in write path. + * By default all data types are supported. + */ + def supportsDataType(dataType: DataType): Boolean = true + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def formatName(): String = "ORC" + * }}} + */ + def formatName: String + protected def partitions: Seq[FilePartition] = { val selectedPartitions = fileIndex.listFiles(Seq.empty, Seq.empty) val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions) @@ -57,5 +74,13 @@ abstract class FileScan( partitions.toArray } - override def toBatch: Batch = this + override def toBatch: Batch = { + readSchema.foreach { field => + if (!supportsDataType(field.dataType)) { + throw new AnalysisException( + s"$formatName data source does not support ${field.dataType.catalogString} data type.") + } + } + this + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index 0dbef145f732..21d3e5e29cfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -16,19 +16,30 @@ */ package org.apache.spark.sql.execution.datasources.v2 +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.sources.v2.{SupportsBatchRead, SupportsBatchWrite, Table} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsBatchRead, SupportsBatchWrite, Table} import org.apache.spark.sql.types.StructType abstract class FileTable( sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, + options: DataSourceOptions, userSpecifiedSchema: Option[StructType]) extends Table with SupportsBatchRead with SupportsBatchWrite { - def getFileIndex: PartitioningAwareFileIndex = this.fileIndex + lazy val fileIndex: PartitioningAwareFileIndex = { + val filePaths = options.paths() + val hadoopConf = + sparkSession.sessionState.newHadoopConfWithOptions(options.asMap().asScala.toMap) + val rootPathsSpecified = DataSource.checkAndGlobPathIfNecessary(filePaths, hadoopConf, + checkEmptyGlobPath = true, checkFilesExist = options.checkFilesExist()) + val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) + new InMemoryFileIndex(sparkSession, rootPathsSpecified, + options.asMap().asScala.toMap, userSpecifiedSchema, fileStatusCache) + } lazy val dataSchema: StructType = userSpecifiedSchema.orElse { inferSchema(fileIndex.allFiles()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala index ce9b52f29d7b..6a94248a6f0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.writer.{BatchWrite, SupportsSaveMode, WriteBuilder} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration abstract class FileWriteBuilder(options: DataSourceOptions) @@ -104,12 +104,34 @@ abstract class FileWriteBuilder(options: DataSourceOptions) options: Map[String, String], dataSchema: StructType): OutputWriterFactory + /** + * Returns whether this format supports the given [[DataType]] in write path. + * By default all data types are supported. + */ + def supportsDataType(dataType: DataType): Boolean = true + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def formatName(): String = "ORC" + * }}} + */ + def formatName: String + private def validateInputs(): Unit = { assert(schema != null, "Missing input data schema") assert(queryId != null, "Missing query ID") assert(mode != null, "Missing save mode") assert(options.paths().length == 1) DataSource.validateSchema(schema) + schema.foreach { field => + if (!supportsDataType(field.dataType)) { + throw new AnalysisException( + s"$formatName data source does not support ${field.dataType.catalogString} data type.") + } + } } private def getJobInstance(hadoopConf: Configuration, path: Path): Job = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala index be9c180a901b..eb573b317142 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala @@ -29,8 +29,7 @@ import org.apache.spark.util.SerializableConfiguration case class FileWriterFactory ( description: WriteJobDescription, - committer: FileCommitProtocol, - conf: SerializableConfiguration) extends DataWriterFactory { + committer: FileCommitProtocol) extends DataWriterFactory { override def createWriter(partitionId: Int, realTaskId: Long): DataWriter[InternalRow] = { val taskAttemptContext = createTaskAttemptContext(partitionId) committer.setupTask(taskAttemptContext) @@ -46,7 +45,7 @@ case class FileWriterFactory ( val taskId = new TaskID(jobId, TaskType.MAP, partitionId) val taskAttemptId = new TaskAttemptID(taskId, 0) // Set up the configuration object - val hadoopConf = conf.value + val hadoopConf = description.serializableHadoopConf.value hadoopConf.set("mapreduce.job.id", jobId.toString) hadoopConf.set("mapreduce.task.id", taskId.toString) hadoopConf.set("mapreduce.task.attempt.id", taskAttemptId.toString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 50c5e4f2ad7d..d7cb2457433b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -17,17 +17,22 @@ package org.apache.spark.sql.execution.datasources.v2 +import java.util.UUID + import scala.util.control.NonFatal import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} -import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.{AlwaysTrue, Filter} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsBatchWrite} +import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsSaveMode, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.util.{LongAccumulator, Utils} /** @@ -42,17 +47,137 @@ case class WriteToDataSourceV2(batchWrite: BatchWrite, query: LogicalPlan) } /** - * The physical plan for writing data into data source v2. + * Physical plan node for append into a v2 table. + * + * Rows in the output data set are appended. + */ +case class AppendDataExec( + table: SupportsBatchWrite, + writeOptions: DataSourceOptions, + query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { + + override protected def doExecute(): RDD[InternalRow] = { + val batchWrite = newWriteBuilder() match { + case builder: SupportsSaveMode => + builder.mode(SaveMode.Append).buildForBatch() + + case builder => + builder.buildForBatch() + } + doWrite(batchWrite) + } +} + +/** + * Physical plan node for overwrite into a v2 table. + * + * Overwrites data in a table matched by a set of filters. Rows matching all of the filters will be + * deleted and rows in the output data set are appended. + * + * This plan is used to implement SaveMode.Overwrite. The behavior of SaveMode.Overwrite is to + * truncate the table -- delete all rows -- and append the output data set. This uses the filter + * AlwaysTrue to delete all rows. */ -case class WriteToDataSourceV2Exec(batchWrite: BatchWrite, query: SparkPlan) - extends UnaryExecNode { +case class OverwriteByExpressionExec( + table: SupportsBatchWrite, + deleteWhere: Array[Filter], + writeOptions: DataSourceOptions, + query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { + + private def isTruncate(filters: Array[Filter]): Boolean = { + filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] + } + + override protected def doExecute(): RDD[InternalRow] = { + val batchWrite = newWriteBuilder() match { + case builder: SupportsTruncate if isTruncate(deleteWhere) => + builder.truncate().buildForBatch() + + case builder: SupportsSaveMode if isTruncate(deleteWhere) => + builder.mode(SaveMode.Overwrite).buildForBatch() + + case builder: SupportsOverwrite => + builder.overwrite(deleteWhere).buildForBatch() + + case _ => + throw new SparkException(s"Table does not support dynamic partition overwrite: $table") + } + + doWrite(batchWrite) + } +} + +/** + * Physical plan node for dynamic partition overwrite into a v2 table. + * + * Dynamic partition overwrite is the behavior of Hive INSERT OVERWRITE ... PARTITION queries, and + * Spark INSERT OVERWRITE queries when spark.sql.sources.partitionOverwriteMode=dynamic. Each + * partition in the output data set replaces the corresponding existing partition in the table or + * creates a new partition. Existing partitions for which there is no data in the output data set + * are not modified. + */ +case class OverwritePartitionsDynamicExec( + table: SupportsBatchWrite, + writeOptions: DataSourceOptions, + query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { + + override protected def doExecute(): RDD[InternalRow] = { + val batchWrite = newWriteBuilder() match { + case builder: SupportsDynamicOverwrite => + builder.overwriteDynamicPartitions().buildForBatch() + + case builder: SupportsSaveMode => + builder.mode(SaveMode.Overwrite).buildForBatch() + + case _ => + throw new SparkException(s"Table does not support dynamic partition overwrite: $table") + } + + doWrite(batchWrite) + } +} + +case class WriteToDataSourceV2Exec( + batchWrite: BatchWrite, + query: SparkPlan + ) extends V2TableWriteExec { + + import DataSourceV2Implicits._ + + def writeOptions: DataSourceOptions = Map.empty[String, String].toDataSourceOptions + + override protected def doExecute(): RDD[InternalRow] = { + doWrite(batchWrite) + } +} + +/** + * Helper for physical plans that build batch writes. + */ +trait BatchWriteHelper { + def table: SupportsBatchWrite + def query: SparkPlan + def writeOptions: DataSourceOptions + + def newWriteBuilder(): WriteBuilder = { + table.newWriteBuilder(writeOptions) + .withInputDataSchema(query.schema) + .withQueryId(UUID.randomUUID().toString) + } +} + +/** + * The base physical plan for writing data into data source v2. + */ +trait V2TableWriteExec extends UnaryExecNode { + def query: SparkPlan var commitProgress: Option[StreamWriterCommitProgress] = None override def child: SparkPlan = query override def output: Seq[Attribute] = Nil - override protected def doExecute(): RDD[InternalRow] = { + protected def doWrite(batchWrite: BatchWrite): RDD[InternalRow] = { val writerFactory = batchWrite.createBatchWriterFactory() val useCommitCoordinator = batchWrite.useCommitCoordinator val rdd = query.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala index db1f2f793422..f279af49ba9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala @@ -20,7 +20,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.sources.v2.{DataSourceOptions, Table} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ class OrcDataSourceV2 extends FileDataSourceV2 { @@ -34,13 +34,28 @@ class OrcDataSourceV2 extends FileDataSourceV2 { override def getTable(options: DataSourceOptions): Table = { val tableName = getTableName(options) - val fileIndex = getFileIndex(options, None) - OrcTable(tableName, sparkSession, fileIndex, None) + OrcTable(tableName, sparkSession, options, None) } override def getTable(options: DataSourceOptions, schema: StructType): Table = { val tableName = getTableName(options) - val fileIndex = getFileIndex(options, Some(schema)) - OrcTable(tableName, sparkSession, fileIndex, Some(schema)) + OrcTable(tableName, sparkSession, options, Some(schema)) + } +} + +object OrcDataSourceV2 { + def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportsDataType(f.dataType) } + + case ArrayType(elementType, _) => supportsDataType(elementType) + + case MapType(keyType, valueType, _) => + supportsDataType(keyType) && supportsDataType(valueType) + + case udt: UserDefinedType[_] => supportsDataType(udt.sqlType) + + case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index a792ad318b39..3c5dc1f50d7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration case class OrcScan( @@ -31,7 +31,7 @@ case class OrcScan( hadoopConf: Configuration, fileIndex: PartitioningAwareFileIndex, dataSchema: StructType, - readSchema: StructType) extends FileScan(sparkSession, fileIndex) { + readSchema: StructType) extends FileScan(sparkSession, fileIndex, readSchema) { override def isSplitable(path: Path): Boolean = true override def createReaderFactory(): PartitionReaderFactory = { @@ -40,4 +40,10 @@ case class OrcScan( OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, dataSchema, fileIndex.partitionSchema, readSchema) } + + override def supportsDataType(dataType: DataType): Boolean = { + OrcDataSourceV2.supportsDataType(dataType) + } + + override def formatName: String = "ORC" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala index b467e505f1ba..249df8b8622f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.v2.orc import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.orc.OrcUtils import org.apache.spark.sql.execution.datasources.v2.FileTable import org.apache.spark.sql.sources.v2.DataSourceOptions @@ -29,9 +28,9 @@ import org.apache.spark.sql.types.StructType case class OrcTable( name: String, sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, + options: DataSourceOptions, userSpecifiedSchema: Option[StructType]) - extends FileTable(sparkSession, fileIndex, userSpecifiedSchema) { + extends FileTable(sparkSession, options, userSpecifiedSchema) { override def newScanBuilder(options: DataSourceOptions): OrcScanBuilder = new OrcScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala index 80429d91d5e4..1aec4d872a64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala @@ -63,4 +63,10 @@ class OrcWriteBuilder(options: DataSourceOptions) extends FileWriteBuilder(optio } } } + + override def supportsDataType(dataType: DataType): Boolean = { + OrcDataSourceV2.supportsDataType(dataType) + } + + override def formatName: String = "ORC" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 2c339759f95b..cca279030dfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWrite, RateCo import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.writer.streaming.SupportsOutputMode import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.Clock @@ -513,13 +514,16 @@ class MicroBatchExecution( val triggerLogicalPlan = sink match { case _: Sink => newAttributePlan - case s: StreamingWriteSupportProvider => - val writer = s.createStreamingWriteSupport( - s"$runId", - newAttributePlan.schema, - outputMode, - new DataSourceOptions(extraOptions.asJava)) - WriteToDataSourceV2(new MicroBatchWrite(currentBatchId, writer), newAttributePlan) + case s: SupportsStreamingWrite => + // TODO: we should translate OutputMode to concrete write actions like truncate, but + // the truncate action is being developed in SPARK-26666. + val writeBuilder = s.newWriteBuilder(new DataSourceOptions(extraOptions.asJava)) + .withQueryId(runId.toString) + .withInputDataSchema(newAttributePlan.schema) + val streamingWrite = writeBuilder.asInstanceOf[SupportsOutputMode] + .outputMode(outputMode) + .buildForStreaming() + WriteToDataSourceV2(new MicroBatchWrite(currentBatchId, streamingWrite), newAttributePlan) case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } @@ -549,7 +553,7 @@ class MicroBatchExecution( SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) - case _: StreamingWriteSupportProvider => + case _: SupportsStreamingWrite => // This doesn't accumulate any data - it just forces execution of the microbatch writer. nextBatch.collect() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 83d38dcade7e..1b7aa548e6d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.{DataSourceV2, Table} +import org.apache.spark.sql.sources.v2.{Table, TableProvider} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -86,13 +86,13 @@ case class StreamingExecutionRelation( // know at read time whether the query is continuous or not, so we need to be able to // swap a V1 relation back in. /** - * Used to link a [[DataSourceV2]] into a streaming + * Used to link a [[TableProvider]] into a streaming * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. This is only used for creating * a streaming [[org.apache.spark.sql.DataFrame]] from [[org.apache.spark.sql.DataFrameReader]], * and should be converted before passing to [[StreamExecution]]. */ case class StreamingRelationV2( - dataSource: DataSourceV2, + source: TableProvider, sourceName: String, table: Table, extraOptions: Map[String, String], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 9c5c16f4f5d1..348bc767b2c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.ConsoleWriteSupport +import org.apache.spark.sql.execution.streaming.sources.ConsoleWrite import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.writer.WriteBuilder +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingWrite, SupportsOutputMode} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -30,17 +31,12 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) override def schema: StructType = data.schema } -class ConsoleSinkProvider extends DataSourceV2 - with StreamingWriteSupportProvider +class ConsoleSinkProvider extends TableProvider with DataSourceRegister with CreatableRelationProvider { - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - new ConsoleWriteSupport(schema, options) + override def getTable(options: DataSourceOptions): Table = { + ConsoleTable } def createRelation( @@ -60,3 +56,28 @@ class ConsoleSinkProvider extends DataSourceV2 def shortName(): String = "console" } + +object ConsoleTable extends Table with SupportsStreamingWrite { + + override def name(): String = "console" + + override def schema(): StructType = StructType(Nil) + + override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + new WriteBuilder with SupportsOutputMode { + private var inputSchema: StructType = _ + + override def withInputDataSchema(schema: StructType): WriteBuilder = { + this.inputSchema = schema + this + } + + override def outputMode(mode: OutputMode): WriteBuilder = this + + override def buildForStreaming(): StreamingWrite = { + assert(inputSchema != null) + new ConsoleWrite(inputSchema, options) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index b22795d20776..20101c7fda32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -32,8 +32,9 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming.{StreamingRelationV2, _} import org.apache.spark.sql.sources.v2 -import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider, SupportsContinuousRead} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsContinuousRead, SupportsStreamingWrite} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.streaming.SupportsOutputMode import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.Clock @@ -42,7 +43,7 @@ class ContinuousExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: StreamingWriteSupportProvider, + sink: SupportsStreamingWrite, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, @@ -174,12 +175,15 @@ class ContinuousExecution( "CurrentTimestamp and CurrentDate not yet supported for continuous processing") } - val writer = sink.createStreamingWriteSupport( - s"$runId", - withNewSources.schema, - outputMode, - new DataSourceOptions(extraOptions.asJava)) - val planWithSink = WriteToContinuousDataSource(writer, withNewSources) + // TODO: we should translate OutputMode to concrete write actions like truncate, but + // the truncate action is being developed in SPARK-26666. + val writeBuilder = sink.newWriteBuilder(new DataSourceOptions(extraOptions.asJava)) + .withQueryId(runId.toString) + .withInputDataSchema(withNewSources.schema) + val streamingWrite = writeBuilder.asInstanceOf[SupportsOutputMode] + .outputMode(outputMode) + .buildForStreaming() + val planWithSink = WriteToContinuousDataSource(streamingWrite, withNewSources) reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( @@ -214,9 +218,8 @@ class ContinuousExecution( trigger.asInstanceOf[ContinuousTrigger].intervalMs.toString) // Use the parent Spark session for the endpoint since it's where this query ID is registered. - val epochEndpoint = - EpochCoordinatorRef.create( - writer, stream, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) + val epochEndpoint = EpochCoordinatorRef.create( + streamingWrite, stream, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index d1bda79f4b6e..a99842220424 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -25,7 +25,7 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeR import org.apache.spark.sql.SparkSession import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable @@ -82,7 +82,7 @@ private[sql] object EpochCoordinatorRef extends Logging { * Create a reference to a new [[EpochCoordinator]]. */ def create( - writeSupport: StreamingWriteSupport, + writeSupport: StreamingWrite, stream: ContinuousStream, query: ContinuousExecution, epochCoordinatorId: String, @@ -115,7 +115,7 @@ private[sql] object EpochCoordinatorRef extends Logging { * have both committed and reported an end offset for a given epoch. */ private[continuous] class EpochCoordinator( - writeSupport: StreamingWriteSupport, + writeSupport: StreamingWrite, stream: ContinuousStream, query: ContinuousExecution, startEpoch: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala index 7ad21cc304e7..54f484c4adae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite /** * The logical plan for writing data in a continuous stream. */ -case class WriteToContinuousDataSource( - writeSupport: StreamingWriteSupport, query: LogicalPlan) extends LogicalPlan { +case class WriteToContinuousDataSource(write: StreamingWrite, query: LogicalPlan) + extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index 2178466d6314..2f3af6a6544c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -26,21 +26,22 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite /** - * The physical plan for writing data into a continuous processing [[StreamingWriteSupport]]. + * The physical plan for writing data into a continuous processing [[StreamingWrite]]. */ -case class WriteToContinuousDataSourceExec(writeSupport: StreamingWriteSupport, query: SparkPlan) - extends UnaryExecNode with Logging { +case class WriteToContinuousDataSourceExec(write: StreamingWrite, query: SparkPlan) + extends UnaryExecNode with Logging { + override def child: SparkPlan = query override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writerFactory = writeSupport.createStreamingWriterFactory() + val writerFactory = write.createStreamingWriterFactory() val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) - logInfo(s"Start processing data source write support: $writeSupport. " + + logInfo(s"Start processing data source write support: $write. " + s"The input RDD has ${rdd.partitions.length} partitions.") EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala index 833e62f35ede..f2ff30bcf1be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala @@ -22,12 +22,12 @@ import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ -class ConsoleWriteSupport(schema: StructType, options: DataSourceOptions) - extends StreamingWriteSupport with Logging { +class ConsoleWrite(schema: StructType, options: DataSourceOptions) + extends StreamingWrite with Logging { // Number of rows to display, by default 20 rows protected val numRowsToShow = options.getInt("numRows", 20) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala similarity index 66% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index 4218fd51ad20..6fbb59c43625 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -22,63 +22,73 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.python.PythonForeachWriter -import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsStreamingWrite, Table} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite, SupportsOutputMode} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType /** - * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified - * [[ForeachWriter]]. + * A write-only table for forwarding data into the specified [[ForeachWriter]]. * * @param writer The [[ForeachWriter]] to process all data. * @param converter An object to convert internal rows to target type T. Either it can be * a [[ExpressionEncoder]] or a direct converter function. * @tparam T The expected type of the sink. */ -case class ForeachWriteSupportProvider[T]( +case class ForeachWriterTable[T]( writer: ForeachWriter[T], converter: Either[ExpressionEncoder[T], InternalRow => T]) - extends StreamingWriteSupportProvider { - - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - new StreamingWriteSupport { - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - - override def createStreamingWriterFactory(): StreamingDataWriterFactory = { - val rowConverter: InternalRow => T = converter match { - case Left(enc) => - val boundEnc = enc.resolveAndBind( - schema.toAttributes, - SparkSession.getActiveSession.get.sessionState.analyzer) - boundEnc.fromRow - case Right(func) => - func - } - ForeachWriterFactory(writer, rowConverter) + extends Table with SupportsStreamingWrite { + + override def name(): String = "ForeachSink" + + override def schema(): StructType = StructType(Nil) + + override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + new WriteBuilder with SupportsOutputMode { + private var inputSchema: StructType = _ + + override def withInputDataSchema(schema: StructType): WriteBuilder = { + this.inputSchema = schema + this } - override def toString: String = "ForeachSink" + override def outputMode(mode: OutputMode): WriteBuilder = this + + override def buildForStreaming(): StreamingWrite = { + new StreamingWrite { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + override def createStreamingWriterFactory(): StreamingDataWriterFactory = { + val rowConverter: InternalRow => T = converter match { + case Left(enc) => + val boundEnc = enc.resolveAndBind( + inputSchema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + boundEnc.fromRow + case Right(func) => + func + } + ForeachWriterFactory(writer, rowConverter) + } + } + } } } } -object ForeachWriteSupportProvider { +object ForeachWriterTable { def apply[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T]): ForeachWriteSupportProvider[_] = { + encoder: ExpressionEncoder[T]): ForeachWriterTable[_] = { writer match { case pythonWriter: PythonForeachWriter => - new ForeachWriteSupportProvider[UnsafeRow]( + new ForeachWriterTable[UnsafeRow]( pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) case _ => - new ForeachWriteSupportProvider[T](writer, Left(encoder)) + new ForeachWriterTable[T](writer, Left(encoder)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala index 143235efee81..f3951897ea74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriter, DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} /** * A [[BatchWrite]] used to hook V2 stream writers into a microbatch plan. It implements * the non-streaming interface, forwarding the epoch ID determined at construction to a wrapped * streaming write support. */ -class MicroBatchWrite(eppchId: Long, val writeSupport: StreamingWriteSupport) extends BatchWrite { +class MicroBatchWrite(eppchId: Long, val writeSupport: StreamingWrite) extends BatchWrite { override def commit(messages: Array[WriterCommitMessage]): Unit = { writeSupport.commit(eppchId, messages) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 075c6b9362ba..3a0082536512 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -40,8 +40,7 @@ import org.apache.spark.sql.types._ * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ -class RateStreamProvider extends DataSourceV2 - with TableProvider with DataSourceRegister { +class RateStreamProvider extends TableProvider with DataSourceRegister { import RateStreamProvider._ override def getTable(options: DataSourceOptions): Table = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala index c3b24a8f65dd..8ac5bfc307aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala @@ -31,8 +31,7 @@ import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} -class TextSocketSourceProvider extends DataSourceV2 - with TableProvider with DataSourceRegister with Logging { +class TextSocketSourceProvider extends TableProvider with DataSourceRegister with Logging { private def checkParameters(params: DataSourceOptions): Unit = { logWarning("The socket source should not be used for production applications! " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index c50dc7bcb8da..3fc2cbe0fde5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -32,9 +32,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsStreamingWrite} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite, SupportsOutputMode} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -42,15 +42,31 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider - with MemorySinkBase with Logging { - - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - new MemoryStreamingWriteSupport(this, mode, schema) +class MemorySinkV2 extends SupportsStreamingWrite with MemorySinkBase with Logging { + + override def name(): String = "MemorySinkV2" + + override def schema(): StructType = StructType(Nil) + + override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + new WriteBuilder with SupportsOutputMode { + private var mode: OutputMode = _ + private var inputSchema: StructType = _ + + override def outputMode(mode: OutputMode): WriteBuilder = { + this.mode = mode + this + } + + override def withInputDataSchema(schema: StructType): WriteBuilder = { + this.inputSchema = schema + this + } + + override def buildForStreaming(): StreamingWrite = { + new MemoryStreamingWrite(MemorySinkV2.this, mode, inputSchema) + } + } } private case class AddedData(batchId: Long, data: Array[Row]) @@ -122,9 +138,9 @@ class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} -class MemoryStreamingWriteSupport( +class MemoryStreamingWrite( val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) - extends StreamingWriteSupport { + extends StreamingWrite { override def createStreamingWriterFactory: MemoryWriterFactory = { MemoryWriterFactory(outputMode, schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 3f941cc6e107..a1ab55a7185c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import org.apache.spark.annotation.Stable +import org.apache.spark.annotation.{Evolving, Stable} //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines all the filters that we can push down to the data sources. @@ -218,3 +218,27 @@ case class StringEndsWith(attribute: String, value: String) extends Filter { case class StringContains(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } + +/** + * A filter that always evaluates to `true`. + */ +@Evolving +case class AlwaysTrue() extends Filter { + override def references: Array[String] = Array.empty +} + +@Evolving +object AlwaysTrue extends AlwaysTrue { +} + +/** + * A filter that always evaluates to `false`. + */ +@Evolving +case class AlwaysFalse() extends Filter { + override def references: Array[String] = Array.empty +} + +@Evolving +object AlwaysFalse extends AlwaysFalse { +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 866681838af8..ef21caa3ac29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -173,7 +173,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo ds match { case provider: TableProvider => val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - ds = provider, conf = sparkSession.sessionState.conf) + source = provider, conf = sparkSession.sessionState.conf) val options = sessionOptions ++ extraOptions val dsOptions = new DataSourceOptions(options.asJava) val table = userSpecifiedSchema match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index ea596ba728c1..984199488fa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsStreamingWrite, TableProvider} /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -278,7 +278,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = ForeachWriteSupportProvider[T](foreachWriter, ds.exprEnc) + val sink = ForeachWriterTable[T](foreachWriter, ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), @@ -304,30 +304,29 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") - var options = extraOptions.toMap - val sink = ds.getConstructor().newInstance() match { - case w: StreamingWriteSupportProvider - if !disabledSources.contains(w.getClass.getCanonicalName) => - val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - w, df.sparkSession.sessionState.conf) - options = sessionOptions ++ extraOptions - w - case _ => - val ds = DataSource( - df.sparkSession, - className = source, - options = options, - partitionColumns = normalizedParCols.getOrElse(Nil)) - ds.createSink(outputMode) + val useV1Source = disabledSources.contains(cls.getCanonicalName) + + val sink = if (classOf[TableProvider].isAssignableFrom(cls) && !useV1Source) { + val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + source = provider, conf = df.sparkSession.sessionState.conf) + val options = sessionOptions ++ extraOptions + val dsOptions = new DataSourceOptions(options.asJava) + provider.getTable(dsOptions) match { + case s: SupportsStreamingWrite => s + case _ => createV1Sink() + } + } else { + createV1Sink() } df.sparkSession.sessionState.streamingQueryManager.startQuery( - options.get("queryName"), - options.get("checkpointLocation"), + extraOptions.get("queryName"), + extraOptions.get("checkpointLocation"), df, - options, + extraOptions.toMap, sink, outputMode, useTempCheckpointLocation = source == "console" || source == "noop", @@ -336,6 +335,15 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } } + private def createV1Sink(): BaseStreamingSink = { + val ds = DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + ds.createSink(outputMode) + } + /** * Sets the output of the streaming query to be processed using the provided writer object. * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 0bd8a9299ef4..a7fa800a49d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS -import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider +import org.apache.spark.sql.sources.v2.SupportsStreamingWrite import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -261,7 +261,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } (sink, trigger) match { - case (v2Sink: StreamingWriteSupportProvider, trigger: ContinuousTrigger) => + case (v2Sink: SupportsStreamingWrite, trigger: ContinuousTrigger) => if (operationCheckEnabled) { UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) } diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index a36b0cfa6ff1..914af589384d 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -9,6 +9,6 @@ org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly org.apache.spark.sql.streaming.sources.FakeReadBothModes org.apache.spark.sql.streaming.sources.FakeReadNeitherMode -org.apache.spark.sql.streaming.sources.FakeWriteSupportProvider +org.apache.spark.sql.streaming.sources.FakeWriteOnly org.apache.spark.sql.streaming.sources.FakeNoWrite org.apache.spark.sql.streaming.sources.FakeWriteSupportProviderV1Fallback diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 8c34e47314db..64c4aabd4cdc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} +import org.scalatest.exceptions.TestFailedException + import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.sql.catalyst.ScroogeLikeExample import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} @@ -67,6 +69,41 @@ class DatasetSuite extends QueryTest with SharedSQLContext { data: _*) } + test("toDS should compare map with byte array keys correctly") { + // Choose the order of arrays in such way, that sorting keys of different maps by _.toString + // will not incidentally put equal keys together. + val arrays = (1 to 5).map(_ => Array[Byte](0.toByte, 0.toByte)).sortBy(_.toString).toArray + arrays(0)(1) = 1.toByte + arrays(1)(1) = 2.toByte + arrays(2)(1) = 2.toByte + arrays(3)(1) = 1.toByte + + val mapA = Map(arrays(0) -> "one", arrays(2) -> "two") + val subsetOfA = Map(arrays(0) -> "one") + val equalToA = Map(arrays(1) -> "two", arrays(3) -> "one") + val notEqualToA1 = Map(arrays(1) -> "two", arrays(3) -> "not one") + val notEqualToA2 = Map(arrays(1) -> "two", arrays(4) -> "one") + + // Comparing map with itself + checkDataset(Seq(mapA).toDS(), mapA) + + // Comparing map with equivalent map + checkDataset(Seq(equalToA).toDS(), mapA) + checkDataset(Seq(mapA).toDS(), equalToA) + + // Comparing map with it's subset + intercept[TestFailedException](checkDataset(Seq(subsetOfA).toDS(), mapA)) + intercept[TestFailedException](checkDataset(Seq(mapA).toDS(), subsetOfA)) + + // Comparing map with another map differing by single value + intercept[TestFailedException](checkDataset(Seq(notEqualToA1).toDS(), mapA)) + intercept[TestFailedException](checkDataset(Seq(mapA).toDS(), notEqualToA1)) + + // Comparing map with another map differing by single key + intercept[TestFailedException](checkDataset(Seq(notEqualToA2).toDS(), mapA)) + intercept[TestFailedException](checkDataset(Seq(mapA).toDS(), notEqualToA2)) + } + test("toDS with RDD") { val ds = sparkContext.makeRDD(Seq("a", "b", "c"), 3).toDS() checkDataset( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index fc87b0462bc2..58522f7b1376 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -329,83 +329,97 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") { withTempDir { dir => val tempDir = new File(dir, "files").getCanonicalPath - // TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well. - withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") { - // write path - Seq("csv", "json", "parquet", "orc").foreach { format => - var msg = intercept[AnalysisException] { - sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.contains("Cannot save interval data type into external storage.")) - - msg = intercept[AnalysisException] { - spark.udf.register("testType", () => new IntervalData()) - sql("select testType()").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) + Seq(true).foreach { useV1 => + val useV1List = if (useV1) { + "orc" + } else { + "" } + def errorMessage(format: String, isWrite: Boolean): String = { + if (isWrite && (useV1 || format != "orc")) { + "cannot save interval data type into external storage." + } else { + s"$format data source does not support calendarinterval data type." + } + } + + withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) { + // write path + Seq("csv", "json", "parquet", "orc").foreach { format => + var msg = intercept[AnalysisException] { + sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, true))) + } - // read path - Seq("parquet", "csv").foreach { format => - var msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) - - msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) + // read path + Seq("parquet", "csv").foreach { format => + var msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false))) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false))) + } } } } } test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") { - // TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well. - withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "orc", - SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") { - withTempDir { dir => - val tempDir = new File(dir, "files").getCanonicalPath - - Seq("parquet", "csv", "orc").foreach { format => - // write path - var msg = intercept[AnalysisException] { - sql("select null").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) - - msg = intercept[AnalysisException] { - spark.udf.register("testType", () => new NullData()) - sql("select testType()").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) - - // read path - msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", NullType, true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) - - msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) + Seq(true).foreach { useV1 => + val useV1List = if (useV1) { + "orc" + } else { + "" + } + def errorMessage(format: String): String = { + s"$format data source does not support null data type." + } + withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> useV1List, + SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) { + withTempDir { dir => + val tempDir = new File(dir, "files").getCanonicalPath + + Seq("parquet", "csv", "orc").foreach { format => + // write path + var msg = intercept[AnalysisException] { + sql("select null").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(errorMessage(format))) + + msg = intercept[AnalysisException] { + spark.udf.register("testType", () => new NullData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(errorMessage(format))) + + // read path + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", NullType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(errorMessage(format))) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(errorMessage(format))) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index d83deb17a090..f8298c9da97e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -341,9 +341,9 @@ object QueryTest { case (a: Array[_], b: Array[_]) => a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} case (a: Map[_, _], b: Map[_, _]) => - val entries1 = a.iterator.toSeq.sortBy(_.toString()) - val entries2 = b.iterator.toSeq.sortBy(_.toString()) - compare(entries1, entries2) + a.size == b.size && a.keys.forall { aKey => + b.keys.find(bKey => compare(aKey, bKey)).exists(bKey => compare(a(aKey), b(bKey))) + } case (a: Iterable[_], b: Iterable[_]) => a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} case (a: Product, b: Product) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala index 4a695ac74c47..bc5a30e97fae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala @@ -232,5 +232,8 @@ class OrcPartitionDiscoverySuite extends OrcPartitionDiscoveryTest with SharedSQ class OrcV1PartitionDiscoverySuite extends OrcPartitionDiscoveryTest with SharedSQLContext { override protected def sparkConf: SparkConf = - super.sparkConf.set(SQLConf.USE_V1_SOURCE_READER_LIST, "orc") + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_READER_LIST, "orc") + .set(SQLConf.USE_V1_SOURCE_WRITER_LIST, "orc") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index c0d26011e791..9bf122766687 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -690,5 +690,8 @@ class OrcQuerySuite extends OrcQueryTest with SharedSQLContext { class OrcV1QuerySuite extends OrcQuerySuite { override protected def sparkConf: SparkConf = - super.sparkConf.set(SQLConf.USE_V1_SOURCE_READER_LIST, "orc") + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_READER_LIST, "orc") + .set(SQLConf.USE_V1_SOURCE_WRITER_LIST, "orc") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala index cf5bbb3fff70..5a1bf9b43756 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala @@ -31,7 +31,10 @@ import org.apache.spark.sql.internal.SQLConf class OrcV1FilterSuite extends OrcFilterSuite { override protected def sparkConf: SparkConf = - super.sparkConf.set(SQLConf.USE_V1_SOURCE_READER_LIST, "orc") + super + .sparkConf + .set(SQLConf.USE_V1_SOURCE_READER_LIST, "orc") + .set(SQLConf.USE_V1_SOURCE_WRITER_LIST, "orc") override def checkFilterPredicate( df: DataFrame, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 61857365ac98..e80437754051 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -43,9 +43,9 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("streaming writer") { val sink = new MemorySinkV2 - val writeSupport = new MemoryStreamingWriteSupport( + val write = new MemoryStreamingWrite( sink, OutputMode.Append(), new StructType().add("i", "int")) - writeSupport.commit(0, + write.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), @@ -53,7 +53,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - writeSupport.commit(19, + write.commit(19, Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 511fdfe5c23a..6b5c45e40ab0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -351,19 +351,21 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } - test("SPARK-25700: do not read schema when writing in other modes except append mode") { + test("SPARK-25700: do not read schema when writing in other modes except append and overwrite") { withTempPath { file => val cls = classOf[SimpleWriteOnlyDataSource] val path = file.getCanonicalPath val df = spark.range(5).select('id as 'i, -'id as 'j) // non-append mode should not throw exception, as they don't access schema. df.write.format(cls.getName).option("path", path).mode("error").save() - df.write.format(cls.getName).option("path", path).mode("overwrite").save() df.write.format(cls.getName).option("path", path).mode("ignore").save() - // append mode will access schema and should throw exception. + // append and overwrite modes will access the schema and should throw exception. intercept[SchemaReadAttemptException] { df.write.format(cls.getName).option("path", path).mode("append").save() } + intercept[SchemaReadAttemptException] { + df.write.format(cls.getName).option("path", path).mode("overwrite").save() + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala index f903c17923d0..0b1e3b5fb076 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala @@ -33,8 +33,8 @@ class DataSourceV2UtilsSuite extends SparkFunSuite { conf.setConfString(s"spark.sql.$keyPrefix.config.name", "false") conf.setConfString("spark.datasource.another.config.name", "123") conf.setConfString(s"spark.datasource.$keyPrefix.", "123") - val cs = classOf[DataSourceV2WithSessionConfig].getConstructor().newInstance() - val confs = DataSourceV2Utils.extractSessionConfigs(cs.asInstanceOf[DataSourceV2], conf) + val source = new DataSourceV2WithSessionConfig + val confs = DataSourceV2Utils.extractSessionConfigs(source, conf) assert(confs.size == 2) assert(confs.keySet.filter(_.startsWith("spark.datasource")).size == 0) assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index daca65fd1ad2..c56a54598cd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -38,8 +38,7 @@ import org.apache.spark.util.SerializableConfiguration * Each task writes data to `target/_temporary/uniqueId/$jobId-$partitionId-$attemptNumber`. * Each job moves files from `target/_temporary/uniqueId/` to `target`. */ -class SimpleWritableDataSource extends DataSourceV2 - with TableProvider with SessionConfigSupport { +class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { private val tableSchema = new StructType().add("i", "long").add("j", "long") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index d3d210c02e90..bad22590807a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, ContinuousStream, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{DataType, IntegerType, StructType} @@ -43,7 +43,7 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { override def beforeEach(): Unit = { super.beforeEach() epochEndpoint = EpochCoordinatorRef.create( - mock[StreamingWriteSupport], + mock[StreamingWrite], mock[ContinuousStream], mock[ContinuousExecution], coordinatorId, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index a0b56ec17f0b..f74285f4b0fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.LocalSparkSession import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.test.TestSparkSession class EpochCoordinatorSuite @@ -40,13 +40,13 @@ class EpochCoordinatorSuite private var epochCoordinator: RpcEndpointRef = _ - private var writeSupport: StreamingWriteSupport = _ + private var writeSupport: StreamingWrite = _ private var query: ContinuousExecution = _ private var orderVerifier: InOrder = _ override def beforeEach(): Unit = { val stream = mock[ContinuousStream] - writeSupport = mock[StreamingWriteSupport] + writeSupport = mock[StreamingWrite] query = mock[ContinuousExecution] orderVerifier = inOrder(writeSupport, query) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 62f166602941..c841793fdd4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.WriteBuilder import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -71,13 +71,10 @@ trait FakeContinuousReadTable extends Table with SupportsContinuousRead { override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new FakeScanBuilder } -trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider { - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - LastWriteOptions.options = options +trait FakeStreamingWriteTable extends Table with SupportsStreamingWrite { + override def name(): String = "fake" + override def schema(): StructType = StructType(Seq()) + override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { throw new IllegalStateException("fake sink - cannot actually write") } } @@ -129,20 +126,33 @@ class FakeReadNeitherMode extends DataSourceRegister with TableProvider { } } -class FakeWriteSupportProvider +class FakeWriteOnly extends DataSourceRegister - with FakeStreamingWriteSupportProvider + with TableProvider with SessionConfigSupport { override def shortName(): String = "fake-write-microbatch-continuous" override def keyPrefix: String = shortName() + + override def getTable(options: DataSourceOptions): Table = { + LastWriteOptions.options = options + new Table with FakeStreamingWriteTable { + override def name(): String = "fake" + override def schema(): StructType = StructType(Nil) + } + } } -class FakeNoWrite extends DataSourceRegister { +class FakeNoWrite extends DataSourceRegister with TableProvider { override def shortName(): String = "fake-write-neither-mode" + override def getTable(options: DataSourceOptions): Table = { + new Table { + override def name(): String = "fake" + override def schema(): StructType = StructType(Nil) + } + } } - case class FakeWriteV1FallbackException() extends Exception class FakeSink extends Sink { @@ -150,17 +160,24 @@ class FakeSink extends Sink { } class FakeWriteSupportProviderV1Fallback extends DataSourceRegister - with FakeStreamingWriteSupportProvider with StreamSinkProvider { + with TableProvider with StreamSinkProvider { override def createSink( - sqlContext: SQLContext, - parameters: Map[String, String], - partitionColumns: Seq[String], - outputMode: OutputMode): Sink = { + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { new FakeSink() } override def shortName(): String = "fake-write-v1-fallback" + + override def getTable(options: DataSourceOptions): Table = { + new Table with FakeStreamingWriteTable { + override def name(): String = "fake" + override def schema(): StructType = StructType(Nil) + } + } } object LastReadOptions { @@ -260,7 +277,7 @@ class StreamingDataSourceV2Suite extends StreamTest { testPositiveCaseWithQuery( "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) { v2Query => assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink - .isInstanceOf[FakeWriteSupportProviderV1Fallback]) + .isInstanceOf[Table]) } // Ensure we create a V1 sink with the config. Note the config is a comma separated @@ -319,19 +336,20 @@ class StreamingDataSourceV2Suite extends StreamTest { for ((read, write, trigger) <- cases) { testQuietly(s"stream with read format $read, write format $write, trigger $trigger") { - val table = DataSource.lookupDataSource(read, spark.sqlContext.conf).getConstructor() + val sourceTable = DataSource.lookupDataSource(read, spark.sqlContext.conf).getConstructor() + .newInstance().asInstanceOf[TableProvider].getTable(DataSourceOptions.empty()) + + val sinkTable = DataSource.lookupDataSource(write, spark.sqlContext.conf).getConstructor() .newInstance().asInstanceOf[TableProvider].getTable(DataSourceOptions.empty()) - val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf). - getConstructor().newInstance() - (table, writeSource, trigger) match { + (sourceTable, sinkTable, trigger) match { // Valid microbatch queries. - case (_: SupportsMicroBatchRead, _: StreamingWriteSupportProvider, t) + case (_: SupportsMicroBatchRead, _: SupportsStreamingWrite, t) if !t.isInstanceOf[ContinuousTrigger] => testPositiveCase(read, write, trigger) // Valid continuous queries. - case (_: SupportsContinuousRead, _: StreamingWriteSupportProvider, + case (_: SupportsContinuousRead, _: SupportsStreamingWrite, _: ContinuousTrigger) => testPositiveCase(read, write, trigger) @@ -342,12 +360,12 @@ class StreamingDataSourceV2Suite extends StreamTest { s"Data source $read does not support streamed reading") // Invalid - can't write - case (_, w, _) if !w.isInstanceOf[StreamingWriteSupportProvider] => + case (_, w, _) if !w.isInstanceOf[SupportsStreamingWrite] => testNegativeCase(read, write, trigger, s"Data source $write does not support streamed writing") // Invalid - trigger is continuous but reader is not - case (r, _: StreamingWriteSupportProvider, _: ContinuousTrigger) + case (r, _: SupportsStreamingWrite, _: ContinuousTrigger) if !r.isInstanceOf[SupportsContinuousRead] => testNegativeCase(read, write, trigger, s"Data source $read does not support continuous processing") diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTablesOperation.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTablesOperation.java index 1a7ca79163d7..2af17a662a29 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTablesOperation.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/operation/GetTablesOperation.java @@ -46,7 +46,7 @@ public class GetTablesOperation extends MetadataOperation { private final String schemaName; private final String tableName; private final List tableTypes = new ArrayList(); - private final RowSet rowSet; + protected final RowSet rowSet; private final TableTypeMapping tableTypeMapping; diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala new file mode 100644 index 000000000000..369650047b10 --- /dev/null +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkGetTablesOperation.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.util.{List => JList} + +import scala.collection.JavaConverters.seqAsJavaListConverter + +import org.apache.hadoop.hive.ql.security.authorization.plugin.HiveOperationType +import org.apache.hadoop.hive.ql.security.authorization.plugin.HivePrivilegeObjectUtils +import org.apache.hive.service.cli._ +import org.apache.hive.service.cli.operation.GetTablesOperation +import org.apache.hive.service.cli.session.HiveSession + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.catalog.CatalogTableType +import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ + +/** + * Spark's own GetTablesOperation + * + * @param sqlContext SQLContext to use + * @param parentSession a HiveSession from SessionManager + * @param catalogName catalog name. null if not applicable + * @param schemaName database name, null or a concrete database name + * @param tableName table name pattern + * @param tableTypes list of allowed table types, e.g. "TABLE", "VIEW" + */ +private[hive] class SparkGetTablesOperation( + sqlContext: SQLContext, + parentSession: HiveSession, + catalogName: String, + schemaName: String, + tableName: String, + tableTypes: JList[String]) + extends GetTablesOperation(parentSession, catalogName, schemaName, tableName, tableTypes) { + + if (tableTypes != null) { + this.tableTypes.addAll(tableTypes) + } + + override def runInternal(): Unit = { + setState(OperationState.RUNNING) + // Always use the latest class loader provided by executionHive's state. + val executionHiveClassLoader = sqlContext.sharedState.jarClassLoader + Thread.currentThread().setContextClassLoader(executionHiveClassLoader) + + val catalog = sqlContext.sessionState.catalog + val schemaPattern = convertSchemaPattern(schemaName) + val matchingDbs = catalog.listDatabases(schemaPattern) + + if (isAuthV2Enabled) { + val privObjs = + HivePrivilegeObjectUtils.getHivePrivDbObjects(seqAsJavaListConverter(matchingDbs).asJava) + val cmdStr = s"catalog : $catalogName, schemaPattern : $schemaName" + authorizeMetaGets(HiveOperationType.GET_TABLES, privObjs, cmdStr) + } + + val tablePattern = convertIdentifierPattern(tableName, true) + matchingDbs.foreach { dbName => + catalog.listTables(dbName, tablePattern).foreach { tableIdentifier => + val catalogTable = catalog.getTableMetadata(tableIdentifier) + val tableType = tableTypeString(catalogTable.tableType) + if (tableTypes == null || tableTypes.isEmpty || tableTypes.contains(tableType)) { + val rowData = Array[AnyRef]( + "", + catalogTable.database, + catalogTable.identifier.table, + tableType, + catalogTable.comment.getOrElse("")) + rowSet.addRow(rowData) + } + } + } + setState(OperationState.FINISHED) + } + + private def tableTypeString(tableType: CatalogTableType): String = tableType match { + case EXTERNAL | MANAGED => "TABLE" + case VIEW => "VIEW" + case t => + throw new IllegalArgumentException(s"Unknown table type is found at showCreateHiveTable: $t") + } +} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 85b6c7134755..7947d1785a8f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql.hive.thriftserver.server -import java.util.{Map => JMap} +import java.util.{List => JList, Map => JMap} import java.util.concurrent.ConcurrentHashMap import org.apache.hive.service.cli._ -import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, GetSchemasOperation, Operation, OperationManager} +import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, GetSchemasOperation, MetadataOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.hive.HiveUtils -import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation, SparkGetSchemasOperation} +import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation, SparkGetSchemasOperation, SparkGetTablesOperation} import org.apache.spark.sql.internal.SQLConf /** @@ -76,6 +76,22 @@ private[thriftserver] class SparkSQLOperationManager() operation } + override def newGetTablesOperation( + parentSession: HiveSession, + catalogName: String, + schemaName: String, + tableName: String, + tableTypes: JList[String]): MetadataOperation = synchronized { + val sqlContext = sessionToContexts.get(parentSession.getSessionHandle) + require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" + + " initialized or had already closed.") + val operation = new SparkGetTablesOperation(sqlContext, parentSession, + catalogName, schemaName, tableName, tableTypes) + handleToOperation.put(operation.getHandle, operation) + logDebug(s"Created GetTablesOperation with session=$parentSession.") + operation + } + def setConfMap(conf: SQLConf, confMap: java.util.Map[String, String]): Unit = { val iterator = confMap.entrySet().iterator() while (iterator.hasNext) { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index f9509aed4aaa..0f53fcd327f1 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -280,7 +280,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { var defaultV2: String = null var data: ArrayBuffer[Int] = null - withMultipleConnectionJdbcStatement("test_map")( + withMultipleConnectionJdbcStatement("test_map", "db1.test_map2")( // create table { statement => diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala index 9a997ae01df9..bf9982388d6b 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/SparkMetadataOperationSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.thriftserver -import java.util.Properties +import java.util.{Arrays => JArrays, List => JList, Properties} import org.apache.hive.jdbc.{HiveConnection, HiveQueryResultSet, Utils => JdbcUtils} import org.apache.hive.service.auth.PlainSaslHelper @@ -100,4 +100,89 @@ class SparkMetadataOperationSuite extends HiveThriftJdbcTest { } } } + + test("Spark's own GetTablesOperation(SparkGetTablesOperation)") { + def testGetTablesOperation( + schema: String, + tableNamePattern: String, + tableTypes: JList[String])(f: HiveQueryResultSet => Unit): Unit = { + val rawTransport = new TSocket("localhost", serverPort) + val connection = new HiveConnection(s"jdbc:hive2://localhost:$serverPort", new Properties) + val user = System.getProperty("user.name") + val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) + val client = new TCLIService.Client(new TBinaryProtocol(transport)) + transport.open() + + var rs: HiveQueryResultSet = null + + try { + val openResp = client.OpenSession(new TOpenSessionReq) + val sessHandle = openResp.getSessionHandle + + val getTableReq = new TGetTablesReq(sessHandle) + getTableReq.setSchemaName(schema) + getTableReq.setTableName(tableNamePattern) + getTableReq.setTableTypes(tableTypes) + + val getTableResp = client.GetTables(getTableReq) + + JdbcUtils.verifySuccess(getTableResp.getStatus) + + rs = new HiveQueryResultSet.Builder(connection) + .setClient(client) + .setSessionHandle(sessHandle) + .setStmtHandle(getTableResp.getOperationHandle) + .build() + + f(rs) + } finally { + rs.close() + connection.close() + transport.close() + rawTransport.close() + } + } + + def checkResult(tableNames: Seq[String], rs: HiveQueryResultSet): Unit = { + if (tableNames.nonEmpty) { + for (i <- tableNames.indices) { + assert(rs.next()) + assert(rs.getString("TABLE_NAME") === tableNames(i)) + } + } else { + assert(!rs.next()) + } + } + + withJdbcStatement("table1", "table2") { statement => + Seq( + "CREATE TABLE table1(key INT, val STRING)", + "CREATE TABLE table2(key INT, val STRING)", + "CREATE VIEW view1 AS SELECT * FROM table2").foreach(statement.execute) + + testGetTablesOperation("%", "%", null) { rs => + checkResult(Seq("table1", "table2", "view1"), rs) + } + + testGetTablesOperation("%", "table1", null) { rs => + checkResult(Seq("table1"), rs) + } + + testGetTablesOperation("%", "table_not_exist", null) { rs => + checkResult(Seq.empty, rs) + } + + testGetTablesOperation("%", "%", JArrays.asList("TABLE")) { rs => + checkResult(Seq("table1", "table2"), rs) + } + + testGetTablesOperation("%", "%", JArrays.asList("VIEW")) { rs => + checkResult(Seq("view1"), rs) + } + + testGetTablesOperation("%", "%", JArrays.asList("TABLE", "VIEW")) { rs => + checkResult(Seq("table1", "table2", "view1"), rs) + } + } + } } diff --git a/sql/hive/src/main/resources/META-INF/services/org.apache.spark.deploy.security.HadoopDelegationTokenProvider b/sql/hive/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider similarity index 100% rename from sql/hive/src/main/resources/META-INF/services/org.apache.spark.deploy.security.HadoopDelegationTokenProvider rename to sql/hive/src/main/resources/META-INF/services/org.apache.spark.security.HadoopDelegationTokenProvider diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/security/HiveDelegationTokenProvider.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/security/HiveDelegationTokenProvider.scala index c0c46187b13a..faee405d70cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/security/HiveDelegationTokenProvider.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/security/HiveDelegationTokenProvider.scala @@ -23,7 +23,6 @@ import java.security.PrivilegedExceptionAction import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.metadata.Hive @@ -33,9 +32,9 @@ import org.apache.hadoop.security.token.Token import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.deploy.security.HadoopDelegationTokenProvider import org.apache.spark.internal.Logging import org.apache.spark.internal.config.KEYTAB +import org.apache.spark.security.HadoopDelegationTokenProvider import org.apache.spark.util.Utils private[spark] class HiveDelegationTokenProvider diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index dd0e1bd0fe30..1dd60c6757d4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -206,7 +206,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.3.2", "2.4.0") + val testingVersions = Seq("2.3.3", "2.4.0") protected var spark: SparkSession = _ @@ -260,19 +260,10 @@ object PROCESS_TABLES extends QueryTest with SQLTestUtils { // SPARK-22356: overlapped columns between data and partition schema in data source tables val tbl_with_col_overlap = s"tbl_with_col_overlap_$index" - // For Spark 2.2.0 and 2.1.x, the behavior is different from Spark 2.0, 2.2.1, 2.3+ - if (testingVersions(index).startsWith("2.1") || testingVersions(index) == "2.2.0") { - spark.sql("msck repair table " + tbl_with_col_overlap) - assert(spark.table(tbl_with_col_overlap).columns === Array("i", "j", "p")) - checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) - assert(sql("desc " + tbl_with_col_overlap).select("col_name") - .as[String].collect().mkString(",").contains("i,j,p")) - } else { - assert(spark.table(tbl_with_col_overlap).columns === Array("i", "p", "j")) - checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) - assert(sql("desc " + tbl_with_col_overlap).select("col_name") - .as[String].collect().mkString(",").contains("i,p,j")) - } + assert(spark.table(tbl_with_col_overlap).columns === Array("i", "p", "j")) + checkAnswer(spark.table(tbl_with_col_overlap), Row(1, 1, 1) :: Row(1, 1, 1) :: Nil) + assert(sql("desc " + tbl_with_col_overlap).select("col_name") + .as[String].collect().mkString(",").contains("i,p,j")) } } }

timestamplongtimestamp
timestampType