Skip to content

Commit 158ad0b

Browse files
committed
[SPARK-2097][SQL] UDF Support
This patch adds the ability to register lambda functions written in Python, Java or Scala as UDFs for use in SQL or HiveQL. Scala: ```scala registerFunction("strLenScala", (_: String).length) sql("SELECT strLenScala('test')") ``` Python: ```python sqlCtx.registerFunction("strLenPython", lambda x: len(x), IntegerType()) sqlCtx.sql("SELECT strLenPython('test')") ``` Java: ```java sqlContext.registerFunction("stringLengthJava", new UDF1<String, Integer>() { Override public Integer call(String str) throws Exception { return str.length(); } }, DataType.IntegerType); sqlContext.sql("SELECT stringLengthJava('test')"); ``` Author: Michael Armbrust <[email protected]> Closes #1063 from marmbrus/udfs and squashes the following commits: 9eda0fe [Michael Armbrust] newline 747c05e [Michael Armbrust] Add some scala UDF tests. d92727d [Michael Armbrust] Merge remote-tracking branch 'apache/master' into udfs 005d684 [Michael Armbrust] Fix naming and formatting. d14dac8 [Michael Armbrust] Fix last line of autogened java files. 8135c48 [Michael Armbrust] Move UDF unit tests to pyspark. 40b0ffd [Michael Armbrust] Merge remote-tracking branch 'apache/master' into udfs 6a36890 [Michael Armbrust] Switch logging so that SQLContext can be serializable. 7a83101 [Michael Armbrust] Drop toString 795fd15 [Michael Armbrust] Try to avoid capturing SQLContext. e54fb45 [Michael Armbrust] Docs and tests. 437cbe3 [Michael Armbrust] Update use of dataTypes, fix some python tests, address review comments. 01517d6 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udfs 8e6c932 [Michael Armbrust] WIP 3f96a52 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udfs 6237c8d [Michael Armbrust] WIP 2766f0b [Michael Armbrust] Move udfs support to SQL from hive. Add support for Java UDFs. 0f7d50c [Michael Armbrust] Draft of native Spark SQL UDFs for Scala and Python.
1 parent 4c47711 commit 158ad0b

File tree

38 files changed

+1861
-19
lines changed

38 files changed

+1861
-19
lines changed

python/pyspark/sql.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,13 @@
2828
from operator import itemgetter
2929

3030
from pyspark.rdd import RDD, PipelinedRDD
31-
from pyspark.serializers import BatchedSerializer, PickleSerializer
31+
from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
32+
33+
from itertools import chain, ifilter, imap
3234

3335
from py4j.protocol import Py4JError
36+
from py4j.java_collections import ListConverter, MapConverter
37+
3438

3539
__all__ = [
3640
"StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType",
@@ -932,6 +936,39 @@ def _ssql_ctx(self):
932936
self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
933937
return self._scala_SQLContext
934938

939+
def registerFunction(self, name, f, returnType=StringType()):
940+
"""Registers a lambda function as a UDF so it can be used in SQL statements.
941+
942+
In addition to a name and the function itself, the return type can be optionally specified.
943+
When the return type is not given it default to a string and conversion will automatically
944+
be done. For any other return type, the produced object must match the specified type.
945+
946+
>>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x))
947+
>>> sqlCtx.sql("SELECT stringLengthString('test')").collect()
948+
[Row(c0=u'4')]
949+
>>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
950+
>>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
951+
[Row(c0=4)]
952+
>>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
953+
>>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
954+
[Row(c0=5)]
955+
"""
956+
func = lambda _, it: imap(lambda x: f(*x), it)
957+
command = (func,
958+
BatchedSerializer(PickleSerializer(), 1024),
959+
BatchedSerializer(PickleSerializer(), 1024))
960+
env = MapConverter().convert(self._sc.environment,
961+
self._sc._gateway._gateway_client)
962+
includes = ListConverter().convert(self._sc._python_includes,
963+
self._sc._gateway._gateway_client)
964+
self._ssql_ctx.registerPython(name,
965+
bytearray(CloudPickleSerializer().dumps(command)),
966+
env,
967+
includes,
968+
self._sc.pythonExec,
969+
self._sc._javaAccumulator,
970+
str(returnType))
971+
935972
def inferSchema(self, rdd):
936973
"""Infer and apply a schema to an RDD of L{Row}s.
937974

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,49 @@
1818
package org.apache.spark.sql.catalyst.analysis
1919

2020
import org.apache.spark.sql.catalyst.expressions.Expression
21+
import scala.collection.mutable
2122

2223
/** A catalog for looking up user defined functions, used by an [[Analyzer]]. */
2324
trait FunctionRegistry {
25+
type FunctionBuilder = Seq[Expression] => Expression
26+
27+
def registerFunction(name: String, builder: FunctionBuilder): Unit
28+
2429
def lookupFunction(name: String, children: Seq[Expression]): Expression
2530
}
2631

32+
trait OverrideFunctionRegistry extends FunctionRegistry {
33+
34+
val functionBuilders = new mutable.HashMap[String, FunctionBuilder]()
35+
36+
def registerFunction(name: String, builder: FunctionBuilder) = {
37+
functionBuilders.put(name, builder)
38+
}
39+
40+
abstract override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
41+
functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name,children))
42+
}
43+
}
44+
45+
class SimpleFunctionRegistry extends FunctionRegistry {
46+
val functionBuilders = new mutable.HashMap[String, FunctionBuilder]()
47+
48+
def registerFunction(name: String, builder: FunctionBuilder) = {
49+
functionBuilders.put(name, builder)
50+
}
51+
52+
override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
53+
functionBuilders(name)(children)
54+
}
55+
}
56+
2757
/**
2858
* A trivial catalog that returns an error when a function is requested. Used for testing when all
2959
* functions are already filled in and the analyser needs only to resolve attribute references.
3060
*/
3161
object EmptyFunctionRegistry extends FunctionRegistry {
62+
def registerFunction(name: String, builder: FunctionBuilder) = ???
63+
3264
def lookupFunction(name: String, children: Seq[Expression]): Expression = {
3365
throw new UnsupportedOperationException
3466
}

0 commit comments

Comments
 (0)