@@ -401,6 +401,24 @@ def test_none_shape_bool(self, xp: ModuleType):
401
401
a = a [a ]
402
402
xp_assert_equal (isclose (a , b ), xp .asarray ([True , False ]))
403
403
404
+ @pytest .mark .skip_xp_backend (Backend .NUMPY_READONLY , reason = "xp=xp" )
405
+ @pytest .mark .skip_xp_backend (Backend .TORCH , reason = "Array API 2024.12 support" )
406
+ def test_python_scalar (self , xp : ModuleType ):
407
+ a = xp .asarray ([0.0 , 0.1 ], dtype = xp .float32 )
408
+ xp_assert_equal (isclose (a , 0.0 ), xp .asarray ([True , False ]))
409
+ xp_assert_equal (isclose (0.0 , a ), xp .asarray ([True , False ]))
410
+
411
+ a = xp .asarray ([0 , 1 ], dtype = xp .int16 )
412
+ xp_assert_equal (isclose (a , 0 ), xp .asarray ([True , False ]))
413
+ xp_assert_equal (isclose (0 , a ), xp .asarray ([True , False ]))
414
+
415
+ xp_assert_equal (isclose (0 , 0 , xp = xp ), xp .asarray (True ))
416
+ xp_assert_equal (isclose (0 , 1 , xp = xp ), xp .asarray (False ))
417
+
418
+ def test_all_python_scalars (self ):
419
+ with pytest .raises (TypeError , match = "Unrecognized" ):
420
+ isclose (0 , 0 )
421
+
404
422
def test_xp (self , xp : ModuleType ):
405
423
a = xp .asarray ([0.0 , 0.0 ])
406
424
b = xp .asarray ([1e-9 , 1e-4 ])
@@ -413,30 +431,22 @@ def test_basic(self, xp: ModuleType):
413
431
# Using 0-dimensional array
414
432
a = xp .asarray (1 )
415
433
b = xp .asarray ([[1 , 2 ], [3 , 4 ]])
416
- k = xp .asarray ([[1 , 2 ], [3 , 4 ]])
417
- xp_assert_equal (kron (a , b ), k )
418
- a = xp .asarray ([[1 , 2 ], [3 , 4 ]])
419
- b = xp .asarray (1 )
420
- xp_assert_equal (kron (a , b ), k )
434
+ xp_assert_equal (kron (a , b ), b )
435
+ xp_assert_equal (kron (b , a ), b )
421
436
422
437
# Using 1-dimensional array
423
438
a = xp .asarray ([3 ])
424
439
b = xp .asarray ([[1 , 2 ], [3 , 4 ]])
425
440
k = xp .asarray ([[3 , 6 ], [9 , 12 ]])
426
441
xp_assert_equal (kron (a , b ), k )
427
- a = xp .asarray ([[1 , 2 ], [3 , 4 ]])
428
- b = xp .asarray ([3 ])
429
- xp_assert_equal (kron (a , b ), k )
442
+ xp_assert_equal (kron (b , a ), k )
430
443
431
444
# Using 3-dimensional array
432
445
a = xp .asarray ([[[1 ]], [[2 ]]])
433
446
b = xp .asarray ([[1 , 2 ], [3 , 4 ]])
434
447
k = xp .asarray ([[[1 , 2 ], [3 , 4 ]], [[2 , 4 ], [6 , 8 ]]])
435
448
xp_assert_equal (kron (a , b ), k )
436
- a = xp .asarray ([[1 , 2 ], [3 , 4 ]])
437
- b = xp .asarray ([[[1 ]], [[2 ]]])
438
- k = xp .asarray ([[[1 , 2 ], [3 , 4 ]], [[2 , 4 ], [6 , 8 ]]])
439
- xp_assert_equal (kron (a , b ), k )
449
+ xp_assert_equal (kron (b , a ), k )
440
450
441
451
def test_kron_smoke (self , xp : ModuleType ):
442
452
a = xp .ones ((3 , 3 ))
@@ -474,6 +484,18 @@ def test_kron_shape(
474
484
k = kron (a , b )
475
485
assert k .shape == expected_shape
476
486
487
+ def test_python_scalar (self , xp : ModuleType ):
488
+ a = 1
489
+ # Test no dtype promotion to xp.asarray(a); use b.dtype
490
+ b = xp .asarray ([[1 , 2 ], [3 , 4 ]], dtype = xp .int16 )
491
+ xp_assert_equal (kron (a , b ), b )
492
+ xp_assert_equal (kron (b , a ), b )
493
+ xp_assert_equal (kron (1 , 1 , xp = xp ), xp .asarray (1 ))
494
+
495
+ def test_all_python_scalars (self ):
496
+ with pytest .raises (TypeError , match = "Unrecognized" ):
497
+ kron (1 , 1 )
498
+
477
499
def test_device (self , xp : ModuleType , device : Device ):
478
500
x1 = xp .asarray ([1 , 2 , 3 ], device = device )
479
501
x2 = xp .asarray ([4 , 5 ], device = device )
@@ -601,6 +623,28 @@ def test_shapes(
601
623
actual = setdiff1d (x1 , x2 , assume_unique = assume_unique )
602
624
xp_assert_equal (actual , xp .empty ((0 ,)))
603
625
626
+ @pytest .mark .skip_xp_backend (Backend .NUMPY_READONLY , reason = "xp=xp" )
627
+ @pytest .mark .parametrize ("assume_unique" , [True , False ])
628
+ def test_python_scalar (self , xp : ModuleType , assume_unique : bool ):
629
+ # Test no dtype promotion to xp.asarray(x2); use x1.dtype
630
+ x1 = xp .asarray ([3 , 1 , 2 ], dtype = xp .int16 )
631
+ x2 = 3
632
+ actual = setdiff1d (x1 , x2 , assume_unique = assume_unique )
633
+ xp_assert_equal (actual , xp .asarray ([1 , 2 ], dtype = xp .int16 ))
634
+
635
+ actual = setdiff1d (x2 , x1 , assume_unique = assume_unique )
636
+ xp_assert_equal (actual , xp .asarray ([], dtype = xp .int16 ))
637
+
638
+ xp_assert_equal (
639
+ setdiff1d (0 , 0 , assume_unique = assume_unique , xp = xp ),
640
+ xp .asarray ([0 ])[:0 ], # Default int dtype for backend
641
+ )
642
+
643
+ @pytest .mark .parametrize ("assume_unique" , [True , False ])
644
+ def test_all_python_scalars (self , assume_unique : bool ):
645
+ with pytest .raises (TypeError , match = "Unrecognized" ):
646
+ setdiff1d (0 , 0 , assume_unique = assume_unique )
647
+
604
648
def test_device (self , xp : ModuleType , device : Device ):
605
649
x1 = xp .asarray ([3 , 8 , 20 ], device = device )
606
650
x2 = xp .asarray ([2 , 3 , 4 ], device = device )
0 commit comments