Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,37 @@ package org.apache.spark.sql.catalog.v2

import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.CatalogManager

/**
* A trait to encapsulate catalog lookup function and helpful extractors.
*/
@Experimental
trait LookupCatalog {

protected def lookupCatalog(name: String): CatalogPlugin

type CatalogObjectIdentifier = (Option[CatalogPlugin], Identifier)
val catalogManager: CatalogManager

/**
* Extract catalog plugin and identifier from a multi-part identifier.
*/
object CatalogObjectIdentifier {
def unapply(parts: Seq[String]): Some[CatalogObjectIdentifier] = parts match {
case Seq(name) =>
Some((None, Identifier.of(Array.empty, name)))
case Seq(catalogName, tail @ _*) =>
def unapply(parts: Seq[String]): Option[(CatalogPlugin, Identifier)] = {
assert(parts.nonEmpty)
if (parts.length == 1) {
catalogManager.getDefaultCatalog().map { catalog =>
(catalog, Identifier.of(Array.empty, parts.last))
}
} else {
try {
Some((Some(lookupCatalog(catalogName)), Identifier.of(tail.init.toArray, tail.last)))
val catalog = catalogManager.getCatalog(parts.head)
Some((catalog, Identifier.of(parts.tail.init.toArray, parts.last)))
} catch {
case _: CatalogNotFoundException =>
Some((None, Identifier.of(parts.init.toArray, parts.last)))
catalogManager.getDefaultCatalog().map { catalog =>
(catalog, Identifier.of(parts.init.toArray, parts.last))
}
}
}
}
}

Expand All @@ -54,17 +60,11 @@ trait LookupCatalog {
*/
object AsTableIdentifier {
def unapply(parts: Seq[String]): Option[TableIdentifier] = parts match {
case CatalogObjectIdentifier(None, ident) =>
ident.namespace match {
case Array() =>
Some(TableIdentifier(ident.name))
case Array(database) =>
Some(TableIdentifier(ident.name, Some(database)))
case _ =>
None
}
case _ =>
None
case CatalogObjectIdentifier(_, _) =>
throw new IllegalStateException(parts.mkString(".") + " is not a TableIdentifier.")
case Seq(tblName) => Some(TableIdentifier(tblName))
case Seq(dbName, tblName) => Some(TableIdentifier(tblName, Some(dbName)))
case _ => None
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.util.Random

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, LookupCatalog}
import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, LookupCatalog}
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.OuterScopes
Expand Down Expand Up @@ -104,8 +104,7 @@ class Analyzer(
this(catalog, conf, conf.optimizerMaxIterations)
}

override protected def lookupCatalog(name: String): CatalogPlugin =
throw new CatalogNotFoundException("No catalog lookup function")
override lazy val catalogManager: CatalogManager = new CatalogManager(conf)

def executeAndCheck(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = {
AnalysisHelper.markInAnalyzer {
Expand Down Expand Up @@ -163,7 +162,6 @@ class Analyzer(
new SubstituteUnresolvedOrdinals(conf)),
Batch("Resolution", fixedPoint,
ResolveTableValuedFunctions ::
ResolveTables ::
ResolveRelations ::
ResolveReferences ::
ResolveCreateNamedStruct ::
Expand Down Expand Up @@ -658,20 +656,6 @@ class Analyzer(
}
}

/**
* Resolve table relations with concrete relations from v2 catalog.
*
* [[ResolveRelations]] still resolves v1 tables.
*/
object ResolveTables extends Rule[LogicalPlan] {
import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util._

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case u @ UnresolvedRelation(CatalogObjectIdentifier(Some(catalogPlugin), ident)) =>
loadTable(catalogPlugin, ident).map(DataSourceV2Relation.create).getOrElse(u)
}
}

/**
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
Expand Down Expand Up @@ -704,7 +688,7 @@ class Analyzer(
// Note this is compatible with the views defined by older versions of Spark(before 2.2), which
// have empty defaultDatabase and all the relations in viewText have database part defined.
def resolveRelation(plan: LogicalPlan): LogicalPlan = plan match {
case u @ UnresolvedRelation(AsTableIdentifier(ident)) if !isRunningDirectlyOnFiles(ident) =>
case u @ UnresolvedRelation(ident) =>
val defaultDatabase = AnalysisContext.get.defaultDatabase
val foundRelation = lookupTableFromCatalog(ident, u, defaultDatabase)
if (foundRelation != u) {
Expand Down Expand Up @@ -735,7 +719,7 @@ class Analyzer(
}

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case i @ InsertIntoTable(u @ UnresolvedRelation(AsTableIdentifier(ident)), _, child, _, _)
case i @ InsertIntoTable(u @ UnresolvedRelation(ident), _, child, _, _)
if child.resolved =>
EliminateSubqueryAliases(lookupTableFromCatalog(ident, u)) match {
case v: View =>
Expand All @@ -752,28 +736,47 @@ class Analyzer(
// and the default database is only used to look up a view);
// 3. Use the currentDb of the SessionCatalog.
private def lookupTableFromCatalog(
tableIdentifier: TableIdentifier,
nameParts: Seq[String],
u: UnresolvedRelation,
defaultDatabase: Option[String] = None): LogicalPlan = {
val tableIdentWithDb = tableIdentifier.copy(
database = tableIdentifier.database.orElse(defaultDatabase))
import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.CatalogHelper

val namePartsWithDb = if (nameParts.length == 1) {
defaultDatabase.toSeq ++ nameParts
} else {
nameParts
}

try {
catalog.lookupRelation(tableIdentWithDb)
namePartsWithDb match {
case AsTempViewIdentifier(ident) => catalog.lookupRelation(ident)

case CatalogObjectIdentifier(v2Catalog, ident) =>
val table = v2Catalog.asTableCatalog.loadTable(ident)
DataSourceV2Relation.create(table)

case _ =>
// The builtin hive catalog doesn't support more than 2 table name parts. Here we assume
// the first name part is a catalog which doesn't exist.
if (namePartsWithDb.length > 2) {
throw new CatalogNotFoundException(s"Catalog '${namePartsWithDb.head}' not found.")
}
catalog.lookupRelation(TableIdentifier(namePartsWithDb))
}
} catch {
case _: NoSuchTableException | _: NoSuchDatabaseException =>
u
}
}

// If the database part is specified, and we support running SQL directly on files, and
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed anymore because in #24741 we delay the error reporting of unresolved relation to CheckAnalysis

// it's not a temporary view, and the table does not exist, then let's just return the
// original UnresolvedRelation. It is possible we are matching a query like "select *
// from parquet.`/path/to/query`". The plan will get resolved in the rule `ResolveDataSource`.
// Note that we are testing (!db_exists || !table_exists) because the catalog throws
// an exception from tableExists if the database does not exist.
private def isRunningDirectlyOnFiles(table: TableIdentifier): Boolean = {
table.database.isDefined && conf.runSQLonFile && !catalog.isTemporaryTable(table) &&
(!catalog.databaseExists(table.database.get) || !catalog.tableExists(table))
object AsTempViewIdentifier {
def unapply(parts: Seq[String]): Option[TableIdentifier] = {
if (parts.nonEmpty && parts.length <= 2) {
Some(TableIdentifier(parts)).filter(catalog.isTemporaryTable)
} else {
None
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.catalog

import scala.collection.mutable

import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Catalogs}
import org.apache.spark.sql.internal.SQLConf

/**
* A thread-safe manager for [[CatalogPlugin]]s. It tracks all the registered catalogs, and allow
* the caller to look up a catalog by name.
*/
class CatalogManager(conf: SQLConf) {

/**
* Tracks all the registered catalogs.
*/
private val catalogs = mutable.HashMap.empty[String, CatalogPlugin]

/**
* Looks up a catalog by name.
*/
def getCatalog(name: String): CatalogPlugin = synchronized {
catalogs.getOrElseUpdate(name, Catalogs.load(name, conf))
}

def getDefaultCatalog(): Option[CatalogPlugin] = {
conf.defaultV2Catalog.map(getCatalog)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ case class QualifiedTableName(database: String, name: String) {

object TableIdentifier {
def apply(tableName: String): TableIdentifier = new TableIdentifier(tableName)

def apply(nameParts: Seq[String]): TableIdentifier = {
assert(nameParts.nonEmpty && nameParts.length <= 2)
if (nameParts.length == 1) {
TableIdentifier(nameParts.last)
} else {
TableIdentifier(nameParts.last, Some(nameParts.head))
}
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@
*/
package org.apache.spark.sql.catalyst.catalog.v2

import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.{mock, when}
import org.mockito.invocation.InvocationOnMock
import org.scalatest.Inside
import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, Identifier, LookupCatalog}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.CatalogManager
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.util.CaseInsensitiveStringMap

Expand All @@ -33,10 +37,17 @@ private case class TestCatalogPlugin(override val name: String) extends CatalogP
class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside {
import CatalystSqlParser._

private val catalogs = Seq("prod", "test").map(x => x -> new TestCatalogPlugin(x)).toMap
private val catalogs = Seq("prod", "test").map(x => x -> TestCatalogPlugin(x)).toMap

override def lookupCatalog(name: String): CatalogPlugin =
catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found"))
override val catalogManager: CatalogManager = {
val manager = mock(classOf[CatalogManager])
when(manager.getCatalog(any())).thenAnswer((invocation: InvocationOnMock) => {
val name = invocation.getArgument[String](0)
catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found"))
})
when(manager.getDefaultCatalog()).thenReturn(None)
manager
}

test("catalog object identifier") {
Seq(
Expand All @@ -54,8 +65,9 @@ class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside {
case (sql, expectedCatalog, namespace, name) =>
inside(parseMultipartIdentifier(sql)) {
case CatalogObjectIdentifier(catalog, ident) =>
catalog shouldEqual expectedCatalog
Some(catalog) shouldEqual expectedCatalog
ident shouldEqual Identifier.of(namespace.toArray, name)
case _ => assert(expectedCatalog.isEmpty)
}
}
}
Expand All @@ -78,10 +90,11 @@ class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside {
"prod.func",
"prod.db.tbl",
"ns1.ns2.tbl").foreach { sql =>
parseMultipartIdentifier(sql) match {
case AsTableIdentifier(_) =>
fail(s"$sql should not be resolved as TableIdentifier")
case _ =>
val nameParts = parseMultipartIdentifier(sql)
if (nameParts.head == "prod") {
intercept[IllegalStateException](AsTableIdentifier.unapply(nameParts))
} else {
assert(AsTableIdentifier.unapply(nameParts).isEmpty)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -607,12 +607,6 @@ class SparkSession private(
*/
@transient lazy val catalog: Catalog = new CatalogImpl(self)

@transient private lazy val catalogs = new mutable.HashMap[String, CatalogPlugin]()

private[sql] def catalog(name: String): CatalogPlugin = synchronized {
catalogs.getOrElseUpdate(name, Catalogs.load(name, sessionState.conf))
}

/**
* Returns the specified table/view as a `DataFrame`.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,29 +181,32 @@ case class CreateViewCommand(
* Permanent views are not allowed to reference temp objects, including temp function and views
*/
private def verifyTemporaryObjectsNotExists(sparkSession: SparkSession): Unit = {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the sake of consistency, I think it'd be better to use either sparkSession or session like the below methods after this one

import sparkSession.sessionState.analyzer.AsTableIdentifier

if (!isTemporary) {
// This func traverses the unresolved plan `child`. Below are the reasons:
// 1) Analyzer replaces unresolved temporary views by a SubqueryAlias with the corresponding
// logical plan. After replacement, it is impossible to detect whether the SubqueryAlias is
// added/generated from a temporary view.
// 2) The temp functions are represented by multiple classes. Most are inaccessible from this
// package (e.g., HiveGenericUDF).
child.collect {
child.foreach {
// Disallow creating permanent views based on temporary views.
case UnresolvedRelation(AsTableIdentifier(ident))
if sparkSession.sessionState.catalog.isTemporaryTable(ident) =>
// temporary views are only stored in the session catalog
throw new AnalysisException(s"Not allowed to create a permanent view $name by " +
s"referencing a temporary view $ident")
case UnresolvedRelation(parts) =>
// The `DataSourceResolution` rule guarantees this.
assert(parts.nonEmpty && parts.length <= 2)
val tblIdent = TableIdentifier(parts)
if (sparkSession.sessionState.catalog.isTemporaryTable(tblIdent)) {
// temporary views are only stored in the session catalog
throw new AnalysisException(s"Not allowed to create a permanent view $name by " +
s"referencing a temporary view $tblIdent")
}
case other if !other.resolved => other.expressions.flatMap(_.collect {
// Disallow creating permanent views based on temporary UDFs.
case e: UnresolvedFunction
if sparkSession.sessionState.catalog.isTemporaryFunction(e.name) =>
throw new AnalysisException(s"Not allowed to create a permanent view $name by " +
s"referencing a temporary function `${e.name}`")
})
case _ =>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be better to add a comment like do nothing

}
}
}
Expand Down
Loading