@@ -304,7 +304,7 @@ def randn(shape, dtype):
304304
305305
306306def assert_close_to_numpy (actx , op , args ):
307- assert np .allclose (
307+ np .testing . assert_allclose (
308308 actx .to_numpy (
309309 op (actx .np , * [
310310 actx .from_numpy (arg ) if isinstance (arg , np .ndarray ) else arg
@@ -328,7 +328,7 @@ def assert_close_to_numpy_in_containers(actx, op, args):
328328 if isinstance (actx_result , DOFArray ):
329329 actx_result = actx_result [0 ]
330330
331- assert np .allclose (actx .to_numpy (actx_result ), ref_result )
331+ np .testing . assert_allclose (actx .to_numpy (actx_result ), ref_result )
332332
333333 # }}}
334334
@@ -342,7 +342,7 @@ def assert_close_to_numpy_in_containers(actx, op, args):
342342 if isinstance (obj_array_result , np .ndarray ):
343343 obj_array_result = obj_array_result [0 ][0 ]
344344
345- assert np .allclose (actx .to_numpy (obj_array_result ), ref_result )
345+ np .testing . assert_allclose (actx .to_numpy (obj_array_result ), ref_result )
346346
347347 # }}}
348348
@@ -383,7 +383,6 @@ def test_array_context_np_workalike(actx_factory, sym_name, n_args, dtype):
383383 pytest .skip (f"'{ sym_name } ' not implemented on '{ type (actx ).__name__ } '" )
384384
385385 ndofs = 512
386- args = [randn (ndofs , dtype ) for i in range (n_args )]
387386
388387 c_to_numpy_arc_functions = {
389388 "atan" : "arctan" ,
@@ -396,6 +395,12 @@ def evaluate(_np, *_args):
396395
397396 return func (* _args )
398397
398+ args = [randn (ndofs , dtype ) for i in range (n_args )]
399+ assert_close_to_numpy_in_containers (actx , evaluate , args )
400+
401+ args = [randn (ndofs , dtype ) for i in range (n_args )]
402+ for i in range (n_args ):
403+ args [i ][i ] = np .nan
399404 assert_close_to_numpy_in_containers (actx , evaluate , args )
400405
401406 if sym_name in ["where" , "min" , "max" , "any" , "all" , "conj" , "vdot" , "sum" ]:
@@ -636,11 +641,20 @@ def get_imag(ary):
636641
637642# {{{ reductions same as numpy
638643
644+ @pytest .mark .parametrize ("input_case" , ["numbers" , "nans" , "mixed" ])
639645@pytest .mark .parametrize ("op" , ["sum" , "min" , "max" ])
640- def test_reductions_same_as_numpy (actx_factory , op ):
646+ def test_reductions_same_as_numpy (actx_factory , op , input_case ):
641647 actx = actx_factory ()
642648
643- ary = np .random .randn (3000 )
649+ if input_case == "numbers" :
650+ ary = np .random .randn (3000 )
651+ elif input_case == "nans" :
652+ ary = np .array ([np .nan , np .nan ], dtype = np .float64 )
653+ elif input_case == "mixed" :
654+ ary = np .array ([np .nan , 1. ], dtype = np .float64 )
655+ else :
656+ raise ValueError ("invalid input case" )
657+
644658 np_red = getattr (np , op )(ary )
645659 actx_red = getattr (actx .np , op )(actx .from_numpy (ary ))
646660 actx_red = actx .to_numpy (actx_red )
@@ -652,7 +666,7 @@ def test_reductions_same_as_numpy(actx_factory, op):
652666 else :
653667 assert actx_red .shape == ()
654668
655- assert np .allclose ( np_red , actx_red )
669+ np .testing . assert_allclose ( actx_red , np_red )
656670
657671
658672@pytest .mark .parametrize ("sym_name" , ["any" , "all" ])
0 commit comments