@@ -33,7 +33,7 @@ class providing the base-class of operations.
3333
3434from pandas ._libs import Timestamp
3535import pandas ._libs .groupby as libgroupby
36- from pandas ._typing import FrameOrSeries , Scalar
36+ from pandas ._typing import DtypeObj , FrameOrSeries , Scalar
3737from pandas .compat import set_function_name
3838from pandas .compat .numpy import function as nv
3939from pandas .errors import AbstractMethodError
@@ -42,7 +42,6 @@ class providing the base-class of operations.
4242from pandas .core .dtypes .cast import maybe_downcast_to_dtype
4343from pandas .core .dtypes .common import (
4444 ensure_float ,
45- groupby_result_dtype ,
4645 is_datetime64_dtype ,
4746 is_extension_array_dtype ,
4847 is_integer_dtype ,
@@ -795,7 +794,7 @@ def _cumcount_array(self, ascending: bool = True):
795794
796795 def _try_cast (self , result , obj , numeric_only : bool = False , how : str = "" ):
797796 """
798- Try to cast the result to our obj original type,
797+ Try to cast the result to the desired type,
799798 we may have roundtripped through object in the mean-time.
800799
801800 If numeric_only is True, then only try to cast numerics
@@ -806,8 +805,7 @@ def _try_cast(self, result, obj, numeric_only: bool = False, how: str = ""):
806805 dtype = obj ._values .dtype
807806 else :
808807 dtype = obj .dtype
809-
810- dtype = groupby_result_dtype (dtype , how )
808+ dtype = self ._result_dtype (dtype , how )
811809
812810 if not is_scalar (result ):
813811 if is_extension_array_dtype (dtype ) and dtype .kind != "M" :
@@ -1028,6 +1026,30 @@ def _apply_filter(self, indices, dropna):
10281026 filtered = self ._selected_obj .where (mask ) # Fill with NaNs.
10291027 return filtered
10301028
1029+ @staticmethod
1030+ def _result_dtype (dtype , how ) -> DtypeObj :
1031+ """
1032+ Get the desired dtype of a groupby result based on the
1033+ input dtype and how the aggregation is done.
1034+
1035+ Parameters
1036+ ----------
1037+ dtype : dtype, type
1038+ The input dtype of the groupby.
1039+ how : str
1040+ How the aggregation is performed.
1041+
1042+ Returns
1043+ -------
1044+ The desired dtype of the aggregation result.
1045+ """
1046+ d = {
1047+ (np .dtype (np .bool ), "add" ): np .dtype (np .int64 ),
1048+ (np .dtype (np .bool ), "cumsum" ): np .dtype (np .int64 ),
1049+ (np .dtype (np .bool ), "sum" ): np .dtype (np .int64 ),
1050+ }
1051+ return d .get ((dtype , how ), dtype )
1052+
10311053
10321054class GroupBy (_GroupBy ):
10331055 """
0 commit comments