@@ -17,13 +17,15 @@ class providing the base-class of operations.
1717 Callable ,
1818 Dict ,
1919 FrozenSet ,
20+ Generic ,
2021 Hashable ,
2122 Iterable ,
2223 List ,
2324 Mapping ,
2425 Optional ,
2526 Tuple ,
2627 Type ,
28+ TypeVar ,
2729 Union ,
2830)
2931
@@ -376,13 +378,14 @@ def _group_selection_context(groupby):
376378]
377379
378380
379- class _GroupBy (PandasObject , SelectionMixin ):
381+ class _GroupBy (PandasObject , SelectionMixin , Generic [ FrameOrSeries ] ):
380382 _group_selection = None
381383 _apply_whitelist : FrozenSet [str ] = frozenset ()
384+ obj : FrameOrSeries
382385
383386 def __init__ (
384387 self ,
385- obj : NDFrame ,
388+ obj : FrameOrSeries ,
386389 keys : Optional [_KeysArgType ] = None ,
387390 axis : int = 0 ,
388391 level = None ,
@@ -1079,7 +1082,11 @@ def _apply_filter(self, indices, dropna):
10791082 return filtered
10801083
10811084
1082- class GroupBy (_GroupBy ):
1085+ # We require another typevar to track operations that expand dimensions, like ohlc
1086+ FrameOrSeries2 = TypeVar ("FrameOrSeries2" , bound = NDFrame )
1087+
1088+
1089+ class GroupBy (_GroupBy [FrameOrSeries ]):
10831090 """
10841091 Class for grouping and aggregating relational data.
10851092
@@ -1390,25 +1397,25 @@ def size(self):
13901397 return self ._reindex_output (result , fill_value = 0 )
13911398
13921399 @doc (_agg_template , fname = "sum" , no = True , mc = 0 )
1393- def sum (self , numeric_only : bool = True , min_count : int = 0 ):
1400+ def sum (self , numeric_only : bool = True , min_count : int = 0 ) -> FrameOrSeries :
13941401 return self ._agg_general (
13951402 numeric_only = numeric_only , min_count = min_count , alias = "add" , npfunc = np .sum
13961403 )
13971404
13981405 @doc (_agg_template , fname = "prod" , no = True , mc = 0 )
1399- def prod (self , numeric_only : bool = True , min_count : int = 0 ):
1406+ def prod (self , numeric_only : bool = True , min_count : int = 0 ) -> FrameOrSeries :
14001407 return self ._agg_general (
14011408 numeric_only = numeric_only , min_count = min_count , alias = "prod" , npfunc = np .prod
14021409 )
14031410
14041411 @doc (_agg_template , fname = "min" , no = False , mc = - 1 )
1405- def min (self , numeric_only : bool = False , min_count : int = - 1 ):
1412+ def min (self , numeric_only : bool = False , min_count : int = - 1 ) -> FrameOrSeries :
14061413 return self ._agg_general (
14071414 numeric_only = numeric_only , min_count = min_count , alias = "min" , npfunc = np .min
14081415 )
14091416
14101417 @doc (_agg_template , fname = "max" , no = False , mc = - 1 )
1411- def max (self , numeric_only : bool = False , min_count : int = - 1 ):
1418+ def max (self , numeric_only : bool = False , min_count : int = - 1 ) -> FrameOrSeries :
14121419 return self ._agg_general (
14131420 numeric_only = numeric_only , min_count = min_count , alias = "max" , npfunc = np .max
14141421 )
@@ -1431,7 +1438,7 @@ def get_loc_notna(x, loc: int):
14311438 return get_loc_notna (x , loc = loc )
14321439
14331440 @doc (_agg_template , fname = "first" , no = False , mc = - 1 )
1434- def first (self , numeric_only : bool = False , min_count : int = - 1 ):
1441+ def first (self , numeric_only : bool = False , min_count : int = - 1 ) -> FrameOrSeries :
14351442 first_compat = partial (self ._get_loc , loc = 0 )
14361443
14371444 return self ._agg_general (
@@ -1441,8 +1448,7 @@ def first(self, numeric_only: bool = False, min_count: int = -1):
14411448 npfunc = first_compat ,
14421449 )
14431450
1444- @doc (_agg_template , fname = "last" , no = False , mc = - 1 )
1445- def last (self , numeric_only : bool = False , min_count : int = - 1 ):
1451+ def last (self , numeric_only : bool = False , min_count : int = - 1 ) -> FrameOrSeries :
14461452 last_compat = partial (self ._get_loc , loc = - 1 )
14471453
14481454 return self ._agg_general (
@@ -2467,8 +2473,8 @@ def tail(self, n=5):
24672473 return self ._selected_obj [mask ]
24682474
24692475 def _reindex_output (
2470- self , output : FrameOrSeries , fill_value : Scalar = np .NaN
2471- ) -> FrameOrSeries :
2476+ self , output : FrameOrSeries2 , fill_value : Scalar = np .NaN
2477+ ) -> FrameOrSeries2 :
24722478 """
24732479 If we have categorical groupers, then we might want to make sure that
24742480 we have a fully re-indexed output to the levels. This means expanding
0 commit comments