@@ -2059,6 +2059,7 @@ def post_processor(vals: np.ndarray, inference: Optional[Type]) -> np.ndarray:
20592059 return self ._get_cythonized_result (
20602060 "group_quantile" ,
20612061 aggregate = True ,
2062+ needs_counts = True ,
20622063 needs_values = True ,
20632064 needs_mask = True ,
20642065 cython_dtype = np .dtype (np .float64 ),
@@ -2072,6 +2073,7 @@ def post_processor(vals: np.ndarray, inference: Optional[Type]) -> np.ndarray:
20722073 self ._get_cythonized_result (
20732074 "group_quantile" ,
20742075 aggregate = True ,
2076+ needs_counts = True ,
20752077 needs_values = True ,
20762078 needs_mask = True ,
20772079 cython_dtype = np .dtype (np .float64 ),
@@ -2348,9 +2350,10 @@ def _get_cythonized_result(
23482350 how : str ,
23492351 cython_dtype : np .dtype ,
23502352 aggregate : bool = False ,
2353+ needs_counts : bool = False ,
23512354 needs_values : bool = False ,
2355+ min_count : Optional [int ] = None ,
23522356 needs_mask : bool = False ,
2353- needs_ngroups : bool = False ,
23542357 result_is_index : bool = False ,
23552358 pre_processing = None ,
23562359 post_processing = None ,
@@ -2367,14 +2370,16 @@ def _get_cythonized_result(
23672370 aggregate : bool, default False
23682371 Whether the result should be aggregated to match the number of
23692372 groups
2373+ needs_counts : bool, default False
2374+ Whether the counts should be a part of the Cython call
23702375 needs_values : bool, default False
23712376 Whether the values should be a part of the Cython call
23722377 signature
2378+ min_count : int, default None
2379+ When not None, min_count for the Cython call
23732380 needs_mask : bool, default False
23742381 Whether boolean mask needs to be part of the Cython call
23752382 signature
2376- needs_ngroups : bool, default False
2377- Whether number of groups is part of the Cython call signature
23782383 result_is_index : bool, default False
23792384 Whether the result of the Cython operation is an index of
23802385 values to be retrieved, instead of the actual values themselves
@@ -2414,74 +2419,63 @@ def _get_cythonized_result(
24142419 labels , _ , ngroups = grouper .group_info
24152420 output : Dict [base .OutputKey , np .ndarray ] = {}
24162421 base_func = getattr (libgroupby , how )
2422+ inferences = None
24172423
2418- if how == "group_quantile" :
2419- values = self ._obj_with_exclusions ._values
2420- result_sz = ngroups if aggregate else len (values )
2424+ values = self ._obj_with_exclusions ._values
2425+ result_sz = ngroups if aggregate else len (values )
2426+ if self ._obj_with_exclusions .ndim == 1 :
2427+ width = 1
2428+ else :
2429+ width = len (self ._obj_with_exclusions .columns )
2430+ result = np .zeros ((result_sz , width ), dtype = cython_dtype )
2431+ func = partial (base_func , result )
24212432
2422- vals , inferences = pre_processing (values )
2423- if self ._obj_with_exclusions .ndim == 1 :
2424- width = 1
2425- vals = np .reshape (vals , (- 1 , 1 ))
2426- else :
2427- width = len (self ._obj_with_exclusions .columns )
2428- result = np .zeros ((result_sz , width ), dtype = cython_dtype )
2433+ if needs_counts :
24292434 counts = np .zeros (self .ngroups , dtype = np .int64 )
2430- mask = isna (vals ).view (np .uint8 )
2431-
2432- func = partial (base_func , result , counts , vals , labels , - 1 , mask )
2433- func (** kwargs ) # Call func to modify indexer values in place
2434- result = post_processing (result , inferences )
2435+ func = partial (func , counts )
24352436
2437+ if needs_values :
2438+ vals = values
2439+ if pre_processing :
2440+ vals , inferences = pre_processing (vals )
24362441 if self ._obj_with_exclusions .ndim == 1 :
2437- key = base .OutputKey (label = self ._obj_with_exclusions .name , position = 0 )
2438- output [key ] = result [:, 0 ]
2439- else :
2440- for idx , name in enumerate (self ._obj_with_exclusions .columns ):
2441- key = base .OutputKey (label = name , position = idx )
2442- output [key ] = result [:, idx ]
2442+ vals = np .reshape (vals , (- 1 , 1 ))
2443+ func = partial (func , vals )
24432444
2444- if aggregate :
2445- return self ._wrap_aggregated_output (output )
2446- else :
2447- return self ._wrap_transformed_output (output )
2445+ # Groupby always needs labels
2446+ func = partial (func , labels )
24482447
2449- for idx , obj in enumerate (self ._iterate_slices ()):
2450- name = obj .name
2451- values = obj ._values
2448+ if min_count is not None :
2449+ func = partial (func , min_count )
24522450
2453- if aggregate :
2454- result_sz = ngroups
2451+ if needs_mask :
2452+ if self ._obj_with_exclusions .ndim == 1 :
2453+ # If needs_values is True, don't need to reshape again
2454+ if needs_values :
2455+ mask = isna (vals ).view (np .uint8 )
2456+ else :
2457+ mask = isna (np .reshape (values , (- 1 , 1 ))).view (np .uint8 )
24552458 else :
2456- result_sz = len (values )
2457-
2458- result = np .zeros (result_sz , dtype = cython_dtype )
2459- func = partial (base_func , result , labels )
2460- inferences = None
2461-
2462- if needs_values :
2463- vals = values
2464- if pre_processing :
2465- vals , inferences = pre_processing (vals )
2466- func = partial (func , vals )
2467-
2468- if needs_mask :
24692459 mask = isna (values ).view (np .uint8 )
2470- func = partial (func , mask )
2471-
2472- if needs_ngroups :
2473- func = partial (func , ngroups )
2460+ func = partial (func , mask )
24742461
2475- func (** kwargs ) # Call func to modify indexer values in place
2462+ func (** kwargs ) # Call func to modify indexer values in place
24762463
2477- if result_is_index :
2478- result = algorithms .take_nd (values , result )
2464+ # TODO: Probably not correct
2465+ if result_is_index :
2466+ result = algorithms .take_nd (values , result )
24792467
2480- if post_processing :
2481- result = post_processing (result , inferences )
2468+ if post_processing :
2469+ result = post_processing (result , inferences )
24822470
2483- key = base .OutputKey (label = name , position = idx )
2484- output [key ] = result
2471+ # TODO: Perhaps there is a better way to get result into output
2472+ if self ._obj_with_exclusions .ndim == 1 :
2473+ key = base .OutputKey (label = self ._obj_with_exclusions .name , position = 0 )
2474+ output [key ] = result [:, 0 ]
2475+ else :
2476+ for idx , name in enumerate (self ._obj_with_exclusions .columns ):
2477+ key = base .OutputKey (label = name , position = idx )
2478+ output [key ] = result [:, idx ]
24852479
24862480 if aggregate :
24872481 return self ._wrap_aggregated_output (output )
0 commit comments