Skip to content

Commit b8624b0

Browse files
ueshingatorsmile
authored andcommitted
[SPARK-20396][SQL][PYSPARK][FOLLOW-UP] groupby().apply() with pandas udf
## What changes were proposed in this pull request? This is a follow-up of #18732. This pr modifies `GroupedData.apply()` method to convert pandas udf to grouped udf implicitly. ## How was this patch tested? Exisiting tests. Author: Takuya UESHIN <[email protected]> Closes #19517 from ueshin/issues/SPARK-20396/fup2.
1 parent 568763b commit b8624b0

File tree

13 files changed

+114
-52
lines changed

13 files changed

+114
-52
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ private[spark] object PythonEvalType {
3636
val NON_UDF = 0
3737
val SQL_BATCHED_UDF = 1
3838
val SQL_PANDAS_UDF = 2
39+
val SQL_PANDAS_GROUPED_UDF = 3
3940
}
4041

4142
/**

python/pyspark/serializers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ class PythonEvalType(object):
8686
NON_UDF = 0
8787
SQL_BATCHED_UDF = 1
8888
SQL_PANDAS_UDF = 2
89+
SQL_PANDAS_GROUPED_UDF = 3
8990

9091

9192
class Serializer(object):

python/pyspark/sql/functions.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2038,13 +2038,22 @@ def _wrap_function(sc, func, returnType):
20382038
sc.pythonVer, broadcast_vars, sc._javaAccumulator)
20392039

20402040

2041+
class PythonUdfType(object):
2042+
# row-at-a-time UDFs
2043+
NORMAL_UDF = 0
2044+
# scalar vectorized UDFs
2045+
PANDAS_UDF = 1
2046+
# grouped vectorized UDFs
2047+
PANDAS_GROUPED_UDF = 2
2048+
2049+
20412050
class UserDefinedFunction(object):
20422051
"""
20432052
User defined function in Python
20442053
20452054
.. versionadded:: 1.3
20462055
"""
2047-
def __init__(self, func, returnType, name=None, vectorized=False):
2056+
def __init__(self, func, returnType, name=None, pythonUdfType=PythonUdfType.NORMAL_UDF):
20482057
if not callable(func):
20492058
raise TypeError(
20502059
"Not a function or callable (__call__ is not defined): "
@@ -2058,7 +2067,7 @@ def __init__(self, func, returnType, name=None, vectorized=False):
20582067
self._name = name or (
20592068
func.__name__ if hasattr(func, '__name__')
20602069
else func.__class__.__name__)
2061-
self.vectorized = vectorized
2070+
self.pythonUdfType = pythonUdfType
20622071

20632072
@property
20642073
def returnType(self):
@@ -2090,7 +2099,7 @@ def _create_judf(self):
20902099
wrapped_func = _wrap_function(sc, self.func, self.returnType)
20912100
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
20922101
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
2093-
self._name, wrapped_func, jdt, self.vectorized)
2102+
self._name, wrapped_func, jdt, self.pythonUdfType)
20942103
return judf
20952104

20962105
def __call__(self, *cols):
@@ -2121,33 +2130,33 @@ def wrapper(*args):
21212130

21222131
wrapper.func = self.func
21232132
wrapper.returnType = self.returnType
2124-
wrapper.vectorized = self.vectorized
2133+
wrapper.pythonUdfType = self.pythonUdfType
21252134

21262135
return wrapper
21272136

21282137

2129-
def _create_udf(f, returnType, vectorized):
2138+
def _create_udf(f, returnType, pythonUdfType):
21302139

2131-
def _udf(f, returnType=StringType(), vectorized=vectorized):
2132-
if vectorized:
2140+
def _udf(f, returnType=StringType(), pythonUdfType=pythonUdfType):
2141+
if pythonUdfType == PythonUdfType.PANDAS_UDF:
21332142
import inspect
21342143
argspec = inspect.getargspec(f)
21352144
if len(argspec.args) == 0 and argspec.varargs is None:
21362145
raise ValueError(
21372146
"0-arg pandas_udfs are not supported. "
21382147
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
21392148
)
2140-
udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized)
2149+
udf_obj = UserDefinedFunction(f, returnType, pythonUdfType=pythonUdfType)
21412150
return udf_obj._wrapped()
21422151

21432152
# decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf
21442153
if f is None or isinstance(f, (str, DataType)):
21452154
# If DataType has been passed as a positional argument
21462155
# for decorator use it as a returnType
21472156
return_type = f or returnType
2148-
return functools.partial(_udf, returnType=return_type, vectorized=vectorized)
2157+
return functools.partial(_udf, returnType=return_type, pythonUdfType=pythonUdfType)
21492158
else:
2150-
return _udf(f=f, returnType=returnType, vectorized=vectorized)
2159+
return _udf(f=f, returnType=returnType, pythonUdfType=pythonUdfType)
21512160

21522161

21532162
@since(1.3)
@@ -2181,7 +2190,7 @@ def udf(f=None, returnType=StringType()):
21812190
| 8| JOHN DOE| 22|
21822191
+----------+--------------+------------+
21832192
"""
2184-
return _create_udf(f, returnType=returnType, vectorized=False)
2193+
return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.NORMAL_UDF)
21852194

21862195

21872196
@since(2.3)
@@ -2252,7 +2261,7 @@ def pandas_udf(f=None, returnType=StringType()):
22522261
22532262
.. note:: The user-defined function must be deterministic.
22542263
"""
2255-
return _create_udf(f, returnType=returnType, vectorized=True)
2264+
return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.PANDAS_UDF)
22562265

22572266

22582267
blacklist = ['map', 'since', 'ignore_unicode_prefix']

python/pyspark/sql/group.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pyspark.rdd import ignore_unicode_prefix
2020
from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal
2121
from pyspark.sql.dataframe import DataFrame
22+
from pyspark.sql.functions import PythonUdfType, UserDefinedFunction
2223
from pyspark.sql.types import *
2324

2425
__all__ = ["GroupedData"]
@@ -235,11 +236,13 @@ def apply(self, udf):
235236
.. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
236237
237238
"""
238-
from pyspark.sql.functions import pandas_udf
239+
import inspect
239240

240241
# Columns are special because hasattr always return True
241-
if isinstance(udf, Column) or not hasattr(udf, 'func') or not udf.vectorized:
242-
raise ValueError("The argument to apply must be a pandas_udf")
242+
if isinstance(udf, Column) or not hasattr(udf, 'func') \
243+
or udf.pythonUdfType != PythonUdfType.PANDAS_UDF \
244+
or len(inspect.getargspec(udf.func).args) != 1:
245+
raise ValueError("The argument to apply must be a 1-arg pandas_udf")
243246
if not isinstance(udf.returnType, StructType):
244247
raise ValueError("The returnType of the pandas_udf must be a StructType")
245248

@@ -268,8 +271,9 @@ def wrapped(*cols):
268271
return [(result[result.columns[i]], arrow_type)
269272
for i, arrow_type in enumerate(arrow_return_types)]
270273

271-
wrapped_udf_obj = pandas_udf(wrapped, returnType)
272-
udf_column = wrapped_udf_obj(*[df[col] for col in df.columns])
274+
udf_obj = UserDefinedFunction(
275+
wrapped, returnType, name=udf.__name__, pythonUdfType=PythonUdfType.PANDAS_GROUPED_UDF)
276+
udf_column = udf_obj(*[df[col] for col in df.columns])
273277
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
274278
return DataFrame(jdf, self.sql_ctx)
275279

python/pyspark/sql/tests.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3383,6 +3383,15 @@ def test_vectorized_udf_varargs(self):
33833383
res = df.select(f(col('id')))
33843384
self.assertEquals(df.collect(), res.collect())
33853385

3386+
def test_vectorized_udf_unsupported_types(self):
3387+
from pyspark.sql.functions import pandas_udf, col
3388+
schema = StructType([StructField("dt", DateType(), True)])
3389+
df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema)
3390+
f = pandas_udf(lambda x: x, DateType())
3391+
with QuietTest(self.sc):
3392+
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
3393+
df.select(f(col('dt'))).collect()
3394+
33863395

33873396
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
33883397
class GroupbyApplyTests(ReusedPySparkTestCase):
@@ -3492,6 +3501,18 @@ def normalize(pdf):
34923501
expected = expected.assign(norm=expected.norm.astype('float64'))
34933502
self.assertFramesEqual(expected, result)
34943503

3504+
def test_datatype_string(self):
3505+
from pyspark.sql.functions import pandas_udf
3506+
df = self.data
3507+
3508+
foo_udf = pandas_udf(
3509+
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
3510+
"id long, v int, v1 double, v2 long")
3511+
3512+
result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
3513+
expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
3514+
self.assertFramesEqual(expected, result)
3515+
34953516
def test_wrong_return_type(self):
34963517
from pyspark.sql.functions import pandas_udf
34973518
df = self.data
@@ -3517,9 +3538,25 @@ def test_wrong_args(self):
35173538
df.groupby('id').apply(sum(df.v))
35183539
with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
35193540
df.groupby('id').apply(df.v + 1)
3541+
with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
3542+
df.groupby('id').apply(
3543+
pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])))
3544+
with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
3545+
df.groupby('id').apply(
3546+
pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())])))
35203547
with self.assertRaisesRegexp(ValueError, 'returnType'):
35213548
df.groupby('id').apply(pandas_udf(lambda x: x, DoubleType()))
35223549

3550+
def test_unsupported_types(self):
3551+
from pyspark.sql.functions import pandas_udf, col
3552+
schema = StructType(
3553+
[StructField("id", LongType(), True), StructField("dt", DateType(), True)])
3554+
df = self.spark.createDataFrame([(1, datetime.date(1970, 1, 1),)], schema=schema)
3555+
f = pandas_udf(lambda x: x, df.schema)
3556+
with QuietTest(self.sc):
3557+
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
3558+
df.groupby('id').apply(f).collect()
3559+
35233560

35243561
if __name__ == "__main__":
35253562
from pyspark.sql.tests import *

python/pyspark/worker.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from pyspark.serializers import write_with_length, write_int, read_long, \
3333
write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \
3434
BatchedSerializer, ArrowStreamPandasSerializer
35-
from pyspark.sql.types import to_arrow_type, StructType
35+
from pyspark.sql.types import to_arrow_type
3636
from pyspark import shuffle
3737

3838
pickleSer = PickleSerializer()
@@ -74,28 +74,19 @@ def wrap_udf(f, return_type):
7474

7575

7676
def wrap_pandas_udf(f, return_type):
77-
# If the return_type is a StructType, it indicates this is a groupby apply udf,
78-
# and has already been wrapped under apply(), otherwise, it's a vectorized column udf.
79-
# We can distinguish these two by return type because in groupby apply, we always specify
80-
# returnType as a StructType, and in vectorized column udf, StructType is not supported.
81-
#
82-
# TODO: Look into refactoring use of StructType to be more flexible for future pandas_udfs
83-
if isinstance(return_type, StructType):
84-
return lambda *a: f(*a)
85-
else:
86-
arrow_return_type = to_arrow_type(return_type)
77+
arrow_return_type = to_arrow_type(return_type)
8778

88-
def verify_result_length(*a):
89-
result = f(*a)
90-
if not hasattr(result, "__len__"):
91-
raise TypeError("Return type of the user-defined functon should be "
92-
"Pandas.Series, but is {}".format(type(result)))
93-
if len(result) != len(a[0]):
94-
raise RuntimeError("Result vector from pandas_udf was not the required length: "
95-
"expected %d, got %d" % (len(a[0]), len(result)))
96-
return result
79+
def verify_result_length(*a):
80+
result = f(*a)
81+
if not hasattr(result, "__len__"):
82+
raise TypeError("Return type of the user-defined functon should be "
83+
"Pandas.Series, but is {}".format(type(result)))
84+
if len(result) != len(a[0]):
85+
raise RuntimeError("Result vector from pandas_udf was not the required length: "
86+
"expected %d, got %d" % (len(a[0]), len(result)))
87+
return result
9788

98-
return lambda *a: (verify_result_length(*a), arrow_return_type)
89+
return lambda *a: (verify_result_length(*a), arrow_return_type)
9990

10091

10192
def read_single_udf(pickleSer, infile, eval_type):
@@ -111,6 +102,9 @@ def read_single_udf(pickleSer, infile, eval_type):
111102
# the last returnType will be the return type of UDF
112103
if eval_type == PythonEvalType.SQL_PANDAS_UDF:
113104
return arg_offsets, wrap_pandas_udf(row_func, return_type)
105+
elif eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF:
106+
# a groupby apply udf has already been wrapped under apply()
107+
return arg_offsets, row_func
114108
else:
115109
return arg_offsets, wrap_udf(row_func, return_type)
116110

@@ -133,7 +127,8 @@ def read_udfs(pickleSer, infile, eval_type):
133127

134128
func = lambda _, it: map(mapper, it)
135129

136-
if eval_type == PythonEvalType.SQL_PANDAS_UDF:
130+
if eval_type == PythonEvalType.SQL_PANDAS_UDF \
131+
or eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF:
137132
ser = ArrowStreamPandasSerializer()
138133
else:
139134
ser = BatchedSerializer(PickleSerializer(), 100)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expre
2424
* This is used by DataFrame.groupby().apply().
2525
*/
2626
case class FlatMapGroupsInPandas(
27-
groupingAttributes: Seq[Attribute],
28-
functionExpr: Expression,
29-
output: Seq[Attribute],
30-
child: LogicalPlan) extends UnaryNode {
27+
groupingAttributes: Seq[Attribute],
28+
functionExpr: Expression,
29+
output: Seq[Attribute],
30+
child: LogicalPlan) extends UnaryNode {
31+
3132
/**
3233
* This is needed because output attributes are considered `references` when
3334
* passed through the constructor.

sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
3030
import org.apache.spark.sql.catalyst.plans.logical._
3131
import org.apache.spark.sql.catalyst.util.usePrettyExpression
3232
import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
33-
import org.apache.spark.sql.execution.python.PythonUDF
33+
import org.apache.spark.sql.execution.python.{PythonUDF, PythonUdfType}
3434
import org.apache.spark.sql.internal.SQLConf
3535
import org.apache.spark.sql.types.{NumericType, StructType}
3636

@@ -437,7 +437,7 @@ class RelationalGroupedDataset protected[sql](
437437
}
438438

439439
/**
440-
* Applies a vectorized python user-defined function to each group of data.
440+
* Applies a grouped vectorized python user-defined function to each group of data.
441441
* The user-defined function defines a transformation: `pandas.DataFrame` -> `pandas.DataFrame`.
442442
* For each group, all elements in the group are passed as a `pandas.DataFrame` and the results
443443
* for all groups are combined into a new [[DataFrame]].
@@ -449,7 +449,8 @@ class RelationalGroupedDataset protected[sql](
449449
* workers.
450450
*/
451451
private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = {
452-
require(expr.vectorized, "Must pass a vectorized python udf")
452+
require(expr.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF,
453+
"Must pass a grouped vectorized python udf")
453454
require(expr.dataType.isInstanceOf[StructType],
454455
"The returnType of the vectorized python udf must be a StructType")
455456

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
137137
udf.references.subsetOf(child.outputSet)
138138
}
139139
if (validUdfs.nonEmpty) {
140+
if (validUdfs.exists(_.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF)) {
141+
throw new IllegalArgumentException("Can not use grouped vectorized UDFs")
142+
}
143+
140144
val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
141145
AttributeReference(s"pythonUDF$i", u.dataType)()
142146
}
143147

144-
val evaluation = validUdfs.partition(_.vectorized) match {
148+
val evaluation = validUdfs.partition(_.pythonUdfType == PythonUdfType.PANDAS_UDF) match {
145149
case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty =>
146150
ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)
147151
case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>

sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ case class FlatMapGroupsInPandasExec(
9494

9595
val columnarBatchIter = new ArrowPythonRunner(
9696
chainedFunc, bufferSize, reuseWorker,
97-
PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema)
97+
PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema)
9898
.compute(grouped, context.partitionId(), context)
9999

100100
columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output))

0 commit comments

Comments
 (0)