|
7 | 7 | more details. |
8 | 8 |
|
9 | 9 | """ |
| 10 | +from numpy import bool_ as bool |
10 | 11 | from numpy import ( |
| 12 | + complex64, |
| 13 | + complex128, |
11 | 14 | dtype, |
12 | | - bool_ as bool, |
13 | | - intp, |
| 15 | + float32, |
| 16 | + float64, |
14 | 17 | int8, |
15 | 18 | int16, |
16 | 19 | int32, |
17 | 20 | int64, |
| 21 | + intp, |
18 | 22 | uint8, |
19 | 23 | uint16, |
20 | 24 | uint32, |
21 | 25 | uint64, |
22 | | - float32, |
23 | | - float64, |
24 | | - complex64, |
25 | | - complex128, |
26 | 26 | ) |
27 | 27 |
|
| 28 | +from ._typing import Device, DType |
| 29 | + |
28 | 30 |
|
29 | 31 | class __array_namespace_info__: |
30 | 32 | """ |
@@ -131,7 +133,11 @@ def default_device(self): |
131 | 133 | """ |
132 | 134 | return "cpu" |
133 | 135 |
|
134 | | - def default_dtypes(self, *, device=None): |
| 136 | + def default_dtypes( |
| 137 | + self, |
| 138 | + *, |
| 139 | + device: Device | None = None, |
| 140 | + ) -> dict[str, dtype[intp | float64 | complex128]]: |
135 | 141 | """ |
136 | 142 | The default data types used for new NumPy arrays. |
137 | 143 |
|
@@ -183,7 +189,12 @@ def default_dtypes(self, *, device=None): |
183 | 189 | "indexing": dtype(intp), |
184 | 190 | } |
185 | 191 |
|
186 | | - def dtypes(self, *, device=None, kind=None): |
| 192 | + def dtypes( |
| 193 | + self, |
| 194 | + *, |
| 195 | + device: Device | None = None, |
| 196 | + kind: str | tuple[str, ...] | None = None, |
| 197 | + ) -> dict[str, DType]: |
187 | 198 | """ |
188 | 199 | The array API data types supported by NumPy. |
189 | 200 |
|
@@ -260,7 +271,7 @@ def dtypes(self, *, device=None, kind=None): |
260 | 271 | "complex128": dtype(complex128), |
261 | 272 | } |
262 | 273 | if kind == "bool": |
263 | | - return {"bool": bool} |
| 274 | + return {"bool": dtype(bool)} |
264 | 275 | if kind == "signed integer": |
265 | 276 | return { |
266 | 277 | "int8": dtype(int8), |
@@ -312,13 +323,13 @@ def dtypes(self, *, device=None, kind=None): |
312 | 323 | "complex128": dtype(complex128), |
313 | 324 | } |
314 | 325 | if isinstance(kind, tuple): |
315 | | - res = {} |
| 326 | + res: dict[str, DType] = {} |
316 | 327 | for k in kind: |
317 | 328 | res.update(self.dtypes(kind=k)) |
318 | 329 | return res |
319 | 330 | raise ValueError(f"unsupported kind: {kind!r}") |
320 | 331 |
|
321 | | - def devices(self): |
| 332 | + def devices(self) -> list[Device]: |
322 | 333 | """ |
323 | 334 | The devices supported by NumPy. |
324 | 335 |
|
|
0 commit comments