@@ -296,7 +296,9 @@ def trans(x):
296296 return result
297297
298298
299- def maybe_cast_result (result , obj : "Series" , numeric_only : bool = False , how : str = "" ):
299+ def maybe_cast_result (
300+ result : ArrayLike , obj : "Series" , numeric_only : bool = False , how : str = ""
301+ ) -> ArrayLike :
300302 """
301303 Try casting result to a different type if appropriate
302304
@@ -319,19 +321,20 @@ def maybe_cast_result(result, obj: "Series", numeric_only: bool = False, how: st
319321 dtype = obj .dtype
320322 dtype = maybe_cast_result_dtype (dtype , how )
321323
322- if not is_scalar (result ):
323- if (
324- is_extension_array_dtype (dtype )
325- and not is_categorical_dtype (dtype )
326- and dtype .kind != "M"
327- ):
328- # We have to special case categorical so as not to upcast
329- # things like counts back to categorical
330- cls = dtype .construct_array_type ()
331- result = maybe_cast_to_extension_array (cls , result , dtype = dtype )
324+ assert not is_scalar (result )
325+
326+ if (
327+ is_extension_array_dtype (dtype )
328+ and not is_categorical_dtype (dtype )
329+ and dtype .kind != "M"
330+ ):
331+ # We have to special case categorical so as not to upcast
332+ # things like counts back to categorical
333+ cls = dtype .construct_array_type ()
334+ result = maybe_cast_to_extension_array (cls , result , dtype = dtype )
332335
333- elif numeric_only and is_numeric_dtype (dtype ) or not numeric_only :
334- result = maybe_downcast_to_dtype (result , dtype )
336+ elif numeric_only and is_numeric_dtype (dtype ) or not numeric_only :
337+ result = maybe_downcast_to_dtype (result , dtype )
335338
336339 return result
337340
0 commit comments