Skip to content

Commit aadce78

Browse files
committed
[SPARK-7886] Use FunctionRegistry for built-in expressions in HiveContext.
1 parent 490d5a7 commit aadce78

File tree

5 files changed

+42
-47
lines changed

5 files changed

+42
-47
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,16 @@ trait FunctionRegistry {
3535
def lookupFunction(name: String, children: Seq[Expression]): Expression
3636
}
3737

38-
trait OverrideFunctionRegistry extends FunctionRegistry {
38+
class OverrideFunctionRegistry(underlying: FunctionRegistry) extends FunctionRegistry {
3939

4040
private val functionBuilders = StringKeyHashMap[FunctionBuilder](caseSensitive = false)
4141

4242
override def registerFunction(name: String, builder: FunctionBuilder): Unit = {
4343
functionBuilders.put(name, builder)
4444
}
4545

46-
abstract override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
47-
functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name, children))
46+
override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
47+
functionBuilders.get(name).map(_(children)).getOrElse(underlying.lookupFunction(name, children))
4848
}
4949
}
5050

@@ -133,6 +133,12 @@ object FunctionRegistry {
133133
expression[Sum]("sum")
134134
)
135135

136+
val builtin: FunctionRegistry = {
137+
val fr = new SimpleFunctionRegistry
138+
expressions.foreach { case (name, builder) => fr.registerFunction(name, builder) }
139+
fr
140+
}
141+
136142
/** See usage above. */
137143
private def expression[T <: Expression](name: String)
138144
(implicit tag: ClassTag[T]): (String, FunctionBuilder) = {

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
120120

121121
// TODO how to handle the temp function per user session?
122122
@transient
123-
protected[sql] lazy val functionRegistry: FunctionRegistry = {
124-
val fr = new SimpleFunctionRegistry
125-
FunctionRegistry.expressions.foreach { case (name, func) => fr.registerFunction(name, func) }
126-
fr
127-
}
123+
protected[sql] lazy val functionRegistry: FunctionRegistry =
124+
new OverrideFunctionRegistry(FunctionRegistry.builtin)
128125

129126
@transient
130127
protected[sql] lazy val analyzer: Analyzer =

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
374374
// Note that HiveUDFs will be overridden by functions registered in this context.
375375
@transient
376376
override protected[sql] lazy val functionRegistry: FunctionRegistry =
377-
new HiveFunctionRegistry with OverrideFunctionRegistry
377+
new OverrideFunctionRegistry(new HiveFunctionRegistry(FunctionRegistry.builtin))
378378

379379
/* An analyzer that uses the Hive metastore. */
380380
@transient

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,14 +1353,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
13531353
UnresolvedStar(Some(name))
13541354

13551355
/* Aggregate Functions */
1356-
case Token("TOK_FUNCTION", Token(AVG(), Nil) :: arg :: Nil) => Average(nodeToExpr(arg))
1357-
case Token("TOK_FUNCTION", Token(COUNT(), Nil) :: arg :: Nil) => Count(nodeToExpr(arg))
13581356
case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1))
13591357
case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => CountDistinct(args.map(nodeToExpr))
1360-
case Token("TOK_FUNCTION", Token(SUM(), Nil) :: arg :: Nil) => Sum(nodeToExpr(arg))
13611358
case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg))
1362-
case Token("TOK_FUNCTION", Token(MAX(), Nil) :: arg :: Nil) => Max(nodeToExpr(arg))
1363-
case Token("TOK_FUNCTION", Token(MIN(), Nil) :: arg :: Nil) => Min(nodeToExpr(arg))
13641359

13651360
/* System functions about string operations */
13661361
case Token("TOK_FUNCTION", Token(UPPER(), Nil) :: arg :: Nil) => Upper(nodeToExpr(arg))
@@ -1469,17 +1464,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
14691464
case Token("[", child :: ordinal :: Nil) =>
14701465
UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal))
14711466

1472-
/* Other functions */
1473-
case Token("TOK_FUNCTION", Token(ARRAY(), Nil) :: children) =>
1474-
CreateArray(children.map(nodeToExpr))
1475-
case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand()
1476-
case Token("TOK_FUNCTION", Token(RAND(), Nil) :: seed :: Nil) => Rand(seed.toString.toLong)
1477-
case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) =>
1478-
Substring(nodeToExpr(string), nodeToExpr(pos), Literal.create(Integer.MAX_VALUE, IntegerType))
1479-
case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) =>
1480-
Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length))
1481-
case Token("TOK_FUNCTION", Token(COALESCE(), Nil) :: list) => Coalesce(list.map(nodeToExpr))
1482-
14831467
/* Window Functions */
14841468
case Token("TOK_FUNCTION", Token(name, Nil) +: args :+ Token("TOK_WINDOWSPEC", spec)) =>
14851469
val function = UnresolvedWindowFunction(name, args.map(nodeToExpr))

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,43 +33,51 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper
3333
import org.apache.spark.Logging
3434
import org.apache.spark.sql.AnalysisException
3535
import org.apache.spark.sql.catalyst.analysis
36+
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
3637
import org.apache.spark.sql.catalyst.errors.TreeNodeException
3738
import org.apache.spark.sql.catalyst.expressions._
3839
import org.apache.spark.sql.catalyst.plans.logical._
3940
import org.apache.spark.sql.catalyst.rules.Rule
4041
import org.apache.spark.sql.hive.HiveShim._
4142
import org.apache.spark.sql.types._
4243

44+
import scala.util.Try
4345

44-
private[hive] abstract class HiveFunctionRegistry
46+
47+
private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry)
4548
extends analysis.FunctionRegistry with HiveInspectors {
4649

4750
def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name)
4851

4952
override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
50-
// We only look it up to see if it exists, but do not include it in the HiveUDF since it is
51-
// not always serializable.
52-
val functionInfo: FunctionInfo =
53-
Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse(
54-
throw new AnalysisException(s"undefined function $name"))
55-
56-
val functionClassName = functionInfo.getFunctionClass.getName
57-
58-
if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
59-
HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children)
60-
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
61-
HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children)
62-
} else if (
63-
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
64-
HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children)
65-
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
66-
HiveUdaf(new HiveFunctionWrapper(functionClassName), children)
67-
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
68-
HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children)
69-
} else {
70-
sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
53+
Try(underlying.lookupFunction(name, children)).getOrElse {
54+
// We only look it up to see if it exists, but do not include it in the HiveUDF since it is
55+
// not always serializable.
56+
val functionInfo: FunctionInfo =
57+
Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse(
58+
throw new AnalysisException(s"undefined function $name"))
59+
60+
val functionClassName = functionInfo.getFunctionClass.getName
61+
62+
if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
63+
HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children)
64+
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
65+
HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children)
66+
} else if (
67+
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
68+
HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children)
69+
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
70+
HiveUdaf(new HiveFunctionWrapper(functionClassName), children)
71+
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
72+
HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children)
73+
} else {
74+
sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
75+
}
7176
}
7277
}
78+
79+
override def registerFunction(name: String, builder: FunctionBuilder): Unit =
80+
throw new UnsupportedOperationException
7381
}
7482

7583
private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])

0 commit comments

Comments
 (0)