Skip to content

Commit 789e642

Browse files
committed
Restrict the number of arguments for grouped udf to only 1.
1 parent 10512a6 commit 789e642

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

python/pyspark/sql/functions.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2135,15 +2135,20 @@ def wrapper(*args):
21352135
def _create_udf(f, returnType, pythonUdfType):
21362136

21372137
def _udf(f, returnType=StringType(), pythonUdfType=pythonUdfType):
2138-
if pythonUdfType == PythonUdfType.PANDAS_UDF \
2139-
or pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF:
2138+
if pythonUdfType == PythonUdfType.PANDAS_UDF:
21402139
import inspect
21412140
argspec = inspect.getargspec(f)
21422141
if len(argspec.args) == 0 and argspec.varargs is None:
21432142
raise ValueError(
2144-
"0-arg pandas_udfs/pandas_grouped_udfs are not supported. "
2143+
"0-arg pandas_udfs are not supported. "
21452144
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
21462145
)
2146+
elif pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF:
2147+
import inspect
2148+
argspec = inspect.getargspec(f)
2149+
if len(argspec.args) != 1 and argspec.varargs is None:
2150+
raise ValueError("Only 1-arg pandas_grouped_udfs are supported.")
2151+
21472152
udf_obj = UserDefinedFunction(f, returnType, pythonUdfType=pythonUdfType)
21482153
return udf_obj._wrapped()
21492154

python/pyspark/sql/tests.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3516,6 +3516,32 @@ def test_wrong_return_type(self):
35163516
with self.assertRaisesRegexp(Exception, 'Invalid.*type'):
35173517
df.groupby('id').apply(foo).sort('id').toPandas()
35183518

3519+
def test_zero_or_more_than_1_parameters(self):
3520+
from pyspark.sql.functions import pandas_grouped_udf
3521+
error_str = 'Only 1-arg pandas_grouped_udfs are supported.'
3522+
with QuietTest(self.sc):
3523+
with self.assertRaisesRegexp(ValueError, error_str):
3524+
pandas_grouped_udf(lambda: 1, 'one long')
3525+
with self.assertRaisesRegexp(ValueError, error_str):
3526+
@pandas_grouped_udf
3527+
def zero_no_type():
3528+
return 1
3529+
with self.assertRaisesRegexp(ValueError, error_str):
3530+
@pandas_grouped_udf("one long")
3531+
def zero_with_type():
3532+
return 1
3533+
3534+
with self.assertRaisesRegexp(ValueError, error_str):
3535+
pandas_grouped_udf(lambda pdf, x: pdf, 'one long')
3536+
with self.assertRaisesRegexp(ValueError, error_str):
3537+
@pandas_grouped_udf
3538+
def zero_no_type(pdf, x):
3539+
return pdf
3540+
with self.assertRaisesRegexp(ValueError, error_str):
3541+
@pandas_grouped_udf("one long")
3542+
def zero_with_type(pdf, x):
3543+
return pdf
3544+
35193545
def test_wrong_args(self):
35203546
from pyspark.sql.functions import udf, pandas_udf, pandas_grouped_udf, sum
35213547
df = self.data

0 commit comments

Comments
 (0)