From df4cb471c9712a2fe496664028d9303caebd8777 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 28 Jun 2019 11:20:14 +0800 Subject: [PATCH] Enhance SessionCatalog.listTables --- .../sql/catalyst/catalog/SessionCatalog.scala | 29 ++++++++- .../catalog/SessionCatalogSuite.scala | 65 +++++++++++++++++++ 2 files changed, 91 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index dcc62989d9d2..74559f5d8879 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -784,7 +784,19 @@ class SessionCatalog( * Note that, if the specified database is global temporary view database, we will list global * temporary views. */ - def listTables(db: String, pattern: String): Seq[TableIdentifier] = { + def listTables(db: String, pattern: String): Seq[TableIdentifier] = listTables(db, pattern, true) + + /** + * List all matching tables in the specified database, including local temporary views + * if includeLocalTempViews is enabled. + * + * Note that, if the specified database is global temporary view database, we will list global + * temporary views. + */ + def listTables( + db: String, + pattern: String, + includeLocalTempViews: Boolean): Seq[TableIdentifier] = { val dbName = formatDatabaseName(db) val dbTables = if (dbName == globalTempViewManager.database) { globalTempViewManager.listViewNames(pattern).map { name => @@ -796,12 +808,23 @@ class SessionCatalog( TableIdentifier(name, Some(dbName)) } } - val localTempViews = synchronized { + + if (includeLocalTempViews) { + dbTables ++ listLocalTempViews(pattern) + } else { + dbTables + } + } + + /** + * List all matching local temporary views. + */ + def listLocalTempViews(pattern: String): Seq[TableIdentifier] = { + synchronized { StringUtils.filterPattern(tempViews.keys.toSeq, pattern).map { name => TableIdentifier(name) } } - dbTables ++ localTempViews } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 5a9e4adf6d88..bce85534ce7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -717,6 +717,71 @@ abstract class SessionCatalogSuite extends AnalysisTest { } } + test("list tables with pattern and includeLocalTempViews") { + withEmptyCatalog { catalog => + catalog.createDatabase(newDb("mydb"), ignoreIfExists = false) + catalog.createTable(newTable("tbl1", "mydb"), ignoreIfExists = false) + catalog.createTable(newTable("tbl2", "mydb"), ignoreIfExists = false) + val tempTable = Range(1, 10, 2, 10) + catalog.createTempView("temp_view1", tempTable, overrideIfExists = false) + catalog.createTempView("temp_view4", tempTable, overrideIfExists = false) + + assert(catalog.listTables("mydb").toSet == catalog.listTables("mydb", "*").toSet) + assert(catalog.listTables("mydb").toSet == catalog.listTables("mydb", "*", true).toSet) + assert(catalog.listTables("mydb").toSet == + catalog.listTables("mydb", "*", false).toSet ++ catalog.listLocalTempViews("*")) + assert(catalog.listTables("mydb", "*", true).toSet == + Set(TableIdentifier("tbl1", Some("mydb")), + TableIdentifier("tbl2", Some("mydb")), + TableIdentifier("temp_view1"), + TableIdentifier("temp_view4"))) + assert(catalog.listTables("mydb", "*", false).toSet == + Set(TableIdentifier("tbl1", Some("mydb")), TableIdentifier("tbl2", Some("mydb")))) + assert(catalog.listTables("mydb", "tbl*", true).toSet == + Set(TableIdentifier("tbl1", Some("mydb")), TableIdentifier("tbl2", Some("mydb")))) + assert(catalog.listTables("mydb", "tbl*", false).toSet == + Set(TableIdentifier("tbl1", Some("mydb")), TableIdentifier("tbl2", Some("mydb")))) + assert(catalog.listTables("mydb", "temp_view*", true).toSet == + Set(TableIdentifier("temp_view1"), TableIdentifier("temp_view4"))) + assert(catalog.listTables("mydb", "temp_view*", false).toSet == Set.empty) + } + } + + test("list temporary view with pattern") { + withBasicCatalog { catalog => + val tempTable = Range(1, 10, 2, 10) + catalog.createTempView("temp_view1", tempTable, overrideIfExists = false) + catalog.createTempView("temp_view4", tempTable, overrideIfExists = false) + assert(catalog.listLocalTempViews("*").toSet == + Set(TableIdentifier("temp_view1"), TableIdentifier("temp_view4"))) + assert(catalog.listLocalTempViews("temp_view*").toSet == + Set(TableIdentifier("temp_view1"), TableIdentifier("temp_view4"))) + assert(catalog.listLocalTempViews("*1").toSet == Set(TableIdentifier("temp_view1"))) + assert(catalog.listLocalTempViews("does_not_exist").toSet == Set.empty) + } + } + + test("list global temporary view and local temporary view with pattern") { + withBasicCatalog { catalog => + val tempTable = Range(1, 10, 2, 10) + catalog.createTempView("temp_view1", tempTable, overrideIfExists = false) + catalog.createTempView("temp_view4", tempTable, overrideIfExists = false) + catalog.globalTempViewManager.create("global_temp_view1", tempTable, overrideIfExists = false) + catalog.globalTempViewManager.create("global_temp_view2", tempTable, overrideIfExists = false) + assert(catalog.listTables(catalog.globalTempViewManager.database, "*").toSet == + Set(TableIdentifier("temp_view1"), + TableIdentifier("temp_view4"), + TableIdentifier("global_temp_view1", Some(catalog.globalTempViewManager.database)), + TableIdentifier("global_temp_view2", Some(catalog.globalTempViewManager.database)))) + assert(catalog.listTables(catalog.globalTempViewManager.database, "*temp_view1").toSet == + Set(TableIdentifier("temp_view1"), + TableIdentifier("global_temp_view1", Some(catalog.globalTempViewManager.database)))) + assert(catalog.listTables(catalog.globalTempViewManager.database, "global*").toSet == + Set(TableIdentifier("global_temp_view1", Some(catalog.globalTempViewManager.database)), + TableIdentifier("global_temp_view2", Some(catalog.globalTempViewManager.database)))) + } + } + // -------------------------------------------------------------------------- // Partitions // --------------------------------------------------------------------------