1+ # pyright: reportPrivateUsage=false
12from __future__ import annotations
23
3- from typing import Optional , Union
4+ from builtins import bool as py_bool
5+ from typing import TYPE_CHECKING , cast
6+
7+ import numpy as np
48
59from .._internal import get_xp
610from ..common import _aliases , _helpers
711from ..common ._typing import NestedSequence , SupportsBufferProtocol
812from ._info import __array_namespace_info__
913from ._typing import Array , Device , DType
1014
11- import numpy as np
15+ if TYPE_CHECKING :
16+ from typing import Any , Literal , TypeAlias
17+
18+ from typing_extensions import Buffer , TypeIs
19+
20+ _Copy : TypeAlias = py_bool | Literal [2 ] | np ._CopyMode
1221
1322bool = np .bool_
1423
6574iinfo = get_xp (np )(_aliases .iinfo )
6675
6776
68- def _supports_buffer_protocol (obj ):
77+ def _supports_buffer_protocol (obj : object ) -> TypeIs [ Buffer ]: # pyright: ignore[reportUnusedFunction]
6978 try :
70- memoryview (obj )
79+ memoryview (obj ) # pyright: ignore[reportArgumentType]
7180 except TypeError :
7281 return False
7382 return True
@@ -78,18 +87,13 @@ def _supports_buffer_protocol(obj):
7887# complicated enough that it's easier to define it separately for each module
7988# rather than trying to combine everything into one function in common/
8089def asarray (
81- obj : (
82- Array
83- | bool | int | float | complex
84- | NestedSequence [bool | int | float | complex ]
85- | SupportsBufferProtocol
86- ),
90+ obj : Array | complex | NestedSequence [complex ] | SupportsBufferProtocol ,
8791 / ,
8892 * ,
89- dtype : Optional [ DType ] = None ,
90- device : Optional [ Device ] = None ,
91- copy : Optional [ Union [ bool , np . _CopyMode ]] = None ,
92- ** kwargs ,
93+ dtype : DType | None = None ,
94+ device : Device | None = None ,
95+ copy : _Copy | None = None ,
96+ ** kwargs : Any ,
9397) -> Array :
9498 """
9599 Array API compatibility wrapper for asarray().
@@ -106,51 +110,70 @@ def asarray(
106110 elif copy is True :
107111 copy = np ._CopyMode .ALWAYS
108112
109- return np .array (obj , copy = copy , dtype = dtype , ** kwargs )
113+ return np .array (obj , copy = copy , dtype = dtype , ** kwargs ) # pyright: ignore
110114
111115
112116def astype (
113117 x : Array ,
114118 dtype : DType ,
115119 / ,
116120 * ,
117- copy : bool = True ,
118- device : Optional [ Device ] = None ,
121+ copy : py_bool = True ,
122+ device : Device | None = None ,
119123) -> Array :
120124 _helpers ._check_device (np , device )
121125 return x .astype (dtype = dtype , copy = copy )
122126
123127
124128# count_nonzero returns a python int for axis=None and keepdims=False
125129# https://github.com/numpy/numpy/issues/17562
126- def count_nonzero (x : Array , axis = None , keepdims = False ) -> Array :
127- result = np .count_nonzero (x , axis = axis , keepdims = keepdims )
130+ def count_nonzero (
131+ x : Array ,
132+ axis : int | tuple [int , ...] | None = None ,
133+ keepdims : py_bool = False ,
134+ ) -> Array :
135+ result = cast ("Any" , np .count_nonzero (x , axis = axis , keepdims = keepdims )) # pyright: ignore
128136 if axis is None and not keepdims :
129137 return np .asarray (result )
130138 return result
131139
132140
133141# These functions are completely new here. If the library already has them
134142# (i.e., numpy 2.0), use the library version instead of our wrapper.
135- if hasattr (np , ' vecdot' ):
143+ if hasattr (np , " vecdot" ):
136144 vecdot = np .vecdot
137145else :
138146 vecdot = get_xp (np )(_aliases .vecdot )
139147
140- if hasattr (np , ' isdtype' ):
148+ if hasattr (np , " isdtype" ):
141149 isdtype = np .isdtype
142150else :
143151 isdtype = get_xp (np )(_aliases .isdtype )
144152
145- if hasattr (np , ' unstack' ):
153+ if hasattr (np , " unstack" ):
146154 unstack = np .unstack
147155else :
148156 unstack = get_xp (np )(_aliases .unstack )
149157
150- __all__ = _aliases .__all__ + ['__array_namespace_info__' , 'asarray' , 'astype' ,
151- 'acos' , 'acosh' , 'asin' , 'asinh' , 'atan' ,
152- 'atan2' , 'atanh' , 'bitwise_left_shift' ,
153- 'bitwise_invert' , 'bitwise_right_shift' ,
154- 'bool' , 'concat' , 'count_nonzero' , 'pow' ]
155-
156- _all_ignore = ['np' , 'get_xp' ]
158+ __all__ = [
159+ "__array_namespace_info__" ,
160+ "asarray" ,
161+ "astype" ,
162+ "acos" ,
163+ "acosh" ,
164+ "asin" ,
165+ "asinh" ,
166+ "atan" ,
167+ "atan2" ,
168+ "atanh" ,
169+ "bitwise_left_shift" ,
170+ "bitwise_invert" ,
171+ "bitwise_right_shift" ,
172+ "bool" ,
173+ "concat" ,
174+ "count_nonzero" ,
175+ "pow" ,
176+ ]
177+ __all__ += _aliases .__all__
178+
179+ _all_ignore = ["np" , "get_xp" ]
0 commit comments