Skip to content

Commit 24e1e41

Browse files
wangyumgatorsmile
authored andcommitted
[SPARK-28196][SQL] Add a new listTables and listLocalTempViews APIs for SessionCatalog
## What changes were proposed in this pull request? This pr add two API for [SessionCatalog](https://github.com/apache/spark/blob/df4cb471c9712a2fe496664028d9303caebd8777/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala): ```scala def listTables(db: String, pattern: String, includeLocalTempViews: Boolean): Seq[TableIdentifier] def listLocalTempViews(pattern: String): Seq[TableIdentifier] ``` Because in some cases `listTables` does not need local temporary view and sometimes only need list local temporary view. ## How was this patch tested? unit tests Closes #24995 from wangyum/SPARK-28196. Authored-by: Yuming Wang <[email protected]> Signed-off-by: gatorsmile <[email protected]>
1 parent e0e2144 commit 24e1e41

File tree

2 files changed

+91
-3
lines changed

2 files changed

+91
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,19 @@ class SessionCatalog(
784784
* Note that, if the specified database is global temporary view database, we will list global
785785
* temporary views.
786786
*/
787-
def listTables(db: String, pattern: String): Seq[TableIdentifier] = {
787+
def listTables(db: String, pattern: String): Seq[TableIdentifier] = listTables(db, pattern, true)
788+
789+
/**
790+
* List all matching tables in the specified database, including local temporary views
791+
* if includeLocalTempViews is enabled.
792+
*
793+
* Note that, if the specified database is global temporary view database, we will list global
794+
* temporary views.
795+
*/
796+
def listTables(
797+
db: String,
798+
pattern: String,
799+
includeLocalTempViews: Boolean): Seq[TableIdentifier] = {
788800
val dbName = formatDatabaseName(db)
789801
val dbTables = if (dbName == globalTempViewManager.database) {
790802
globalTempViewManager.listViewNames(pattern).map { name =>
@@ -796,12 +808,23 @@ class SessionCatalog(
796808
TableIdentifier(name, Some(dbName))
797809
}
798810
}
799-
val localTempViews = synchronized {
811+
812+
if (includeLocalTempViews) {
813+
dbTables ++ listLocalTempViews(pattern)
814+
} else {
815+
dbTables
816+
}
817+
}
818+
819+
/**
820+
* List all matching local temporary views.
821+
*/
822+
def listLocalTempViews(pattern: String): Seq[TableIdentifier] = {
823+
synchronized {
800824
StringUtils.filterPattern(tempViews.keys.toSeq, pattern).map { name =>
801825
TableIdentifier(name)
802826
}
803827
}
804-
dbTables ++ localTempViews
805828
}
806829

807830
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,71 @@ abstract class SessionCatalogSuite extends AnalysisTest {
717717
}
718718
}
719719

720+
test("list tables with pattern and includeLocalTempViews") {
721+
withEmptyCatalog { catalog =>
722+
catalog.createDatabase(newDb("mydb"), ignoreIfExists = false)
723+
catalog.createTable(newTable("tbl1", "mydb"), ignoreIfExists = false)
724+
catalog.createTable(newTable("tbl2", "mydb"), ignoreIfExists = false)
725+
val tempTable = Range(1, 10, 2, 10)
726+
catalog.createTempView("temp_view1", tempTable, overrideIfExists = false)
727+
catalog.createTempView("temp_view4", tempTable, overrideIfExists = false)
728+
729+
assert(catalog.listTables("mydb").toSet == catalog.listTables("mydb", "*").toSet)
730+
assert(catalog.listTables("mydb").toSet == catalog.listTables("mydb", "*", true).toSet)
731+
assert(catalog.listTables("mydb").toSet ==
732+
catalog.listTables("mydb", "*", false).toSet ++ catalog.listLocalTempViews("*"))
733+
assert(catalog.listTables("mydb", "*", true).toSet ==
734+
Set(TableIdentifier("tbl1", Some("mydb")),
735+
TableIdentifier("tbl2", Some("mydb")),
736+
TableIdentifier("temp_view1"),
737+
TableIdentifier("temp_view4")))
738+
assert(catalog.listTables("mydb", "*", false).toSet ==
739+
Set(TableIdentifier("tbl1", Some("mydb")), TableIdentifier("tbl2", Some("mydb"))))
740+
assert(catalog.listTables("mydb", "tbl*", true).toSet ==
741+
Set(TableIdentifier("tbl1", Some("mydb")), TableIdentifier("tbl2", Some("mydb"))))
742+
assert(catalog.listTables("mydb", "tbl*", false).toSet ==
743+
Set(TableIdentifier("tbl1", Some("mydb")), TableIdentifier("tbl2", Some("mydb"))))
744+
assert(catalog.listTables("mydb", "temp_view*", true).toSet ==
745+
Set(TableIdentifier("temp_view1"), TableIdentifier("temp_view4")))
746+
assert(catalog.listTables("mydb", "temp_view*", false).toSet == Set.empty)
747+
}
748+
}
749+
750+
test("list temporary view with pattern") {
751+
withBasicCatalog { catalog =>
752+
val tempTable = Range(1, 10, 2, 10)
753+
catalog.createTempView("temp_view1", tempTable, overrideIfExists = false)
754+
catalog.createTempView("temp_view4", tempTable, overrideIfExists = false)
755+
assert(catalog.listLocalTempViews("*").toSet ==
756+
Set(TableIdentifier("temp_view1"), TableIdentifier("temp_view4")))
757+
assert(catalog.listLocalTempViews("temp_view*").toSet ==
758+
Set(TableIdentifier("temp_view1"), TableIdentifier("temp_view4")))
759+
assert(catalog.listLocalTempViews("*1").toSet == Set(TableIdentifier("temp_view1")))
760+
assert(catalog.listLocalTempViews("does_not_exist").toSet == Set.empty)
761+
}
762+
}
763+
764+
test("list global temporary view and local temporary view with pattern") {
765+
withBasicCatalog { catalog =>
766+
val tempTable = Range(1, 10, 2, 10)
767+
catalog.createTempView("temp_view1", tempTable, overrideIfExists = false)
768+
catalog.createTempView("temp_view4", tempTable, overrideIfExists = false)
769+
catalog.globalTempViewManager.create("global_temp_view1", tempTable, overrideIfExists = false)
770+
catalog.globalTempViewManager.create("global_temp_view2", tempTable, overrideIfExists = false)
771+
assert(catalog.listTables(catalog.globalTempViewManager.database, "*").toSet ==
772+
Set(TableIdentifier("temp_view1"),
773+
TableIdentifier("temp_view4"),
774+
TableIdentifier("global_temp_view1", Some(catalog.globalTempViewManager.database)),
775+
TableIdentifier("global_temp_view2", Some(catalog.globalTempViewManager.database))))
776+
assert(catalog.listTables(catalog.globalTempViewManager.database, "*temp_view1").toSet ==
777+
Set(TableIdentifier("temp_view1"),
778+
TableIdentifier("global_temp_view1", Some(catalog.globalTempViewManager.database))))
779+
assert(catalog.listTables(catalog.globalTempViewManager.database, "global*").toSet ==
780+
Set(TableIdentifier("global_temp_view1", Some(catalog.globalTempViewManager.database)),
781+
TableIdentifier("global_temp_view2", Some(catalog.globalTempViewManager.database))))
782+
}
783+
}
784+
720785
// --------------------------------------------------------------------------
721786
// Partitions
722787
// --------------------------------------------------------------------------

0 commit comments

Comments
 (0)