From ceb58cae21411d90f08a4446254d26c7f991cf6d Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Thu, 20 Jun 2019 12:42:21 -0700 Subject: [PATCH 1/2] SPARK-27919: Add v2 session catalog. --- .../spark/sql/catalog/v2/LookupCatalog.scala | 103 ++- .../sql/catalyst/analysis/Analyzer.scala | 6 + .../analysis/UpdateAttributeNullability.scala | 2 +- .../plans/logical/basicLogicalOperators.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 5 + .../catalog/v2/LookupCatalogSuite.scala | 107 +++ .../datasources/DataSourceResolution.scala | 52 +- .../datasources/v2/V2SessionCatalog.scala | 255 +++++++ .../internal/BaseSessionStateBuilder.scala | 2 +- .../command/PlanResolutionSuite.scala | 123 +++- .../v2/V2SessionCatalogSuite.scala | 683 ++++++++++++++++++ .../sql/sources/v2/DataSourceV2SQLSuite.scala | 94 ++- .../sql/hive/HiveSessionStateBuilder.scala | 2 +- 13 files changed, 1349 insertions(+), 87 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala index 5464a7496d23..5f7ee30cdab7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala @@ -17,36 +17,91 @@ package org.apache.spark.sql.catalog.v2 +import scala.util.control.NonFatal + import org.apache.spark.annotation.Experimental +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.TableIdentifier /** * A trait to encapsulate catalog lookup function and helpful extractors. */ @Experimental -trait LookupCatalog { +trait LookupCatalog extends Logging { + + import LookupCatalog._ + protected def defaultCatalogName: Option[String] = None protected def lookupCatalog(name: String): CatalogPlugin - type CatalogObjectIdentifier = (Option[CatalogPlugin], Identifier) + /** + * Returns the default catalog. When set, this catalog is used for all identifiers that do not + * set a specific catalog. When this is None, the session catalog is responsible for the + * identifier. + * + * If this is None and a table's provider (source) is a v2 provider, the v2 session catalog will + * be used. + */ + def defaultCatalog: Option[CatalogPlugin] = { + try { + defaultCatalogName.map(lookupCatalog) + } catch { + case NonFatal(e) => + logError(s"Cannot load default v2 catalog: ${defaultCatalogName.get}", e) + None + } + } /** - * Extract catalog plugin and identifier from a multi-part identifier. + * This catalog is a v2 catalog that delegates to the v1 session catalog. it is used when the + * session catalog is responsible for an identifier, but the source requires the v2 catalog API. + * This happens when the source implementation extends the v2 TableProvider API and is not listed + * in the fallback configuration, spark.sql.sources.write.useV1SourceList */ - object CatalogObjectIdentifier { - def unapply(parts: Seq[String]): Some[CatalogObjectIdentifier] = parts match { - case Seq(name) => - Some((None, Identifier.of(Array.empty, name))) + def sessionCatalog: Option[CatalogPlugin] = { + try { + Some(lookupCatalog(SESSION_CATALOG_NAME)) + } catch { + case NonFatal(e) => + logError("Cannot load v2 session catalog", e) + None + } + } + + /** + * Extract catalog plugin and remaining identifier names. + * + * This does not substitute the default catalog if no catalog is set in the identifier. + */ + private object CatalogAndIdentifier { + def unapply(parts: Seq[String]): Some[(Option[CatalogPlugin], Seq[String])] = parts match { + case Seq(_) => + Some((None, parts)) case Seq(catalogName, tail @ _*) => try { - Some((Some(lookupCatalog(catalogName)), Identifier.of(tail.init.toArray, tail.last))) + Some((Some(lookupCatalog(catalogName)), tail)) } catch { case _: CatalogNotFoundException => - Some((None, Identifier.of(parts.init.toArray, parts.last))) + Some((None, parts)) } } } + type CatalogObjectIdentifier = (Option[CatalogPlugin], Identifier) + + /** + * Extract catalog and identifier from a multi-part identifier with the default catalog if needed. + */ + object CatalogObjectIdentifier { + def unapply(parts: Seq[String]): Some[CatalogObjectIdentifier] = parts match { + case CatalogAndIdentifier(maybeCatalog, nameParts) => + Some(( + maybeCatalog.orElse(defaultCatalog), + Identifier.of(nameParts.init.toArray, nameParts.last) + )) + } + } + /** * Extract legacy table identifier from a multi-part identifier. * @@ -54,12 +109,12 @@ trait LookupCatalog { */ object AsTableIdentifier { def unapply(parts: Seq[String]): Option[TableIdentifier] = parts match { - case CatalogObjectIdentifier(None, ident) => - ident.namespace match { - case Array() => - Some(TableIdentifier(ident.name)) - case Array(database) => - Some(TableIdentifier(ident.name, Some(database))) + case CatalogAndIdentifier(None, names) if defaultCatalog.isEmpty => + names match { + case Seq(name) => + Some(TableIdentifier(name)) + case Seq(database, name) => + Some(TableIdentifier(name, Some(database))) case _ => None } @@ -67,4 +122,22 @@ trait LookupCatalog { None } } + + /** + * For temp views, extract a table identifier from a multi-part identifier if it has no catalog. + */ + object AsTemporaryViewIdentifier { + def unapply(parts: Seq[String]): Option[TableIdentifier] = parts match { + case CatalogAndIdentifier(None, Seq(table)) => + Some(TableIdentifier(table)) + case CatalogAndIdentifier(None, Seq(database, table)) => + Some(TableIdentifier(table, Some(database))) + case _ => + None + } + } +} + +object LookupCatalog { + val SESSION_CATALOG_NAME: String = "session" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5d37e909f80a..c2e259b3960d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -104,6 +104,8 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } + override protected def defaultCatalogName: Option[String] = conf.defaultV2Catalog + override protected def lookupCatalog(name: String): CatalogPlugin = throw new CatalogNotFoundException("No catalog lookup function") @@ -713,6 +715,10 @@ class Analyzer( u } + case u @ UnresolvedRelation(AsTemporaryViewIdentifier(ident)) + if catalog.isTemporaryTable(ident) => + resolveRelation(lookupTableFromCatalog(ident, u, AnalysisContext.get.defaultDatabase)) + // The view's child should be a logical plan parsed from the `desc.viewText`, the variable // `viewText` should be defined, or else we throw an error on the generation of the View // operator. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateAttributeNullability.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateAttributeNullability.scala index 2210e180bc75..3eae34da7e50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateAttributeNullability.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UpdateAttributeNullability.scala @@ -37,7 +37,7 @@ object UpdateAttributeNullability extends Rule[LogicalPlan] { case p if !p.resolved => p // Skip leaf node, as it has no child and no need to update nullability. case p: LeafNode => p - case p: LogicalPlan => + case p: LogicalPlan if p.childrenResolved => val nullabilities = p.children.flatMap(c => c.output).groupBy(_.exprId).map { // If there are multiple Attributes having the same ExprId, we need to resolve // the conflict of nullable field. We do not really expect this to happen. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 273a4389ba40..8c7c0726876c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -432,7 +432,7 @@ case class CreateTableAsSelect( override def children: Seq[LogicalPlan] = Seq(query) - override lazy val resolved: Boolean = { + override lazy val resolved: Boolean = childrenResolved && { // the table schema is created from the query schema, so the only resolution needed is to check // that the columns referenced by the table's partitioning exist in the query schema val references = partitioning.flatMap(_.references).toSet diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index af67632706df..e2636d27e353 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1833,6 +1833,11 @@ object SQLConf { .stringConf .createOptional + val V2_SESSION_CATALOG = buildConf("spark.sql.catalog.session") + .doc("Name of the default v2 catalog, used when a catalog is not identified in queries") + .stringConf + .createWithDefault("org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog") + val LEGACY_LOOSE_UPCAST = buildConf("spark.sql.legacy.looseUpcast") .doc("When true, the upcast will be loose and allows string to atomic types.") .booleanConf diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala index 783751ff7986..56d785e40dc1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala @@ -85,4 +85,111 @@ class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside { } } } + + test("temporary table identifier") { + Seq( + ("tbl", TableIdentifier("tbl")), + ("db.tbl", TableIdentifier("tbl", Some("db"))), + ("`db.tbl`", TableIdentifier("db.tbl")), + ("parquet.`file:/tmp/db.tbl`", TableIdentifier("file:/tmp/db.tbl", Some("parquet"))), + ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", + TableIdentifier("s3://buck/tmp/abc.json", Some("org.apache.spark.sql.json")))).foreach { + case (sqlIdent: String, expectedTableIdent: TableIdentifier) => + // when there is no catalog and the namespace has one part, the rule should match + inside(parseMultipartIdentifier(sqlIdent)) { + case AsTemporaryViewIdentifier(ident) => + ident shouldEqual expectedTableIdent + } + } + + Seq("prod.func", "prod.db.tbl", "test.db.tbl", "ns1.ns2.tbl", "test.ns1.ns2.ns3.tbl") + .foreach { sqlIdent => + inside(parseMultipartIdentifier(sqlIdent)) { + case AsTemporaryViewIdentifier(_) => + fail("AsTemporaryTableIdentifier should not match when " + + "the catalog is set or the namespace has multiple parts") + case _ => + // expected + } + } + } +} + +class LookupCatalogWithDefaultSuite extends SparkFunSuite with LookupCatalog with Inside { + import CatalystSqlParser._ + + private val catalogs = Seq("prod", "test").map(x => x -> new TestCatalogPlugin(x)).toMap + + override def defaultCatalogName: Option[String] = Some("prod") + + override def lookupCatalog(name: String): CatalogPlugin = + catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found")) + + test("catalog object identifier") { + Seq( + ("tbl", catalogs.get("prod"), Seq.empty, "tbl"), + ("db.tbl", catalogs.get("prod"), Seq("db"), "tbl"), + ("prod.func", catalogs.get("prod"), Seq.empty, "func"), + ("ns1.ns2.tbl", catalogs.get("prod"), Seq("ns1", "ns2"), "tbl"), + ("prod.db.tbl", catalogs.get("prod"), Seq("db"), "tbl"), + ("test.db.tbl", catalogs.get("test"), Seq("db"), "tbl"), + ("test.ns1.ns2.ns3.tbl", catalogs.get("test"), Seq("ns1", "ns2", "ns3"), "tbl"), + ("`db.tbl`", catalogs.get("prod"), Seq.empty, "db.tbl"), + ("parquet.`file:/tmp/db.tbl`", catalogs.get("prod"), Seq("parquet"), "file:/tmp/db.tbl"), + ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", catalogs.get("prod"), + Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach { + case (sql, expectedCatalog, namespace, name) => + inside(parseMultipartIdentifier(sql)) { + case CatalogObjectIdentifier(catalog, ident) => + catalog shouldEqual expectedCatalog + ident shouldEqual Identifier.of(namespace.toArray, name) + } + } + } + + test("table identifier") { + Seq( + "tbl", + "db.tbl", + "`db.tbl`", + "parquet.`file:/tmp/db.tbl`", + "`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", + "prod.func", + "prod.db.tbl", + "ns1.ns2.tbl").foreach { sql => + parseMultipartIdentifier(sql) match { + case AsTableIdentifier(_) => + fail(s"$sql should not be resolved as TableIdentifier") + case _ => + } + } + } + + test("temporary table identifier") { + Seq( + ("tbl", TableIdentifier("tbl")), + ("db.tbl", TableIdentifier("tbl", Some("db"))), + ("`db.tbl`", TableIdentifier("db.tbl")), + ("parquet.`file:/tmp/db.tbl`", TableIdentifier("file:/tmp/db.tbl", Some("parquet"))), + ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", + TableIdentifier("s3://buck/tmp/abc.json", Some("org.apache.spark.sql.json")))).foreach { + case (sqlIdent: String, expectedTableIdent: TableIdentifier) => + // when there is no catalog and the namespace has one part, the rule should match + inside(parseMultipartIdentifier(sqlIdent)) { + case AsTemporaryViewIdentifier(ident) => + ident shouldEqual expectedTableIdent + } + } + + Seq("prod.func", "prod.db.tbl", "test.db.tbl", "ns1.ns2.tbl", "test.ns1.ns2.ns3.tbl") + .foreach { sqlIdent => + inside(parseMultipartIdentifier(sqlIdent)) { + case AsTemporaryViewIdentifier(_) => + fail("AsTemporaryTableIdentifier should not match when " + + "the catalog is set or the namespace has multiple parts") + case _ => + // expected + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index 26f7230c8fe8..1b7bb169b36f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -26,31 +26,32 @@ import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier, LookupCatalog import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.CastSupport -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils, UnresolvedCatalogRelation} import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan} import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, QualifiedColType} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, DropTableCommand} +import org.apache.spark.sql.execution.datasources.v2.{CatalogTableAsV2, DataSourceV2Relation} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.TableProvider import org.apache.spark.sql.types.{HIVE_TYPE_STRING, HiveStringType, MetadataBuilder, StructField, StructType} case class DataSourceResolution( conf: SQLConf, - findCatalog: String => CatalogPlugin) - extends Rule[LogicalPlan] with CastSupport with LookupCatalog { + lookup: LookupCatalog) + extends Rule[LogicalPlan] with CastSupport { import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + import lookup._ - override protected def lookupCatalog(name: String): CatalogPlugin = findCatalog(name) - - def defaultCatalog: Option[CatalogPlugin] = conf.defaultV2Catalog.map(findCatalog) + lazy val v2SessionCatalog: CatalogPlugin = lookup.sessionCatalog + .getOrElse(throw new AnalysisException("No v2 session catalog implementation is available")) override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case CreateTableStatement( AsTableIdentifier(table), schema, partitionCols, bucketSpec, properties, V1WriteProvider(provider), options, location, comment, ifNotExists) => - + // the source is v1, the identifier has no catalog, and there is no default v2 catalog val tableDesc = buildCatalogTable(table, schema, partitionCols, bucketSpec, properties, provider, options, location, comment, ifNotExists) val mode = if (ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists @@ -58,18 +59,22 @@ case class DataSourceResolution( CreateTable(tableDesc, mode, None) case create: CreateTableStatement => - // the provider was not a v1 source, convert to a v2 plan + // the provider was not a v1 source or a v2 catalog is the default, convert to a v2 plan val CatalogObjectIdentifier(maybeCatalog, identifier) = create.tableName - val catalog = maybeCatalog.orElse(defaultCatalog) - .getOrElse(throw new AnalysisException( - s"No catalog specified for table ${identifier.quoted} and no default catalog is set")) - .asTableCatalog - convertCreateTable(catalog, identifier, create) + maybeCatalog match { + case Some(catalog) => + // the identifier had a catalog, or there is a default v2 catalog + convertCreateTable(catalog.asTableCatalog, identifier, create) + case _ => + // the identifier had no catalog and no default catalog is set, but the source is v2. + // use the v2 session catalog, which delegates to the global v1 session catalog + convertCreateTable(v2SessionCatalog.asTableCatalog, identifier, create) + } case CreateTableAsSelectStatement( AsTableIdentifier(table), query, partitionCols, bucketSpec, properties, V1WriteProvider(provider), options, location, comment, ifNotExists) => - + // the source is v1, the identifier has no catalog, and there is no default v2 catalog val tableDesc = buildCatalogTable(table, new StructType, partitionCols, bucketSpec, properties, provider, options, location, comment, ifNotExists) val mode = if (ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists @@ -77,13 +82,17 @@ case class DataSourceResolution( CreateTable(tableDesc, mode, Some(query)) case create: CreateTableAsSelectStatement => - // the provider was not a v1 source, convert to a v2 plan + // the provider was not a v1 source or a v2 catalog is the default, convert to a v2 plan val CatalogObjectIdentifier(maybeCatalog, identifier) = create.tableName - val catalog = maybeCatalog.orElse(defaultCatalog) - .getOrElse(throw new AnalysisException( - s"No catalog specified for table ${identifier.quoted} and no default catalog is set")) - .asTableCatalog - convertCTAS(catalog, identifier, create) + maybeCatalog match { + case Some(catalog) => + // the identifier had a catalog, or there is a default v2 catalog + convertCTAS(catalog.asTableCatalog, identifier, create) + case _ => + // the identifier had no catalog and no default catalog is set, but the source is v2. + // use the v2 session catalog, which delegates to the global v1 session catalog + convertCTAS(v2SessionCatalog.asTableCatalog, identifier, create) + } case DropTableStatement(CatalogObjectIdentifier(Some(catalog), ident), ifExists, _) => DropTable(catalog.asTableCatalog, ident, ifExists) @@ -118,6 +127,9 @@ case class DataSourceResolution( if newColumns.forall(_.name.size == 1) => // only top-level adds are supported using AlterTableAddColumnsCommand AlterTableAddColumnsCommand(table, newColumns.map(convertToStructField)) + + case DataSourceV2Relation(CatalogTableAsV2(catalogTable), _, _) => + UnresolvedCatalogRelation(catalogTable) } object V1WriteProvider { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala new file mode 100644 index 000000000000..4cd0346b57e7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util +import java.util.Locale + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog, TableChange} +import org.apache.spark.sql.catalog.v2.expressions.{BucketTransform, FieldReference, IdentityTransform, LogicalExpressions, Transform} +import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils, SessionCatalog} +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.internal.SessionState +import org.apache.spark.sql.sources.v2.{Table, TableCapability} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * A [[TableCatalog]] that translates calls to the v1 SessionCatalog. + */ +class V2SessionCatalog(sessionState: SessionState) extends TableCatalog { + def this() = { + this(SparkSession.active.sessionState) + } + + private lazy val catalog: SessionCatalog = sessionState.catalog + + private var _name: String = _ + + override def name: String = _name + + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = { + this._name = name + } + + override def listTables(namespace: Array[String]): Array[Identifier] = { + namespace match { + case Array(db) => + catalog.listTables(db).map(ident => Identifier.of(Array(db), ident.table)).toArray + case _ => + throw new NoSuchNamespaceException(namespace) + } + } + + override def loadTable(ident: Identifier): Table = { + val catalogTable = try { + catalog.getTableMetadata(ident.asTableIdentifier) + } catch { + case _: NoSuchTableException => + throw new NoSuchTableException(ident) + } + + CatalogTableAsV2(catalogTable) + } + + override def invalidateTable(ident: Identifier): Unit = { + catalog.refreshTable(ident.asTableIdentifier) + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + + val (partitionColumns, maybeBucketSpec) = V2SessionCatalog.convertTransforms(partitions) + val provider = properties.getOrDefault("provider", sessionState.conf.defaultDataSourceName) + val tableProperties = properties.asScala + val location = Option(properties.get("location")) + val storage = DataSource.buildStorageFormatFromOptions(tableProperties.toMap) + .copy(locationUri = location.map(CatalogUtils.stringToURI)) + + val tableDesc = CatalogTable( + identifier = ident.asTableIdentifier, + tableType = CatalogTableType.MANAGED, + storage = storage, + schema = schema, + provider = Some(provider), + partitionColumnNames = partitionColumns, + bucketSpec = maybeBucketSpec, + properties = tableProperties.toMap, + tracksPartitionsInCatalog = sessionState.conf.manageFilesourcePartitions, + comment = Option(properties.get("comment"))) + + try { + catalog.createTable(tableDesc, ignoreIfExists = false) + } catch { + case _: TableAlreadyExistsException => + throw new TableAlreadyExistsException(ident) + } + + loadTable(ident) + } + + override def alterTable( + ident: Identifier, + changes: TableChange*): Table = { + val catalogTable = try { + catalog.getTableMetadata(ident.asTableIdentifier) + } catch { + case _: NoSuchTableException => + throw new NoSuchTableException(ident) + } + + val properties = CatalogV2Util.applyPropertiesChanges(catalogTable.properties, changes) + val schema = CatalogV2Util.applySchemaChanges(catalogTable.schema, changes) + + try { + catalog.alterTable(catalogTable.copy(properties = properties, schema = schema)) + } catch { + case _: NoSuchTableException => + throw new NoSuchTableException(ident) + } + + loadTable(ident) + } + + override def dropTable(ident: Identifier): Boolean = { + try { + if (loadTable(ident) != null) { + catalog.dropTable( + ident.asTableIdentifier, + ignoreIfNotExists = true, + purge = true /* skip HDFS trash */) + true + } else { + false + } + } catch { + case _: NoSuchTableException => + false + } + } + + implicit class TableIdentifierHelper(ident: Identifier) { + def asTableIdentifier: TableIdentifier = { + ident.namespace match { + case Array(db) => + TableIdentifier(ident.name, Some(db)) + case Array() => + TableIdentifier(ident.name, Some(catalog.getCurrentDatabase)) + case _ => + throw new NoSuchTableException(ident) + } + } + } + + override def toString: String = s"V2SessionCatalog($name)" +} + +/** + * An implementation of catalog v2 [[Table]] to expose v1 table metadata. + */ +case class CatalogTableAsV2(v1Table: CatalogTable) extends Table { + implicit class IdentifierHelper(identifier: TableIdentifier) { + def quoted: String = { + identifier.database match { + case Some(db) => + Seq(db, identifier.table).map(quote).mkString(".") + case _ => + quote(identifier.table) + + } + } + + private def quote(part: String): String = { + if (part.contains(".") || part.contains("`")) { + s"`${part.replace("`", "``")}`" + } else { + part + } + } + } + + def catalogTable: CatalogTable = v1Table + + lazy val options: Map[String, String] = { + v1Table.storage.locationUri match { + case Some(uri) => + v1Table.storage.properties + ("path" -> uri.toString) + case _ => + v1Table.storage.properties + } + } + + override lazy val properties: util.Map[String, String] = v1Table.properties.asJava + + override lazy val schema: StructType = v1Table.schema + + override lazy val partitioning: Array[Transform] = { + val partitions = new mutable.ArrayBuffer[Transform]() + + v1Table.partitionColumnNames.foreach { col => + partitions += LogicalExpressions.identity(col) + } + + v1Table.bucketSpec.foreach { spec => + partitions += LogicalExpressions.bucket(spec.numBuckets, spec.bucketColumnNames: _*) + } + + partitions.toArray + } + + override def name: String = v1Table.identifier.quoted + + override def capabilities: util.Set[TableCapability] = new util.HashSet[TableCapability]() + + override def toString: String = s"CatalogTableAsV2($name)" +} + +private[sql] object V2SessionCatalog { + /** + * Convert v2 Transforms to v1 partition columns and an optional bucket spec. + */ + private def convertTransforms(partitions: Seq[Transform]): (Seq[String], Option[BucketSpec]) = { + val identityCols = new mutable.ArrayBuffer[String] + var bucketSpec = Option.empty[BucketSpec] + + partitions.map { + case IdentityTransform(FieldReference(Seq(col))) => + identityCols += col + + case BucketTransform(numBuckets, FieldReference(Seq(col))) => + bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, Nil)) + + case transform => + throw new UnsupportedOperationException( + s"SessionCatalog does not support partition transform: $transform") + } + + (identityCols, bucketSpec) + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 8dc30eaa3a31..b05a5dfea3ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -170,7 +170,7 @@ abstract class BaseSessionStateBuilder( new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: - DataSourceResolution(conf, session.catalog(_)) +: + DataSourceResolution(conf, this) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 727160dafc5d..7df0dabd67f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -21,7 +21,7 @@ import java.net.URI import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode} -import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, Identifier, TableCatalog, TestTableCatalog} +import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, Identifier, LookupCatalog, TableCatalog, TestTableCatalog} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} @@ -43,17 +43,43 @@ class PlanResolutionSuite extends AnalysisTest { newCatalog } - private val lookupCatalog: String => CatalogPlugin = { - case "testcat" => - testCat - case name => - throw new CatalogNotFoundException(s"No such catalog: $name") + private val v2SessionCatalog = { + val newCatalog = new TestTableCatalog + newCatalog.initialize("session", CaseInsensitiveStringMap.empty()) + newCatalog + } + + private val lookupWithDefault: LookupCatalog = new LookupCatalog { + override protected def defaultCatalogName: Option[String] = Some("testcat") + + override protected def lookupCatalog(name: String): CatalogPlugin = name match { + case "testcat" => + testCat + case "session" => + v2SessionCatalog + case _ => + throw new CatalogNotFoundException(s"No such catalog: $name") + } + } + + private val lookupWithoutDefault: LookupCatalog = new LookupCatalog { + override protected def defaultCatalogName: Option[String] = None + + override protected def lookupCatalog(name: String): CatalogPlugin = name match { + case "testcat" => + testCat + case "session" => + v2SessionCatalog + case _ => + throw new CatalogNotFoundException(s"No such catalog: $name") + } } - def parseAndResolve(query: String): LogicalPlan = { + def parseAndResolve(query: String, withDefault: Boolean = false): LogicalPlan = { val newConf = conf.copy() newConf.setConfString("spark.sql.default.catalog", "testcat") - DataSourceResolution(newConf, lookupCatalog).apply(parsePlan(query)) + DataSourceResolution(newConf, if (withDefault) lookupWithDefault else lookupWithoutDefault) + .apply(parsePlan(query)) } private def parseResolveCompare(query: String, expected: LogicalPlan): Unit = @@ -338,7 +364,46 @@ class PlanResolutionSuite extends AnalysisTest { } } - test("Test v2 CreateTable with data source v2 provider") { + test("Test v2 CreateTable with default catalog") { + val sql = + s""" + |CREATE TABLE IF NOT EXISTS mydb.table_name ( + | id bigint, + | description string, + | point struct) + |USING parquet + |COMMENT 'table comment' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |OPTIONS (path 's3://bucket/path/to/data', other 20) + """.stripMargin + + val expectedProperties = Map( + "p1" -> "v1", + "p2" -> "v2", + "other" -> "20", + "provider" -> "parquet", + "location" -> "s3://bucket/path/to/data", + "comment" -> "table comment") + + parseAndResolve(sql, withDefault = true) match { + case create: CreateV2Table => + assert(create.catalog.name == "testcat") + assert(create.tableName == Identifier.of(Array("mydb"), "table_name")) + assert(create.tableSchema == new StructType() + .add("id", LongType) + .add("description", StringType) + .add("point", new StructType().add("x", DoubleType).add("y", DoubleType))) + assert(create.partitioning.isEmpty) + assert(create.properties == expectedProperties) + assert(create.ignoreIfExists) + + case other => + fail(s"Expected to parse ${classOf[CreateV2Table].getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("Test v2 CreateTable with data source v2 provider and no default") { val sql = s""" |CREATE TABLE IF NOT EXISTS mydb.page_view ( @@ -360,7 +425,7 @@ class PlanResolutionSuite extends AnalysisTest { parseAndResolve(sql) match { case create: CreateV2Table => - assert(create.catalog.name == "testcat") + assert(create.catalog.name == "session") assert(create.tableName == Identifier.of(Array("mydb"), "page_view")) assert(create.tableSchema == new StructType() .add("id", LongType) @@ -410,7 +475,41 @@ class PlanResolutionSuite extends AnalysisTest { } } - test("Test v2 CTAS with data source v2 provider") { + test("Test v2 CTAS with default catalog") { + val sql = + s""" + |CREATE TABLE IF NOT EXISTS mydb.table_name + |USING parquet + |COMMENT 'table comment' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |OPTIONS (path 's3://bucket/path/to/data', other 20) + |AS SELECT * FROM src + """.stripMargin + + val expectedProperties = Map( + "p1" -> "v1", + "p2" -> "v2", + "other" -> "20", + "provider" -> "parquet", + "location" -> "s3://bucket/path/to/data", + "comment" -> "table comment") + + parseAndResolve(sql, withDefault = true) match { + case ctas: CreateTableAsSelect => + assert(ctas.catalog.name == "testcat") + assert(ctas.tableName == Identifier.of(Array("mydb"), "table_name")) + assert(ctas.properties == expectedProperties) + assert(ctas.writeOptions == Map("other" -> "20")) + assert(ctas.partitioning.isEmpty) + assert(ctas.ignoreIfExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableAsSelect].getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("Test v2 CTAS with data source v2 provider and no default") { val sql = s""" |CREATE TABLE IF NOT EXISTS mydb.page_view @@ -430,7 +529,7 @@ class PlanResolutionSuite extends AnalysisTest { parseAndResolve(sql) match { case ctas: CreateTableAsSelect => - assert(ctas.catalog.name == "testcat") + assert(ctas.catalog.name == "session") assert(ctas.tableName == Identifier.of(Array("mydb"), "page_view")) assert(ctas.properties == expectedProperties) assert(ctas.writeOptions.isEmpty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala new file mode 100644 index 000000000000..3822882cc91c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala @@ -0,0 +1,683 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util +import java.util.Collections + +import scala.collection.JavaConverters._ + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalog.v2.{Catalogs, Identifier, TableCatalog, TableChange} +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class V2SessionCatalogSuite + extends SparkFunSuite with SharedSQLContext with BeforeAndAfter with BeforeAndAfterAll { + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + + private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String] + private val schema: StructType = new StructType() + .add("id", IntegerType) + .add("data", StringType) + + override protected def beforeAll(): Unit = { + super.beforeAll() + spark.sql("""CREATE DATABASE IF NOT EXISTS db""") + spark.sql("""CREATE DATABASE IF NOT EXISTS ns""") + spark.sql("""CREATE DATABASE IF NOT EXISTS ns2""") + } + + override protected def afterAll(): Unit = { + spark.sql("""DROP TABLE IF EXISTS db.test_table""") + spark.sql("""DROP DATABASE IF EXISTS db""") + spark.sql("""DROP DATABASE IF EXISTS ns""") + spark.sql("""DROP DATABASE IF EXISTS ns2""") + super.afterAll() + } + + after { + newCatalog().dropTable(testIdent) + } + + private def newCatalog(): TableCatalog = { + val newCatalog = new V2SessionCatalog(spark.sessionState) + newCatalog.initialize("test", CaseInsensitiveStringMap.empty()) + newCatalog + } + + private val testIdent = Identifier.of(Array("db"), "test_table") + + test("Catalogs can load the catalog") { + val catalog = newCatalog() + + val conf = new SQLConf + conf.setConfString("spark.sql.catalog.test", catalog.getClass.getName) + + val loaded = Catalogs.load("test", conf) + assert(loaded.getClass == catalog.getClass) + } + + test("listTables") { + val catalog = newCatalog() + val ident1 = Identifier.of(Array("ns"), "test_table_1") + val ident2 = Identifier.of(Array("ns"), "test_table_2") + val ident3 = Identifier.of(Array("ns2"), "test_table_1") + + assert(catalog.listTables(Array("ns")).isEmpty) + + catalog.createTable(ident1, schema, Array.empty, emptyProps) + + assert(catalog.listTables(Array("ns")).toSet == Set(ident1)) + assert(catalog.listTables(Array("ns2")).isEmpty) + + catalog.createTable(ident3, schema, Array.empty, emptyProps) + catalog.createTable(ident2, schema, Array.empty, emptyProps) + + assert(catalog.listTables(Array("ns")).toSet == Set(ident1, ident2)) + assert(catalog.listTables(Array("ns2")).toSet == Set(ident3)) + + catalog.dropTable(ident1) + + assert(catalog.listTables(Array("ns")).toSet == Set(ident2)) + + catalog.dropTable(ident2) + + assert(catalog.listTables(Array("ns")).isEmpty) + assert(catalog.listTables(Array("ns2")).toSet == Set(ident3)) + + catalog.dropTable(ident3) + } + + test("createTable") { + val catalog = newCatalog() + + assert(!catalog.tableExists(testIdent)) + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name) + assert(parsed == Seq("db", "test_table")) + assert(table.schema == schema) + assert(table.properties.asScala == Map()) + + assert(catalog.tableExists(testIdent)) + } + + test("createTable: with properties") { + val catalog = newCatalog() + + val properties = new util.HashMap[String, String]() + properties.put("property", "value") + + assert(!catalog.tableExists(testIdent)) + + val table = catalog.createTable(testIdent, schema, Array.empty, properties) + + val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name) + assert(parsed == Seq("db", "test_table")) + assert(table.schema == schema) + assert(table.properties == properties) + + assert(catalog.tableExists(testIdent)) + } + + test("createTable: table already exists") { + val catalog = newCatalog() + + assert(!catalog.tableExists(testIdent)) + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + val exc = intercept[TableAlreadyExistsException] { + catalog.createTable(testIdent, schema, Array.empty, emptyProps) + } + + assert(exc.message.contains(table.name())) + assert(exc.message.contains("already exists")) + + assert(catalog.tableExists(testIdent)) + } + + test("tableExists") { + val catalog = newCatalog() + + assert(!catalog.tableExists(testIdent)) + + catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(catalog.tableExists(testIdent)) + + catalog.dropTable(testIdent) + + assert(!catalog.tableExists(testIdent)) + } + + test("loadTable") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + val loaded = catalog.loadTable(testIdent) + + assert(table.name == loaded.name) + assert(table.schema == loaded.schema) + assert(table.properties == loaded.properties) + } + + test("loadTable: table does not exist") { + val catalog = newCatalog() + + val exc = intercept[NoSuchTableException] { + catalog.loadTable(testIdent) + } + + assert(exc.message.contains(testIdent.quoted)) + assert(exc.message.contains("not found")) + } + + test("invalidateTable") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + catalog.invalidateTable(testIdent) + + val loaded = catalog.loadTable(testIdent) + + assert(table.name == loaded.name) + assert(table.schema == loaded.schema) + assert(table.properties == loaded.properties) + } + + test("invalidateTable: table does not exist") { + val catalog = newCatalog() + + assert(catalog.tableExists(testIdent) === false) + + catalog.invalidateTable(testIdent) + } + + test("alterTable: add property") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.properties.asScala == Map()) + + val updated = catalog.alterTable(testIdent, TableChange.setProperty("prop-1", "1")) + assert(updated.properties.asScala == Map("prop-1" -> "1")) + + val loaded = catalog.loadTable(testIdent) + assert(loaded.properties.asScala == Map("prop-1" -> "1")) + + assert(table.properties.asScala == Map()) + } + + test("alterTable: add property to existing") { + val catalog = newCatalog() + + val properties = new util.HashMap[String, String]() + properties.put("prop-1", "1") + + val table = catalog.createTable(testIdent, schema, Array.empty, properties) + + assert(table.properties.asScala == Map("prop-1" -> "1")) + + val updated = catalog.alterTable(testIdent, TableChange.setProperty("prop-2", "2")) + assert(updated.properties.asScala == Map("prop-1" -> "1", "prop-2" -> "2")) + + val loaded = catalog.loadTable(testIdent) + assert(loaded.properties.asScala == Map("prop-1" -> "1", "prop-2" -> "2")) + + assert(table.properties.asScala == Map("prop-1" -> "1")) + } + + test("alterTable: remove existing property") { + val catalog = newCatalog() + + val properties = new util.HashMap[String, String]() + properties.put("prop-1", "1") + + val table = catalog.createTable(testIdent, schema, Array.empty, properties) + + assert(table.properties.asScala == Map("prop-1" -> "1")) + + val updated = catalog.alterTable(testIdent, TableChange.removeProperty("prop-1")) + assert(updated.properties.asScala == Map()) + + val loaded = catalog.loadTable(testIdent) + assert(loaded.properties.asScala == Map()) + + assert(table.properties.asScala == Map("prop-1" -> "1")) + } + + test("alterTable: remove missing property") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.properties.asScala == Map()) + + val updated = catalog.alterTable(testIdent, TableChange.removeProperty("prop-1")) + assert(updated.properties.asScala == Map()) + + val loaded = catalog.loadTable(testIdent) + assert(loaded.properties.asScala == Map()) + + assert(table.properties.asScala == Map()) + } + + test("alterTable: add top-level column") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, TableChange.addColumn(Array("ts"), TimestampType)) + + assert(updated.schema == schema.add("ts", TimestampType)) + } + + test("alterTable: add required column") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, + TableChange.addColumn(Array("ts"), TimestampType, false)) + + assert(updated.schema == schema.add("ts", TimestampType, nullable = false)) + } + + test("alterTable: add column with comment") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, + TableChange.addColumn(Array("ts"), TimestampType, false, "comment text")) + + val field = StructField("ts", TimestampType, nullable = false).withComment("comment text") + assert(updated.schema == schema.add(field)) + } + + test("alterTable: add nested column") { + val catalog = newCatalog() + + val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val tableSchema = schema.add("point", pointStruct) + + val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps) + + assert(table.schema == tableSchema) + + val updated = catalog.alterTable(testIdent, + TableChange.addColumn(Array("point", "z"), DoubleType)) + + val expectedSchema = schema.add("point", pointStruct.add("z", DoubleType)) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: add column to primitive field fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, TableChange.addColumn(Array("data", "ts"), TimestampType)) + } + + assert(exc.getMessage.contains("Not a struct")) + assert(exc.getMessage.contains("data")) + + // the table has not changed + assert(catalog.loadTable(testIdent).schema == schema) + } + + test("alterTable: add field to missing column fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, + TableChange.addColumn(Array("missing_col", "new_field"), StringType)) + } + + assert(exc.getMessage.contains("missing_col")) + assert(exc.getMessage.contains("Cannot find")) + } + + test("alterTable: update column data type") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, TableChange.updateColumnType(Array("id"), LongType)) + + val expectedSchema = new StructType().add("id", LongType).add("data", StringType) + assert(updated.schema == expectedSchema) + } + + test("alterTable: update column data type and nullability") { + val catalog = newCatalog() + + val originalSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("data", StringType) + val table = catalog.createTable(testIdent, originalSchema, Array.empty, emptyProps) + + assert(table.schema == originalSchema) + + val updated = catalog.alterTable(testIdent, + TableChange.updateColumnType(Array("id"), LongType, true)) + + val expectedSchema = new StructType().add("id", LongType).add("data", StringType) + assert(updated.schema == expectedSchema) + } + + test("alterTable: update optional column to required fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, TableChange.updateColumnType(Array("id"), LongType, false)) + } + + assert(exc.getMessage.contains("Cannot change optional column to required")) + assert(exc.getMessage.contains("id")) + } + + test("alterTable: update missing column fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, + TableChange.updateColumnType(Array("missing_col"), LongType)) + } + + assert(exc.getMessage.contains("missing_col")) + assert(exc.getMessage.contains("Cannot find")) + } + + test("alterTable: add comment") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, + TableChange.updateColumnComment(Array("id"), "comment text")) + + val expectedSchema = new StructType() + .add("id", IntegerType, nullable = true, "comment text") + .add("data", StringType) + assert(updated.schema == expectedSchema) + } + + test("alterTable: replace comment") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + catalog.alterTable(testIdent, TableChange.updateColumnComment(Array("id"), "comment text")) + + val expectedSchema = new StructType() + .add("id", IntegerType, nullable = true, "replacement comment") + .add("data", StringType) + + val updated = catalog.alterTable(testIdent, + TableChange.updateColumnComment(Array("id"), "replacement comment")) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: add comment to missing column fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, + TableChange.updateColumnComment(Array("missing_col"), "comment")) + } + + assert(exc.getMessage.contains("missing_col")) + assert(exc.getMessage.contains("Cannot find")) + } + + test("alterTable: rename top-level column") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, TableChange.renameColumn(Array("id"), "some_id")) + + val expectedSchema = new StructType().add("some_id", IntegerType).add("data", StringType) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: rename nested column") { + val catalog = newCatalog() + + val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val tableSchema = schema.add("point", pointStruct) + + val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps) + + assert(table.schema == tableSchema) + + val updated = catalog.alterTable(testIdent, + TableChange.renameColumn(Array("point", "x"), "first")) + + val newPointStruct = new StructType().add("first", DoubleType).add("y", DoubleType) + val expectedSchema = schema.add("point", newPointStruct) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: rename struct column") { + val catalog = newCatalog() + + val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val tableSchema = schema.add("point", pointStruct) + + val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps) + + assert(table.schema == tableSchema) + + val updated = catalog.alterTable(testIdent, + TableChange.renameColumn(Array("point"), "p")) + + val newPointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val expectedSchema = schema.add("p", newPointStruct) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: rename missing column fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, + TableChange.renameColumn(Array("missing_col"), "new_name")) + } + + assert(exc.getMessage.contains("missing_col")) + assert(exc.getMessage.contains("Cannot find")) + } + + test("alterTable: multiple changes") { + val catalog = newCatalog() + + val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val tableSchema = schema.add("point", pointStruct) + + val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps) + + assert(table.schema == tableSchema) + + val updated = catalog.alterTable(testIdent, + TableChange.renameColumn(Array("point", "x"), "first"), + TableChange.renameColumn(Array("point", "y"), "second")) + + val newPointStruct = new StructType().add("first", DoubleType).add("second", DoubleType) + val expectedSchema = schema.add("point", newPointStruct) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: delete top-level column") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, + TableChange.deleteColumn(Array("id"))) + + val expectedSchema = new StructType().add("data", StringType) + assert(updated.schema == expectedSchema) + } + + test("alterTable: delete nested column") { + val catalog = newCatalog() + + val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val tableSchema = schema.add("point", pointStruct) + + val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps) + + assert(table.schema == tableSchema) + + val updated = catalog.alterTable(testIdent, + TableChange.deleteColumn(Array("point", "y"))) + + val newPointStruct = new StructType().add("x", DoubleType) + val expectedSchema = schema.add("point", newPointStruct) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: delete missing column fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col"))) + } + + assert(exc.getMessage.contains("missing_col")) + assert(exc.getMessage.contains("Cannot find")) + } + + test("alterTable: delete missing nested column fails") { + val catalog = newCatalog() + + val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val tableSchema = schema.add("point", pointStruct) + + val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps) + + assert(table.schema == tableSchema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z"))) + } + + assert(exc.getMessage.contains("z")) + assert(exc.getMessage.contains("Cannot find")) + } + + test("alterTable: table does not exist") { + val catalog = newCatalog() + + val exc = intercept[NoSuchTableException] { + catalog.alterTable(testIdent, TableChange.setProperty("prop", "val")) + } + + assert(exc.message.contains(testIdent.quoted)) + assert(exc.message.contains("not found")) + } + + test("dropTable") { + val catalog = newCatalog() + + assert(!catalog.tableExists(testIdent)) + + catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(catalog.tableExists(testIdent)) + + val wasDropped = catalog.dropTable(testIdent) + + assert(wasDropped) + assert(!catalog.tableExists(testIdent)) + } + + test("dropTable: table does not exist") { + val catalog = newCatalog() + + assert(!catalog.tableExists(testIdent)) + + val wasDropped = catalog.dropTable(testIdent) + + assert(!wasDropped) + assert(!catalog.tableExists(testIdent)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala index 96345e22dbd5..01752125ac26 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala @@ -21,9 +21,10 @@ import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalog.v2.Identifier import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{LongType, StringType, StructType} @@ -37,7 +38,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn before { spark.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName) spark.conf.set("spark.sql.catalog.testcat2", classOf[TestInMemoryTableCatalog].getName) - spark.conf.set("spark.sql.default.catalog", "testcat") + spark.conf.set("spark.sql.catalog.session", classOf[TestInMemoryTableCatalog].getName) val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") df.createOrReplaceTempView("source") @@ -47,8 +48,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn after { spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog].clearTables() - spark.sql("DROP TABLE source") - spark.sql("DROP TABLE source2") + spark.catalog("session").asInstanceOf[TestInMemoryTableCatalog].clearTables() } test("CreateTable: use v2 plan because catalog is set") { @@ -66,13 +66,13 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Seq.empty) } - test("CreateTable: use v2 plan because provider is v2") { + test("CreateTable: use v2 plan and session catalog when provider is v2") { spark.sql(s"CREATE TABLE table_name (id bigint, data string) USING $orc2") - val testCatalog = spark.catalog("testcat").asTableCatalog + val testCatalog = spark.catalog("session").asTableCatalog val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) - assert(table.name == "testcat.table_name") + assert(table.name == "session.table_name") assert(table.partitioning.isEmpty) assert(table.properties == Map("provider" -> orc2).asJava) assert(table.schema == new StructType().add("id", LongType).add("data", StringType)) @@ -137,22 +137,23 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), Seq.empty) } - test("CreateTable: fail analysis when default catalog is needed but missing") { - val originalDefaultCatalog = conf.getConfString("spark.sql.default.catalog") - try { - conf.unsetConf("spark.sql.default.catalog") + test("CreateTable: use default catalog for v2 sources when default catalog is set") { + val sparkSession = spark.newSession() + sparkSession.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName) + sparkSession.conf.set("spark.sql.default.catalog", "testcat") + sparkSession.sql(s"CREATE TABLE table_name (id bigint, data string) USING foo") - val exc = intercept[AnalysisException] { - spark.sql(s"CREATE TABLE table_name USING $orc2 AS SELECT id, data FROM source") - } + val testCatalog = sparkSession.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) - assert(exc.getMessage.contains("No catalog specified for table")) - assert(exc.getMessage.contains("table_name")) - assert(exc.getMessage.contains("no default catalog is set")) + assert(table.name == "testcat.table_name") + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> "foo").asJava) + assert(table.schema == new StructType().add("id", LongType).add("data", StringType)) - } finally { - conf.setConfString("spark.sql.default.catalog", originalDefaultCatalog) - } + // check that the table is empty + val rdd = sparkSession.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Seq.empty) } test("CreateTableAsSelect: use v2 plan because catalog is set") { @@ -172,13 +173,13 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source")) } - test("CreateTableAsSelect: use v2 plan because provider is v2") { + test("CreateTableAsSelect: use v2 plan and session catalog when provider is v2") { spark.sql(s"CREATE TABLE table_name USING $orc2 AS SELECT id, data FROM source") - val testCatalog = spark.catalog("testcat").asTableCatalog + val testCatalog = spark.catalog("session").asTableCatalog val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) - assert(table.name == "testcat.table_name") + assert(table.name == "session.table_name") assert(table.partitioning.isEmpty) assert(table.properties == Map("provider" -> orc2).asJava) assert(table.schema == new StructType() @@ -251,22 +252,43 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), spark.table("source")) } - test("CreateTableAsSelect: fail analysis when default catalog is needed but missing") { - val originalDefaultCatalog = conf.getConfString("spark.sql.default.catalog") - try { - conf.unsetConf("spark.sql.default.catalog") + test("CreateTableAsSelect: use default catalog for v2 sources when default catalog is set") { + val sparkSession = spark.newSession() + sparkSession.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName) + sparkSession.conf.set("spark.sql.default.catalog", "testcat") - val exc = intercept[AnalysisException] { - spark.sql(s"CREATE TABLE table_name USING $orc2 AS SELECT id, data FROM source") - } + val df = sparkSession.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") + df.createOrReplaceTempView("source") - assert(exc.getMessage.contains("No catalog specified for table")) - assert(exc.getMessage.contains("table_name")) - assert(exc.getMessage.contains("no default catalog is set")) + // setting the default catalog breaks the reference to source because the default catalog is + // used and AsTableIdentifier no longer matches + sparkSession.sql(s"CREATE TABLE table_name USING foo AS SELECT id, data FROM source") - } finally { - conf.setConfString("spark.sql.default.catalog", originalDefaultCatalog) - } + val testCatalog = sparkSession.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name == "testcat.table_name") + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> "foo").asJava) + assert(table.schema == new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType)) + + val rdd = sparkSession.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), sparkSession.table("source")) + } + + test("CreateTableAsSelect: v2 session catalog can load v1 source table") { + val sparkSession = spark.newSession() + sparkSession.conf.set("spark.sql.catalog.session", classOf[V2SessionCatalog].getName) + + val df = sparkSession.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") + df.createOrReplaceTempView("source") + + sparkSession.sql(s"CREATE TABLE table_name USING parquet AS SELECT id, data FROM source") + + // use the catalog name to force loading with the v2 catalog + checkAnswer(sparkSession.sql(s"TABLE session.table_name"), sparkSession.table("source")) } test("DropTable: basic") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index b04b3f1031d7..2fa108825982 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -74,7 +74,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: - DataSourceResolution(conf, session.catalog(_)) +: + DataSourceResolution(conf, this) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = From 4169a8760d6e358914071460c91f381d4ae89b0b Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Tue, 9 Jul 2019 10:35:45 -0700 Subject: [PATCH 2/2] Fix temp view handling: views have precedence over catalog lookups. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 12 ++++++++---- .../sql/catalyst/catalog/v2/LookupCatalogSuite.scala | 4 ++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c2e259b3960d..1d0dba262c10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -669,6 +669,10 @@ class Analyzer( import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util._ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case u @ UnresolvedRelation(AsTemporaryViewIdentifier(ident)) + if catalog.isTemporaryTable(ident) => + u // temporary views take precedence over catalog table names + case u @ UnresolvedRelation(CatalogObjectIdentifier(Some(catalogPlugin), ident)) => loadTable(catalogPlugin, ident).map(DataSourceV2Relation.create).getOrElse(u) } @@ -706,6 +710,10 @@ class Analyzer( // Note this is compatible with the views defined by older versions of Spark(before 2.2), which // have empty defaultDatabase and all the relations in viewText have database part defined. def resolveRelation(plan: LogicalPlan): LogicalPlan = plan match { + case u @ UnresolvedRelation(AsTemporaryViewIdentifier(ident)) + if catalog.isTemporaryTable(ident) => + resolveRelation(lookupTableFromCatalog(ident, u, AnalysisContext.get.defaultDatabase)) + case u @ UnresolvedRelation(AsTableIdentifier(ident)) if !isRunningDirectlyOnFiles(ident) => val defaultDatabase = AnalysisContext.get.defaultDatabase val foundRelation = lookupTableFromCatalog(ident, u, defaultDatabase) @@ -715,10 +723,6 @@ class Analyzer( u } - case u @ UnresolvedRelation(AsTemporaryViewIdentifier(ident)) - if catalog.isTemporaryTable(ident) => - resolveRelation(lookupTableFromCatalog(ident, u, AnalysisContext.get.defaultDatabase)) - // The view's child should be a logical plan parsed from the `desc.viewText`, the variable // `viewText` should be defined, or else we throw an error on the generation of the View // operator. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala index 56d785e40dc1..52543d16d481 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala @@ -106,7 +106,7 @@ class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside { .foreach { sqlIdent => inside(parseMultipartIdentifier(sqlIdent)) { case AsTemporaryViewIdentifier(_) => - fail("AsTemporaryTableIdentifier should not match when " + + fail("AsTemporaryViewIdentifier should not match when " + "the catalog is set or the namespace has multiple parts") case _ => // expected @@ -185,7 +185,7 @@ class LookupCatalogWithDefaultSuite extends SparkFunSuite with LookupCatalog wit .foreach { sqlIdent => inside(parseMultipartIdentifier(sqlIdent)) { case AsTemporaryViewIdentifier(_) => - fail("AsTemporaryTableIdentifier should not match when " + + fail("AsTemporaryViewIdentifier should not match when " + "the catalog is set or the namespace has multiple parts") case _ => // expected