@@ -36,7 +36,6 @@ class providing the base-class of operations.
3636from pandas ._libs import Timestamp
3737import pandas ._libs .groupby as libgroupby
3838from pandas ._typing import FrameOrSeries , Scalar
39- from pandas .compat import set_function_name
4039from pandas .compat .numpy import function as nv
4140from pandas .errors import AbstractMethodError
4241from pandas .util ._decorators import Appender , Substitution , cache_readonly , doc
@@ -945,6 +944,32 @@ def _wrap_transformed_output(self, output: Mapping[base.OutputKey, np.ndarray]):
945944 def _wrap_applied_output (self , keys , values , not_indexed_same : bool = False ):
946945 raise AbstractMethodError (self )
947946
947+ def _agg_general (
948+ self , numeric_only = True , min_count = - 1 , * , alias : str , npfunc : Callable
949+ ):
950+ self ._set_group_selection ()
951+
952+ # try a cython aggregation if we can
953+ try :
954+ return self ._cython_agg_general (
955+ how = alias , alt = npfunc , numeric_only = numeric_only , min_count = min_count ,
956+ )
957+ except DataError :
958+ pass
959+ except NotImplementedError as err :
960+ if "function is not implemented for this dtype" in str (
961+ err
962+ ) or "category dtype not supported" in str (err ):
963+ # raised in _get_cython_function, in some cases can
964+ # be trimmed by implementing cython funcs for more dtypes
965+ pass
966+ else :
967+ raise
968+
969+ # apply a non-cython aggregation
970+ result = self .aggregate (lambda x : npfunc (x , axis = self .axis ))
971+ return result
972+
948973 def _cython_agg_general (
949974 self , how : str , alt = None , numeric_only : bool = True , min_count : int = - 1
950975 ):
@@ -1438,105 +1463,62 @@ def size(self):
14381463 result = self ._obj_1d_constructor (result )
14391464 return self ._reindex_output (result , fill_value = 0 )
14401465
1441- @classmethod
1442- def _add_numeric_operations (cls ):
1443- """
1444- Add numeric operations to the GroupBy generically.
1466+ def sum (self , numeric_only = True , min_count = 0 ):
1467+ return self ._agg_general (
1468+ numeric_only = numeric_only , min_count = min_count , alias = "add" , npfunc = np .sum
1469+ )
1470+
1471+ def prod (self , numeric_only = True , min_count = 0 ):
1472+ return self ._agg_general (
1473+ numeric_only = numeric_only , min_count = min_count , alias = "prod" , npfunc = np .prod
1474+ )
1475+
1476+ def min (self , numeric_only = False , min_count = - 1 ):
1477+ return self ._agg_general (
1478+ numeric_only = numeric_only , min_count = min_count , alias = "min" , npfunc = np .min
1479+ )
1480+
1481+ def max (self , numeric_only = False , min_count = - 1 ):
1482+ return self ._agg_general (
1483+ numeric_only = numeric_only , min_count = min_count , alias = "max" , npfunc = np .max
1484+ )
1485+
1486+ @staticmethod
1487+ def _get_loc (x , axis : int = 0 , * , loc : int ):
1488+ """Helper function for first/last item that isn't NA.
14451489 """
14461490
1447- def groupby_function (
1448- name : str ,
1449- alias : str ,
1450- npfunc ,
1451- numeric_only : bool = True ,
1452- min_count : int = - 1 ,
1453- ):
1491+ def get_loc_notna (x , loc : int ):
1492+ x = x .to_numpy ()
1493+ x = x [notna (x )]
1494+ if len (x ) == 0 :
1495+ return np .nan
1496+ return x [loc ]
14541497
1455- _local_template = """
1456- Compute %(f)s of group values.
1457-
1458- Parameters
1459- ----------
1460- numeric_only : bool, default %(no)s
1461- Include only float, int, boolean columns. If None, will attempt to use
1462- everything, then use only numeric data.
1463- min_count : int, default %(mc)s
1464- The required number of valid values to perform the operation. If fewer
1465- than ``min_count`` non-NA values are present the result will be NA.
1466-
1467- Returns
1468- -------
1469- Series or DataFrame
1470- Computed %(f)s of values within each group.
1471- """
1472-
1473- @Substitution (name = "groupby" , f = name , no = numeric_only , mc = min_count )
1474- @Appender (_common_see_also )
1475- @Appender (_local_template )
1476- def func (self , numeric_only = numeric_only , min_count = min_count ):
1477- self ._set_group_selection ()
1478-
1479- # try a cython aggregation if we can
1480- try :
1481- return self ._cython_agg_general (
1482- how = alias ,
1483- alt = npfunc ,
1484- numeric_only = numeric_only ,
1485- min_count = min_count ,
1486- )
1487- except DataError :
1488- pass
1489- except NotImplementedError as err :
1490- if "function is not implemented for this dtype" in str (
1491- err
1492- ) or "category dtype not supported" in str (err ):
1493- # raised in _get_cython_function, in some cases can
1494- # be trimmed by implementing cython funcs for more dtypes
1495- pass
1496- else :
1497- raise
1498-
1499- # apply a non-cython aggregation
1500- result = self .aggregate (lambda x : npfunc (x , axis = self .axis ))
1501- return result
1502-
1503- set_function_name (func , name , cls )
1504-
1505- return func
1506-
1507- def first_compat (x , axis = 0 ):
1508- def first (x ):
1509- x = x .to_numpy ()
1510-
1511- x = x [notna (x )]
1512- if len (x ) == 0 :
1513- return np .nan
1514- return x [0 ]
1515-
1516- if isinstance (x , DataFrame ):
1517- return x .apply (first , axis = axis )
1518- else :
1519- return first (x )
1520-
1521- def last_compat (x , axis = 0 ):
1522- def last (x ):
1523- x = x .to_numpy ()
1524- x = x [notna (x )]
1525- if len (x ) == 0 :
1526- return np .nan
1527- return x [- 1 ]
1528-
1529- if isinstance (x , DataFrame ):
1530- return x .apply (last , axis = axis )
1531- else :
1532- return last (x )
1533-
1534- cls .sum = groupby_function ("sum" , "add" , np .sum , min_count = 0 )
1535- cls .prod = groupby_function ("prod" , "prod" , np .prod , min_count = 0 )
1536- cls .min = groupby_function ("min" , "min" , np .min , numeric_only = False )
1537- cls .max = groupby_function ("max" , "max" , np .max , numeric_only = False )
1538- cls .first = groupby_function ("first" , "first" , first_compat , numeric_only = False )
1539- cls .last = groupby_function ("last" , "last" , last_compat , numeric_only = False )
1498+ if isinstance (x , DataFrame ):
1499+ return x .apply (get_loc_notna , axis = axis , loc = loc )
1500+ else :
1501+ return get_loc_notna (x , loc = loc )
1502+
1503+ def first (self , numeric_only = False , min_count = - 1 ):
1504+ first_compat = partial (self ._get_loc , loc = 0 )
1505+
1506+ return self ._agg_general (
1507+ numeric_only = numeric_only ,
1508+ min_count = min_count ,
1509+ alias = "first" ,
1510+ npfunc = first_compat ,
1511+ )
1512+
1513+ def last (self , numeric_only = False , min_count = - 1 ):
1514+ last_compat = partial (self ._get_loc , loc = - 1 )
1515+
1516+ return self ._agg_general (
1517+ numeric_only = numeric_only ,
1518+ min_count = min_count ,
1519+ alias = "last" ,
1520+ npfunc = last_compat ,
1521+ )
15401522
15411523 @Substitution (name = "groupby" )
15421524 @Appender (_common_see_also )
0 commit comments