@@ -1783,48 +1783,42 @@ def test_trunc(x):
17831783
17841784def _check_binary_with_scalars (func_data , x1x2 ):
17851785 x1 , x2 = x1x2
1786- func , name , refimpl , kwds , expected_dtype = func_data
1786+ func_name , refimpl , kwds , expected_dtype = func_data
1787+ func = getattr (xp , func_name )
17871788 out = func (x1 , x2 )
17881789 in_dtypes , in_shapes , (x1a , x2a ) = _convert_scalars_helper (x1 , x2 )
17891790 _assert_correctness_binary (
1790- name , refimpl , in_dtypes , in_shapes , (x1a , x2a ), out , expected_dtype , ** kwds
1791+ func_name , refimpl , in_dtypes , in_shapes , (x1a , x2a ), out , expected_dtype , ** kwds
17911792 )
17921793
17931794
17941795def _filter_zero (x ):
17951796 return x != 0 if dh .is_scalar (x ) else (not xp .any (x == 0 ))
17961797
1797- # workarounds for xp.copysign etc only available in 2023.12
1798- # Without it, test suite fails to import with ARRAY_API_VERSION=2022.12
1799- _xp_copysign = getattr (xp , "copysign" , None )
1800- _xp_hypot = getattr (xp , "hypot" , None )
1801- _xp_maximum = getattr (xp , "maximum" , None )
1802- _xp_minimum = getattr (xp , "minimum" , None )
1803-
18041798
18051799@pytest .mark .min_version ("2024.12" )
18061800@pytest .mark .parametrize ('func_data' ,
1807- # xp_func, name , refimpl, kwargs, expected_dtype
1801+ # func_name , refimpl, kwargs, expected_dtype
18081802 [
1809- (xp . add , "add" , operator .add , {}, None ),
1810- (xp . atan2 , "atan2" , math .atan2 , {}, None ),
1811- (_xp_copysign , "copysign" , math .copysign , {}, None ),
1812- (xp . divide , "divide" , operator .truediv , {"filter_" : lambda s : s != 0 }, None ),
1813- (_xp_hypot , "hypot" , math .hypot , {}, None ),
1814- (xp . logaddexp , "logaddexp" , logaddexp_refimpl , {}, None ),
1815- (_xp_maximum , "maximum" , max , {'strict_check' : True }, None ),
1816- (_xp_minimum , "minimum" , min , {'strict_check' : True }, None ),
1817- (xp . multiply , "mul " , operator .mul , {}, None ),
1818- (xp . subtract , "sub " , operator .sub , {}, None ),
1819-
1820- (xp . equal , "equal" , operator .eq , {}, xp .bool ),
1821- (xp . not_equal , "neq " , operator .ne , {}, xp .bool ),
1822- (xp . less , "less" , operator .lt , {}, xp .bool ),
1823- (xp . less_equal , "les_equal " , operator .le , {}, xp .bool ),
1824- (xp . greater , "greater" , operator .gt , {}, xp .bool ),
1825- (xp . greater_equal , "greater_equal" , operator .ge , {}, xp .bool ),
1803+ ("add" , operator .add , {}, None ),
1804+ ("atan2" , math .atan2 , {}, None ),
1805+ ("copysign" , math .copysign , {}, None ),
1806+ ("divide" , operator .truediv , {"filter_" : lambda s : s != 0 }, None ),
1807+ ("hypot" , math .hypot , {}, None ),
1808+ ("logaddexp" , logaddexp_refimpl , {}, None ),
1809+ ("maximum" , max , {'strict_check' : True }, None ),
1810+ ("minimum" , min , {'strict_check' : True }, None ),
1811+ ("multiply " , operator .mul , {}, None ),
1812+ ("subtract " , operator .sub , {}, None ),
1813+
1814+ ("equal" , operator .eq , {}, xp .bool ),
1815+ ("not_equal " , operator .ne , {}, xp .bool ),
1816+ ("less" , operator .lt , {}, xp .bool ),
1817+ ("less_equal " , operator .le , {}, xp .bool ),
1818+ ("greater" , operator .gt , {}, xp .bool ),
1819+ ("greater_equal" , operator .ge , {}, xp .bool ),
18261820 ],
1827- ids = lambda func_data : func_data [1 ] # use names for test IDs
1821+ ids = lambda func_data : func_data [0 ] # use names for test IDs
18281822)
18291823@given (x1x2 = hh .array_and_py_scalar (dh .real_float_dtypes ))
18301824def test_binary_with_scalars_real (func_data , x1x2 ):
@@ -1833,13 +1827,13 @@ def test_binary_with_scalars_real(func_data, x1x2):
18331827
18341828@pytest .mark .min_version ("2024.12" )
18351829@pytest .mark .parametrize ('func_data' ,
1836- # xp_func, name , refimpl, kwargs, expected_dtype
1830+ # func_name , refimpl, kwargs, expected_dtype
18371831 [
1838- (xp . logical_and , "logical_and" , operator .and_ , {"expr_template" : "({} or {})={}" }, None ),
1839- (xp . logical_or , "logical_or" , operator .or_ , {"expr_template" : "({} or {})={}" }, None ),
1840- (xp . logical_xor , "logical_xor" , operator .xor , {"expr_template" : "({} or {})={}" }, None ),
1832+ ("logical_and" , operator .and_ , {"expr_template" : "({} or {})={}" }, None ),
1833+ ("logical_or" , operator .or_ , {"expr_template" : "({} or {})={}" }, None ),
1834+ ("logical_xor" , operator .xor , {"expr_template" : "({} or {})={}" }, None ),
18411835 ],
1842- ids = lambda func_data : func_data [1 ] # use names for test IDs
1836+ ids = lambda func_data : func_data [0 ] # use names for test IDs
18431837)
18441838@given (x1x2 = hh .array_and_py_scalar ([xp .bool ]))
18451839def test_binary_with_scalars_bool (func_data , x1x2 ):
@@ -1848,36 +1842,34 @@ def test_binary_with_scalars_bool(func_data, x1x2):
18481842
18491843@pytest .mark .min_version ("2024.12" )
18501844@pytest .mark .parametrize ('func_data' ,
1851- # xp_func, name , refimpl, kwargs, expected_dtype
1845+ # func_name , refimpl, kwargs, expected_dtype
18521846 [
1853-
1854- (xp .floor_divide , "floor_divide" , operator .floordiv , {}, None ),
1855- (xp .remainder , "remainder" , operator .mod , {}, None ),
1847+ ("floor_divide" , operator .floordiv , {}, None ),
1848+ ("remainder" , operator .mod , {}, None ),
18561849 ],
1857- ids = lambda func_data : func_data [1 ] # use names for test IDs
1850+ ids = lambda func_data : func_data [0 ] # use names for test IDs
18581851)
18591852@given (x1x2 = hh .array_and_py_scalar ([xp .int64 ]))
18601853def test_binary_with_scalars_int (func_data , x1x2 ):
1861-
18621854 assume (_filter_zero (x1x2 [1 ]))
18631855 assume (_filter_zero (x1x2 [0 ]) and _filter_zero (x1x2 [1 ]))
18641856 _check_binary_with_scalars (func_data , x1x2 )
18651857
18661858
18671859@pytest .mark .min_version ("2024.12" )
18681860@pytest .mark .parametrize ('func_data' ,
1869- # xp_func, name , refimpl, kwargs, expected_dtype
1861+ # func_name , refimpl, kwargs, expected_dtype
18701862 [
1871- (xp . bitwise_and , "bitwise_and" , operator .and_ , {}, None ),
1872- (xp . bitwise_or , "bitwise_or" , operator .or_ , {}, None ),
1873- (xp . bitwise_xor , "bitwise_xor" , operator .xor , {}, None ),
1863+ ("bitwise_and" , operator .and_ , {}, None ),
1864+ ("bitwise_or" , operator .or_ , {}, None ),
1865+ ("bitwise_xor" , operator .xor , {}, None ),
18741866 ],
1875- ids = lambda func_data : func_data [1 ] # use names for test IDs
1867+ ids = lambda func_data : func_data [0 ] # use names for test IDs
18761868)
18771869@given (x1x2 = hh .array_and_py_scalar ([xp .int32 ]))
18781870def test_binary_with_scalars_bitwise (func_data , x1x2 ):
1879- xp_func , name , refimpl , kwargs , expected = func_data
1871+ func_name , refimpl , kwargs , expected = func_data
18801872 # repack the refimpl
18811873 refimpl_ = lambda l , r : mock_int_dtype (refimpl (l , r ), xp .int32 )
1882- _check_binary_with_scalars ((xp_func , name , refimpl_ , kwargs ,expected ), x1x2 )
1874+ _check_binary_with_scalars ((func_name , refimpl_ , kwargs , expected ), x1x2 )
18831875
0 commit comments