Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 44 additions & 6 deletions pandas/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@
notna,
)


from pandas.core.util.numba_ import GLOBAL_USE_NUMBA
from pandas.core import nanops_numba

if TYPE_CHECKING:
from collections.abc import Callable

Expand Down Expand Up @@ -97,6 +101,38 @@ def _f(*args, **kwargs):
return cast(F, _f)


class numba_switch:
def __init__(self, name=None, **kwargs) -> None:
self.name = name
self.kwargs = kwargs

def __call__(self, alt: F) -> F:
nb_name = self.name or alt.__name__

try:
nb_func = getattr(nanops_numba, nb_name)
except (AttributeError, NameError): # pragma: no cover
nb_func = None

@functools.wraps(alt)
def f(
values: np.ndarray,
*,
axis: AxisInt | None = None,
skipna: bool = True,
**kwds,
):
disallowed = values.dtype == "O"
if GLOBAL_USE_NUMBA and not disallowed:
result = nb_func(values, skipna=skipna, axis=axis, **kwds)
else:
result = alt(values, axis=axis, skipna=skipna, **kwds)

return result

return cast(F, f)


class bottleneck_switch:
def __init__(self, name=None, **kwargs) -> None:
self.name = name
Expand Down Expand Up @@ -593,6 +629,7 @@ def nanall(
return values.all(axis) # type: ignore[return-value]


@numba_switch()
@disallow("M8")
@_datetimelike_compat
@maybe_operate_rowwise
Expand Down Expand Up @@ -660,7 +697,7 @@ def _mask_datetimelike_result(
return result


@bottleneck_switch()
@numba_switch()
@_datetimelike_compat
def nanmean(
values: np.ndarray,
Expand Down Expand Up @@ -910,7 +947,7 @@ def _get_counts_nanvar(
return count, d


@bottleneck_switch(ddof=1)
@numba_switch(ddof=1)
def nanstd(
values,
*,
Expand Down Expand Up @@ -944,7 +981,7 @@ def nanstd(
>>> from pandas.core import nanops
>>> s = pd.Series([1, np.nan, 2, 3])
>>> nanops.nanstd(s.values)
1.0
np.float64(1.0)
"""
if values.dtype == "M8[ns]":
values = values.view("m8[ns]")
Expand All @@ -957,7 +994,7 @@ def nanstd(


@disallow("M8", "m8")
@bottleneck_switch(ddof=1)
@numba_switch(ddof=1)
def nanvar(
values: np.ndarray,
*,
Expand Down Expand Up @@ -991,7 +1028,7 @@ def nanvar(
>>> from pandas.core import nanops
>>> s = pd.Series([1, np.nan, 2, 3])
>>> nanops.nanvar(s.values)
1.0
np.float64(1.0)
"""
dtype = values.dtype
mask = _maybe_get_mask(values, skipna, mask)
Expand Down Expand Up @@ -1035,6 +1072,7 @@ def nanvar(
return result


@numba_switch()
@disallow("M8", "m8")
def nansem(
values: np.ndarray,
Expand Down Expand Up @@ -1089,7 +1127,7 @@ def nansem(


def _nanminmax(meth, fill_value_typ):
@bottleneck_switch(name=f"nan{meth}")
@numba_switch(name=f"nan{meth}")
@_datetimelike_compat
def reduction(
values: np.ndarray,
Expand Down
Loading
Loading