|
27 | 27 | from pyspark import since, SparkContext |
28 | 28 | from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix |
29 | 29 | from pyspark.serializers import PickleSerializer, AutoBatchedSerializer |
30 | | -from pyspark.sql.types import StringType |
| 30 | +from pyspark.sql.types import StringType, DataType, _parse_datatype_string |
31 | 31 | from pyspark.sql.column import Column, _to_java_column, _to_seq |
32 | 32 | from pyspark.sql.dataframe import DataFrame |
33 | 33 |
|
@@ -1865,7 +1865,9 @@ class UserDefinedFunction(object): |
1865 | 1865 | """ |
1866 | 1866 | def __init__(self, func, returnType, name=None): |
1867 | 1867 | self.func = func |
1868 | | - self.returnType = returnType |
| 1868 | + self.returnType = ( |
| 1869 | + returnType if isinstance(returnType, DataType) |
| 1870 | + else _parse_datatype_string(returnType)) |
1869 | 1871 | # Stores UserDefinedPythonFunctions jobj, once initialized |
1870 | 1872 | self._judf_placeholder = None |
1871 | 1873 | self._name = name or ( |
@@ -1909,7 +1911,7 @@ def udf(f, returnType=StringType()): |
1909 | 1911 | it is present in the query. |
1910 | 1912 |
|
1911 | 1913 | :param f: python function |
1912 | | - :param returnType: a :class:`pyspark.sql.types.DataType` object |
| 1914 | + :param returnType: a :class:`pyspark.sql.types.DataType` object or data type string. |
1913 | 1915 |
|
1914 | 1916 | >>> from pyspark.sql.types import IntegerType |
1915 | 1917 | >>> slen = udf(lambda s: len(s), IntegerType()) |
|
0 commit comments