Skip to content

Commit b76ccd8

Browse files
author
bomeng
committed
fix like pattern in show ddl
1 parent 78c1076 commit b76ccd8

File tree

5 files changed

+47
-24
lines changed

5 files changed

+47
-24
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.mutable
2121

2222
import org.apache.spark.sql.AnalysisException
2323
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
24-
24+
import org.apache.spark.sql.catalyst.util.StringUtils
2525

2626
/**
2727
* An in-memory (ephemeral) implementation of the system catalog.
@@ -47,11 +47,6 @@ class InMemoryCatalog extends ExternalCatalog {
4747
// Database name -> description
4848
private val catalog = new scala.collection.mutable.HashMap[String, DatabaseDesc]
4949

50-
private def filterPattern(names: Seq[String], pattern: String): Seq[String] = {
51-
val regex = pattern.replaceAll("\\*", ".*").r
52-
names.filter { funcName => regex.pattern.matcher(funcName).matches() }
53-
}
54-
5550
private def functionExists(db: String, funcName: String): Boolean = {
5651
requireDbExists(db)
5752
catalog(db).functions.contains(funcName)
@@ -141,7 +136,7 @@ class InMemoryCatalog extends ExternalCatalog {
141136
}
142137

143138
override def listDatabases(pattern: String): Seq[String] = synchronized {
144-
filterPattern(listDatabases(), pattern)
139+
StringUtils.filterPattern(listDatabases(), pattern)
145140
}
146141

147142
override def setCurrentDatabase(db: String): Unit = { /* no-op */ }
@@ -208,7 +203,7 @@ class InMemoryCatalog extends ExternalCatalog {
208203
}
209204

210205
override def listTables(db: String, pattern: String): Seq[String] = synchronized {
211-
filterPattern(listTables(db), pattern)
206+
StringUtils.filterPattern(listTables(db), pattern)
212207
}
213208

214209
// --------------------------------------------------------------------------
@@ -322,7 +317,7 @@ class InMemoryCatalog extends ExternalCatalog {
322317

323318
override def listFunctions(db: String, pattern: String): Seq[String] = synchronized {
324319
requireDbExists(db)
325-
filterPattern(catalog(db).functions.keysIterator.toSeq, pattern)
320+
StringUtils.filterPattern(catalog(db).functions.keysIterator.toSeq, pattern)
326321
}
327322

328323
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionE
2828
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
2929
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
3030
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
31-
31+
import org.apache.spark.sql.catalyst.util.StringUtils
3232

3333
/**
3434
* An internal catalog that is used by a Spark Session. This internal catalog serves as a
@@ -297,9 +297,7 @@ class SessionCatalog(
297297
def listTables(db: String, pattern: String): Seq[TableIdentifier] = {
298298
val dbTables =
299299
externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) }
300-
val regex = pattern.replaceAll("\\*", ".*").r
301-
val _tempTables = tempTables.keys.toSeq
302-
.filter { t => regex.pattern.matcher(t).matches() }
300+
val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern)
303301
.map { t => TableIdentifier(t) }
304302
dbTables ++ _tempTables
305303
}
@@ -613,9 +611,7 @@ class SessionCatalog(
613611
def listFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = {
614612
val dbFunctions =
615613
externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) }
616-
val regex = pattern.replaceAll("\\*", ".*").r
617-
val loadedFunctions = functionRegistry.listFunction()
618-
.filter { f => regex.pattern.matcher(f).matches() }
614+
val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern)
619615
.map { f => FunctionIdentifier(f) }
620616
// TODO: Actually, there will be dbFunctions that have been loaded into the FunctionRegistry.
621617
// So, the returned list may have two entries for the same function.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.util
1919

20-
import java.util.regex.Pattern
20+
import java.util.regex.{Pattern, PatternSyntaxException}
2121

2222
import org.apache.spark.unsafe.types.UTF8String
2323

@@ -52,4 +52,26 @@ object StringUtils {
5252

5353
def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
5454
def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)
55+
56+
/**
57+
* This utility can be used for filtering pattern in the "Like" of "Show Tables / Functions" DDL
58+
* @param names the names list to be filtered
59+
* @param pattern the filter pattern, only '*' and '|' are allowed as wildcards, others will
60+
* follows regular expression convention, case insensitive match and white spaces
61+
* on both ends will be ignored
62+
* @return the filtered names list in order
63+
*/
64+
def filterPattern(names: Seq[String], pattern: String): Seq[String] = {
65+
val funcNames = scala.collection.mutable.SortedSet.empty[String]
66+
pattern.trim().split("\\|").foreach {
67+
subPattern =>
68+
try {
69+
val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r
70+
funcNames ++= names.filter{name => regex.pattern.matcher(name).matches()}
71+
} catch {
72+
case _: PatternSyntaxException =>
73+
}
74+
}
75+
funcNames.toSeq
76+
}
5577
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,16 @@ class StringUtilsSuite extends SparkFunSuite {
3131
assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E")
3232
assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E")
3333
}
34+
35+
test("filter pattern") {
36+
val names = Seq("a1", "a2", "b2", "c3")
37+
assert(filterPattern(names, " * ") === Seq("a1", "a2", "b2", "c3"))
38+
assert(filterPattern(names, "*a*") === Seq("a1", "a2"))
39+
assert(filterPattern(names, " *a* ") === Seq("a1", "a2"))
40+
assert(filterPattern(names, " a* ") === Seq("a1", "a2"))
41+
assert(filterPattern(names, " a.* ") === Seq("a1", "a2"))
42+
assert(filterPattern(names, " B.*|a* ") === Seq("a1", "a2", "b2"))
43+
assert(filterPattern(names, " a. ") === Seq("a1", "a2"))
44+
assert(filterPattern(names, " d* ") === Nil)
45+
}
3446
}

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.AccumulatorSuite
2424
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
2525
import org.apache.spark.sql.catalyst.expressions.SortOrder
2626
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
27+
import org.apache.spark.sql.catalyst.util.StringUtils
2728
import org.apache.spark.sql.execution.aggregate
2829
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin}
2930
import org.apache.spark.sql.functions._
@@ -56,17 +57,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
5657

5758
test("show functions") {
5859
def getFunctions(pattern: String): Seq[Row] = {
59-
val regex = java.util.regex.Pattern.compile(pattern)
60-
sqlContext.sessionState.functionRegistry.listFunction()
61-
.filter(regex.matcher(_).matches()).map(Row(_))
60+
StringUtils.filterPattern(sqlContext.sessionState.functionRegistry.listFunction(), pattern)
61+
.map(Row(_))
6262
}
63-
checkAnswer(sql("SHOW functions"), getFunctions(".*"))
63+
checkAnswer(sql("SHOW functions"), getFunctions("*"))
6464
Seq("^c*", "*e$", "log*", "*date*").foreach { pattern =>
6565
// For the pattern part, only '*' and '|' are allowed as wildcards.
6666
// For '*', we need to replace it to '.*'.
67-
checkAnswer(
68-
sql(s"SHOW FUNCTIONS '$pattern'"),
69-
getFunctions(pattern.replaceAll("\\*", ".*")))
67+
checkAnswer(sql(s"SHOW FUNCTIONS '$pattern'"), getFunctions(pattern))
7068
}
7169
}
7270

0 commit comments

Comments
 (0)