diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 01fc5d65c0363..fb71155657f2d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -171,7 +171,12 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter.transferMapSpillFile"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter.transferMapSpillFile"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions"), + + // [SPARK-39506] In terms of 3 layer namespace effort, add currentCatalog, setCurrentCatalog and listCatalogs API to Catalog interface + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.currentCatalog"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.setCurrentCatalog"), + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listCatalogs") ) def excludes(version: String) = version match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 1436574c0d90a..e75ba094da4f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -589,4 +589,25 @@ abstract class Catalog { * @since 2.0.0 */ def refreshByPath(path: String): Unit + + /** + * Returns the current catalog in this session. + * + * @since 3.4.0 + */ + def currentCatalog(): String + + /** + * Sets the current catalog in this session. + * + * @since 3.4.0 + */ + def setCurrentCatalog(catalogName: String): Unit + + /** + * Returns a list of catalogs available in this session. + * + * @since 3.4.0 + */ + def listCatalogs(): Dataset[CatalogMetadata] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala index 1e4e0b1474550..84839d2d1fdb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/interface.scala @@ -26,6 +26,25 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams // Note: all classes here are expected to be wrapped in Datasets and so must extend // DefinedByConstructorParams for the catalog to be able to create encoders for them. +/** + * A catalog in Spark, as returned by the `listCatalogs` method defined in [[Catalog]]. + * + * @param name name of the catalog + * @param description description of the catalog + * @since 3.4.0 + */ +class CatalogMetadata( + val name: String, + @Nullable val description: String) + extends DefinedByConstructorParams { + + override def toString: String = { + "Catalog[" + + s"name='$name', " + + Option(description).map { d => s"description='$d'] " }.getOrElse("]") + } +} + /** * A database in Spark, as returned by the `listDatabases` method defined in [[Catalog]]. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index f89a87c301149..49cb9a3e897bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -21,7 +21,7 @@ import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.spark.sql._ -import org.apache.spark.sql.catalog.{Catalog, Column, Database, Function, Table} +import org.apache.spark.sql.catalog.{Catalog, CatalogMetadata, Column, Database, Function, Table} import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedTable, ResolvedView, UnresolvedDBObjectName, UnresolvedNamespace, UnresolvedTable, UnresolvedTableOrView} import org.apache.spark.sql.catalyst.catalog._ @@ -589,10 +589,20 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def uncacheTable(tableName: String): Unit = { - val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) - sessionCatalog.lookupTempView(tableIdent).map(uncacheView).getOrElse { - sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName), - cascade = true) + // We first try to parse `tableName` to see if it is 2 part name. If so, then in HMS we check + // if it is a temp view and uncache the temp view from HMS, otherwise we uncache it from the + // cache manager. + // if `tableName` is not 2 part name, then we directly uncache it from the cache manager. + try { + val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) + sessionCatalog.lookupTempView(tableIdent).map(uncacheView).getOrElse { + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName), + cascade = true) + } + } catch { + case e: org.apache.spark.sql.catalyst.parser.ParseException => + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName), + cascade = true) } } @@ -671,6 +681,40 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { override def refreshByPath(resourcePath: String): Unit = { sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, resourcePath) } + + /** + * Returns the current default catalog in this session. + * + * @since 3.4.0 + */ + override def currentCatalog(): String = { + sparkSession.sessionState.catalogManager.currentCatalog.name() + } + + /** + * Sets the current default catalog in this session. + * + * @since 3.4.0 + */ + override def setCurrentCatalog(catalogName: String): Unit = { + sparkSession.sessionState.catalogManager.setCurrentCatalog(catalogName) + } + + /** + * Returns a list of catalogs in this session. + * + * @since 3.4.0 + */ + override def listCatalogs(): Dataset[CatalogMetadata] = { + val catalogs = sparkSession.sessionState.catalogManager.listCatalogs(None) + CatalogImpl.makeDataset(catalogs.map(name => makeCatalog(name)), sparkSession) + } + + private def makeCatalog(name: String): CatalogMetadata = { + new CatalogMetadata( + name = name, + description = null) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 4844884f6935d..a1a946ddd7155 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -21,7 +21,7 @@ import java.io.File import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.catalog.{Column, Database, Function, Table} import org.apache.spark.sql.catalyst.{FunctionIdentifier, ScalaReflection, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.AnalysisTest @@ -63,6 +63,12 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf sessionCatalog.createTable(utils.newTable(name, db), ignoreIfExists = false) } + private def createTable(name: String, db: String, catalog: String, source: String, + schema: StructType, option: Map[String, String], description: String): DataFrame = { + spark.catalog.createTable(Array(catalog, db, name).mkString("."), source, + schema, description, option) + } + private def createTempTable(name: String): Unit = { createTempView(sessionCatalog, name, Range(1, 2, 3, 4), overrideIfExists = true) } @@ -579,12 +585,8 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf val tableSchema = new StructType().add("i", "int") val description = "this is a test table" - val df = spark.catalog.createTable( - tableName = Array(catalogName, dbName, tableName).mkString("."), - source = classOf[FakeV2Provider].getName, - schema = tableSchema, - description = description, - options = Map.empty[String, String]) + val df = createTable(tableName, dbName, catalogName, classOf[FakeV2Provider].getName, + tableSchema, Map.empty[String, String], description) assert(df.schema.equals(tableSchema)) val testCatalog = @@ -603,12 +605,8 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf val tableSchema = new StructType().add("i", "int") val description = "this is a test table" - val df = spark.catalog.createTable( - tableName = Array(catalogName, dbName, tableName).mkString("."), - source = classOf[FakeV2Provider].getName, - schema = tableSchema, - description = description, - options = Map("path" -> dir.getAbsolutePath)) + val df = createTable(tableName, dbName, catalogName, classOf[FakeV2Provider].getName, + tableSchema, Map("path" -> dir.getAbsolutePath), description) assert(df.schema.equals(tableSchema)) val testCatalog = @@ -630,23 +628,13 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf val tableName = "my_table" val tableSchema = new StructType().add("i", "int") val description = "this is a test managed table" - - spark.catalog.createTable( - tableName = Array(catalogName, dbName, tableName).mkString("."), - source = classOf[FakeV2Provider].getName, - schema = tableSchema, - description = description, - options = Map.empty[String, String]) + createTable(tableName, dbName, catalogName, classOf[FakeV2Provider].getName, tableSchema, + Map.empty[String, String], description) val tableName2 = "my_table2" val description2 = "this is a test external table" - - spark.catalog.createTable( - tableName = Array(catalogName, dbName, tableName2).mkString("."), - source = classOf[FakeV2Provider].getName, - schema = tableSchema, - description = description2, - options = Map("path" -> dir.getAbsolutePath)) + createTable(tableName2, dbName, catalogName, classOf[FakeV2Provider].getName, tableSchema, + Map("path" -> dir.getAbsolutePath), description2) val tables = spark.catalog.listTables("testcat.my_db").collect() assert(tables.size == 2) @@ -689,12 +677,8 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf val tableSchema = new StructType().add("i", "int") val description = "this is a test table" - spark.catalog.createTable( - tableName = Array(catalogName, dbName, tableName).mkString("."), - source = classOf[FakeV2Provider].getName, - schema = tableSchema, - description = description, - options = Map.empty[String, String]) + createTable(tableName, dbName, catalogName, classOf[FakeV2Provider].getName, tableSchema, + Map.empty[String, String], description) val t = spark.catalog.getTable(Array(catalogName, dbName, tableName).mkString(".")) val expectedTable = @@ -721,13 +705,8 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf val tableSchema = new StructType().add("i", "int") assert(!spark.catalog.tableExists(Array(catalogName, dbName, tableName).mkString("."))) - - spark.catalog.createTable( - tableName = Array(catalogName, dbName, tableName).mkString("."), - source = classOf[FakeV2Provider].getName, - schema = tableSchema, - description = "", - options = Map.empty[String, String]) + createTable(tableName, dbName, catalogName, classOf[FakeV2Provider].getName, tableSchema, + Map.empty[String, String], "") assert(spark.catalog.tableExists(Array(catalogName, dbName, tableName).mkString("."))) } @@ -743,4 +722,31 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf val catalogName2 = "catalog_not_exists" assert(!spark.catalog.databaseExists(Array(catalogName2, dbName).mkString("."))) } + + test("SPARK-39506: three layer namespace compatibility - cache table, isCached and" + + "uncacheTable") { + val tableSchema = new StructType().add("i", "int") + createTable("my_table", "my_db", "testcat", classOf[FakeV2Provider].getName, + tableSchema, Map.empty[String, String], "") + createTable("my_table2", "my_db", "testcat", classOf[FakeV2Provider].getName, + tableSchema, Map.empty[String, String], "") + + spark.catalog.cacheTable("testcat.my_db.my_table", StorageLevel.DISK_ONLY) + assert(spark.table("testcat.my_db.my_table").storageLevel == StorageLevel.DISK_ONLY) + assert(spark.catalog.isCached("testcat.my_db.my_table")) + + spark.catalog.cacheTable("testcat.my_db.my_table2") + assert(spark.catalog.isCached("testcat.my_db.my_table2")) + + spark.catalog.uncacheTable("testcat.my_db.my_table") + assert(!spark.catalog.isCached("testcat.my_db.my_table")) + } + + test("SPARK-39506: test setCurrentCatalog, currentCatalog and listCatalogs") { + spark.catalog.setCurrentCatalog("testcat") + assert(spark.catalog.currentCatalog().equals("testcat")) + spark.catalog.setCurrentCatalog("spark_catalog") + assert(spark.catalog.currentCatalog().equals("spark_catalog")) + assert(spark.catalog.listCatalogs().collect().map(c => c.name).toSet == Set("testcat")) + } }