Skip to content

Commit 98f72df

Browse files
ksonjJoshRosen
authored andcommitted
[SPARK-6553] [pyspark] Support functools.partial as UDF
Use `f.__repr__()` instead of `f.__name__` when instantiating `UserDefinedFunction`s, so `functools.partial`s may be used. Author: ksonj <[email protected]> Closes #5206 from ksonj/partials and squashes the following commits: ea66f3d [ksonj] Inserted blank lines for PEP8 compliance d81b02b [ksonj] added tests for udf with partial function and callable object 2c76100 [ksonj] Makes UDFs work with all types of callables b814a12 [ksonj] support functools.partial as udf
1 parent bc04fa2 commit 98f72df

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

python/pyspark/sql/functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ def _create_judf(self):
123123
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
124124
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
125125
jdt = ssql_ctx.parseDataType(self.returnType.json())
126-
judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
126+
fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
127+
judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env,
127128
includes, sc.pythonExec, broadcast_vars,
128129
sc._javaAccumulator, jdt)
129130
return judf

python/pyspark/sql/tests.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import shutil
2626
import tempfile
2727
import pickle
28+
import functools
2829

2930
import py4j
3031

@@ -41,6 +42,7 @@
4142
from pyspark.sql.types import *
4243
from pyspark.sql.types import UserDefinedType, _infer_type
4344
from pyspark.tests import ReusedPySparkTestCase
45+
from pyspark.sql.functions import UserDefinedFunction
4446

4547

4648
class ExamplePointUDT(UserDefinedType):
@@ -114,6 +116,35 @@ def tearDownClass(cls):
114116
ReusedPySparkTestCase.tearDownClass()
115117
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
116118

119+
def test_udf_with_callable(self):
120+
d = [Row(number=i, squared=i**2) for i in range(10)]
121+
rdd = self.sc.parallelize(d)
122+
data = self.sqlCtx.createDataFrame(rdd)
123+
124+
class PlusFour:
125+
def __call__(self, col):
126+
if col is not None:
127+
return col + 4
128+
129+
call = PlusFour()
130+
pudf = UserDefinedFunction(call, LongType())
131+
res = data.select(pudf(data['number']).alias('plus_four'))
132+
self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
133+
134+
def test_udf_with_partial_function(self):
135+
d = [Row(number=i, squared=i**2) for i in range(10)]
136+
rdd = self.sc.parallelize(d)
137+
data = self.sqlCtx.createDataFrame(rdd)
138+
139+
def some_func(col, param):
140+
if col is not None:
141+
return col + param
142+
143+
pfunc = functools.partial(some_func, param=4)
144+
pudf = UserDefinedFunction(pfunc, LongType())
145+
res = data.select(pudf(data['number']).alias('plus_four'))
146+
self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
147+
117148
def test_udf(self):
118149
self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
119150
[row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()

0 commit comments

Comments
 (0)