9797 get_indexer_dict ,
9898)
9999
100- _CYTHON_FUNCTIONS = {
101- "aggregate" : {
102- "add" : "group_add" ,
103- "prod" : "group_prod" ,
104- "min" : "group_min" ,
105- "max" : "group_max" ,
106- "mean" : "group_mean" ,
107- "median" : "group_median" ,
108- "var" : "group_var" ,
109- "first" : "group_nth" ,
110- "last" : "group_last" ,
111- "ohlc" : "group_ohlc" ,
112- },
113- "transform" : {
114- "cumprod" : "group_cumprod" ,
115- "cumsum" : "group_cumsum" ,
116- "cummin" : "group_cummin" ,
117- "cummax" : "group_cummax" ,
118- "rank" : "group_rank" ,
119- },
120- }
121-
122-
123- @functools .lru_cache (maxsize = None )
124- def _get_cython_function (kind : str , how : str , dtype : np .dtype , is_numeric : bool ):
125-
126- dtype_str = dtype .name
127- ftype = _CYTHON_FUNCTIONS [kind ][how ]
128-
129- # see if there is a fused-type version of function
130- # only valid for numeric
131- f = getattr (libgroupby , ftype , None )
132- if f is not None :
133- if is_numeric :
134- return f
135- elif dtype == object :
136- if "object" not in f .__signatures__ :
137- # raise NotImplementedError here rather than TypeError later
100+
101+ class WrappedCythonOp :
102+ """
103+ Dispatch logic for functions defined in _libs.groupby
104+ """
105+
106+ def __init__ (self , kind : str , how : str ):
107+ self .kind = kind
108+ self .how = how
109+
110+ _CYTHON_FUNCTIONS = {
111+ "aggregate" : {
112+ "add" : "group_add" ,
113+ "prod" : "group_prod" ,
114+ "min" : "group_min" ,
115+ "max" : "group_max" ,
116+ "mean" : "group_mean" ,
117+ "median" : "group_median" ,
118+ "var" : "group_var" ,
119+ "first" : "group_nth" ,
120+ "last" : "group_last" ,
121+ "ohlc" : "group_ohlc" ,
122+ },
123+ "transform" : {
124+ "cumprod" : "group_cumprod" ,
125+ "cumsum" : "group_cumsum" ,
126+ "cummin" : "group_cummin" ,
127+ "cummax" : "group_cummax" ,
128+ "rank" : "group_rank" ,
129+ },
130+ }
131+
132+ _cython_arity = {"ohlc" : 4 } # OHLC
133+
134+ # Note: we make this a classmethod and pass kind+how so that caching
135+ # works at the class level and not the instance level
136+ @classmethod
137+ @functools .lru_cache (maxsize = None )
138+ def _get_cython_function (
139+ cls , kind : str , how : str , dtype : np .dtype , is_numeric : bool
140+ ):
141+
142+ dtype_str = dtype .name
143+ ftype = cls ._CYTHON_FUNCTIONS [kind ][how ]
144+
145+ # see if there is a fused-type version of function
146+ # only valid for numeric
147+ f = getattr (libgroupby , ftype , None )
148+ if f is not None :
149+ if is_numeric :
150+ return f
151+ elif dtype == object :
152+ if "object" not in f .__signatures__ :
153+ # raise NotImplementedError here rather than TypeError later
154+ raise NotImplementedError (
155+ f"function is not implemented for this dtype: "
156+ f"[how->{ how } ,dtype->{ dtype_str } ]"
157+ )
158+ return f
159+
160+ raise NotImplementedError (
161+ f"function is not implemented for this dtype: "
162+ f"[how->{ how } ,dtype->{ dtype_str } ]"
163+ )
164+
165+ def get_cython_func_and_vals (self , values : np .ndarray , is_numeric : bool ):
166+ """
167+ Find the appropriate cython function, casting if necessary.
168+
169+ Parameters
170+ ----------
171+ values : np.ndarray
172+ is_numeric : bool
173+
174+ Returns
175+ -------
176+ func : callable
177+ values : np.ndarray
178+ """
179+ how = self .how
180+ kind = self .kind
181+
182+ if how in ["median" , "cumprod" ]:
183+ # these two only have float64 implementations
184+ if is_numeric :
185+ values = ensure_float64 (values )
186+ else :
138187 raise NotImplementedError (
139188 f"function is not implemented for this dtype: "
140- f"[how->{ how } ,dtype->{ dtype_str } ]"
189+ f"[how->{ how } ,dtype->{ values . dtype . name } ]"
141190 )
142- return f
191+ func = getattr (libgroupby , f"group_{ how } _float64" )
192+ return func , values
143193
144- raise NotImplementedError (
145- f"function is not implemented for this dtype: "
146- f"[how->{ how } ,dtype->{ dtype_str } ]"
147- )
194+ func = self ._get_cython_function (kind , how , values .dtype , is_numeric )
195+
196+ if values .dtype .kind in ["i" , "u" ]:
197+ if how in ["add" , "var" , "prod" , "mean" , "ohlc" ]:
198+ # result may still include NaN, so we have to cast
199+ values = ensure_float64 (values )
200+
201+ return func , values
202+
203+ def disallow_invalid_ops (self , dtype : DtypeObj , is_numeric : bool = False ):
204+ """
205+ Check if we can do this operation with our cython functions.
206+
207+ Raises
208+ ------
209+ NotImplementedError
210+ This is either not a valid function for this dtype, or
211+ valid but not implemented in cython.
212+ """
213+ how = self .how
214+
215+ if is_numeric :
216+ # never an invalid op for those dtypes, so return early as fastpath
217+ return
218+
219+ if is_categorical_dtype (dtype ) or is_sparse (dtype ):
220+ # categoricals are only 1d, so we
221+ # are not setup for dim transforming
222+ raise NotImplementedError (f"{ dtype } dtype not supported" )
223+ elif is_datetime64_any_dtype (dtype ):
224+ # we raise NotImplemented if this is an invalid operation
225+ # entirely, e.g. adding datetimes
226+ if how in ["add" , "prod" , "cumsum" , "cumprod" ]:
227+ raise NotImplementedError (
228+ f"datetime64 type does not support { how } operations"
229+ )
230+ elif is_timedelta64_dtype (dtype ):
231+ if how in ["prod" , "cumprod" ]:
232+ raise NotImplementedError (
233+ f"timedelta64 type does not support { how } operations"
234+ )
235+
236+ def get_output_shape (self , ngroups : int , values : np .ndarray ) -> Shape :
237+ how = self .how
238+ kind = self .kind
239+
240+ arity = self ._cython_arity .get (how , 1 )
241+
242+ out_shape : Shape
243+ if how == "ohlc" :
244+ out_shape = (ngroups , 4 )
245+ elif arity > 1 :
246+ raise NotImplementedError (
247+ "arity of more than 1 is not supported for the 'how' argument"
248+ )
249+ elif kind == "transform" :
250+ out_shape = values .shape
251+ else :
252+ out_shape = (ngroups ,) + values .shape [1 :]
253+ return out_shape
254+
255+ def get_out_dtype (self , dtype : np .dtype ) -> np .dtype :
256+ how = self .how
257+
258+ if how == "rank" :
259+ out_dtype = "float64"
260+ else :
261+ if is_numeric_dtype (dtype ):
262+ out_dtype = f"{ dtype .kind } { dtype .itemsize } "
263+ else :
264+ out_dtype = "object"
265+ return np .dtype (out_dtype )
148266
149267
150268class BaseGrouper :
@@ -437,8 +555,6 @@ def get_group_levels(self) -> List[Index]:
437555 # ------------------------------------------------------------
438556 # Aggregation functions
439557
440- _cython_arity = {"ohlc" : 4 } # OHLC
441-
442558 @final
443559 def _is_builtin_func (self , arg ):
444560 """
@@ -447,80 +563,6 @@ def _is_builtin_func(self, arg):
447563 """
448564 return SelectionMixin ._builtin_table .get (arg , arg )
449565
450- @final
451- def _get_cython_func_and_vals (
452- self , kind : str , how : str , values : np .ndarray , is_numeric : bool
453- ):
454- """
455- Find the appropriate cython function, casting if necessary.
456-
457- Parameters
458- ----------
459- kind : str
460- how : str
461- values : np.ndarray
462- is_numeric : bool
463-
464- Returns
465- -------
466- func : callable
467- values : np.ndarray
468- """
469- if how in ["median" , "cumprod" ]:
470- # these two only have float64 implementations
471- if is_numeric :
472- values = ensure_float64 (values )
473- else :
474- raise NotImplementedError (
475- f"function is not implemented for this dtype: "
476- f"[how->{ how } ,dtype->{ values .dtype .name } ]"
477- )
478- func = getattr (libgroupby , f"group_{ how } _float64" )
479- return func , values
480-
481- func = _get_cython_function (kind , how , values .dtype , is_numeric )
482-
483- if values .dtype .kind in ["i" , "u" ]:
484- if how in ["add" , "var" , "prod" , "mean" , "ohlc" ]:
485- # result may still include NaN, so we have to cast
486- values = ensure_float64 (values )
487-
488- return func , values
489-
490- @final
491- def _disallow_invalid_ops (
492- self , dtype : DtypeObj , how : str , is_numeric : bool = False
493- ):
494- """
495- Check if we can do this operation with our cython functions.
496-
497- Raises
498- ------
499- NotImplementedError
500- This is either not a valid function for this dtype, or
501- valid but not implemented in cython.
502- """
503- if is_numeric :
504- # never an invalid op for those dtypes, so return early as fastpath
505- return
506-
507- if is_categorical_dtype (dtype ) or is_sparse (dtype ):
508- # categoricals are only 1d, so we
509- # are not setup for dim transforming
510- raise NotImplementedError (f"{ dtype } dtype not supported" )
511- elif is_datetime64_any_dtype (dtype ):
512- # we raise NotImplemented if this is an invalid operation
513- # entirely, e.g. adding datetimes
514- if how in ["add" , "prod" , "cumsum" , "cumprod" ]:
515- raise NotImplementedError (
516- f"datetime64 type does not support { how } operations"
517- )
518- elif is_timedelta64_dtype (dtype ):
519- if how in ["prod" , "cumprod" ]:
520- raise NotImplementedError (
521- f"timedelta64 type does not support { how } operations"
522- )
523-
524566 @final
525567 def _ea_wrap_cython_operation (
526568 self , kind : str , values , how : str , axis : int , min_count : int = - 1 , ** kwargs
@@ -593,9 +635,11 @@ def _cython_operation(
593635 dtype = values .dtype
594636 is_numeric = is_numeric_dtype (dtype )
595637
638+ cy_op = WrappedCythonOp (kind = kind , how = how )
639+
596640 # can we do this operation with our cython functions
597641 # if not raise NotImplementedError
598- self . _disallow_invalid_ops (dtype , how , is_numeric )
642+ cy_op . disallow_invalid_ops (dtype , is_numeric )
599643
600644 if is_extension_array_dtype (dtype ):
601645 return self ._ea_wrap_cython_operation (
@@ -637,43 +681,23 @@ def _cython_operation(
637681 if not is_complex_dtype (dtype ):
638682 values = ensure_float64 (values )
639683
640- arity = self ._cython_arity .get (how , 1 )
641684 ngroups = self .ngroups
685+ comp_ids , _ , _ = self .group_info
642686
643687 assert axis == 1
644688 values = values .T
645- if how == "ohlc" :
646- out_shape = (ngroups , 4 )
647- elif arity > 1 :
648- raise NotImplementedError (
649- "arity of more than 1 is not supported for the 'how' argument"
650- )
651- elif kind == "transform" :
652- out_shape = values .shape
653- else :
654- out_shape = (ngroups ,) + values .shape [1 :]
655-
656- func , values = self ._get_cython_func_and_vals (kind , how , values , is_numeric )
657-
658- if how == "rank" :
659- out_dtype = "float"
660- else :
661- if is_numeric :
662- out_dtype = f"{ values .dtype .kind } { values .dtype .itemsize } "
663- else :
664- out_dtype = "object"
665689
666- codes , _ , _ = self .group_info
690+ out_shape = cy_op .get_output_shape (ngroups , values )
691+ func , values = cy_op .get_cython_func_and_vals (values , is_numeric )
692+ out_dtype = cy_op .get_out_dtype (values .dtype )
667693
668694 result = maybe_fill (np .empty (out_shape , dtype = out_dtype ))
669695 if kind == "aggregate" :
670- counts = np .zeros (self . ngroups , dtype = np .int64 )
671- result = self . _aggregate (result , counts , values , codes , func , min_count )
696+ counts = np .zeros (ngroups , dtype = np .int64 )
697+ func (result , counts , values , comp_ids , min_count )
672698 elif kind == "transform" :
673699 # TODO: min_count
674- result = self ._transform (
675- result , values , codes , func , is_datetimelike , ** kwargs
676- )
700+ func (result , values , comp_ids , ngroups , is_datetimelike , ** kwargs )
677701
678702 if is_integer_dtype (result .dtype ) and not is_datetimelike :
679703 mask = result == iNaT
@@ -697,28 +721,6 @@ def _cython_operation(
697721
698722 return op_result
699723
700- @final
701- def _aggregate (
702- self , result , counts , values , comp_ids , agg_func , min_count : int = - 1
703- ):
704- if agg_func is libgroupby .group_nth :
705- # different signature from the others
706- agg_func (result , counts , values , comp_ids , min_count , rank = 1 )
707- else :
708- agg_func (result , counts , values , comp_ids , min_count )
709-
710- return result
711-
712- @final
713- def _transform (
714- self , result , values , comp_ids , transform_func , is_datetimelike : bool , ** kwargs
715- ):
716-
717- _ , _ , ngroups = self .group_info
718- transform_func (result , values , comp_ids , ngroups , is_datetimelike , ** kwargs )
719-
720- return result
721-
722724 def agg_series (self , obj : Series , func : F ):
723725 # Caller is responsible for checking ngroups != 0
724726 assert self .ngroups != 0
0 commit comments