7575import pandas .core .indexes .base as ibase
7676from pandas .core .internals import BlockManager , make_block
7777from pandas .core .series import Series
78+ from pandas .core .util .numba_ import (
79+ check_kwargs_and_nopython ,
80+ get_jit_arguments ,
81+ jit_user_function ,
82+ split_for_numba ,
83+ validate_udf ,
84+ )
7885
7986from pandas .plotting import boxplot_frame_groupby
8087
@@ -154,6 +161,8 @@ def pinner(cls):
154161class SeriesGroupBy (GroupBy [Series ]):
155162 _apply_whitelist = base .series_apply_whitelist
156163
164+ _numba_func_cache : Dict [Callable , Callable ] = {}
165+
157166 def _iterate_slices (self ) -> Iterable [Series ]:
158167 yield self ._selected_obj
159168
@@ -463,11 +472,13 @@ def _aggregate_named(self, func, *args, **kwargs):
463472
464473 @Substitution (klass = "Series" , selected = "A." )
465474 @Appender (_transform_template )
466- def transform (self , func , * args , ** kwargs ):
475+ def transform (self , func , * args , engine = "cython" , engine_kwargs = None , ** kwargs ):
467476 func = self ._get_cython_func (func ) or func
468477
469478 if not isinstance (func , str ):
470- return self ._transform_general (func , * args , ** kwargs )
479+ return self ._transform_general (
480+ func , * args , engine = engine , engine_kwargs = engine_kwargs , ** kwargs
481+ )
471482
472483 elif func not in base .transform_kernel_whitelist :
473484 msg = f"'{ func } ' is not a valid function name for transform(name)"
@@ -482,16 +493,33 @@ def transform(self, func, *args, **kwargs):
482493 result = getattr (self , func )(* args , ** kwargs )
483494 return self ._transform_fast (result , func )
484495
485- def _transform_general (self , func , * args , ** kwargs ):
496+ def _transform_general (
497+ self , func , * args , engine = "cython" , engine_kwargs = None , ** kwargs
498+ ):
486499 """
487500 Transform with a non-str `func`.
488501 """
502+
503+ if engine == "numba" :
504+ nopython , nogil , parallel = get_jit_arguments (engine_kwargs )
505+ check_kwargs_and_nopython (kwargs , nopython )
506+ validate_udf (func )
507+ numba_func = self ._numba_func_cache .get (
508+ func , jit_user_function (func , nopython , nogil , parallel )
509+ )
510+
489511 klass = type (self ._selected_obj )
490512
491513 results = []
492514 for name , group in self :
493515 object .__setattr__ (group , "name" , name )
494- res = func (group , * args , ** kwargs )
516+ if engine == "numba" :
517+ values , index = split_for_numba (group )
518+ res = numba_func (values , index , * args )
519+ if func not in self ._numba_func_cache :
520+ self ._numba_func_cache [func ] = numba_func
521+ else :
522+ res = func (group , * args , ** kwargs )
495523
496524 if isinstance (res , (ABCDataFrame , ABCSeries )):
497525 res = res ._values
@@ -819,6 +847,8 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
819847
820848 _apply_whitelist = base .dataframe_apply_whitelist
821849
850+ _numba_func_cache : Dict [Callable , Callable ] = {}
851+
822852 _agg_see_also_doc = dedent (
823853 """
824854 See Also
@@ -1355,19 +1385,35 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
13551385 # Handle cases like BinGrouper
13561386 return self ._concat_objects (keys , values , not_indexed_same = not_indexed_same )
13571387
1358- def _transform_general (self , func , * args , ** kwargs ):
1388+ def _transform_general (
1389+ self , func , * args , engine = "cython" , engine_kwargs = None , ** kwargs
1390+ ):
13591391 from pandas .core .reshape .concat import concat
13601392
13611393 applied = []
13621394 obj = self ._obj_with_exclusions
13631395 gen = self .grouper .get_iterator (obj , axis = self .axis )
1364- fast_path , slow_path = self ._define_paths (func , * args , ** kwargs )
1396+ if engine == "numba" :
1397+ nopython , nogil , parallel = get_jit_arguments (engine_kwargs )
1398+ check_kwargs_and_nopython (kwargs , nopython )
1399+ validate_udf (func )
1400+ numba_func = self ._numba_func_cache .get (
1401+ func , jit_user_function (func , nopython , nogil , parallel )
1402+ )
1403+ else :
1404+ fast_path , slow_path = self ._define_paths (func , * args , ** kwargs )
13651405
1366- path = None
13671406 for name , group in gen :
13681407 object .__setattr__ (group , "name" , name )
13691408
1370- if path is None :
1409+ if engine == "numba" :
1410+ values , index = split_for_numba (group )
1411+ res = numba_func (values , index , * args )
1412+ if func not in self ._numba_func_cache :
1413+ self ._numba_func_cache [func ] = numba_func
1414+ # Return the result as a DataFrame for concatenation later
1415+ res = DataFrame (res , index = group .index , columns = group .columns )
1416+ else :
13711417 # Try slow path and fast path.
13721418 try :
13731419 path , res = self ._choose_path (fast_path , slow_path , group )
@@ -1376,8 +1422,6 @@ def _transform_general(self, func, *args, **kwargs):
13761422 except ValueError as err :
13771423 msg = "transform must return a scalar value for each group"
13781424 raise ValueError (msg ) from err
1379- else :
1380- res = path (group )
13811425
13821426 if isinstance (res , Series ):
13831427
@@ -1411,13 +1455,15 @@ def _transform_general(self, func, *args, **kwargs):
14111455
14121456 @Substitution (klass = "DataFrame" , selected = "" )
14131457 @Appender (_transform_template )
1414- def transform (self , func , * args , ** kwargs ):
1458+ def transform (self , func , * args , engine = "cython" , engine_kwargs = None , ** kwargs ):
14151459
14161460 # optimized transforms
14171461 func = self ._get_cython_func (func ) or func
14181462
14191463 if not isinstance (func , str ):
1420- return self ._transform_general (func , * args , ** kwargs )
1464+ return self ._transform_general (
1465+ func , * args , engine = engine , engine_kwargs = engine_kwargs , ** kwargs
1466+ )
14211467
14221468 elif func not in base .transform_kernel_whitelist :
14231469 msg = f"'{ func } ' is not a valid function name for transform(name)"
@@ -1439,7 +1485,9 @@ def transform(self, func, *args, **kwargs):
14391485 ):
14401486 return self ._transform_fast (result , func )
14411487
1442- return self ._transform_general (func , * args , ** kwargs )
1488+ return self ._transform_general (
1489+ func , engine = engine , engine_kwargs = engine_kwargs , * args , ** kwargs
1490+ )
14431491
14441492 def _transform_fast (self , result : DataFrame , func_nm : str ) -> DataFrame :
14451493 """
0 commit comments