Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}

import org.apache.spark.sql.catalyst.util.StringUtils

/**
* An in-memory (ephemeral) implementation of the system catalog.
Expand All @@ -47,11 +47,6 @@ class InMemoryCatalog extends ExternalCatalog {
// Database name -> description
private val catalog = new scala.collection.mutable.HashMap[String, DatabaseDesc]

private def filterPattern(names: Seq[String], pattern: String): Seq[String] = {
val regex = pattern.replaceAll("\\*", ".*").r
names.filter { funcName => regex.pattern.matcher(funcName).matches() }
}

private def functionExists(db: String, funcName: String): Boolean = {
requireDbExists(db)
catalog(db).functions.contains(funcName)
Expand Down Expand Up @@ -141,7 +136,7 @@ class InMemoryCatalog extends ExternalCatalog {
}

override def listDatabases(pattern: String): Seq[String] = synchronized {
filterPattern(listDatabases(), pattern)
StringUtils.filterPattern(listDatabases(), pattern)
}

override def setCurrentDatabase(db: String): Unit = { /* no-op */ }
Expand Down Expand Up @@ -208,7 +203,7 @@ class InMemoryCatalog extends ExternalCatalog {
}

override def listTables(db: String, pattern: String): Seq[String] = synchronized {
filterPattern(listTables(db), pattern)
StringUtils.filterPattern(listTables(db), pattern)
}

// --------------------------------------------------------------------------
Expand Down Expand Up @@ -322,7 +317,7 @@ class InMemoryCatalog extends ExternalCatalog {

override def listFunctions(db: String, pattern: String): Seq[String] = synchronized {
requireDbExists(db)
filterPattern(catalog(db).functions.keysIterator.toSeq, pattern)
StringUtils.filterPattern(catalog(db).functions.keysIterator.toSeq, pattern)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionE
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}

import org.apache.spark.sql.catalyst.util.StringUtils

/**
* An internal catalog that is used by a Spark Session. This internal catalog serves as a
Expand Down Expand Up @@ -297,9 +297,7 @@ class SessionCatalog(
def listTables(db: String, pattern: String): Seq[TableIdentifier] = {
val dbTables =
externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) }
val regex = pattern.replaceAll("\\*", ".*").r
val _tempTables = tempTables.keys.toSeq
.filter { t => regex.pattern.matcher(t).matches() }
val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern)
.map { t => TableIdentifier(t) }
dbTables ++ _tempTables
}
Expand Down Expand Up @@ -613,9 +611,7 @@ class SessionCatalog(
def listFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = {
val dbFunctions =
externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) }
val regex = pattern.replaceAll("\\*", ".*").r
val loadedFunctions = functionRegistry.listFunction()
.filter { f => regex.pattern.matcher(f).matches() }
val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern)
.map { f => FunctionIdentifier(f) }
// TODO: Actually, there will be dbFunctions that have been loaded into the FunctionRegistry.
// So, the returned list may have two entries for the same function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.util

import java.util.regex.Pattern
import java.util.regex.{Pattern, PatternSyntaxException}

import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -52,4 +52,26 @@ object StringUtils {

def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)

/**
* This utility can be used for filtering pattern in the "Like" of "Show Tables / Functions" DDL
* @param names the names list to be filtered
* @param pattern the filter pattern, only '*' and '|' are allowed as wildcards, others will
* follows regular expression convention, case insensitive match and white spaces
* on both ends will be ignored
* @return the filtered names list in order
*/
def filterPattern(names: Seq[String], pattern: String): Seq[String] = {
val funcNames = scala.collection.mutable.SortedSet.empty[String]
pattern.trim().split("\\|").foreach {
Copy link
Contributor

@andrewor14 andrewor14 Apr 6, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style:

....foreach { subPattern =>
  try {
  } catch {
  }
}

subPattern =>
try {
val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r
funcNames ++= names.filter{name => regex.pattern.matcher(name).matches()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: filter { name => ... }

} catch {
case _: PatternSyntaxException =>
}
}
funcNames.toSeq
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,16 @@ class StringUtilsSuite extends SparkFunSuite {
assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E")
assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E")
}

test("filter pattern") {
val names = Seq("a1", "a2", "b2", "c3")
assert(filterPattern(names, " * ") === Seq("a1", "a2", "b2", "c3"))
assert(filterPattern(names, "*a*") === Seq("a1", "a2"))
assert(filterPattern(names, " *a* ") === Seq("a1", "a2"))
assert(filterPattern(names, " a* ") === Seq("a1", "a2"))
assert(filterPattern(names, " a.* ") === Seq("a1", "a2"))
assert(filterPattern(names, " B.*|a* ") === Seq("a1", "a2", "b2"))
assert(filterPattern(names, " a. ") === Seq("a1", "a2"))
assert(filterPattern(names, " d* ") === Nil)
}
}
12 changes: 5 additions & 7 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -56,17 +57,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {

test("show functions") {
def getFunctions(pattern: String): Seq[Row] = {
val regex = java.util.regex.Pattern.compile(pattern)
sqlContext.sessionState.functionRegistry.listFunction()
.filter(regex.matcher(_).matches()).map(Row(_))
StringUtils.filterPattern(sqlContext.sessionState.functionRegistry.listFunction(), pattern)
.map(Row(_))
}
checkAnswer(sql("SHOW functions"), getFunctions(".*"))
checkAnswer(sql("SHOW functions"), getFunctions("*"))
Seq("^c*", "*e$", "log*", "*date*").foreach { pattern =>
// For the pattern part, only '*' and '|' are allowed as wildcards.
// For '*', we need to replace it to '.*'.
checkAnswer(
sql(s"SHOW FUNCTIONS '$pattern'"),
getFunctions(pattern.replaceAll("\\*", ".*")))
checkAnswer(sql(s"SHOW FUNCTIONS '$pattern'"), getFunctions(pattern))
}
}

Expand Down