@@ -36,7 +36,6 @@ class providing the base-class of operations.
3636from pandas ._libs import Timestamp
3737import pandas ._libs .groupby as libgroupby
3838from pandas ._typing import FrameOrSeries , Scalar
39- from pandas .compat import set_function_name
4039from pandas .compat .numpy import function as nv
4140from pandas .errors import AbstractMethodError
4241from pandas .util ._decorators import Appender , Substitution , cache_readonly , doc
@@ -942,6 +941,32 @@ def _wrap_transformed_output(self, output: Mapping[base.OutputKey, np.ndarray]):
942941 def _wrap_applied_output (self , keys , values , not_indexed_same : bool = False ):
943942 raise AbstractMethodError (self )
944943
944+ def _agg_general (
945+ self , numeric_only = True , min_count = - 1 , * , alias : str , npfunc : Callable
946+ ):
947+ self ._set_group_selection ()
948+
949+ # try a cython aggregation if we can
950+ try :
951+ return self ._cython_agg_general (
952+ how = alias , alt = npfunc , numeric_only = numeric_only , min_count = min_count ,
953+ )
954+ except DataError :
955+ pass
956+ except NotImplementedError as err :
957+ if "function is not implemented for this dtype" in str (
958+ err
959+ ) or "category dtype not supported" in str (err ):
960+ # raised in _get_cython_function, in some cases can
961+ # be trimmed by implementing cython funcs for more dtypes
962+ pass
963+ else :
964+ raise
965+
966+ # apply a non-cython aggregation
967+ result = self .aggregate (lambda x : npfunc (x , axis = self .axis ))
968+ return result
969+
945970 def _cython_agg_general (
946971 self , how : str , alt = None , numeric_only : bool = True , min_count : int = - 1
947972 ):
@@ -1424,105 +1449,62 @@ def size(self):
14241449 result .name = self .obj .name
14251450 return self ._reindex_output (result , fill_value = 0 )
14261451
1427- @classmethod
1428- def _add_numeric_operations (cls ):
1429- """
1430- Add numeric operations to the GroupBy generically.
1452+ def sum (self , numeric_only = True , min_count = 0 ):
1453+ return self ._agg_general (
1454+ numeric_only = numeric_only , min_count = min_count , alias = "add" , npfunc = np .sum
1455+ )
1456+
1457+ def prod (self , numeric_only = True , min_count = 0 ):
1458+ return self ._agg_general (
1459+ numeric_only = numeric_only , min_count = min_count , alias = "prod" , npfunc = np .prod
1460+ )
1461+
1462+ def min (self , numeric_only = False , min_count = - 1 ):
1463+ return self ._agg_general (
1464+ numeric_only = numeric_only , min_count = min_count , alias = "min" , npfunc = np .min
1465+ )
1466+
1467+ def max (self , numeric_only = False , min_count = - 1 ):
1468+ return self ._agg_general (
1469+ numeric_only = numeric_only , min_count = min_count , alias = "max" , npfunc = np .max
1470+ )
1471+
1472+ @staticmethod
1473+ def _get_loc (x , axis : int = 0 , * , loc : int ):
1474+ """Helper function for first/last item that isn't NA.
14311475 """
14321476
1433- def groupby_function (
1434- name : str ,
1435- alias : str ,
1436- npfunc ,
1437- numeric_only : bool = True ,
1438- min_count : int = - 1 ,
1439- ):
1477+ def get_loc_notna (x , loc : int ):
1478+ x = x .to_numpy ()
1479+ x = x [notna (x )]
1480+ if len (x ) == 0 :
1481+ return np .nan
1482+ return x [loc ]
14401483
1441- _local_template = """
1442- Compute %(f)s of group values.
1443-
1444- Parameters
1445- ----------
1446- numeric_only : bool, default %(no)s
1447- Include only float, int, boolean columns. If None, will attempt to use
1448- everything, then use only numeric data.
1449- min_count : int, default %(mc)s
1450- The required number of valid values to perform the operation. If fewer
1451- than ``min_count`` non-NA values are present the result will be NA.
1452-
1453- Returns
1454- -------
1455- Series or DataFrame
1456- Computed %(f)s of values within each group.
1457- """
1458-
1459- @Substitution (name = "groupby" , f = name , no = numeric_only , mc = min_count )
1460- @Appender (_common_see_also )
1461- @Appender (_local_template )
1462- def func (self , numeric_only = numeric_only , min_count = min_count ):
1463- self ._set_group_selection ()
1464-
1465- # try a cython aggregation if we can
1466- try :
1467- return self ._cython_agg_general (
1468- how = alias ,
1469- alt = npfunc ,
1470- numeric_only = numeric_only ,
1471- min_count = min_count ,
1472- )
1473- except DataError :
1474- pass
1475- except NotImplementedError as err :
1476- if "function is not implemented for this dtype" in str (
1477- err
1478- ) or "category dtype not supported" in str (err ):
1479- # raised in _get_cython_function, in some cases can
1480- # be trimmed by implementing cython funcs for more dtypes
1481- pass
1482- else :
1483- raise
1484-
1485- # apply a non-cython aggregation
1486- result = self .aggregate (lambda x : npfunc (x , axis = self .axis ))
1487- return result
1488-
1489- set_function_name (func , name , cls )
1490-
1491- return func
1492-
1493- def first_compat (x , axis = 0 ):
1494- def first (x ):
1495- x = x .to_numpy ()
1496-
1497- x = x [notna (x )]
1498- if len (x ) == 0 :
1499- return np .nan
1500- return x [0 ]
1501-
1502- if isinstance (x , DataFrame ):
1503- return x .apply (first , axis = axis )
1504- else :
1505- return first (x )
1506-
1507- def last_compat (x , axis = 0 ):
1508- def last (x ):
1509- x = x .to_numpy ()
1510- x = x [notna (x )]
1511- if len (x ) == 0 :
1512- return np .nan
1513- return x [- 1 ]
1514-
1515- if isinstance (x , DataFrame ):
1516- return x .apply (last , axis = axis )
1517- else :
1518- return last (x )
1519-
1520- cls .sum = groupby_function ("sum" , "add" , np .sum , min_count = 0 )
1521- cls .prod = groupby_function ("prod" , "prod" , np .prod , min_count = 0 )
1522- cls .min = groupby_function ("min" , "min" , np .min , numeric_only = False )
1523- cls .max = groupby_function ("max" , "max" , np .max , numeric_only = False )
1524- cls .first = groupby_function ("first" , "first" , first_compat , numeric_only = False )
1525- cls .last = groupby_function ("last" , "last" , last_compat , numeric_only = False )
1484+ if isinstance (x , DataFrame ):
1485+ return x .apply (get_loc_notna , axis = axis , loc = loc )
1486+ else :
1487+ return get_loc_notna (x , loc = loc )
1488+
1489+ def first (self , numeric_only = False , min_count = - 1 ):
1490+ first_compat = partial (self ._get_loc , loc = 0 )
1491+
1492+ return self ._agg_general (
1493+ numeric_only = numeric_only ,
1494+ min_count = min_count ,
1495+ alias = "first" ,
1496+ npfunc = first_compat ,
1497+ )
1498+
1499+ def last (self , numeric_only = False , min_count = - 1 ):
1500+ last_compat = partial (self ._get_loc , loc = - 1 )
1501+
1502+ return self ._agg_general (
1503+ numeric_only = numeric_only ,
1504+ min_count = min_count ,
1505+ alias = "last" ,
1506+ npfunc = last_compat ,
1507+ )
15261508
15271509 @Substitution (name = "groupby" )
15281510 @Appender (_common_see_also )
0 commit comments