2828from pyspark import since , SparkContext
2929from pyspark .rdd import _prepare_for_python_RDD , ignore_unicode_prefix
3030from pyspark .serializers import PickleSerializer , AutoBatchedSerializer
31- from pyspark .sql .types import StringType , DataType , _parse_datatype_string
31+ from pyspark .sql .types import StringType , StructType , DataType , _parse_datatype_string
3232from pyspark .sql .column import Column , _to_java_column , _to_seq
3333from pyspark .sql .dataframe import DataFrame
3434
@@ -2044,7 +2044,7 @@ class UserDefinedFunction(object):
20442044
20452045 .. versionadded:: 1.3
20462046 """
2047- def __init__ (self , func , returnType , name = None , vectorized = False ):
2047+ def __init__ (self , func , returnType , name = None , vectorized = False , grouped = False ):
20482048 if not callable (func ):
20492049 raise TypeError (
20502050 "Not a function or callable (__call__ is not defined): "
@@ -2059,6 +2059,7 @@ def __init__(self, func, returnType, name=None, vectorized=False):
20592059 func .__name__ if hasattr (func , '__name__' )
20602060 else func .__class__ .__name__ )
20612061 self .vectorized = vectorized
2062+ self .grouped = grouped
20622063
20632064 @property
20642065 def returnType (self ):
@@ -2090,7 +2091,7 @@ def _create_judf(self):
20902091 wrapped_func = _wrap_function (sc , self .func , self .returnType )
20912092 jdt = spark ._jsparkSession .parseDataType (self .returnType .json ())
20922093 judf = sc ._jvm .org .apache .spark .sql .execution .python .UserDefinedPythonFunction (
2093- self ._name , wrapped_func , jdt , self .vectorized )
2094+ self ._name , wrapped_func , jdt , self .vectorized , self . grouped )
20942095 return judf
20952096
20962097 def __call__ (self , * cols ):
@@ -2122,13 +2123,14 @@ def wrapper(*args):
21222123 wrapper .func = self .func
21232124 wrapper .returnType = self .returnType
21242125 wrapper .vectorized = self .vectorized
2126+ wrapper .grouped = self .grouped
21252127
21262128 return wrapper
21272129
21282130
2129- def _create_udf (f , returnType , vectorized ):
2131+ def _create_udf (f , returnType , vectorized , grouped ):
21302132
2131- def _udf (f , returnType = StringType (), vectorized = vectorized ):
2133+ def _udf (f , returnType = StringType (), vectorized = vectorized , grouped = grouped ):
21322134 if vectorized :
21332135 import inspect
21342136 argspec = inspect .getargspec (f )
@@ -2137,17 +2139,18 @@ def _udf(f, returnType=StringType(), vectorized=vectorized):
21372139 "0-arg pandas_udfs are not supported. "
21382140 "Instead, create a 1-arg pandas_udf and ignore the arg in your function."
21392141 )
2140- udf_obj = UserDefinedFunction (f , returnType , vectorized = vectorized )
2142+ udf_obj = UserDefinedFunction (f , returnType , vectorized = vectorized , grouped = grouped )
21412143 return udf_obj ._wrapped ()
21422144
21432145 # decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf
21442146 if f is None or isinstance (f , (str , DataType )):
21452147 # If DataType has been passed as a positional argument
21462148 # for decorator use it as a returnType
21472149 return_type = f or returnType
2148- return functools .partial (_udf , returnType = return_type , vectorized = vectorized )
2150+ return functools .partial (
2151+ _udf , returnType = return_type , vectorized = vectorized , grouped = grouped )
21492152 else :
2150- return _udf (f = f , returnType = returnType , vectorized = vectorized )
2153+ return _udf (f = f , returnType = returnType , vectorized = vectorized , grouped = grouped )
21512154
21522155
21532156@since (1.3 )
@@ -2181,7 +2184,7 @@ def udf(f=None, returnType=StringType()):
21812184 | 8| JOHN DOE| 22|
21822185 +----------+--------------+------------+
21832186 """
2184- return _create_udf (f , returnType = returnType , vectorized = False )
2187+ return _create_udf (f , returnType = returnType , vectorized = False , grouped = False )
21852188
21862189
21872190@since (2.3 )
@@ -2192,67 +2195,82 @@ def pandas_udf(f=None, returnType=StringType()):
21922195 :param f: user-defined function. A python function if used as a standalone function
21932196 :param returnType: a :class:`pyspark.sql.types.DataType` object
21942197
2195- The user-defined function can define one of the following transformations:
2196-
2197- 1. One or more `pandas.Series` -> A `pandas.Series`
2198-
2199- This udf is used with :meth:`pyspark.sql.DataFrame.withColumn` and
2200- :meth:`pyspark.sql.DataFrame.select`.
2201- The returnType should be a primitive data type, e.g., `DoubleType()`.
2202- The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
2203-
2204- >>> from pyspark.sql.types import IntegerType, StringType
2205- >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType())
2206- >>> @pandas_udf(returnType=StringType())
2207- ... def to_upper(s):
2208- ... return s.str.upper()
2209- ...
2210- >>> @pandas_udf(returnType="integer")
2211- ... def add_one(x):
2212- ... return x + 1
2213- ...
2214- >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
2215- >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\
2216- ... .show() # doctest: +SKIP
2217- +----------+--------------+------------+
2218- |slen(name)|to_upper(name)|add_one(age)|
2219- +----------+--------------+------------+
2220- | 8| JOHN DOE| 22|
2221- +----------+--------------+------------+
2222-
2223- 2. A `pandas.DataFrame` -> A `pandas.DataFrame`
2224-
2225- This udf is only used with :meth:`pyspark.sql.GroupedData.apply`.
2226- The returnType should be a :class:`StructType` describing the schema of the returned
2227- `pandas.DataFrame`.
2228-
2229- >>> df = spark.createDataFrame(
2230- ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
2231- ... ("id", "v"))
2232- >>> @pandas_udf(returnType=df.schema)
2233- ... def normalize(pdf):
2234- ... v = pdf.v
2235- ... return pdf.assign(v=(v - v.mean()) / v.std())
2236- >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP
2237- +---+-------------------+
2238- | id| v|
2239- +---+-------------------+
2240- | 1|-0.7071067811865475|
2241- | 1| 0.7071067811865475|
2242- | 2|-0.8320502943378437|
2243- | 2|-0.2773500981126146|
2244- | 2| 1.1094003924504583|
2245- +---+-------------------+
2246-
2247- .. note:: This type of udf cannot be used with functions such as `withColumn` or `select`
2248- because it defines a `DataFrame` transformation rather than a `Column`
2249- transformation.
2250-
2251- .. seealso:: :meth:`pyspark.sql.GroupedData.apply`
2198+ The user-defined function can define the following transformation:
2199+
2200+ One or more `pandas.Series` -> A `pandas.Series`
2201+
2202+ This udf is used with :meth:`pyspark.sql.DataFrame.withColumn` and
2203+ :meth:`pyspark.sql.DataFrame.select`.
2204+ The returnType should be a primitive data type, e.g., `DoubleType()`.
2205+ The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
2206+
2207+ >>> from pyspark.sql.types import IntegerType, StringType
2208+ >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType())
2209+ >>> @pandas_udf(returnType=StringType())
2210+ ... def to_upper(s):
2211+ ... return s.str.upper()
2212+ ...
2213+ >>> @pandas_udf(returnType="integer")
2214+ ... def add_one(x):
2215+ ... return x + 1
2216+ ...
2217+ >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
2218+ >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\
2219+ ... .show() # doctest: +SKIP
2220+ +----------+--------------+------------+
2221+ |slen(name)|to_upper(name)|add_one(age)|
2222+ +----------+--------------+------------+
2223+ | 8| JOHN DOE| 22|
2224+ +----------+--------------+------------+
2225+
2226+ .. note:: The user-defined function must be deterministic.
2227+ """
2228+ return _create_udf (f , returnType = returnType , vectorized = True , grouped = False )
2229+
2230+
2231+ @since (2.3 )
2232+ def pandas_grouped_udf (f = None , returnType = StructType ()):
2233+ """
2234+ Creates a grouped vectorized user defined function (UDF).
2235+
2236+ :param f: user-defined function. A python function if used as a standalone function
2237+ :param returnType: a :class:`pyspark.sql.types.StructType` object
2238+
2239+ The grouped user-defined function can define the following transformation:
2240+
2241+ A `pandas.DataFrame` -> A `pandas.DataFrame`
2242+
2243+ This udf is only used with :meth:`pyspark.sql.GroupedData.apply`.
2244+ The returnType should be a :class:`StructType` describing the schema of the returned
2245+ `pandas.DataFrame`.
2246+
2247+ >>> df = spark.createDataFrame(
2248+ ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
2249+ ... ("id", "v"))
2250+ >>> @pandas_grouped_udf(returnType=df.schema)
2251+ ... def normalize(pdf):
2252+ ... v = pdf.v
2253+ ... return pdf.assign(v=(v - v.mean()) / v.std())
2254+ >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP
2255+ +---+-------------------+
2256+ | id| v|
2257+ +---+-------------------+
2258+ | 1|-0.7071067811865475|
2259+ | 1| 0.7071067811865475|
2260+ | 2|-0.8320502943378437|
2261+ | 2|-0.2773500981126146|
2262+ | 2| 1.1094003924504583|
2263+ +---+-------------------+
2264+
2265+ .. note:: This type of udf cannot be used with functions such as `withColumn` or `select`
2266+ because it defines a `DataFrame` transformation rather than a `Column`
2267+ transformation.
2268+
2269+ .. seealso:: :meth:`pyspark.sql.GroupedData.apply`
22522270
22532271 .. note:: The user-defined function must be deterministic.
22542272 """
2255- return _create_udf (f , returnType = returnType , vectorized = True )
2273+ return _create_udf (f , returnType = returnType , vectorized = True , grouped = True )
22562274
22572275
22582276blacklist = ['map' , 'since' , 'ignore_unicode_prefix' ]
0 commit comments