@@ -42,6 +42,7 @@ class providing the base-class of operations.
4242from pandas .core .dtypes .cast import maybe_downcast_to_dtype
4343from pandas .core .dtypes .common import (
4444 ensure_float ,
45+ groupby_result_dtype ,
4546 is_datetime64_dtype ,
4647 is_extension_array_dtype ,
4748 is_integer_dtype ,
@@ -792,7 +793,7 @@ def _cumcount_array(self, ascending: bool = True):
792793 rev [sorter ] = np .arange (count , dtype = np .intp )
793794 return out [rev ].astype (np .int64 , copy = False )
794795
795- def _try_cast (self , result , obj , numeric_only : bool = False ):
796+ def _try_cast (self , result , obj , numeric_only : bool = False , how : str = "" ):
796797 """
797798 Try to cast the result to our obj original type,
798799 we may have roundtripped through object in the mean-time.
@@ -806,6 +807,8 @@ def _try_cast(self, result, obj, numeric_only: bool = False):
806807 else :
807808 dtype = obj .dtype
808809
810+ dtype = groupby_result_dtype (dtype , how )
811+
809812 if not is_scalar (result ):
810813 if is_extension_array_dtype (dtype ) and dtype .kind != "M" :
811814 # The function can return something of any type, so check
@@ -852,7 +855,7 @@ def _cython_transform(self, how: str, numeric_only: bool = True, **kwargs):
852855 continue
853856
854857 if self ._transform_should_cast (how ):
855- result = self ._try_cast (result , obj )
858+ result = self ._try_cast (result , obj , how = how )
856859
857860 key = base .OutputKey (label = name , position = idx )
858861 output [key ] = result
@@ -895,12 +898,12 @@ def _cython_agg_general(
895898 assert len (agg_names ) == result .shape [1 ]
896899 for result_column , result_name in zip (result .T , agg_names ):
897900 key = base .OutputKey (label = result_name , position = idx )
898- output [key ] = self ._try_cast (result_column , obj )
901+ output [key ] = self ._try_cast (result_column , obj , how = how )
899902 idx += 1
900903 else :
901904 assert result .ndim == 1
902905 key = base .OutputKey (label = name , position = idx )
903- output [key ] = self ._try_cast (result , obj )
906+ output [key ] = self ._try_cast (result , obj , how = how )
904907 idx += 1
905908
906909 if len (output ) == 0 :
0 commit comments