@@ -367,6 +367,32 @@ def _wrap_results(result, dtype: np.dtype, fill_value=None):
367367 return result
368368
369369
370+ def _datetimelike_compat (func ):
371+ """
372+ If we have datetime64 or timedelta64 values, ensure we have a correct
373+ mask before calling the wrapped function, then cast back afterwards.
374+ """
375+
376+ @functools .wraps (func )
377+ def new_func (values , * , axis = None , skipna = True , mask = None , ** kwargs ):
378+ orig_values = values
379+
380+ datetimelike = values .dtype .kind in ["m" , "M" ]
381+ if datetimelike and mask is None :
382+ mask = isna (values )
383+
384+ result = func (values , axis = axis , skipna = skipna , mask = mask , ** kwargs )
385+
386+ if datetimelike :
387+ result = _wrap_results (result , orig_values .dtype , fill_value = iNaT )
388+ if not skipna :
389+ result = _mask_datetimelike_result (result , axis , mask , orig_values )
390+
391+ return result
392+
393+ return new_func
394+
395+
370396def _na_for_min_count (
371397 values : np .ndarray , axis : Optional [int ]
372398) -> Union [Scalar , np .ndarray ]:
@@ -480,6 +506,7 @@ def nanall(
480506
481507
482508@disallow ("M8" )
509+ @_datetimelike_compat
483510def nansum (
484511 values : np .ndarray ,
485512 * ,
@@ -511,25 +538,18 @@ def nansum(
511538 >>> nanops.nansum(s)
512539 3.0
513540 """
514- orig_values = values
515-
516541 values , mask , dtype , dtype_max , _ = _get_values (
517542 values , skipna , fill_value = 0 , mask = mask
518543 )
519544 dtype_sum = dtype_max
520- datetimelike = False
521545 if is_float_dtype (dtype ):
522546 dtype_sum = dtype
523547 elif is_timedelta64_dtype (dtype ):
524- datetimelike = True
525548 dtype_sum = np .float64
526549
527550 the_sum = values .sum (axis , dtype = dtype_sum )
528551 the_sum = _maybe_null_out (the_sum , axis , mask , values .shape , min_count = min_count )
529552
530- the_sum = _wrap_results (the_sum , dtype )
531- if datetimelike and not skipna :
532- the_sum = _mask_datetimelike_result (the_sum , axis , mask , orig_values )
533553 return the_sum
534554
535555
@@ -552,6 +572,7 @@ def _mask_datetimelike_result(
552572
553573@disallow (PeriodDtype )
554574@bottleneck_switch ()
575+ @_datetimelike_compat
555576def nanmean (
556577 values : np .ndarray ,
557578 * ,
@@ -583,18 +604,14 @@ def nanmean(
583604 >>> nanops.nanmean(s)
584605 1.5
585606 """
586- orig_values = values
587-
588607 values , mask , dtype , dtype_max , _ = _get_values (
589608 values , skipna , fill_value = 0 , mask = mask
590609 )
591610 dtype_sum = dtype_max
592611 dtype_count = np .float64
593612
594613 # not using needs_i8_conversion because that includes period
595- datetimelike = False
596614 if dtype .kind in ["m" , "M" ]:
597- datetimelike = True
598615 dtype_sum = np .float64
599616 elif is_integer_dtype (dtype ):
600617 dtype_sum = np .float64
@@ -616,9 +633,6 @@ def nanmean(
616633 else :
617634 the_mean = the_sum / count if count > 0 else np .nan
618635
619- the_mean = _wrap_results (the_mean , dtype )
620- if datetimelike and not skipna :
621- the_mean = _mask_datetimelike_result (the_mean , axis , mask , orig_values )
622636 return the_mean
623637
624638
@@ -875,7 +889,7 @@ def nanvar(values, *, axis=None, skipna=True, ddof=1, mask=None):
875889 # precision as the original values array.
876890 if is_float_dtype (dtype ):
877891 result = result .astype (dtype )
878- return _wrap_results ( result , values . dtype )
892+ return result
879893
880894
881895@disallow ("M8" , "m8" )
@@ -930,6 +944,7 @@ def nansem(
930944
931945def _nanminmax (meth , fill_value_typ ):
932946 @bottleneck_switch (name = "nan" + meth )
947+ @_datetimelike_compat
933948 def reduction (
934949 values : np .ndarray ,
935950 * ,
@@ -938,13 +953,10 @@ def reduction(
938953 mask : Optional [np .ndarray ] = None ,
939954 ) -> Dtype :
940955
941- orig_values = values
942956 values , mask , dtype , dtype_max , fill_value = _get_values (
943957 values , skipna , fill_value_typ = fill_value_typ , mask = mask
944958 )
945959
946- datetimelike = orig_values .dtype .kind in ["m" , "M" ]
947-
948960 if (axis is not None and values .shape [axis ] == 0 ) or values .size == 0 :
949961 try :
950962 result = getattr (values , meth )(axis , dtype = dtype_max )
@@ -954,12 +966,7 @@ def reduction(
954966 else :
955967 result = getattr (values , meth )(axis )
956968
957- result = _wrap_results (result , dtype , fill_value )
958969 result = _maybe_null_out (result , axis , mask , values .shape )
959-
960- if datetimelike and not skipna :
961- result = _mask_datetimelike_result (result , axis , mask , orig_values )
962-
963970 return result
964971
965972 return reduction
0 commit comments