@@ -2132,15 +2132,35 @@ def wrapper(*args):
21322132 return wrapper
21332133
21342134
2135- def _resolve_decorator (create_udf , f , returnType ):
2136- # decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf, @pandas_grouped_udf
2135+ def _create_udf (f , returnType , pythonUdfType ):
2136+
2137+ def _udf (f , returnType = StringType (), pythonUdfType = pythonUdfType ):
2138+ if pythonUdfType == PythonUdfType .PANDAS_UDF :
2139+ import inspect
2140+ argspec = inspect .getargspec (f )
2141+ if len (argspec .args ) == 0 and argspec .varargs is None :
2142+ raise ValueError (
2143+ "0-arg pandas_udfs are not supported. "
2144+ "Instead, create a 1-arg pandas_udf and ignore the arg in your function."
2145+ )
2146+ elif pythonUdfType == PythonUdfType .PANDAS_GROUPED_UDF :
2147+ import inspect
2148+ argspec = inspect .getargspec (f )
2149+ if len (argspec .args ) != 1 and argspec .varargs is None :
2150+ raise ValueError ("Only 1-arg pandas_grouped_udfs are supported." )
2151+
2152+ udf_obj = UserDefinedFunction (f , returnType , pythonUdfType = pythonUdfType )
2153+ return udf_obj ._wrapped ()
2154+
2155+ # decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf
21372156 if f is None or isinstance (f , (str , DataType )):
21382157 # If DataType has been passed as a positional argument
21392158 # for decorator use it as a returnType
21402159 return_type = f or returnType
2141- return functools .partial (create_udf , returnType = return_type )
2160+ return functools .partial (
2161+ _udf , returnType = return_type , pythonUdfType = pythonUdfType )
21422162 else :
2143- return create_udf (f = f , returnType = returnType )
2163+ return _udf (f = f , returnType = returnType , pythonUdfType = pythonUdfType )
21442164
21452165
21462166@since (1.3 )
@@ -2174,11 +2194,7 @@ def udf(f=None, returnType=StringType()):
21742194 | 8| JOHN DOE| 22|
21752195 +----------+--------------+------------+
21762196 """
2177- def _create_udf (f , returnType ):
2178- udf_obj = UserDefinedFunction (f , returnType , pythonUdfType = PythonUdfType .NORMAL_UDF )
2179- return udf_obj ._wrapped ()
2180-
2181- return _resolve_decorator (_create_udf , f , returnType )
2197+ return _create_udf (f , returnType = returnType , pythonUdfType = PythonUdfType .NORMAL_UDF )
21822198
21832199
21842200@since (2.3 )
@@ -2219,19 +2235,7 @@ def pandas_udf(f=None, returnType=StringType()):
22192235
22202236 .. note:: The user-defined function must be deterministic.
22212237 """
2222- def _create_udf (f , returnType ):
2223- import inspect
2224- argspec = inspect .getargspec (f )
2225- if len (argspec .args ) == 0 and argspec .varargs is None :
2226- raise ValueError (
2227- "0-arg pandas_udfs are not supported. "
2228- "Instead, create a 1-arg pandas_udf and ignore the arg in your function."
2229- )
2230-
2231- udf_obj = UserDefinedFunction (f , returnType , pythonUdfType = PythonUdfType .PANDAS_UDF )
2232- return udf_obj ._wrapped ()
2233-
2234- return _resolve_decorator (_create_udf , f , returnType )
2238+ return _create_udf (f , returnType = returnType , pythonUdfType = PythonUdfType .PANDAS_UDF )
22352239
22362240
22372241@since (2.3 )
@@ -2276,66 +2280,7 @@ def pandas_grouped_udf(f=None, returnType=StructType()):
22762280
22772281 .. note:: The user-defined function must be deterministic.
22782282 """
2279- def _create_udf (f , returnType ):
2280- import inspect
2281- argspec = inspect .getargspec (f )
2282- if len (argspec .args ) != 1 :
2283- raise ValueError ("Only 1-arg pandas_grouped_udfs are supported." )
2284-
2285- # create a dummy udf object as a placeholder.
2286- _udf_obj = UserDefinedFunction (
2287- f , returnType , pythonUdfType = PythonUdfType .PANDAS_GROUPED_UDF )
2288-
2289- # It is possible for a callable instance without __name__ attribute or/and
2290- # __module__ attribute to be wrapped here. For example, functools.partial. In this case,
2291- # we should avoid wrapping the attributes from the wrapped function to the wrapper
2292- # function. So, we take out these attribute names from the default names to set and
2293- # then manually assign it after being wrapped.
2294- assignments = tuple (
2295- a for a in functools .WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__' )
2296-
2297- @functools .wraps (_udf_obj .func , assigned = assignments )
2298- def wrapper (df ):
2299-
2300- func = _udf_obj .func
2301- returnType = _udf_obj .returnType
2302-
2303- # The python executors expects the function to use pd.Series as input and output
2304- # So we to create a wrapper function that turns that to a pd.DataFrame before passing
2305- # down to the user function, then turn the result pd.DataFrame back into pd.Series
2306- columns = df .columns
2307-
2308- def wrapped (* cols ):
2309- from pyspark .sql .types import to_arrow_type
2310- import pandas as pd
2311- result = func (pd .concat (cols , axis = 1 , keys = columns ))
2312- if not isinstance (result , pd .DataFrame ):
2313- raise TypeError ("Return type of the user-defined function should be "
2314- "Pandas.DataFrame, but is {}" .format (type (result )))
2315- if not len (result .columns ) == len (returnType ):
2316- raise RuntimeError (
2317- "Number of columns of the returned Pandas.DataFrame "
2318- "doesn't match specified schema. "
2319- "Expected: {} Actual: {}" .format (len (returnType ), len (result .columns )))
2320- arrow_return_types = (to_arrow_type (field .dataType ) for field in returnType )
2321- return [(result [result .columns [i ]], arrow_type )
2322- for i , arrow_type in enumerate (arrow_return_types )]
2323-
2324- udf_obj = UserDefinedFunction (
2325- wrapped , returnType , name = _udf_obj ._name , pythonUdfType = _udf_obj .pythonUdfType )
2326- return udf_obj (* [df [col ] for col in df .columns ])
2327-
2328- wrapper .__name__ = _udf_obj ._name
2329- wrapper .__module__ = (_udf_obj .func .__module__ if hasattr (_udf_obj .func , '__module__' )
2330- else _udf_obj .func .__class__ .__module__ )
2331-
2332- wrapper .func = _udf_obj .func
2333- wrapper .returnType = _udf_obj .returnType
2334- wrapper .pythonUdfType = _udf_obj .pythonUdfType
2335-
2336- return wrapper
2337-
2338- return _resolve_decorator (_create_udf , f , returnType )
2283+ return _create_udf (f , returnType = returnType , pythonUdfType = PythonUdfType .PANDAS_GROUPED_UDF )
23392284
23402285
23412286blacklist = ['map' , 'since' , 'ignore_unicode_prefix' ]
0 commit comments