Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,51 +17,68 @@

package org.apache.spark.sql.catalyst.analysis

import java.lang.reflect.Modifier
import java.util.Locale
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.xml._
import org.apache.spark.sql.catalyst.util.StringKeyHashMap
import org.apache.spark.sql.types._


/**
* A catalog for looking up user defined functions, used by an [[Analyzer]].
*
* Note: The implementation should be thread-safe to allow concurrent access.
* Note:
* 1) The implementation should be thread-safe to allow concurrent access.
* 2) the database name is always case-sensitive here, callers are responsible to
* format the database name w.r.t. case-sensitive config.
*/
trait FunctionRegistry {

final def registerFunction(name: String, builder: FunctionBuilder): Unit = {
registerFunction(name, new ExpressionInfo(builder.getClass.getCanonicalName, name), builder)
final def registerFunction(name: FunctionIdentifier, builder: FunctionBuilder): Unit = {
val info = new ExpressionInfo(
builder.getClass.getCanonicalName, name.database.orNull, name.funcName)
registerFunction(name, info, builder)
}

def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder): Unit
def registerFunction(
name: FunctionIdentifier,
info: ExpressionInfo,
builder: FunctionBuilder): Unit

/* Create or replace a temporary function. */
final def createOrReplaceTempFunction(name: String, builder: FunctionBuilder): Unit = {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we already expose FunctionRegistry to the stable class UDFRegistration, I added this extra API for a helper function.

Ideally, this function should only exist in SessionCatalog.

registerFunction(
FunctionIdentifier(name),
builder)
}

@throws[AnalysisException]("If function does not exist")
def lookupFunction(name: String, children: Seq[Expression]): Expression
def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression

/* List all of the registered function names. */
def listFunction(): Seq[String]
def listFunction(): Seq[FunctionIdentifier]

/* Get the class of the registered function by specified name. */
def lookupFunction(name: String): Option[ExpressionInfo]
def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo]

/* Get the builder of the registered function by specified name. */
def lookupFunctionBuilder(name: String): Option[FunctionBuilder]
def lookupFunctionBuilder(name: FunctionIdentifier): Option[FunctionBuilder]

/** Drop a function and return whether the function existed. */
def dropFunction(name: String): Boolean
def dropFunction(name: FunctionIdentifier): Boolean

/** Checks if a function with a given name exists. */
def functionExists(name: String): Boolean = lookupFunction(name).isDefined
def functionExists(name: FunctionIdentifier): Boolean = lookupFunction(name).isDefined

/** Clear all registered functions. */
def clear(): Unit
Expand All @@ -72,39 +89,47 @@ trait FunctionRegistry {

class SimpleFunctionRegistry extends FunctionRegistry {

protected val functionBuilders =
StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false)
Copy link
Member Author

@gatorsmile gatorsmile May 30, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this PR, the codes has a bug. The database name could be case sensitive.

@GuardedBy("this")
private val functionBuilders =
new mutable.HashMap[FunctionIdentifier, (ExpressionInfo, FunctionBuilder)]

// Resolution of the function name is always case insensitive, but the database name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks weird, database name is always case sensitive and function name is always case insenstive?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is the resolution rule we are using now. : (

// depends on the caller
private def normalizeFuncName(name: FunctionIdentifier): FunctionIdentifier = {
FunctionIdentifier(name.funcName.toLowerCase(Locale.ROOT), name.database)
}

override def registerFunction(
name: String,
name: FunctionIdentifier,
info: ExpressionInfo,
builder: FunctionBuilder): Unit = synchronized {
functionBuilders.put(name, (info, builder))
functionBuilders.put(normalizeFuncName(name), (info, builder))
}

override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
val func = synchronized {
functionBuilders.get(name).map(_._2).getOrElse {
functionBuilders.get(normalizeFuncName(name)).map(_._2).getOrElse {
throw new AnalysisException(s"undefined function $name")
}
}
func(children)
}

override def listFunction(): Seq[String] = synchronized {
functionBuilders.iterator.map(_._1).toList.sorted
Copy link
Member Author

@gatorsmile gatorsmile May 30, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sorted is useless. Thus, I removed it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think sorted output can make users easy to search for a function, shall we still keep it?

override def listFunction(): Seq[FunctionIdentifier] = synchronized {
functionBuilders.iterator.map(_._1).toList
}

override def lookupFunction(name: String): Option[ExpressionInfo] = synchronized {
functionBuilders.get(name).map(_._1)
override def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo] = synchronized {
functionBuilders.get(normalizeFuncName(name)).map(_._1)
}

override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = synchronized {
functionBuilders.get(name).map(_._2)
override def lookupFunctionBuilder(
name: FunctionIdentifier): Option[FunctionBuilder] = synchronized {
functionBuilders.get(normalizeFuncName(name)).map(_._2)
}

override def dropFunction(name: String): Boolean = synchronized {
functionBuilders.remove(name).isDefined
override def dropFunction(name: FunctionIdentifier): Boolean = synchronized {
functionBuilders.remove(normalizeFuncName(name)).isDefined
}

override def clear(): Unit = synchronized {
Expand All @@ -125,28 +150,28 @@ class SimpleFunctionRegistry extends FunctionRegistry {
* functions are already filled in and the analyzer needs only to resolve attribute references.
*/
object EmptyFunctionRegistry extends FunctionRegistry {
override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder)
: Unit = {
override def registerFunction(
name: FunctionIdentifier, info: ExpressionInfo, builder: FunctionBuilder): Unit = {
throw new UnsupportedOperationException
}

override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
throw new UnsupportedOperationException
}

override def listFunction(): Seq[String] = {
override def listFunction(): Seq[FunctionIdentifier] = {
throw new UnsupportedOperationException
}

override def lookupFunction(name: String): Option[ExpressionInfo] = {
override def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo] = {
throw new UnsupportedOperationException
}

override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = {
override def lookupFunctionBuilder(name: FunctionIdentifier): Option[FunctionBuilder] = {
throw new UnsupportedOperationException
}

override def dropFunction(name: String): Boolean = {
override def dropFunction(name: FunctionIdentifier): Boolean = {
throw new UnsupportedOperationException
}

Expand Down Expand Up @@ -455,11 +480,13 @@ object FunctionRegistry {

val builtin: SimpleFunctionRegistry = {
val fr = new SimpleFunctionRegistry
expressions.foreach { case (name, (info, builder)) => fr.registerFunction(name, info, builder) }
expressions.foreach {
case (name, (info, builder)) => fr.registerFunction(FunctionIdentifier(name), info, builder)
}
fr
}

val functionSet: Set[String] = builtin.listFunction().toSet
val functionSet: Set[FunctionIdentifier] = builtin.listFunction().toSet

/** See usage above. */
private def expression[T <: Expression](name: String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1029,13 +1029,12 @@ class SessionCatalog(
requireDbExists(db)
val identifier = name.copy(database = Some(db))
if (functionExists(identifier)) {
// TODO: registry should just take in FunctionIdentifier for type safety
if (functionRegistry.functionExists(identifier.unquotedString)) {
if (functionRegistry.functionExists(identifier)) {
// If we have loaded this function into the FunctionRegistry,
// also drop it from there.
// For a permanent function, because we loaded it to the FunctionRegistry
// when it's first used, we also need to drop it from the FunctionRegistry.
functionRegistry.dropFunction(identifier.unquotedString)
functionRegistry.dropFunction(identifier)
}
externalCatalog.dropFunction(db, name.funcName)
} else if (!ignoreIfNotExists) {
Expand All @@ -1061,7 +1060,7 @@ class SessionCatalog(
def functionExists(name: FunctionIdentifier): Boolean = {
val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
requireDbExists(db)
functionRegistry.functionExists(name.unquotedString) ||
functionRegistry.functionExists(name) ||
externalCatalog.functionExists(db, name.funcName)
}

Expand Down Expand Up @@ -1095,20 +1094,20 @@ class SessionCatalog(
ignoreIfExists: Boolean,
functionBuilder: Option[FunctionBuilder] = None): Unit = {
val func = funcDefinition.identifier
if (functionRegistry.functionExists(func.unquotedString) && !ignoreIfExists) {
if (functionRegistry.functionExists(func) && !ignoreIfExists) {
throw new AnalysisException(s"Function $func already exists")
}
val info = new ExpressionInfo(funcDefinition.className, func.database.orNull, func.funcName)
val builder =
functionBuilder.getOrElse(makeFunctionBuilder(func.unquotedString, funcDefinition.className))
functionRegistry.registerFunction(func.unquotedString, info, builder)
functionRegistry.registerFunction(func, info, builder)
}

/**
* Drop a temporary function.
*/
def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = {
if (!functionRegistry.dropFunction(name) && !ignoreIfNotExists) {
if (!functionRegistry.dropFunction(FunctionIdentifier(name)) && !ignoreIfNotExists) {
throw new NoSuchTempFunctionException(name)
}
}
Expand All @@ -1123,8 +1122,8 @@ class SessionCatalog(
// A temporary function is a function that has been registered in functionRegistry
// without a database name, and is neither a built-in function nor a Hive function
name.database.isEmpty &&
functionRegistry.functionExists(name.funcName) &&
!FunctionRegistry.builtin.functionExists(name.funcName) &&
functionRegistry.functionExists(name) &&
!FunctionRegistry.builtin.functionExists(name) &&
!hiveFunctions.contains(name.funcName.toLowerCase(Locale.ROOT))
}

Expand All @@ -1140,8 +1139,8 @@ class SessionCatalog(
// TODO: just make function registry take in FunctionIdentifier instead of duplicating this
val database = name.database.orElse(Some(currentDb)).map(formatDatabaseName)
val qualifiedName = name.copy(database = database)
functionRegistry.lookupFunction(name.funcName)
Copy link
Member Author

@gatorsmile gatorsmile May 30, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also sounds a bug. This line before this PR ignores the database name.

.orElse(functionRegistry.lookupFunction(qualifiedName.unquotedString))
functionRegistry.lookupFunction(name)
.orElse(functionRegistry.lookupFunction(qualifiedName))
.getOrElse {
val db = qualifiedName.database.get
requireDbExists(db)
Expand Down Expand Up @@ -1176,19 +1175,19 @@ class SessionCatalog(
// Note: the implementation of this function is a little bit convoluted.
// We probably shouldn't use a single FunctionRegistry to register all three kinds of functions
// (built-in, temp, and external).
if (name.database.isEmpty && functionRegistry.functionExists(name.funcName)) {
if (name.database.isEmpty && functionRegistry.functionExists(name)) {
// This function has been already loaded into the function registry.
return functionRegistry.lookupFunction(name.funcName, children)
return functionRegistry.lookupFunction(name, children)
}

// If the name itself is not qualified, add the current database to it.
val database = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
val qualifiedName = name.copy(database = Some(database))

if (functionRegistry.functionExists(qualifiedName.unquotedString)) {
if (functionRegistry.functionExists(qualifiedName)) {
// This function has been already loaded into the function registry.
// Unlike the above block, we find this function by using the qualified name.
return functionRegistry.lookupFunction(qualifiedName.unquotedString, children)
return functionRegistry.lookupFunction(qualifiedName, children)
}

// The function has not been loaded to the function registry, which means
Expand All @@ -1209,7 +1208,7 @@ class SessionCatalog(
// At here, we preserve the input from the user.
registerFunction(catalogFunction.copy(identifier = qualifiedName), ignoreIfExists = false)
// Now, we need to create the Expression.
functionRegistry.lookupFunction(qualifiedName.unquotedString, children)
functionRegistry.lookupFunction(qualifiedName, children)
}

/**
Expand All @@ -1229,8 +1228,8 @@ class SessionCatalog(
requireDbExists(dbName)
val dbFunctions = externalCatalog.listFunctions(dbName, pattern).map { f =>
FunctionIdentifier(f, Some(dbName)) }
val loadedFunctions =
StringUtils.filterPattern(functionRegistry.listFunction(), pattern).map { f =>
val loadedFunctions = StringUtils
.filterPattern(functionRegistry.listFunction().map(_.unquotedString), pattern).map { f =>
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR keeps the current behavior. However, I think it is also a bug. The user-specified pattern should not consider the database name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can fix it as a follow-up

// In functionRegistry, function names are stored as an unquoted format.
Try(parser.parseFunctionIdentifier(f)) match {
case Success(e) => e
Expand All @@ -1243,7 +1242,7 @@ class SessionCatalog(
// The session catalog caches some persistent functions in the FunctionRegistry
// so there can be duplicates.
functions.map {
case f if FunctionRegistry.functionSet.contains(f.funcName) => (f, "SYSTEM")
case f if FunctionRegistry.functionSet.contains(f) => (f, "SYSTEM")
case f => (f, "USER")
}.distinct
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,7 @@ abstract class SessionCatalogSuite extends PlanTest {
assert(!catalog.isTemporaryFunction(FunctionIdentifier("func1")))

// Returns false when the function is built-in or hive
assert(FunctionRegistry.builtin.functionExists("sum"))
assert(FunctionRegistry.builtin.functionExists(FunctionIdentifier("sum")))
assert(!catalog.isTemporaryFunction(FunctionIdentifier("sum")))
assert(!catalog.isTemporaryFunction(FunctionIdentifier("histogram_numeric")))
}
Expand Down
Loading