diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala index 932d32022702b..5464a7496d23d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier @Experimental trait LookupCatalog { - def lookupCatalog: Option[(String) => CatalogPlugin] = None + protected def lookupCatalog(name: String): CatalogPlugin type CatalogObjectIdentifier = (Option[CatalogPlugin], Identifier) @@ -34,27 +34,23 @@ trait LookupCatalog { * Extract catalog plugin and identifier from a multi-part identifier. */ object CatalogObjectIdentifier { - def unapply(parts: Seq[String]): Option[CatalogObjectIdentifier] = lookupCatalog.map { lookup => - parts match { - case Seq(name) => - (None, Identifier.of(Array.empty, name)) - case Seq(catalogName, tail @ _*) => - try { - val catalog = lookup(catalogName) - (Some(catalog), Identifier.of(tail.init.toArray, tail.last)) - } catch { - case _: CatalogNotFoundException => - (None, Identifier.of(parts.init.toArray, parts.last)) - } - } + def unapply(parts: Seq[String]): Some[CatalogObjectIdentifier] = parts match { + case Seq(name) => + Some((None, Identifier.of(Array.empty, name))) + case Seq(catalogName, tail @ _*) => + try { + Some((Some(lookupCatalog(catalogName)), Identifier.of(tail.init.toArray, tail.last))) + } catch { + case _: CatalogNotFoundException => + Some((None, Identifier.of(parts.init.toArray, parts.last))) + } } } /** * Extract legacy table identifier from a multi-part identifier. * - * For legacy support only. Please use - * [[org.apache.spark.sql.catalog.v2.LookupCatalog.CatalogObjectIdentifier]] in DSv2 code paths. + * For legacy support only. Please use [[CatalogObjectIdentifier]] instead on DSv2 code paths. */ object AsTableIdentifier { def unapply(parts: Seq[String]): Option[TableIdentifier] = parts match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2672583ec1749..546675edcb2e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalog.v2.{CatalogPlugin, LookupCatalog} +import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, LookupCatalog} import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes @@ -96,18 +96,15 @@ object AnalysisContext { class Analyzer( catalog: SessionCatalog, conf: SQLConf, - maxIterations: Int, - override val lookupCatalog: Option[(String) => CatalogPlugin] = None) + maxIterations: Int) extends RuleExecutor[LogicalPlan] with CheckAnalysis with LookupCatalog { def this(catalog: SessionCatalog, conf: SQLConf) = { this(catalog, conf, conf.optimizerMaxIterations) } - def this(lookupCatalog: Option[(String) => CatalogPlugin], catalog: SessionCatalog, - conf: SQLConf) = { - this(catalog, conf, conf.optimizerMaxIterations, lookupCatalog) - } + override protected def lookupCatalog(name: String): CatalogPlugin = + throw new CatalogNotFoundException("No catalog lookup function") def executeAndCheck(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = { AnalysisHelper.markInAnalyzer { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala new file mode 100644 index 0000000000000..783751ff79865 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.catalog.v2 + +import org.scalatest.Inside +import org.scalatest.Matchers._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, Identifier, LookupCatalog} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +private case class TestCatalogPlugin(override val name: String) extends CatalogPlugin { + + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = Unit +} + +class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside { + import CatalystSqlParser._ + + private val catalogs = Seq("prod", "test").map(x => x -> new TestCatalogPlugin(x)).toMap + + override def lookupCatalog(name: String): CatalogPlugin = + catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found")) + + test("catalog object identifier") { + Seq( + ("tbl", None, Seq.empty, "tbl"), + ("db.tbl", None, Seq("db"), "tbl"), + ("prod.func", catalogs.get("prod"), Seq.empty, "func"), + ("ns1.ns2.tbl", None, Seq("ns1", "ns2"), "tbl"), + ("prod.db.tbl", catalogs.get("prod"), Seq("db"), "tbl"), + ("test.db.tbl", catalogs.get("test"), Seq("db"), "tbl"), + ("test.ns1.ns2.ns3.tbl", catalogs.get("test"), Seq("ns1", "ns2", "ns3"), "tbl"), + ("`db.tbl`", None, Seq.empty, "db.tbl"), + ("parquet.`file:/tmp/db.tbl`", None, Seq("parquet"), "file:/tmp/db.tbl"), + ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", None, + Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach { + case (sql, expectedCatalog, namespace, name) => + inside(parseMultipartIdentifier(sql)) { + case CatalogObjectIdentifier(catalog, ident) => + catalog shouldEqual expectedCatalog + ident shouldEqual Identifier.of(namespace.toArray, name) + } + } + } + + test("table identifier") { + Seq( + ("tbl", "tbl", None), + ("db.tbl", "tbl", Some("db")), + ("`db.tbl`", "db.tbl", None), + ("parquet.`file:/tmp/db.tbl`", "file:/tmp/db.tbl", Some("parquet")), + ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", "s3://buck/tmp/abc.json", + Some("org.apache.spark.sql.json"))).foreach { + case (sql, table, db) => + inside (parseMultipartIdentifier(sql)) { + case AsTableIdentifier(ident) => + ident shouldEqual TableIdentifier(table, db) + } + } + Seq( + "prod.func", + "prod.db.tbl", + "ns1.ns2.tbl").foreach { sql => + parseMultipartIdentifier(sql) match { + case AsTableIdentifier(_) => + fail(s"$sql should not be resolved as TableIdentifier") + case _ => + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/ResolveMultipartIdentifierSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/ResolveMultipartIdentifierSuite.scala deleted file mode 100644 index 0f2d67eaa9b20..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/ResolveMultipartIdentifierSuite.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.catalog.v2 - -import org.scalatest.Matchers._ - -import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin} -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer} -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.util.CaseInsensitiveStringMap - -private class TestCatalogPlugin(override val name: String) extends CatalogPlugin { - - override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = Unit -} - -class ResolveMultipartIdentifierSuite extends AnalysisTest { - import CatalystSqlParser._ - - private val analyzer = makeAnalyzer(caseSensitive = false) - - private val catalogs = Seq("prod", "test").map(name => name -> new TestCatalogPlugin(name)).toMap - - private def lookupCatalog(catalog: String): CatalogPlugin = - catalogs.getOrElse(catalog, throw new CatalogNotFoundException("Not found")) - - private def makeAnalyzer(caseSensitive: Boolean) = { - val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) - new Analyzer(Some(lookupCatalog _), null, conf) - } - - override protected def getAnalyzer(caseSensitive: Boolean) = analyzer - - private def checkResolution(sqlText: String, expectedCatalog: Option[CatalogPlugin], - expectedNamespace: Array[String], expectedName: String): Unit = { - - import analyzer.CatalogObjectIdentifier - val CatalogObjectIdentifier(catalog, ident) = parseMultipartIdentifier(sqlText) - catalog shouldEqual expectedCatalog - ident.namespace shouldEqual expectedNamespace - ident.name shouldEqual expectedName - } - - private def checkTableResolution(sqlText: String, - expectedIdent: Option[TableIdentifier]): Unit = { - - import analyzer.AsTableIdentifier - parseMultipartIdentifier(sqlText) match { - case AsTableIdentifier(ident) => - assert(Some(ident) === expectedIdent) - case _ => - assert(None === expectedIdent) - } - } - - test("resolve multipart identifier") { - checkResolution("tbl", None, Array.empty, "tbl") - checkResolution("db.tbl", None, Array("db"), "tbl") - checkResolution("prod.func", catalogs.get("prod"), Array.empty, "func") - checkResolution("ns1.ns2.tbl", None, Array("ns1", "ns2"), "tbl") - checkResolution("prod.db.tbl", catalogs.get("prod"), Array("db"), "tbl") - checkResolution("test.db.tbl", catalogs.get("test"), Array("db"), "tbl") - checkResolution("test.ns1.ns2.ns3.tbl", - catalogs.get("test"), Array("ns1", "ns2", "ns3"), "tbl") - checkResolution("`db.tbl`", None, Array.empty, "db.tbl") - checkResolution("parquet.`file:/tmp/db.tbl`", None, Array("parquet"), "file:/tmp/db.tbl") - checkResolution("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", None, - Array("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json") - } - - test("resolve table identifier") { - checkTableResolution("tbl", Some(TableIdentifier("tbl"))) - checkTableResolution("db.tbl", Some(TableIdentifier("tbl", Some("db")))) - checkTableResolution("prod.func", None) - checkTableResolution("ns1.ns2.tbl", None) - checkTableResolution("prod.db.tbl", None) - checkTableResolution("`db.tbl`", Some(TableIdentifier("db.tbl"))) - checkTableResolution("parquet.`file:/tmp/db.tbl`", - Some(TableIdentifier("file:/tmp/db.tbl", Some("parquet")))) - checkTableResolution("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", - Some(TableIdentifier("s3://buck/tmp/abc.json", Some("org.apache.spark.sql.json")))) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index 72b05036fb04d..635100229fbcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -41,7 +41,7 @@ case class DataSourceResolution( import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ - override def lookupCatalog: Option[String => CatalogPlugin] = Some(findCatalog) + override protected def lookupCatalog(name: String): CatalogPlugin = findCatalog(name) def defaultCatalog: Option[CatalogPlugin] = conf.defaultV2Catalog.map(findCatalog)