1010 _real_numeric_dtypes ,
1111 _numeric_dtypes ,
1212 _result_type ,
13- _dtype_categories as _dtype_dtype_categories ,
13+ _dtype_categories ,
1414)
1515from ._array_object import Array
1616from ._flags import requires_api_version
@@ -46,11 +46,26 @@ def _binary_ufunc_proto(x1, x2, dtype_category, func_name, np_func):
4646
4747
4848def create_binary_func (func_name , dtype_category , np_func ):
49- def inner (x1 : Array , x2 : Array , / ) -> Array :
49+ def inner (x1 , x2 , / ) -> Array :
5050 return _binary_ufunc_proto (x1 , x2 , dtype_category , func_name , np_func )
5151 return inner
5252
5353
54+ # static type annotation for ArrayOrPythonScalar arguments given a category
55+ # NB: keep the keys in sync with the _dtype_categories dict
56+ _annotations = {
57+ "all" : "bool | int | float | complex | Array" ,
58+ "real numeric" : "int | float | Array" ,
59+ "numeric" : "int | float | complex | Array" ,
60+ "integer" : "int | Array" ,
61+ "integer or boolean" : "int | bool | Array" ,
62+ "boolean" : "bool | Array" ,
63+ "real floating-point" : "float | Array" ,
64+ "complex floating-point" : "complex | Array" ,
65+ "floating-point" : "float | complex | Array" ,
66+ }
67+
68+
5469# func_name: dtype_category (must match that from _dtypes.py)
5570_binary_funcs = {
5671 "add" : "numeric" ,
@@ -97,7 +112,7 @@ def inner(x1: Array, x2: Array, /) -> Array:
97112# create and attach functions to the module
98113for func_name , dtype_category in _binary_funcs .items ():
99114 # sanity check
100- assert dtype_category in _dtype_dtype_categories
115+ assert dtype_category in _dtype_categories
101116
102117 numpy_name = _numpy_renames .get (func_name , func_name )
103118 np_func = getattr (np , numpy_name )
@@ -106,6 +121,8 @@ def inner(x1: Array, x2: Array, /) -> Array:
106121 func .__name__ = func_name
107122
108123 func .__doc__ = _binary_docstring_template % (numpy_name , numpy_name )
124+ func .__annotations__ ['x1' ] = _annotations [dtype_category ]
125+ func .__annotations__ ['x2' ] = _annotations [dtype_category ]
109126
110127 vars ()[func_name ] = func
111128
@@ -117,20 +134,22 @@ def inner(x1: Array, x2: Array, /) -> Array:
117134nextafter = requires_api_version ('2024.12' )(nextafter ) # noqa: F821
118135
119136
120- def bitwise_left_shift (x1 : Array , x2 : Array , / ) -> Array :
137+ def bitwise_left_shift (x1 : int | Array , x2 : int | Array , / ) -> Array :
121138 is_negative = np .any (x2 ._array < 0 ) if isinstance (x2 , Array ) else x2 < 0
122139 if is_negative :
123140 raise ValueError ("bitwise_left_shift(x1, x2) is only defined for x2 >= 0" )
124141 return _bitwise_left_shift (x1 , x2 ) # noqa: F821
125- bitwise_left_shift .__doc__ = _bitwise_left_shift .__doc__ # noqa: F821
142+ if _bitwise_left_shift .__doc__ : # noqa: F821
143+ bitwise_left_shift .__doc__ = _bitwise_left_shift .__doc__ # noqa: F821
126144
127145
128- def bitwise_right_shift (x1 : Array , x2 : Array , / ) -> Array :
146+ def bitwise_right_shift (x1 : int | Array , x2 : int | Array , / ) -> Array :
129147 is_negative = np .any (x2 ._array < 0 ) if isinstance (x2 , Array ) else x2 < 0
130148 if is_negative :
131149 raise ValueError ("bitwise_left_shift(x1, x2) is only defined for x2 >= 0" )
132150 return _bitwise_right_shift (x1 , x2 ) # noqa: F821
133- bitwise_right_shift .__doc__ = _bitwise_right_shift .__doc__ # noqa: F821
151+ if _bitwise_right_shift .__doc__ : # noqa: F821
152+ bitwise_right_shift .__doc__ = _bitwise_right_shift .__doc__ # noqa: F821
134153
135154
136155# clean up to not pollute the namespace
0 commit comments