22from typing import Union
33
44import pytest
5- from hypothesis import given
5+ from hypothesis import given , assume
66from hypothesis import strategies as st
77
88from . import _array_module as xp
@@ -23,26 +23,43 @@ def float32(n: Union[int, float]) -> float:
2323 return struct .unpack ("!f" , struct .pack ("!f" , float (n )))[0 ]
2424
2525
26+ def _float_match_complex (complex_dtype ):
27+ return xp .float32 if complex_dtype == xp .complex64 else xp .float64
28+
29+
2630@given (
27- x_dtype = non_complex_dtypes () ,
28- dtype = non_complex_dtypes () ,
31+ x_dtype = hh . all_dtypes ,
32+ dtype = hh . all_dtypes ,
2933 kw = hh .kwargs (copy = st .booleans ()),
3034 data = st .data (),
3135)
3236def test_astype (x_dtype , dtype , kw , data ):
37+ _complex_dtypes = (xp .complex64 , xp .complex128 )
38+
3339 if xp .bool in (x_dtype , dtype ):
3440 elements_strat = hh .from_dtype (x_dtype )
3541 else :
36- m1 , M1 = dh .dtype_ranges [x_dtype ]
37- m2 , M2 = dh .dtype_ranges [dtype ]
42+
3843 if dh .is_int_dtype (x_dtype ):
3944 cast = int
40- elif x_dtype == xp .float32 :
45+ elif x_dtype in ( xp .float32 , xp . complex64 ) :
4146 cast = float32
4247 else :
4348 cast = float
49+
50+ real_dtype = x_dtype
51+ if x_dtype in _complex_dtypes :
52+ real_dtype = _float_match_complex (x_dtype )
53+ m1 , M1 = dh .dtype_ranges [real_dtype ]
54+
55+ real_dtype = dtype
56+ if dtype in _complex_dtypes :
57+ real_dtype = _float_match_complex (x_dtype )
58+ m2 , M2 = dh .dtype_ranges [real_dtype ]
59+
4460 min_value = cast (max (m1 , m2 ))
4561 max_value = cast (min (M1 , M2 ))
62+
4663 elements_strat = hh .from_dtype (
4764 x_dtype ,
4865 min_value = min_value ,
@@ -54,6 +71,11 @@ def test_astype(x_dtype, dtype, kw, data):
5471 hh .arrays (dtype = x_dtype , shape = hh .shapes (), elements = elements_strat ), label = "x"
5572 )
5673
74+ # according to the spec, "Casting a complex floating-point array to a real-valued
75+ # data type should not be permitted."
76+ # https://data-apis.org/array-api/latest/API_specification/generated/array_api.astype.html#astype
77+ assume (not ((x_dtype in _complex_dtypes ) and (dtype not in _complex_dtypes )))
78+
5779 out = xp .astype (x , dtype , ** kw )
5880
5981 ph .assert_kw_dtype ("astype" , kw_dtype = dtype , out_dtype = out .dtype )
0 commit comments