Skip to content

Commit fdafb35

Browse files
committed
Revert "Restrict checking the number of arguments."
This reverts commit 122a7bc.
1 parent 122a7bc commit fdafb35

File tree

4 files changed

+56
-96
lines changed

4 files changed

+56
-96
lines changed

python/pyspark/sql/functions.py

Lines changed: 27 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -2132,15 +2132,35 @@ def wrapper(*args):
21322132
return wrapper
21332133

21342134

2135-
def _resolve_decorator(create_udf, f, returnType):
2136-
# decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf, @pandas_grouped_udf
2135+
def _create_udf(f, returnType, pythonUdfType):
2136+
2137+
def _udf(f, returnType=StringType(), pythonUdfType=pythonUdfType):
2138+
if pythonUdfType == PythonUdfType.PANDAS_UDF:
2139+
import inspect
2140+
argspec = inspect.getargspec(f)
2141+
if len(argspec.args) == 0 and argspec.varargs is None:
2142+
raise ValueError(
2143+
"0-arg pandas_udfs are not supported. "
2144+
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
2145+
)
2146+
elif pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF:
2147+
import inspect
2148+
argspec = inspect.getargspec(f)
2149+
if len(argspec.args) != 1 and argspec.varargs is None:
2150+
raise ValueError("Only 1-arg pandas_grouped_udfs are supported.")
2151+
2152+
udf_obj = UserDefinedFunction(f, returnType, pythonUdfType=pythonUdfType)
2153+
return udf_obj._wrapped()
2154+
2155+
# decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf
21372156
if f is None or isinstance(f, (str, DataType)):
21382157
# If DataType has been passed as a positional argument
21392158
# for decorator use it as a returnType
21402159
return_type = f or returnType
2141-
return functools.partial(create_udf, returnType=return_type)
2160+
return functools.partial(
2161+
_udf, returnType=return_type, pythonUdfType=pythonUdfType)
21422162
else:
2143-
return create_udf(f=f, returnType=returnType)
2163+
return _udf(f=f, returnType=returnType, pythonUdfType=pythonUdfType)
21442164

21452165

21462166
@since(1.3)
@@ -2174,11 +2194,7 @@ def udf(f=None, returnType=StringType()):
21742194
| 8| JOHN DOE| 22|
21752195
+----------+--------------+------------+
21762196
"""
2177-
def _create_udf(f, returnType):
2178-
udf_obj = UserDefinedFunction(f, returnType, pythonUdfType=PythonUdfType.NORMAL_UDF)
2179-
return udf_obj._wrapped()
2180-
2181-
return _resolve_decorator(_create_udf, f, returnType)
2197+
return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.NORMAL_UDF)
21822198

21832199

21842200
@since(2.3)
@@ -2219,19 +2235,7 @@ def pandas_udf(f=None, returnType=StringType()):
22192235
22202236
.. note:: The user-defined function must be deterministic.
22212237
"""
2222-
def _create_udf(f, returnType):
2223-
import inspect
2224-
argspec = inspect.getargspec(f)
2225-
if len(argspec.args) == 0 and argspec.varargs is None:
2226-
raise ValueError(
2227-
"0-arg pandas_udfs are not supported. "
2228-
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
2229-
)
2230-
2231-
udf_obj = UserDefinedFunction(f, returnType, pythonUdfType=PythonUdfType.PANDAS_UDF)
2232-
return udf_obj._wrapped()
2233-
2234-
return _resolve_decorator(_create_udf, f, returnType)
2238+
return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.PANDAS_UDF)
22352239

22362240

22372241
@since(2.3)
@@ -2276,66 +2280,7 @@ def pandas_grouped_udf(f=None, returnType=StructType()):
22762280
22772281
.. note:: The user-defined function must be deterministic.
22782282
"""
2279-
def _create_udf(f, returnType):
2280-
import inspect
2281-
argspec = inspect.getargspec(f)
2282-
if len(argspec.args) != 1:
2283-
raise ValueError("Only 1-arg pandas_grouped_udfs are supported.")
2284-
2285-
# create a dummy udf object as a placeholder.
2286-
_udf_obj = UserDefinedFunction(
2287-
f, returnType, pythonUdfType=PythonUdfType.PANDAS_GROUPED_UDF)
2288-
2289-
# It is possible for a callable instance without __name__ attribute or/and
2290-
# __module__ attribute to be wrapped here. For example, functools.partial. In this case,
2291-
# we should avoid wrapping the attributes from the wrapped function to the wrapper
2292-
# function. So, we take out these attribute names from the default names to set and
2293-
# then manually assign it after being wrapped.
2294-
assignments = tuple(
2295-
a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__')
2296-
2297-
@functools.wraps(_udf_obj.func, assigned=assignments)
2298-
def wrapper(df):
2299-
2300-
func = _udf_obj.func
2301-
returnType = _udf_obj.returnType
2302-
2303-
# The python executors expects the function to use pd.Series as input and output
2304-
# So we to create a wrapper function that turns that to a pd.DataFrame before passing
2305-
# down to the user function, then turn the result pd.DataFrame back into pd.Series
2306-
columns = df.columns
2307-
2308-
def wrapped(*cols):
2309-
from pyspark.sql.types import to_arrow_type
2310-
import pandas as pd
2311-
result = func(pd.concat(cols, axis=1, keys=columns))
2312-
if not isinstance(result, pd.DataFrame):
2313-
raise TypeError("Return type of the user-defined function should be "
2314-
"Pandas.DataFrame, but is {}".format(type(result)))
2315-
if not len(result.columns) == len(returnType):
2316-
raise RuntimeError(
2317-
"Number of columns of the returned Pandas.DataFrame "
2318-
"doesn't match specified schema. "
2319-
"Expected: {} Actual: {}".format(len(returnType), len(result.columns)))
2320-
arrow_return_types = (to_arrow_type(field.dataType) for field in returnType)
2321-
return [(result[result.columns[i]], arrow_type)
2322-
for i, arrow_type in enumerate(arrow_return_types)]
2323-
2324-
udf_obj = UserDefinedFunction(
2325-
wrapped, returnType, name=_udf_obj._name, pythonUdfType=_udf_obj.pythonUdfType)
2326-
return udf_obj(*[df[col] for col in df.columns])
2327-
2328-
wrapper.__name__ = _udf_obj._name
2329-
wrapper.__module__ = (_udf_obj.func.__module__ if hasattr(_udf_obj.func, '__module__')
2330-
else _udf_obj.func.__class__.__module__)
2331-
2332-
wrapper.func = _udf_obj.func
2333-
wrapper.returnType = _udf_obj.returnType
2334-
wrapper.pythonUdfType = _udf_obj.pythonUdfType
2335-
2336-
return wrapper
2337-
2338-
return _resolve_decorator(_create_udf, f, returnType)
2283+
return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.PANDAS_GROUPED_UDF)
23392284

23402285

23412286
blacklist = ['map', 'since', 'ignore_unicode_prefix']

python/pyspark/sql/group.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,33 @@ def apply(self, udf):
245245
if not isinstance(udf.returnType, StructType):
246246
raise ValueError("The returnType of the pandas_grouped_udf must be a StructType")
247247

248-
udf_column = udf(self._df)
248+
df = self._df
249+
func = udf.func
250+
returnType = udf.returnType
251+
252+
# The python executors expects the function to use pd.Series as input and output
253+
# So we to create a wrapper function that turns that to a pd.DataFrame before passing
254+
# down to the user function, then turn the result pd.DataFrame back into pd.Series
255+
columns = df.columns
256+
257+
def wrapped(*cols):
258+
from pyspark.sql.types import to_arrow_type
259+
import pandas as pd
260+
result = func(pd.concat(cols, axis=1, keys=columns))
261+
if not isinstance(result, pd.DataFrame):
262+
raise TypeError("Return type of the user-defined function should be "
263+
"Pandas.DataFrame, but is {}".format(type(result)))
264+
if not len(result.columns) == len(returnType):
265+
raise RuntimeError(
266+
"Number of columns of the returned Pandas.DataFrame "
267+
"doesn't match specified schema. "
268+
"Expected: {} Actual: {}".format(len(returnType), len(result.columns)))
269+
arrow_return_types = (to_arrow_type(field.dataType) for field in returnType)
270+
return [(result[result.columns[i]], arrow_type)
271+
for i, arrow_type in enumerate(arrow_return_types)]
272+
273+
wrapped_udf_obj = pandas_grouped_udf(wrapped, returnType)
274+
udf_column = wrapped_udf_obj(*[df[col] for col in df.columns])
249275
jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
250276
return DataFrame(jdf, self.sql_ctx)
251277

python/pyspark/sql/tests.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3516,7 +3516,7 @@ def test_wrong_return_type(self):
35163516
with self.assertRaisesRegexp(Exception, 'Invalid.*type'):
35173517
df.groupby('id').apply(foo).sort('id').toPandas()
35183518

3519-
def test_invalid_parameters(self):
3519+
def test_zero_or_more_than_1_parameters(self):
35203520
from pyspark.sql.functions import pandas_grouped_udf
35213521
error_str = 'Only 1-arg pandas_grouped_udfs are supported.'
35223522
with QuietTest(self.sc):
@@ -3542,17 +3542,6 @@ def zero_no_type(pdf, x):
35423542
def zero_with_type(pdf, x):
35433543
return pdf
35443544

3545-
with self.assertRaisesRegexp(ValueError, error_str):
3546-
pandas_grouped_udf(lambda *args: args[0], 'one long')
3547-
with self.assertRaisesRegexp(ValueError, error_str):
3548-
@pandas_grouped_udf
3549-
def zero_no_type(*args):
3550-
return args[0]
3551-
with self.assertRaisesRegexp(ValueError, error_str):
3552-
@pandas_grouped_udf("one long")
3553-
def zero_with_type(*args):
3554-
return args[0]
3555-
35563545
def test_wrong_args(self):
35573546
from pyspark.sql.functions import udf, pandas_udf, pandas_grouped_udf, sum
35583547
df = self.data

python/pyspark/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def read_single_udf(pickleSer, infile, eval_type):
103103
if eval_type == PythonEvalType.SQL_PANDAS_UDF:
104104
return arg_offsets, wrap_pandas_udf(row_func, return_type)
105105
elif eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF:
106-
# a groupby apply udf has already been wrapped
106+
# a groupby apply udf has already been wrapped under apply()
107107
return arg_offsets, row_func
108108
else:
109109
return arg_offsets, wrap_udf(row_func, return_type)

0 commit comments

Comments
 (0)