From 8ea92beeb734407e1666598713407fbe562e457a Mon Sep 17 00:00:00 2001 From: zero323 Date: Wed, 1 Feb 2017 18:56:11 +0100 Subject: [PATCH] Support data type string as a returnType argument of UDF --- python/pyspark/sql/functions.py | 8 +++++--- python/pyspark/sql/tests.py | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 02c2350dc2d6..d1b2e4391f24 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -27,7 +27,7 @@ from pyspark import since, SparkContext from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -from pyspark.sql.types import StringType +from pyspark.sql.types import StringType, DataType, _parse_datatype_string from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.dataframe import DataFrame @@ -1825,7 +1825,9 @@ class UserDefinedFunction(object): """ def __init__(self, func, returnType, name=None): self.func = func - self.returnType = returnType + self.returnType = ( + returnType if isinstance(returnType, DataType) + else _parse_datatype_string(returnType)) # Stores UserDefinedPythonFunctions jobj, once initialized self._judf_placeholder = None self._name = name or ( @@ -1869,7 +1871,7 @@ def udf(f, returnType=StringType()): it is present in the query. :param f: python function - :param returnType: a :class:`pyspark.sql.types.DataType` object + :param returnType: a :class:`pyspark.sql.types.DataType` object or data type string. >>> from pyspark.sql.types import IntegerType >>> slen = udf(lambda s: len(s), IntegerType()) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2fea4ac41f0d..e1ca34152931 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -489,6 +489,21 @@ def test_udf_defers_judf_initalization(self): "judf should be initialized after UDF has been called." ) + def test_udf_with_string_return_type(self): + from pyspark.sql.functions import UserDefinedFunction + + add_one = UserDefinedFunction(lambda x: x + 1, "integer") + make_pair = UserDefinedFunction(lambda x: (-x, x), "struct") + make_array = UserDefinedFunction( + lambda x: [float(x) for x in range(x, x + 3)], "array") + + expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0]) + actual = (self.spark.range(1, 2).toDF("x") + .select(add_one("x"), make_pair("x"), make_array("x")) + .first()) + + self.assertTupleEqual(expected, actual) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd)