11from __future__ import annotations
22
33import contextlib
4+ import typing
45import warnings
5- from typing import TYPE_CHECKING , Any
66
77# array-api-strict#6
8- import array_api_strict as xp # type: ignore[import-untyped]
8+ import array_api_strict as xp # type: ignore[import-untyped] # pyright: ignore[reportMissingTypeStubs]
99import numpy as np
1010import pytest
1111from numpy .testing import assert_allclose , assert_array_equal , assert_equal
1212
1313from array_api_extra import atleast_nd , cov , create_diagonal , expand_dims , kron , sinc
1414
15- if TYPE_CHECKING :
16- Array = Any # To be changed to a Protocol later (see array-api#589)
15+ if typing . TYPE_CHECKING :
16+ from array_api_extra . _typing import Array
1717
1818
1919class TestAtLeastND :
@@ -131,7 +131,7 @@ def test_1d(self):
131131
132132 @pytest .mark .parametrize ("n" , range (1 , 10 ))
133133 @pytest .mark .parametrize ("offset" , range (1 , 10 ))
134- def test_create_diagonal (self , n , offset ):
134+ def test_create_diagonal (self , n : int , offset : int ):
135135 # from scipy._lib tests
136136 rng = np .random .default_rng (2347823 )
137137 one = xp .asarray (1.0 )
@@ -180,9 +180,9 @@ def test_basic(self):
180180 assert_array_equal (kron (a , b , xp = xp ), k )
181181
182182 def test_kron_smoke (self ):
183- a = xp .ones ([ 3 , 3 ] )
184- b = xp .ones ([ 3 , 3 ] )
185- k = xp .ones ([ 9 , 9 ] )
183+ a = xp .ones (( 3 , 3 ) )
184+ b = xp .ones (( 3 , 3 ) )
185+ k = xp .ones (( 9 , 9 ) )
186186
187187 assert_array_equal (kron (a , b , xp = xp ), k )
188188
@@ -197,7 +197,7 @@ def test_kron_smoke(self):
197197 ((2 , 0 , 0 , 2 ), (2 , 0 , 2 )),
198198 ],
199199 )
200- def test_kron_shape (self , shape_a , shape_b ):
200+ def test_kron_shape (self , shape_a : tuple [ int , ...], shape_b : tuple [ int , ...] ):
201201 a = xp .ones (shape_a )
202202 b = xp .ones (shape_b )
203203 normalised_shape_a = xp .asarray (
@@ -271,7 +271,7 @@ def test_simple(self):
271271 assert_allclose (w , xp .flip (w , axis = 0 ))
272272
273273 @pytest .mark .parametrize ("x" , [0 , 1 + 3j ])
274- def test_dtype (self , x ):
274+ def test_dtype (self , x : int | complex ):
275275 with pytest .raises (ValueError , match = "real floating data type" ):
276276 sinc (xp .asarray (x ), xp = xp )
277277
0 commit comments