@@ -17,6 +17,7 @@ def generate_numba_apply_func(
1717 kwargs : Dict [str , Any ],
1818 func : Callable [..., Scalar ],
1919 engine_kwargs : Optional [Dict [str , bool ]],
20+ name : str ,
2021):
2122 """
2223 Generate a numba jitted apply function specified by values from engine_kwargs.
@@ -37,14 +38,16 @@ def generate_numba_apply_func(
3738 function to be applied to each window and will be JITed
3839 engine_kwargs : dict
3940 dictionary of arguments to be passed into numba.jit
41+ name: str
42+ name of the caller (Rolling/Expanding)
4043
4144 Returns
4245 -------
4346 Numba function
4447 """
4548 nopython , nogil , parallel = get_jit_arguments (engine_kwargs , kwargs )
4649
47- cache_key = (func , "rolling_apply " )
50+ cache_key = (func , f" { name } _apply_single " )
4851 if cache_key in NUMBA_FUNC_CACHE :
4952 return NUMBA_FUNC_CACHE [cache_key ]
5053
@@ -153,3 +156,67 @@ def groupby_ewma(
153156 return result
154157
155158 return groupby_ewma
159+
160+
161+ def generate_numba_table_func (
162+ args : Tuple ,
163+ kwargs : Dict [str , Any ],
164+ func : Callable [..., np .ndarray ],
165+ engine_kwargs : Optional [Dict [str , bool ]],
166+ name : str ,
167+ ):
168+ """
169+ Generate a numba jitted function to apply window calculations table-wise.
170+
171+ Func will be passed a M window size x N number of columns array, and
172+ must return a 1 x N number of columns array. Func is intended to operate
173+ row-wise, but the result will be transposed for axis=1.
174+
175+ 1. jit the user's function
176+ 2. Return a rolling apply function with the jitted function inline
177+
178+ Parameters
179+ ----------
180+ args : tuple
181+ *args to be passed into the function
182+ kwargs : dict
183+ **kwargs to be passed into the function
184+ func : function
185+ function to be applied to each window and will be JITed
186+ engine_kwargs : dict
187+ dictionary of arguments to be passed into numba.jit
188+ name : str
189+ caller (Rolling/Expanding) and original method name for numba cache key
190+
191+ Returns
192+ -------
193+ Numba function
194+ """
195+ nopython , nogil , parallel = get_jit_arguments (engine_kwargs , kwargs )
196+
197+ cache_key = (func , f"{ name } _table" )
198+ if cache_key in NUMBA_FUNC_CACHE :
199+ return NUMBA_FUNC_CACHE [cache_key ]
200+
201+ numba_func = jit_user_function (func , nopython , nogil , parallel )
202+ numba = import_optional_dependency ("numba" )
203+
204+ @numba .jit (nopython = nopython , nogil = nogil , parallel = parallel )
205+ def roll_table (
206+ values : np .ndarray , begin : np .ndarray , end : np .ndarray , minimum_periods : int
207+ ):
208+ result = np .empty (values .shape )
209+ min_periods_mask = np .empty (values .shape )
210+ for i in numba .prange (len (result )):
211+ start = begin [i ]
212+ stop = end [i ]
213+ window = values [start :stop ]
214+ count_nan = np .sum (np .isnan (window ), axis = 0 )
215+ sub_result = numba_func (window , * args )
216+ nan_mask = len (window ) - count_nan >= minimum_periods
217+ min_periods_mask [i , :] = nan_mask
218+ result [i , :] = sub_result
219+ result = np .where (min_periods_mask , result , np .nan )
220+ return result
221+
222+ return roll_table
0 commit comments