@@ -33,17 +33,16 @@ class providing the base-class of operations.
3333
3434from pandas ._libs import Timestamp
3535import pandas ._libs .groupby as libgroupby
36- from pandas ._typing import DtypeObj , FrameOrSeries , Scalar
36+ from pandas ._typing import FrameOrSeries , Scalar
3737from pandas .compat import set_function_name
3838from pandas .compat .numpy import function as nv
3939from pandas .errors import AbstractMethodError
4040from pandas .util ._decorators import Appender , Substitution , cache_readonly
4141
42- from pandas .core .dtypes .cast import maybe_downcast_to_dtype
42+ from pandas .core .dtypes .cast import maybe_cast_result
4343from pandas .core .dtypes .common import (
4444 ensure_float ,
4545 is_datetime64_dtype ,
46- is_extension_array_dtype ,
4746 is_integer_dtype ,
4847 is_numeric_dtype ,
4948 is_object_dtype ,
@@ -53,7 +52,7 @@ class providing the base-class of operations.
5352
5453from pandas .core import nanops
5554import pandas .core .algorithms as algorithms
56- from pandas .core .arrays import Categorical , DatetimeArray , try_cast_to_ea
55+ from pandas .core .arrays import Categorical , DatetimeArray
5756from pandas .core .base import DataError , PandasObject , SelectionMixin
5857import pandas .core .common as com
5958from pandas .core .frame import DataFrame
@@ -792,37 +791,6 @@ def _cumcount_array(self, ascending: bool = True):
792791 rev [sorter ] = np .arange (count , dtype = np .intp )
793792 return out [rev ].astype (np .int64 , copy = False )
794793
795- def _try_cast (self , result , obj , numeric_only : bool = False , how : str = "" ):
796- """
797- Try to cast the result to the desired type,
798- we may have roundtripped through object in the mean-time.
799-
800- If numeric_only is True, then only try to cast numerics
801- and not datetimelikes.
802-
803- """
804- if obj .ndim > 1 :
805- dtype = obj ._values .dtype
806- else :
807- dtype = obj .dtype
808- dtype = self ._result_dtype (dtype , how )
809-
810- if not is_scalar (result ):
811- if is_extension_array_dtype (dtype ) and dtype .kind != "M" :
812- # The function can return something of any type, so check
813- # if the type is compatible with the calling EA.
814- # datetime64tz is handled correctly in agg_series,
815- # so is excluded here.
816-
817- if len (result ) and isinstance (result [0 ], dtype .type ):
818- cls = dtype .construct_array_type ()
819- result = try_cast_to_ea (cls , result , dtype = dtype )
820-
821- elif numeric_only and is_numeric_dtype (dtype ) or not numeric_only :
822- result = maybe_downcast_to_dtype (result , dtype )
823-
824- return result
825-
826794 def _transform_should_cast (self , func_nm : str ) -> bool :
827795 """
828796 Parameters
@@ -853,7 +821,7 @@ def _cython_transform(self, how: str, numeric_only: bool = True, **kwargs):
853821 continue
854822
855823 if self ._transform_should_cast (how ):
856- result = self . _try_cast (result , obj , how = how )
824+ result = maybe_cast_result (result , obj , how = how )
857825
858826 key = base .OutputKey (label = name , position = idx )
859827 output [key ] = result
@@ -896,12 +864,12 @@ def _cython_agg_general(
896864 assert len (agg_names ) == result .shape [1 ]
897865 for result_column , result_name in zip (result .T , agg_names ):
898866 key = base .OutputKey (label = result_name , position = idx )
899- output [key ] = self . _try_cast (result_column , obj , how = how )
867+ output [key ] = maybe_cast_result (result_column , obj , how = how )
900868 idx += 1
901869 else :
902870 assert result .ndim == 1
903871 key = base .OutputKey (label = name , position = idx )
904- output [key ] = self . _try_cast (result , obj , how = how )
872+ output [key ] = maybe_cast_result (result , obj , how = how )
905873 idx += 1
906874
907875 if len (output ) == 0 :
@@ -930,7 +898,7 @@ def _python_agg_general(self, func, *args, **kwargs):
930898
931899 assert result is not None
932900 key = base .OutputKey (label = name , position = idx )
933- output [key ] = self . _try_cast (result , obj , numeric_only = True )
901+ output [key ] = maybe_cast_result (result , obj , numeric_only = True )
934902
935903 if len (output ) == 0 :
936904 return self ._python_apply_general (f )
@@ -945,7 +913,7 @@ def _python_agg_general(self, func, *args, **kwargs):
945913 if is_numeric_dtype (values .dtype ):
946914 values = ensure_float (values )
947915
948- output [key ] = self . _try_cast (values [mask ], result )
916+ output [key ] = maybe_cast_result (values [mask ], result )
949917
950918 return self ._wrap_aggregated_output (output )
951919
@@ -1026,30 +994,6 @@ def _apply_filter(self, indices, dropna):
1026994 filtered = self ._selected_obj .where (mask ) # Fill with NaNs.
1027995 return filtered
1028996
1029- @staticmethod
1030- def _result_dtype (dtype , how ) -> DtypeObj :
1031- """
1032- Get the desired dtype of a groupby result based on the
1033- input dtype and how the aggregation is done.
1034-
1035- Parameters
1036- ----------
1037- dtype : dtype, type
1038- The input dtype of the groupby.
1039- how : str
1040- How the aggregation is performed.
1041-
1042- Returns
1043- -------
1044- The desired dtype of the aggregation result.
1045- """
1046- d = {
1047- (np .dtype (np .bool ), "add" ): np .dtype (np .int64 ),
1048- (np .dtype (np .bool ), "cumsum" ): np .dtype (np .int64 ),
1049- (np .dtype (np .bool ), "sum" ): np .dtype (np .int64 ),
1050- }
1051- return d .get ((dtype , how ), dtype )
1052-
1053997
1054998class GroupBy (_GroupBy ):
1055999 """
0 commit comments