Skip to content

Commit 4d2bd95

Browse files
committed
Introduce @pandas_grouped_udf decorator for grouped vectorized UDF.
1 parent 13c1559 commit 4d2bd95

File tree

11 files changed

+149
-127
lines changed

11 files changed

+149
-127
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ private[spark] object PythonEvalType {
3636
val NON_UDF = 0
3737
val SQL_BATCHED_UDF = 1
3838
val SQL_PANDAS_UDF = 2
39+
val SQL_PANDAS_GROUPED_UDF = 3
3940
}
4041

4142
/**

python/pyspark/serializers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class PythonEvalType(object):
8686
NON_UDF = 0
8787
SQL_BATCHED_UDF = 1
8888
SQL_PANDAS_UDF = 2
89+
SQL_PANDAS_GROUPED_UDF = 3
8990

9091

9192
class Serializer(object):

python/pyspark/sql/functions.py

Lines changed: 85 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from pyspark import since, SparkContext
2929
from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
3030
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
31-
from pyspark.sql.types import StringType, DataType, _parse_datatype_string
31+
from pyspark.sql.types import StringType, StructType, DataType, _parse_datatype_string
3232
from pyspark.sql.column import Column, _to_java_column, _to_seq
3333
from pyspark.sql.dataframe import DataFrame
3434

@@ -2044,7 +2044,7 @@ class UserDefinedFunction(object):
20442044
20452045
.. versionadded:: 1.3
20462046
"""
2047-
def __init__(self, func, returnType, name=None, vectorized=False):
2047+
def __init__(self, func, returnType, name=None, vectorized=False, grouped=False):
20482048
if not callable(func):
20492049
raise TypeError(
20502050
"Not a function or callable (__call__ is not defined): "
@@ -2059,6 +2059,7 @@ def __init__(self, func, returnType, name=None, vectorized=False):
20592059
func.__name__ if hasattr(func, '__name__')
20602060
else func.__class__.__name__)
20612061
self.vectorized = vectorized
2062+
self.grouped = grouped
20622063

20632064
@property
20642065
def returnType(self):
@@ -2090,7 +2091,7 @@ def _create_judf(self):
20902091
wrapped_func = _wrap_function(sc, self.func, self.returnType)
20912092
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
20922093
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
2093-
self._name, wrapped_func, jdt, self.vectorized)
2094+
self._name, wrapped_func, jdt, self.vectorized, self.grouped)
20942095
return judf
20952096

20962097
def __call__(self, *cols):
@@ -2122,13 +2123,14 @@ def wrapper(*args):
21222123
wrapper.func = self.func
21232124
wrapper.returnType = self.returnType
21242125
wrapper.vectorized = self.vectorized
2126+
wrapper.grouped = self.grouped
21252127

21262128
return wrapper
21272129

21282130

2129-
def _create_udf(f, returnType, vectorized):
2131+
def _create_udf(f, returnType, vectorized, grouped):
21302132

2131-
def _udf(f, returnType=StringType(), vectorized=vectorized):
2133+
def _udf(f, returnType=StringType(), vectorized=vectorized, grouped=grouped):
21322134
if vectorized:
21332135
import inspect
21342136
argspec = inspect.getargspec(f)
@@ -2137,17 +2139,18 @@ def _udf(f, returnType=StringType(), vectorized=vectorized):
21372139
"0-arg pandas_udfs are not supported. "
21382140
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
21392141
)
2140-
udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized)
2142+
udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized, grouped=grouped)
21412143
return udf_obj._wrapped()
21422144

21432145
# decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf
21442146
if f is None or isinstance(f, (str, DataType)):
21452147
# If DataType has been passed as a positional argument
21462148
# for decorator use it as a returnType
21472149
return_type = f or returnType
2148-
return functools.partial(_udf, returnType=return_type, vectorized=vectorized)
2150+
return functools.partial(
2151+
_udf, returnType=return_type, vectorized=vectorized, grouped=grouped)
21492152
else:
2150-
return _udf(f=f, returnType=returnType, vectorized=vectorized)
2153+
return _udf(f=f, returnType=returnType, vectorized=vectorized, grouped=grouped)
21512154

21522155

21532156
@since(1.3)
@@ -2181,7 +2184,7 @@ def udf(f=None, returnType=StringType()):
21812184
| 8| JOHN DOE| 22|
21822185
+----------+--------------+------------+
21832186
"""
2184-
return _create_udf(f, returnType=returnType, vectorized=False)
2187+
return _create_udf(f, returnType=returnType, vectorized=False, grouped=False)
21852188

21862189

21872190
@since(2.3)
@@ -2192,67 +2195,82 @@ def pandas_udf(f=None, returnType=StringType()):
21922195
:param f: user-defined function. A python function if used as a standalone function
21932196
:param returnType: a :class:`pyspark.sql.types.DataType` object
21942197
2195-
The user-defined function can define one of the following transformations:
2196-
2197-
1. One or more `pandas.Series` -> A `pandas.Series`
2198-
2199-
This udf is used with :meth:`pyspark.sql.DataFrame.withColumn` and
2200-
:meth:`pyspark.sql.DataFrame.select`.
2201-
The returnType should be a primitive data type, e.g., `DoubleType()`.
2202-
The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
2203-
2204-
>>> from pyspark.sql.types import IntegerType, StringType
2205-
>>> slen = pandas_udf(lambda s: s.str.len(), IntegerType())
2206-
>>> @pandas_udf(returnType=StringType())
2207-
... def to_upper(s):
2208-
... return s.str.upper()
2209-
...
2210-
>>> @pandas_udf(returnType="integer")
2211-
... def add_one(x):
2212-
... return x + 1
2213-
...
2214-
>>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
2215-
>>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\
2216-
... .show() # doctest: +SKIP
2217-
+----------+--------------+------------+
2218-
|slen(name)|to_upper(name)|add_one(age)|
2219-
+----------+--------------+------------+
2220-
| 8| JOHN DOE| 22|
2221-
+----------+--------------+------------+
2222-
2223-
2. A `pandas.DataFrame` -> A `pandas.DataFrame`
2224-
2225-
This udf is only used with :meth:`pyspark.sql.GroupedData.apply`.
2226-
The returnType should be a :class:`StructType` describing the schema of the returned
2227-
`pandas.DataFrame`.
2228-
2229-
>>> df = spark.createDataFrame(
2230-
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
2231-
... ("id", "v"))
2232-
>>> @pandas_udf(returnType=df.schema)
2233-
... def normalize(pdf):
2234-
... v = pdf.v
2235-
... return pdf.assign(v=(v - v.mean()) / v.std())
2236-
>>> df.groupby('id').apply(normalize).show() # doctest: +SKIP
2237-
+---+-------------------+
2238-
| id| v|
2239-
+---+-------------------+
2240-
| 1|-0.7071067811865475|
2241-
| 1| 0.7071067811865475|
2242-
| 2|-0.8320502943378437|
2243-
| 2|-0.2773500981126146|
2244-
| 2| 1.1094003924504583|
2245-
+---+-------------------+
2246-
2247-
.. note:: This type of udf cannot be used with functions such as `withColumn` or `select`
2248-
because it defines a `DataFrame` transformation rather than a `Column`
2249-
transformation.
2250-
2251-
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
2198+
The user-defined function can define the following transformation:
2199+
2200+
One or more `pandas.Series` -> A `pandas.Series`
2201+
2202+
This udf is used with :meth:`pyspark.sql.DataFrame.withColumn` and
2203+
:meth:`pyspark.sql.DataFrame.select`.
2204+
The returnType should be a primitive data type, e.g., `DoubleType()`.
2205+
The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
2206+
2207+
>>> from pyspark.sql.types import IntegerType, StringType
2208+
>>> slen = pandas_udf(lambda s: s.str.len(), IntegerType())
2209+
>>> @pandas_udf(returnType=StringType())
2210+
... def to_upper(s):
2211+
... return s.str.upper()
2212+
...
2213+
>>> @pandas_udf(returnType="integer")
2214+
... def add_one(x):
2215+
... return x + 1
2216+
...
2217+
>>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
2218+
>>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\
2219+
... .show() # doctest: +SKIP
2220+
+----------+--------------+------------+
2221+
|slen(name)|to_upper(name)|add_one(age)|
2222+
+----------+--------------+------------+
2223+
| 8| JOHN DOE| 22|
2224+
+----------+--------------+------------+
2225+
2226+
.. note:: The user-defined function must be deterministic.
2227+
"""
2228+
return _create_udf(f, returnType=returnType, vectorized=True, grouped=False)
2229+
2230+
2231+
@since(2.3)
2232+
def pandas_grouped_udf(f=None, returnType=StructType()):
2233+
"""
2234+
Creates a grouped vectorized user defined function (UDF).
2235+
2236+
:param f: user-defined function. A python function if used as a standalone function
2237+
:param returnType: a :class:`pyspark.sql.types.StructType` object
2238+
2239+
The grouped user-defined function can define the following transformation:
2240+
2241+
A `pandas.DataFrame` -> A `pandas.DataFrame`
2242+
2243+
This udf is only used with :meth:`pyspark.sql.GroupedData.apply`.
2244+
The returnType should be a :class:`StructType` describing the schema of the returned
2245+
`pandas.DataFrame`.
2246+
2247+
>>> df = spark.createDataFrame(
2248+
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
2249+
... ("id", "v"))
2250+
>>> @pandas_grouped_udf(returnType=df.schema)
2251+
... def normalize(pdf):
2252+
... v = pdf.v
2253+
... return pdf.assign(v=(v - v.mean()) / v.std())
2254+
>>> df.groupby('id').apply(normalize).show() # doctest: +SKIP
2255+
+---+-------------------+
2256+
| id| v|
2257+
+---+-------------------+
2258+
| 1|-0.7071067811865475|
2259+
| 1| 0.7071067811865475|
2260+
| 2|-0.8320502943378437|
2261+
| 2|-0.2773500981126146|
2262+
| 2| 1.1094003924504583|
2263+
+---+-------------------+
2264+
2265+
.. note:: This type of udf cannot be used with functions such as `withColumn` or `select`
2266+
because it defines a `DataFrame` transformation rather than a `Column`
2267+
transformation.
2268+
2269+
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
22522270
22532271
.. note:: The user-defined function must be deterministic.
22542272
"""
2255-
return _create_udf(f, returnType=returnType, vectorized=True)
2273+
return _create_udf(f, returnType=returnType, vectorized=True, grouped=True)
22562274

22572275

22582276
blacklist = ['map', 'since', 'ignore_unicode_prefix']

python/pyspark/sql/group.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -206,18 +206,18 @@ def apply(self, udf):
206206
to the user-function and the returned `pandas.DataFrame`s are combined as a
207207
:class:`DataFrame`.
208208
The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
209-
returnType of the pandas udf.
209+
returnType of the pandas grouped udf.
210210
211211
This function does not support partial aggregation, and requires shuffling all the data in
212212
the :class:`DataFrame`.
213213
214-
:param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf`
214+
:param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_grouped_udf`
215215
216-
>>> from pyspark.sql.functions import pandas_udf
216+
>>> from pyspark.sql.functions import pandas_grouped_udf
217217
>>> df = spark.createDataFrame(
218218
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
219219
... ("id", "v"))
220-
>>> @pandas_udf(returnType=df.schema)
220+
>>> @pandas_grouped_udf(returnType=df.schema)
221221
... def normalize(pdf):
222222
... v = pdf.v
223223
... return pdf.assign(v=(v - v.mean()) / v.std())
@@ -232,16 +232,17 @@ def apply(self, udf):
232232
| 2| 1.1094003924504583|
233233
+---+-------------------+
234234
235-
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
235+
.. seealso:: :meth:`pyspark.sql.functions.pandas_grouped_udf`
236236
237237
"""
238-
from pyspark.sql.functions import pandas_udf
238+
from pyspark.sql.functions import pandas_grouped_udf
239239

240240
# Columns are special because hasattr always return True
241-
if isinstance(udf, Column) or not hasattr(udf, 'func') or not udf.vectorized:
242-
raise ValueError("The argument to apply must be a pandas_udf")
241+
if isinstance(udf, Column) or not hasattr(udf, 'func') \
242+
or not udf.vectorized or not udf.grouped:
243+
raise ValueError("The argument to apply must be a pandas_grouped_udf")
243244
if not isinstance(udf.returnType, StructType):
244-
raise ValueError("The returnType of the pandas_udf must be a StructType")
245+
raise ValueError("The returnType of the pandas_grouped_udf must be a StructType")
245246

246247
df = self._df
247248
func = udf.func
@@ -268,7 +269,7 @@ def wrapped(*cols):
268269
return [(result[result.columns[i]], arrow_type)
269270
for i, arrow_type in enumerate(arrow_return_types)]
270271

271-
wrapped_udf_obj = pandas_udf(wrapped, returnType)
272+
wrapped_udf_obj = pandas_grouped_udf(wrapped, returnType)
272273
udf_column = wrapped_udf_obj(*[df[col] for col in df.columns])
273274
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
274275
return DataFrame(jdf, self.sql_ctx)

0 commit comments

Comments
 (0)