Skip to content

Commit 7bccc3b

Browse files
author
Davies Liu
committed
python udf
1 parent 58dee20 commit 7bccc3b

File tree

3 files changed

+98
-1
lines changed

3 files changed

+98
-1
lines changed

python/pyspark/sql.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2554,6 +2554,45 @@ def _(col):
25542554
return staticmethod(_)
25552555

25562556

2557+
class UserDefinedFunction(object):
2558+
def __init__(self, func, returnType):
2559+
self.func = func
2560+
self.returnType = returnType
2561+
self._judf = self._create_judf()
2562+
2563+
def _create_judf(self):
2564+
f = self.func
2565+
sc = SparkContext._active_spark_context
2566+
# TODO(davies): refactor
2567+
func = lambda _, it: imap(lambda x: f(*x), it)
2568+
command = (func, None,
2569+
AutoBatchedSerializer(PickleSerializer()),
2570+
AutoBatchedSerializer(PickleSerializer()))
2571+
ser = CloudPickleSerializer()
2572+
pickled_command = ser.dumps(command)
2573+
if len(pickled_command) > (1 << 20): # 1M
2574+
broadcast = sc.broadcast(pickled_command)
2575+
pickled_command = ser.dumps(broadcast)
2576+
broadcast_vars = ListConverter().convert(
2577+
[x._jbroadcast for x in sc._pickled_broadcast_vars],
2578+
sc._gateway._gateway_client)
2579+
sc._pickled_broadcast_vars.clear()
2580+
env = MapConverter().convert(sc.environment, sc._gateway._gateway_client)
2581+
includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client)
2582+
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
2583+
jdt = ssql_ctx.parseDataType(self.returnType.json())
2584+
judf = sc._jvm.Dsl.pythonUDF(f.__name__, bytearray(pickled_command), env, includes,
2585+
sc.pythonExec, broadcast_vars, sc._javaAccumulator, jdt)
2586+
return judf
2587+
2588+
def __call__(self, *cols):
2589+
sc = SparkContext._active_spark_context
2590+
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
2591+
sc._gateway._gateway_client)
2592+
jc = self._judf.apply(sc._jvm.Dsl.toColumns(jcols))
2593+
return Column(jc)
2594+
2595+
25572596
class Dsl(object):
25582597
"""
25592598
A collections of builtin aggregators
@@ -2612,6 +2651,16 @@ def approxCountDistinct(col, rsd=None):
26122651
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
26132652
return Column(jc)
26142653

2654+
@staticmethod
2655+
def udf(f, returnType=StringType()):
2656+
"""Create a user defined function (UDF)
2657+
2658+
>>> slen = Dsl.udf(lambda s: len(s), IntegerType())
2659+
>>> df.select(slen(df.name).As('slen')).collect()
2660+
[Row(slen=5), Row(slen=3)]
2661+
"""
2662+
return UserDefinedFunction(f, returnType)
2663+
26152664

26162665
def _test():
26172666
import doctest

sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717

1818
package org.apache.spark.sql
1919

20-
import java.util.{List => JList}
20+
import java.util.{List => JList, Map => JMap}
21+
22+
import org.apache.spark.Accumulator
23+
import org.apache.spark.api.python.PythonBroadcast
24+
import org.apache.spark.broadcast.Broadcast
2125

2226
import scala.language.implicitConversions
2327
import scala.reflect.runtime.universe.{TypeTag, typeTag}
@@ -177,6 +181,23 @@ object Dsl {
177181
cols.toList.toSeq
178182
}
179183

184+
/**
185+
* This is a private API for Python
186+
* TODO: move this to a private package
187+
*/
188+
def pythonUDF(
189+
name: String,
190+
command: Array[Byte],
191+
envVars: JMap[String, String],
192+
pythonIncludes: JList[String],
193+
pythonExec: String,
194+
broadcastVars: JList[Broadcast[PythonBroadcast]],
195+
accumulator: Accumulator[JList[Array[Byte]]],
196+
dataType: DataType): UserDefinedPythonFunction = {
197+
UserDefinedPythonFunction(name, command, envVars, pythonIncludes, pythonExec, broadcastVars,
198+
accumulator, dataType)
199+
}
200+
180201
//////////////////////////////////////////////////////////////////////////////////////////////
181202
//////////////////////////////////////////////////////////////////////////////////////////////
182203

sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717

1818
package org.apache.spark.sql
1919

20+
import java.util.{List => JList, Map => JMap}
21+
22+
import org.apache.spark.Accumulator
23+
import org.apache.spark.api.python.PythonBroadcast
24+
import org.apache.spark.broadcast.Broadcast
2025
import org.apache.spark.sql.catalyst.expressions.ScalaUdf
26+
import org.apache.spark.sql.execution.PythonUDF
2127
import org.apache.spark.sql.types.DataType
2228

2329
/**
@@ -37,3 +43,24 @@ case class UserDefinedFunction(f: AnyRef, dataType: DataType) {
3743
Column(ScalaUdf(f, dataType, exprs.map(_.expr)))
3844
}
3945
}
46+
47+
/**
48+
* A user-defined Python function. To create one, use the `pythonUDF` functions in [[Dsl]].
49+
* This is used by Python API.
50+
*/
51+
private[sql] case class UserDefinedPythonFunction(
52+
name: String,
53+
command: Array[Byte],
54+
envVars: JMap[String, String],
55+
pythonIncludes: JList[String],
56+
pythonExec: String,
57+
broadcastVars: JList[Broadcast[PythonBroadcast]],
58+
accumulator: Accumulator[JList[Array[Byte]]],
59+
dataType: DataType) {
60+
61+
def apply(exprs: Column*): Column = {
62+
val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, broadcastVars,
63+
accumulator, dataType, exprs.map(_.expr))
64+
Column(udf)
65+
}
66+
}

0 commit comments

Comments
 (0)