Skip to content

Commit 7c4948e

Browse files
committed
TST: (mpf) Add basic tests for the scalar
1 parent e5566fe commit 7c4948e

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import os
2+
3+
os.environ["NUMPY_EXPERIMENTAL_DTYPE_API"] = "1"
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import pytest
2+
3+
import sys
4+
import numpy as np
5+
import operator
6+
7+
from mpfdtype import MPFDType, MPFloat
8+
9+
10+
def test_create_scalar_simple():
11+
# currently inferring 53bit precision from float:
12+
assert MPFloat(12.).prec == 53
13+
# currently infers 64bit or 32bit depending on system:
14+
assert MPFloat(1).prec == sys.maxsize.bit_count() + 1
15+
16+
assert MPFloat(MPFloat(12.)).prec == 53
17+
assert MPFloat(MPFloat(1)).prec == sys.maxsize.bit_count() + 1
18+
19+
20+
def test_create_scalar_prec():
21+
assert MPFloat(1, prec=100).prec == 100
22+
assert MPFloat(12., prec=123).prec == 123
23+
assert MPFloat("12.234", prec=1000).prec == 1000
24+
25+
mpf1 = MPFloat("12.4325", prec=120)
26+
mpf2 = MPFloat(mpf1, prec=150)
27+
assert mpf1 == mpf2
28+
assert mpf2.prec == 150
29+
30+
31+
def test_basic_equality():
32+
assert MPFloat(12) == MPFloat(12.) == MPFloat("12.00", prec=10)
33+
34+
35+
@pytest.mark.parametrize("val", [123532.543, 12893283.5])
36+
def test_scalar_repr(val):
37+
# For non exponentials at least, the repr matches:
38+
val_repr = f"{val:e}".upper()
39+
expected = f"MPFloat('{val_repr}', prec=20)"
40+
assert repr(MPFloat(val, prec=20)) == expected
41+
42+
@pytest.mark.parametrize("op",
43+
["add", "sub", "mul", "pow"])
44+
@pytest.mark.parametrize("other", [3., 12.5, 100., np.nan, np.inf])
45+
def test_binary_ops(op, other):
46+
# Generally, the math ops should behave the same as double math if they
47+
# use double precision (which they currently do).
48+
# (double could have errors, but not for these simple ops)
49+
op = getattr(operator, op)
50+
try:
51+
expected = op(12.5, other)
52+
except Exception as e:
53+
with pytest.raises(type(e)):
54+
op(MPFloat(12.5), other)
55+
with pytest.raises(type(e)):
56+
op(12.5, MPFloat(other))
57+
with pytest.raises(type(e)):
58+
op(MPFloat(12.5), MPFloat(other))
59+
else:
60+
if np.isnan(expected):
61+
# Avoiding isnan (which was also not implemented when written)
62+
res = op(MPFloat(12.5), other)
63+
assert res != res
64+
res = op(12.5, MPFloat(other))
65+
assert res != res
66+
res = op(MPFloat(12.5), MPFloat(other))
67+
assert res != res
68+
else:
69+
assert op(MPFloat(12.5), other) == expected
70+
assert op(12.5, MPFloat(other)) == expected
71+
assert op(MPFloat(12.5), MPFloat(other)) == expected
72+
73+
74+
@pytest.mark.parametrize("op",
75+
["eq", "ne", "le", "lt", "ge", "gt"])
76+
@pytest.mark.parametrize("other", [3., 12.5, 100., np.nan, np.inf])
77+
def test_comparisons(op, other):
78+
op = getattr(operator, op)
79+
expected = op(12.5, other)
80+
assert op(MPFloat(12.5), other) is expected
81+
assert op(12.5, MPFloat(other)) is expected
82+
assert op(MPFloat(12.5), MPFloat(other)) is expected
83+
84+
85+
@pytest.mark.parametrize("op",
86+
["neg", "pos", "abs"])
87+
@pytest.mark.parametrize("val", [3., 12.5, 100., np.nan, np.inf])
88+
def test_comparisons(op, val):
89+
op = getattr(operator, op)
90+
expected = op(val)
91+
if np.isnan(expected):
92+
assert op(MPFloat(val)) != op(MPFloat(val))
93+
else:
94+
assert op(MPFloat(val)) == expected

0 commit comments

Comments
 (0)