77
88from pandas ._config import get_option
99
10- from pandas ._libs import NaT , Timedelta , Timestamp , iNaT , lib
10+ from pandas ._libs import NaT , Period , Timedelta , Timestamp , iNaT , lib
1111from pandas ._typing import Dtype , Scalar
1212from pandas .compat ._optional import import_optional_dependency
1313
1717 is_any_int_dtype ,
1818 is_bool_dtype ,
1919 is_complex ,
20- is_datetime64_dtype ,
21- is_datetime64tz_dtype ,
22- is_datetime_or_timedelta_dtype ,
20+ is_datetime64_any_dtype ,
2321 is_float ,
2422 is_float_dtype ,
2523 is_integer ,
2826 is_object_dtype ,
2927 is_scalar ,
3028 is_timedelta64_dtype ,
29+ needs_i8_conversion ,
3130 pandas_dtype ,
3231)
32+ from pandas .core .dtypes .dtypes import PeriodDtype
3333from pandas .core .dtypes .missing import isna , na_value_for_dtype , notna
3434
3535from pandas .core .construction import extract_array
@@ -134,10 +134,8 @@ def f(
134134
135135
136136def _bn_ok_dtype (dtype : Dtype , name : str ) -> bool :
137- # Bottleneck chokes on datetime64
138- if not is_object_dtype (dtype ) and not (
139- is_datetime_or_timedelta_dtype (dtype ) or is_datetime64tz_dtype (dtype )
140- ):
137+ # Bottleneck chokes on datetime64, PeriodDtype (or and EA)
138+ if not is_object_dtype (dtype ) and not needs_i8_conversion (dtype ):
141139
142140 # GH 15507
143141 # bottleneck does not properly upcast during the sum
@@ -283,17 +281,16 @@ def _get_values(
283281 # with scalar fill_value. This guarantee is important for the
284282 # maybe_upcast_putmask call below
285283 assert is_scalar (fill_value )
284+ values = extract_array (values , extract_numpy = True )
286285
287286 mask = _maybe_get_mask (values , skipna , mask )
288287
289- values = extract_array (values , extract_numpy = True )
290288 dtype = values .dtype
291289
292- if is_datetime_or_timedelta_dtype ( values ) or is_datetime64tz_dtype (values ):
290+ if needs_i8_conversion (values ):
293291 # changing timedelta64/datetime64 to int64 needs to happen after
294292 # finding `mask` above
295- values = getattr (values , "asi8" , values )
296- values = values .view (np .int64 )
293+ values = np .asarray (values .view ("i8" ))
297294
298295 dtype_ok = _na_ok_dtype (dtype )
299296
@@ -307,7 +304,8 @@ def _get_values(
307304
308305 if skipna and copy :
309306 values = values .copy ()
310- if dtype_ok :
307+ assert mask is not None # for mypy
308+ if dtype_ok and mask .any ():
311309 np .putmask (values , mask , fill_value )
312310
313311 # promote if needed
@@ -325,13 +323,14 @@ def _get_values(
325323
326324
327325def _na_ok_dtype (dtype ) -> bool :
328- # TODO: what about datetime64tz? PeriodDtype?
329- return not issubclass (dtype .type , (np .integer , np .timedelta64 , np .datetime64 ))
326+ if needs_i8_conversion (dtype ):
327+ return False
328+ return not issubclass (dtype .type , np .integer )
330329
331330
332331def _wrap_results (result , dtype : Dtype , fill_value = None ):
333332 """ wrap our results if needed """
334- if is_datetime64_dtype ( dtype ) or is_datetime64tz_dtype (dtype ):
333+ if is_datetime64_any_dtype (dtype ):
335334 if fill_value is None :
336335 # GH#24293
337336 fill_value = iNaT
@@ -342,7 +341,8 @@ def _wrap_results(result, dtype: Dtype, fill_value=None):
342341 result = np .nan
343342 result = Timestamp (result , tz = tz )
344343 else :
345- result = result .view (dtype )
344+ # If we have float dtype, taking a view will give the wrong result
345+ result = result .astype (dtype )
346346 elif is_timedelta64_dtype (dtype ):
347347 if not isinstance (result , np .ndarray ):
348348 if result == fill_value :
@@ -356,6 +356,14 @@ def _wrap_results(result, dtype: Dtype, fill_value=None):
356356 else :
357357 result = result .astype ("m8[ns]" ).view (dtype )
358358
359+ elif isinstance (dtype , PeriodDtype ):
360+ if is_float (result ) and result .is_integer ():
361+ result = int (result )
362+ if is_integer (result ):
363+ result = Period ._from_ordinal (result , freq = dtype .freq )
364+ else :
365+ raise NotImplementedError (type (result ), result )
366+
359367 return result
360368
361369
@@ -542,12 +550,7 @@ def nanmean(values, axis=None, skipna=True, mask=None):
542550 )
543551 dtype_sum = dtype_max
544552 dtype_count = np .float64
545- if (
546- is_integer_dtype (dtype )
547- or is_timedelta64_dtype (dtype )
548- or is_datetime64_dtype (dtype )
549- or is_datetime64tz_dtype (dtype )
550- ):
553+ if is_integer_dtype (dtype ) or needs_i8_conversion (dtype ):
551554 dtype_sum = np .float64
552555 elif is_float_dtype (dtype ):
553556 dtype_sum = dtype
0 commit comments