|
25 | 25 | import shutil |
26 | 26 | import tempfile |
27 | 27 | import pickle |
| 28 | +import functools |
28 | 29 |
|
29 | 30 | import py4j |
30 | 31 |
|
|
41 | 42 | from pyspark.sql.types import * |
42 | 43 | from pyspark.sql.types import UserDefinedType, _infer_type |
43 | 44 | from pyspark.tests import ReusedPySparkTestCase |
| 45 | +from pyspark.sql.functions import UserDefinedFunction |
44 | 46 |
|
45 | 47 |
|
46 | 48 | class ExamplePointUDT(UserDefinedType): |
@@ -114,6 +116,35 @@ def tearDownClass(cls): |
114 | 116 | ReusedPySparkTestCase.tearDownClass() |
115 | 117 | shutil.rmtree(cls.tempdir.name, ignore_errors=True) |
116 | 118 |
|
| 119 | + def test_udf_with_callable(self): |
| 120 | + d = [Row(number=i, squared=i**2) for i in range(10)] |
| 121 | + rdd = self.sc.parallelize(d) |
| 122 | + data = self.sqlCtx.createDataFrame(rdd) |
| 123 | + |
| 124 | + class PlusFour: |
| 125 | + def __call__(self, col): |
| 126 | + if col is not None: |
| 127 | + return col + 4 |
| 128 | + |
| 129 | + call = PlusFour() |
| 130 | + pudf = UserDefinedFunction(call, LongType()) |
| 131 | + res = data.select(pudf(data['number']).alias('plus_four')) |
| 132 | + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) |
| 133 | + |
| 134 | + def test_udf_with_partial_function(self): |
| 135 | + d = [Row(number=i, squared=i**2) for i in range(10)] |
| 136 | + rdd = self.sc.parallelize(d) |
| 137 | + data = self.sqlCtx.createDataFrame(rdd) |
| 138 | + |
| 139 | + def some_func(col, param): |
| 140 | + if col is not None: |
| 141 | + return col + param |
| 142 | + |
| 143 | + pfunc = functools.partial(some_func, param=4) |
| 144 | + pudf = UserDefinedFunction(pfunc, LongType()) |
| 145 | + res = data.select(pudf(data['number']).alias('plus_four')) |
| 146 | + self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85) |
| 147 | + |
117 | 148 | def test_udf(self): |
118 | 149 | self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) |
119 | 150 | [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() |
|
0 commit comments