Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,12 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter.transferMapSpillFile"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter.transferMapSpillFile"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions")
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions"),

// [SPARK-39506] In terms of 3 layer namespace effort, add currentCatalog, setCurrentCatalog and listCatalogs API to Catalog interface
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.currentCatalog"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.setCurrentCatalog"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.listCatalogs")
)

def excludes(version: String) = version match {
Expand Down
21 changes: 21 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
Original file line number Diff line number Diff line change
Expand Up @@ -589,4 +589,25 @@ abstract class Catalog {
* @since 2.0.0
*/
def refreshByPath(path: String): Unit

/**
* Returns the current catalog in this session.
*
* @since 3.4.0
*/
def currentCatalog(): String

/**
* Sets the current catalog in this session.
*
* @since 3.4.0
*/
def setCurrentCatalog(catalogName: String): Unit

/**
* Returns a list of catalogs available in this session.
*
* @since 3.4.0
*/
def listCatalogs(): Dataset[CatalogMetadata]
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,25 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams
// Note: all classes here are expected to be wrapped in Datasets and so must extend
// DefinedByConstructorParams for the catalog to be able to create encoders for them.

/**
* A catalog in Spark, as returned by the `listCatalogs` method defined in [[Catalog]].
*
* @param name name of the catalog
* @param description description of the catalog
* @since 3.4.0
*/
class CatalogMetadata(
val name: String,
@Nullable val description: String)
extends DefinedByConstructorParams {

override def toString: String = {
"Catalog[" +
s"name='$name', " +
Option(description).map { d => s"description='$d'] " }.getOrElse("]")
}
}

/**
* A database in Spark, as returned by the `listDatabases` method defined in [[Catalog]].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal

import org.apache.spark.sql._
import org.apache.spark.sql.catalog.{Catalog, Column, Database, Function, Table}
import org.apache.spark.sql.catalog.{Catalog, CatalogMetadata, Column, Database, Function, Table}
import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedTable, ResolvedView, UnresolvedDBObjectName, UnresolvedNamespace, UnresolvedTable, UnresolvedTableOrView}
import org.apache.spark.sql.catalyst.catalog._
Expand Down Expand Up @@ -589,10 +589,20 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
* @since 2.0.0
*/
override def uncacheTable(tableName: String): Unit = {
val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName)
sessionCatalog.lookupTempView(tableIdent).map(uncacheView).getOrElse {
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName),
cascade = true)
// We first try to parse `tableName` to see if it is 2 part name. If so, then in HMS we check
// if it is a temp view and uncache the temp view from HMS, otherwise we uncache it from the
// cache manager.
// if `tableName` is not 2 part name, then we directly uncache it from the cache manager.
try {
val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName)
sessionCatalog.lookupTempView(tableIdent).map(uncacheView).getOrElse {
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName),
cascade = true)
}
} catch {
case e: org.apache.spark.sql.catalyst.parser.ParseException =>
sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName),
cascade = true)
}
}

Expand Down Expand Up @@ -671,6 +681,40 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog {
override def refreshByPath(resourcePath: String): Unit = {
sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, resourcePath)
}

/**
* Returns the current default catalog in this session.
*
* @since 3.4.0
*/
override def currentCatalog(): String = {
sparkSession.sessionState.catalogManager.currentCatalog.name()
}

/**
* Sets the current default catalog in this session.
*
* @since 3.4.0
*/
override def setCurrentCatalog(catalogName: String): Unit = {
sparkSession.sessionState.catalogManager.setCurrentCatalog(catalogName)
}

/**
* Returns a list of catalogs in this session.
*
* @since 3.4.0
*/
override def listCatalogs(): Dataset[CatalogMetadata] = {
val catalogs = sparkSession.sessionState.catalogManager.listCatalogs(None)
CatalogImpl.makeDataset(catalogs.map(name => makeCatalog(name)), sparkSession)
}

private def makeCatalog(name: String): CatalogMetadata = {
new CatalogMetadata(
name = name,
description = null)
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.io.File

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.{AnalysisException, DataFrame}
import org.apache.spark.sql.catalog.{Column, Database, Function, Table}
import org.apache.spark.sql.catalyst.{FunctionIdentifier, ScalaReflection, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.AnalysisTest
Expand Down Expand Up @@ -63,6 +63,12 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf
sessionCatalog.createTable(utils.newTable(name, db), ignoreIfExists = false)
}

private def createTable(name: String, db: String, catalog: String, source: String,
schema: StructType, option: Map[String, String], description: String): DataFrame = {
spark.catalog.createTable(Array(catalog, db, name).mkString("."), source,
schema, description, option)
}

private def createTempTable(name: String): Unit = {
createTempView(sessionCatalog, name, Range(1, 2, 3, 4), overrideIfExists = true)
}
Expand Down Expand Up @@ -579,12 +585,8 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf
val tableSchema = new StructType().add("i", "int")
val description = "this is a test table"

val df = spark.catalog.createTable(
tableName = Array(catalogName, dbName, tableName).mkString("."),
source = classOf[FakeV2Provider].getName,
schema = tableSchema,
description = description,
options = Map.empty[String, String])
val df = createTable(tableName, dbName, catalogName, classOf[FakeV2Provider].getName,
tableSchema, Map.empty[String, String], description)
assert(df.schema.equals(tableSchema))

val testCatalog =
Expand All @@ -603,12 +605,8 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf
val tableSchema = new StructType().add("i", "int")
val description = "this is a test table"

val df = spark.catalog.createTable(
tableName = Array(catalogName, dbName, tableName).mkString("."),
source = classOf[FakeV2Provider].getName,
schema = tableSchema,
description = description,
options = Map("path" -> dir.getAbsolutePath))
val df = createTable(tableName, dbName, catalogName, classOf[FakeV2Provider].getName,
tableSchema, Map("path" -> dir.getAbsolutePath), description)
assert(df.schema.equals(tableSchema))

val testCatalog =
Expand All @@ -630,23 +628,13 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf
val tableName = "my_table"
val tableSchema = new StructType().add("i", "int")
val description = "this is a test managed table"

spark.catalog.createTable(
tableName = Array(catalogName, dbName, tableName).mkString("."),
source = classOf[FakeV2Provider].getName,
schema = tableSchema,
description = description,
options = Map.empty[String, String])
createTable(tableName, dbName, catalogName, classOf[FakeV2Provider].getName, tableSchema,
Map.empty[String, String], description)

val tableName2 = "my_table2"
val description2 = "this is a test external table"

spark.catalog.createTable(
tableName = Array(catalogName, dbName, tableName2).mkString("."),
source = classOf[FakeV2Provider].getName,
schema = tableSchema,
description = description2,
options = Map("path" -> dir.getAbsolutePath))
createTable(tableName2, dbName, catalogName, classOf[FakeV2Provider].getName, tableSchema,
Map("path" -> dir.getAbsolutePath), description2)

val tables = spark.catalog.listTables("testcat.my_db").collect()
assert(tables.size == 2)
Expand Down Expand Up @@ -689,12 +677,8 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf
val tableSchema = new StructType().add("i", "int")
val description = "this is a test table"

spark.catalog.createTable(
tableName = Array(catalogName, dbName, tableName).mkString("."),
source = classOf[FakeV2Provider].getName,
schema = tableSchema,
description = description,
options = Map.empty[String, String])
createTable(tableName, dbName, catalogName, classOf[FakeV2Provider].getName, tableSchema,
Map.empty[String, String], description)

val t = spark.catalog.getTable(Array(catalogName, dbName, tableName).mkString("."))
val expectedTable =
Expand All @@ -721,13 +705,8 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf
val tableSchema = new StructType().add("i", "int")

assert(!spark.catalog.tableExists(Array(catalogName, dbName, tableName).mkString(".")))

spark.catalog.createTable(
tableName = Array(catalogName, dbName, tableName).mkString("."),
source = classOf[FakeV2Provider].getName,
schema = tableSchema,
description = "",
options = Map.empty[String, String])
createTable(tableName, dbName, catalogName, classOf[FakeV2Provider].getName, tableSchema,
Map.empty[String, String], "")

assert(spark.catalog.tableExists(Array(catalogName, dbName, tableName).mkString(".")))
}
Expand All @@ -743,4 +722,31 @@ class CatalogSuite extends SharedSparkSession with AnalysisTest with BeforeAndAf
val catalogName2 = "catalog_not_exists"
assert(!spark.catalog.databaseExists(Array(catalogName2, dbName).mkString(".")))
}

test("SPARK-39506: three layer namespace compatibility - cache table, isCached and" +
"uncacheTable") {
val tableSchema = new StructType().add("i", "int")
createTable("my_table", "my_db", "testcat", classOf[FakeV2Provider].getName,
tableSchema, Map.empty[String, String], "")
createTable("my_table2", "my_db", "testcat", classOf[FakeV2Provider].getName,
tableSchema, Map.empty[String, String], "")

spark.catalog.cacheTable("testcat.my_db.my_table", StorageLevel.DISK_ONLY)
assert(spark.table("testcat.my_db.my_table").storageLevel == StorageLevel.DISK_ONLY)
assert(spark.catalog.isCached("testcat.my_db.my_table"))

spark.catalog.cacheTable("testcat.my_db.my_table2")
assert(spark.catalog.isCached("testcat.my_db.my_table2"))

spark.catalog.uncacheTable("testcat.my_db.my_table")
assert(!spark.catalog.isCached("testcat.my_db.my_table"))
}

test("SPARK-39506: test setCurrentCatalog, currentCatalog and listCatalogs") {
spark.catalog.setCurrentCatalog("testcat")
assert(spark.catalog.currentCatalog().equals("testcat"))
spark.catalog.setCurrentCatalog("spark_catalog")
assert(spark.catalog.currentCatalog().equals("spark_catalog"))
assert(spark.catalog.listCatalogs().collect().map(c => c.name).toSet == Set("testcat"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR, but we should figure out why spark_catalog is missed here.

}
}