11import functools
22import itertools
33import operator
4- from typing import Any , Optional , Tuple , Union
4+ from typing import Any , Optional , Tuple , Union , cast
55
66import numpy as np
77
88from pandas ._config import get_option
99
1010from pandas ._libs import NaT , Timedelta , Timestamp , iNaT , lib
11- from pandas ._typing import ArrayLike , Dtype , Scalar
11+ from pandas ._typing import ArrayLike , Dtype , F , Scalar
1212from pandas .compat ._optional import import_optional_dependency
1313
1414from pandas .core .dtypes .cast import _int64_max , maybe_upcast_putmask
@@ -57,7 +57,7 @@ def __init__(self, *dtypes):
5757 def check (self , obj ) -> bool :
5858 return hasattr (obj , "dtype" ) and issubclass (obj .dtype .type , self .dtypes )
5959
60- def __call__ (self , f ) :
60+ def __call__ (self , f : F ) -> F :
6161 @functools .wraps (f )
6262 def _f (* args , ** kwargs ):
6363 obj_iter = itertools .chain (args , kwargs .values ())
@@ -78,7 +78,7 @@ def _f(*args, **kwargs):
7878 raise TypeError (e ) from e
7979 raise
8080
81- return _f
81+ return cast ( F , _f )
8282
8383
8484class bottleneck_switch :
@@ -878,7 +878,7 @@ def nanargmax(
878878 axis : Optional [int ] = None ,
879879 skipna : bool = True ,
880880 mask : Optional [np .ndarray ] = None ,
881- ) -> int :
881+ ) -> Union [ int , np . ndarray ] :
882882 """
883883 Parameters
884884 ----------
@@ -890,15 +890,25 @@ def nanargmax(
890890
891891 Returns
892892 -------
893- result : int
894- The index of max value in specified axis or -1 in the NA case
893+ result : int or ndarray[int]
894+ The index/indices of max value in specified axis or -1 in the NA case
895895
896896 Examples
897897 --------
898898 >>> import pandas.core.nanops as nanops
899- >>> s = pd.Series ([1, 2, 3, np.nan, 4])
900- >>> nanops.nanargmax(s )
899+ >>> arr = np.array ([1, 2, 3, np.nan, 4])
900+ >>> nanops.nanargmax(arr )
901901 4
902+
903+ >>> arr = np.array(range(12), dtype=np.float64).reshape(4, 3)
904+ >>> arr[2:, 2] = np.nan
905+ >>> arr
906+ array([[ 0., 1., 2.],
907+ [ 3., 4., 5.],
908+ [ 6., 7., nan],
909+ [ 9., 10., nan]])
910+ >>> nanops.nanargmax(arr, axis=1)
911+ array([2, 2, 1, 1], dtype=int64)
902912 """
903913 values , mask , dtype , _ , _ = _get_values (
904914 values , True , fill_value_typ = "-inf" , mask = mask
@@ -914,7 +924,7 @@ def nanargmin(
914924 axis : Optional [int ] = None ,
915925 skipna : bool = True ,
916926 mask : Optional [np .ndarray ] = None ,
917- ) -> int :
927+ ) -> Union [ int , np . ndarray ] :
918928 """
919929 Parameters
920930 ----------
@@ -926,15 +936,25 @@ def nanargmin(
926936
927937 Returns
928938 -------
929- result : int
930- The index of min value in specified axis or -1 in the NA case
939+ result : int or ndarray[int]
940+ The index/indices of min value in specified axis or -1 in the NA case
931941
932942 Examples
933943 --------
934944 >>> import pandas.core.nanops as nanops
935- >>> s = pd.Series ([1, 2, 3, np.nan, 4])
936- >>> nanops.nanargmin(s )
945+ >>> arr = np.array ([1, 2, 3, np.nan, 4])
946+ >>> nanops.nanargmin(arr )
937947 0
948+
949+ >>> arr = np.array(range(12), dtype=np.float64).reshape(4, 3)
950+ >>> arr[2:, 0] = np.nan
951+ >>> arr
952+ array([[ 0., 1., 2.],
953+ [ 3., 4., 5.],
954+ [nan, 7., 8.],
955+ [nan, 10., 11.]])
956+ >>> nanops.nanargmin(arr, axis=1)
957+ array([0, 0, 1, 1], dtype=int64)
938958 """
939959 values , mask , dtype , _ , _ = _get_values (
940960 values , True , fill_value_typ = "+inf" , mask = mask
0 commit comments