@@ -74,7 +74,7 @@ def test_assert_close_tolerance(xp: ModuleType):
7474@param_assert_equal_close
7575@pytest .mark .xfail_xp_backend (Backend .SPARSE , reason = "index by sparse array" )
7676def test_assert_close_equal_none_shape (xp : ModuleType , func : Callable [..., None ]): # type: ignore[explicit-any]
77- """On dask and other lazy backends, test that a shape with NaN's or None's
77+ """On Dask and other lazy backends, test that a shape with NaN's or None's
7878 can be compared to a real shape.
7979 """
8080 a = xp .asarray ([1 , 2 ])
@@ -99,18 +99,18 @@ def test_assert_close_equal_none_shape(xp: ModuleType, func: Callable[..., None]
9999
100100
101101def good_lazy (x : Array ) -> Array :
102- """A function that behaves well in dask and jax.jit"""
102+ """A function that behaves well in Dask and jax.jit"""
103103 return x * 2.0
104104
105105
106106def non_materializable (x : Array ) -> Array :
107107 """
108108 This function materializes the input array, so it will fail when wrapped in jax.jit
109- and it will trigger an expensive computation in dask .
109+ and it will trigger an expensive computation in Dask .
110110 """
111111 xp = array_namespace (x )
112112 # Crashes inside jax.jit
113- # On dask , this triggers two computations of the whole graph
113+ # On Dask , this triggers two computations of the whole graph
114114 if xp .any (x < 0.0 ) or xp .any (x > 10.0 ):
115115 msg = "Values must be in the [0, 10] range"
116116 raise ValueError (msg )
@@ -217,20 +217,20 @@ def test_lazy_xp_function_static_params(xp: ModuleType, func: Callable[..., Arra
217217 erf = None
218218
219219
220- @pytest .mark .filterwarnings ("ignore:__array_wrap__:DeprecationWarning" ) # torch
220+ @pytest .mark .filterwarnings ("ignore:__array_wrap__:DeprecationWarning" ) # PyTorch
221221def test_lazy_xp_function_cython_ufuncs (xp : ModuleType , library : Backend ):
222222 pytest .importorskip ("scipy" )
223223 assert erf is not None
224224 x = xp .asarray ([6.0 , 7.0 ])
225225 if library in (Backend .ARRAY_API_STRICT , Backend .JAX ):
226- # array-api-strict arrays are auto-converted to numpy
226+ # array-api-strict arrays are auto-converted to NumPy
227227 # which results in an assertion error for mismatched namespaces
228- # eager jax arrays are auto-converted to numpy in eager jax
228+ # eager JAX arrays are auto-converted to NumPy in eager JAX
229229 # and fail in jax.jit (which lazy_xp_function tests here)
230230 with pytest .raises ((TypeError , AssertionError )):
231231 xp_assert_equal (cast (Array , erf (x )), xp .asarray ([1.0 , 1.0 ]))
232232 else :
233- # cupy, dask and sparse define __array_ufunc__ and dispatch accordingly
233+ # CuPy, Dask and sparse define __array_ufunc__ and dispatch accordingly
234234 # note that when sparse reduces to scalar it returns a np.generic, which
235235 # would make xp_assert_equal fail.
236236 xp_assert_equal (cast (Array , erf (x )), xp .asarray ([1.0 , 1.0 ]))
@@ -271,7 +271,7 @@ def test_lazy_xp_function_eagerly_raises(da: ModuleType):
271271
272272def f (x : Array ) -> Array :
273273 xp = array_namespace (x )
274- # Crash in jax.jit and trigger compute() on dask
274+ # Crash in jax.jit and trigger compute() on Dask
275275 if not xp .all (x ):
276276 msg = "Values must be non-zero"
277277 raise ValueError (msg )
0 commit comments