Skip to content

Commit 8900328

Browse files
zero323cmonkey
authored andcommitted
[SPARK-19427][PYTHON][SQL] Support data type string as a returnType argument of UDF
## What changes were proposed in this pull request? Add support for data type string as a return type argument of `UserDefinedFunction`: ```python f = udf(lambda x: x, "integer") f.returnType ## IntegerType ``` ## How was this patch tested? Existing unit tests, additional unit tests covering new feature. Author: zero323 <[email protected]> Closes apache#16769 from zero323/SPARK-19427.
1 parent 7730de4 commit 8900328

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

python/pyspark/sql/functions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from pyspark import since, SparkContext
2828
from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
2929
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
30-
from pyspark.sql.types import StringType
30+
from pyspark.sql.types import StringType, DataType, _parse_datatype_string
3131
from pyspark.sql.column import Column, _to_java_column, _to_seq
3232
from pyspark.sql.dataframe import DataFrame
3333

@@ -1865,7 +1865,9 @@ class UserDefinedFunction(object):
18651865
"""
18661866
def __init__(self, func, returnType, name=None):
18671867
self.func = func
1868-
self.returnType = returnType
1868+
self.returnType = (
1869+
returnType if isinstance(returnType, DataType)
1870+
else _parse_datatype_string(returnType))
18691871
# Stores UserDefinedPythonFunctions jobj, once initialized
18701872
self._judf_placeholder = None
18711873
self._name = name or (
@@ -1909,7 +1911,7 @@ def udf(f, returnType=StringType()):
19091911
it is present in the query.
19101912
19111913
:param f: python function
1912-
:param returnType: a :class:`pyspark.sql.types.DataType` object
1914+
:param returnType: a :class:`pyspark.sql.types.DataType` object or data type string.
19131915
19141916
>>> from pyspark.sql.types import IntegerType
19151917
>>> slen = udf(lambda s: len(s), IntegerType())

python/pyspark/sql/tests.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,21 @@ def test_udf_defers_judf_initalization(self):
489489
"judf should be initialized after UDF has been called."
490490
)
491491

492+
def test_udf_with_string_return_type(self):
493+
from pyspark.sql.functions import UserDefinedFunction
494+
495+
add_one = UserDefinedFunction(lambda x: x + 1, "integer")
496+
make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>")
497+
make_array = UserDefinedFunction(
498+
lambda x: [float(x) for x in range(x, x + 3)], "array<double>")
499+
500+
expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0])
501+
actual = (self.spark.range(1, 2).toDF("x")
502+
.select(add_one("x"), make_pair("x"), make_array("x"))
503+
.first())
504+
505+
self.assertTupleEqual(expected, actual)
506+
492507
def test_basic_functions(self):
493508
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
494509
df = self.spark.read.json(rdd)

0 commit comments

Comments
 (0)