From b814a126bf7470be3fbd92d6c6c3ecbad2b8db1b Mon Sep 17 00:00:00 2001 From: ksonj Date: Thu, 26 Mar 2015 10:20:22 +0100 Subject: [PATCH 1/4] support functools.partial as udf --- python/pyspark/sql/functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5873f09ae3275..daaed77dc7458 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -123,7 +123,7 @@ def _create_judf(self): pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) jdt = ssql_ctx.parseDataType(self.returnType.json()) - judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env, + judf = sc._jvm.UserDefinedPythonFunction(f.__repr__(), bytearray(pickled_command), env, includes, sc.pythonExec, broadcast_vars, sc._javaAccumulator, jdt) return judf From 2c761004f5901c3298760ec9a59566c8708262ce Mon Sep 17 00:00:00 2001 From: ksonj Date: Fri, 27 Mar 2015 08:24:02 +0100 Subject: [PATCH 2/4] Makes UDFs work with all types of callables --- python/pyspark/sql/functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index daaed77dc7458..0f2ff2c2e41f9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -123,7 +123,8 @@ def _create_judf(self): pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) jdt = ssql_ctx.parseDataType(self.returnType.json()) - judf = sc._jvm.UserDefinedPythonFunction(f.__repr__(), bytearray(pickled_command), env, + fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ + judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes, sc.pythonExec, broadcast_vars, sc._javaAccumulator, jdt) return judf From d81b02b9552e5faaeb295d6ef4c9a695715b8300 Mon Sep 17 00:00:00 2001 From: ksonj Date: Mon, 30 Mar 2015 10:16:38 +0200 Subject: [PATCH 3/4] added tests for udf with partial function and callable object --- python/pyspark/sql/tests.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2720439416682..8bd76a65c8a97 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -25,6 +25,7 @@ import shutil import tempfile import pickle +import functools import py4j @@ -41,6 +42,7 @@ from pyspark.sql.types import * from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase +from pyspark.sql.functions import UserDefinedFunction class ExamplePointUDT(UserDefinedType): @@ -114,6 +116,31 @@ def tearDownClass(cls): ReusedPySparkTestCase.tearDownClass() shutil.rmtree(cls.tempdir.name, ignore_errors=True) + def test_udf_with_callable(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + class PlusFour: + def __call__(self, col): + if col is not None: + return col + 4 + call = PlusFour() + pudf = UserDefinedFunction(call, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + + def test_udf_with_partial_function(self): + d = [Row(number=i, squared=i**2) for i in range(10)] + rdd = self.sc.parallelize(d) + data = self.sqlCtx.createDataFrame(rdd) + def some_func(col, param): + if col is not None: + return col + param + pfunc = functools.partial(some_func, param=4) + pudf = UserDefinedFunction(pfunc, LongType()) + res = data.select(pudf(data['number']).alias('plus_four')) + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) + def test_udf(self): self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() From ea66f3df12de4d5d3056d0ca97f3b9b15f9d99a7 Mon Sep 17 00:00:00 2001 From: ksonj Date: Tue, 31 Mar 2015 08:54:12 +0200 Subject: [PATCH 4/4] Inserted blank lines for PEP8 compliance --- python/pyspark/sql/tests.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8bd76a65c8a97..81bfad56a2d48 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -120,10 +120,12 @@ def test_udf_with_callable(self): d = [Row(number=i, squared=i**2) for i in range(10)] rdd = self.sc.parallelize(d) data = self.sqlCtx.createDataFrame(rdd) + class PlusFour: def __call__(self, col): if col is not None: return col + 4 + call = PlusFour() pudf = UserDefinedFunction(call, LongType()) res = data.select(pudf(data['number']).alias('plus_four')) @@ -133,9 +135,11 @@ def test_udf_with_partial_function(self): d = [Row(number=i, squared=i**2) for i in range(10)] rdd = self.sc.parallelize(d) data = self.sqlCtx.createDataFrame(rdd) + def some_func(col, param): if col is not None: return col + param + pfunc = functools.partial(some_func, param=4) pudf = UserDefinedFunction(pfunc, LongType()) res = data.select(pudf(data['number']).alias('plus_four'))