Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions pandas/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pandas._config import get_option

from pandas._libs import NaT, Timedelta, Timestamp, iNaT, lib
from pandas._typing import ArrayLike, Dtype, F, Scalar
from pandas._typing import ArrayLike, Dtype, DtypeObj, F, Scalar
from pandas.compat._optional import import_optional_dependency

from pandas.core.dtypes.cast import _int64_max, maybe_upcast_putmask
Expand Down Expand Up @@ -133,7 +133,7 @@ def f(
return f


def _bn_ok_dtype(dtype: Dtype, name: str) -> bool:
def _bn_ok_dtype(dtype: DtypeObj, name: str) -> bool:
# Bottleneck chokes on datetime64, PeriodDtype (or and EA)
if not is_object_dtype(dtype) and not needs_i8_conversion(dtype):

Expand Down Expand Up @@ -166,7 +166,7 @@ def _has_infs(result) -> bool:


def _get_fill_value(
dtype: Dtype, fill_value: Optional[Scalar] = None, fill_value_typ=None
dtype: DtypeObj, fill_value: Optional[Scalar] = None, fill_value_typ=None
):
""" return the correct fill value for the dtype of the values """
if fill_value is not None:
Expand Down Expand Up @@ -270,9 +270,9 @@ def _get_values(
Potential copy of input value array
mask : Optional[ndarray[bool]]
Mask for values, if deemed necessary to compute
dtype : dtype
dtype : np.dtype
dtype for values
dtype_max : dtype
dtype_max : np.dtype
platform independent dtype
fill_value : Any
fill value used
Expand Down Expand Up @@ -312,20 +312,20 @@ def _get_values(
# return a platform independent precision dtype
dtype_max = dtype
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
dtype_max = np.int64
dtype_max = np.dtype(np.int64)
elif is_float_dtype(dtype):
dtype_max = np.float64
dtype_max = np.dtype(np.float64)

return values, mask, dtype, dtype_max, fill_value


def _na_ok_dtype(dtype) -> bool:
def _na_ok_dtype(dtype: DtypeObj) -> bool:
if needs_i8_conversion(dtype):
return False
return not issubclass(dtype.type, np.integer)


def _wrap_results(result, dtype: Dtype, fill_value=None):
def _wrap_results(result, dtype: DtypeObj, fill_value=None):
""" wrap our results if needed """
if is_datetime64_any_dtype(dtype):
if fill_value is None:
Expand Down Expand Up @@ -597,7 +597,7 @@ def get_median(x):
return np.nan
return np.nanmedian(x[mask])

values, mask, dtype, dtype_max, _ = _get_values(values, skipna, mask=mask)
values, mask, dtype, _, _ = _get_values(values, skipna, mask=mask)
if not is_float_dtype(values.dtype):
try:
values = values.astype("f8")
Expand Down Expand Up @@ -716,7 +716,7 @@ def nanstd(values, axis=None, skipna=True, ddof=1, mask=None):
1.0
"""
orig_dtype = values.dtype
values, mask, dtype, dtype_max, fill_value = _get_values(values, skipna, mask=mask)
values, mask, _, _, _ = _get_values(values, skipna, mask=mask)

result = np.sqrt(nanvar(values, axis=axis, skipna=skipna, ddof=ddof, mask=mask))
return _wrap_results(result, orig_dtype)
Expand Down Expand Up @@ -910,9 +910,7 @@ def nanargmax(
>>> nanops.nanargmax(arr, axis=1)
array([2, 2, 1, 1], dtype=int64)
"""
values, mask, dtype, _, _ = _get_values(
values, True, fill_value_typ="-inf", mask=mask
)
values, mask, _, _, _ = _get_values(values, True, fill_value_typ="-inf", mask=mask)
result = values.argmax(axis)
result = _maybe_arg_null_out(result, axis, mask, skipna)
return result
Expand Down Expand Up @@ -956,9 +954,7 @@ def nanargmin(
>>> nanops.nanargmin(arr, axis=1)
array([0, 0, 1, 1], dtype=int64)
"""
values, mask, dtype, _, _ = _get_values(
values, True, fill_value_typ="+inf", mask=mask
)
values, mask, _, _, _ = _get_values(values, True, fill_value_typ="+inf", mask=mask)
result = values.argmin(axis)
result = _maybe_arg_null_out(result, axis, mask, skipna)
return result
Expand Down