11import numpy as np
22
33from pandas import compat
4- from pandas.core.common import isnull, array_equivalent
4+ from pandas.core.common import isnull, array_equivalent, is_dtype_equal
55
66cdef NUMERIC_TYPES = (
77 bool ,
@@ -55,7 +55,7 @@ cpdef assert_dict_equal(a, b, bint compare_keys=True):
5555
5656 return True
5757
58- cpdef assert_almost_equal(a, b, bint check_less_precise = False ,
58+ cpdef assert_almost_equal(a, b, bint check_less_precise = False , check_dtype = True ,
5959 obj = None , lobj = None , robj = None ):
6060 """ Check that left and right objects are almost equal.
6161
@@ -66,6 +66,8 @@ cpdef assert_almost_equal(a, b, bint check_less_precise=False,
6666 check_less_precise : bool, default False
6767 Specify comparison precision.
6868 5 digits (False) or 3 digits (True) after decimal points are compared.
69+ check_dtype: bool, default True
70+ check dtype if both a and b are np.ndarray
6971 obj : str, default None
7072 Specify object name being compared, internally used to show appropriate
7173 assertion message
@@ -82,7 +84,7 @@ cpdef assert_almost_equal(a, b, bint check_less_precise=False,
8284 double diff = 0.0
8385 Py_ssize_t i, na, nb
8486 double fa, fb
85- bint is_unequal = False
87+ bint is_unequal = False , a_is_ndarray, b_is_ndarray
8688
8789 if lobj is None :
8890 lobj = a
@@ -97,36 +99,43 @@ cpdef assert_almost_equal(a, b, bint check_less_precise=False,
9799 assert a == b, " %r != %r " % (a, b)
98100 return True
99101
102+ a_is_ndarray = isinstance (a, np.ndarray)
103+ b_is_ndarray = isinstance (b, np.ndarray)
104+
105+ if obj is None :
106+ if a_is_ndarray or b_is_ndarray:
107+ obj = ' numpy array'
108+ else :
109+ obj = ' Iterable'
110+
100111 if isiterable(a):
101112
102113 if not isiterable(b):
103- from pandas.util.testing import raise_assert_detail
104- if obj is None :
105- obj = ' Iterable'
106- msg = " First object is iterable, second isn't"
107- raise_assert_detail(obj, msg, a, b)
114+ from pandas.util.testing import assert_class_equal
115+ # classes can't be the same, to raise error
116+ assert_class_equal(a, b, obj = obj)
108117
109118 assert has_length(a) and has_length(b), (
110119 " Can't compare objects without length, one or both is invalid: "
111- " (%r , %r )" % (a, b)
112- )
120+ " (%r , %r )" % (a, b))
113121
114- if isinstance (a, np.ndarray) and isinstance (b, np.ndarray):
115- if obj is None :
116- obj = ' numpy array'
122+ if a_is_ndarray and b_is_ndarray:
117123 na, nb = a.size, b.size
118124 if a.shape != b.shape:
119125 from pandas.util.testing import raise_assert_detail
120126 raise_assert_detail(obj, ' {0} shapes are different' .format(obj),
121127 a.shape, b.shape)
128+
129+ if check_dtype and not is_dtype_equal(a, b):
130+ from pandas.util.testing import assert_attr_equal
131+ assert_attr_equal(' dtype' , a, b, obj = obj)
132+
122133 try :
123134 if array_equivalent(a, b, strict_nan = True ):
124135 return True
125136 except :
126137 pass
127138 else :
128- if obj is None :
129- obj = ' Iterable'
130139 na, nb = len (a), len (b)
131140
132141 if na != nb:
@@ -149,54 +158,38 @@ cpdef assert_almost_equal(a, b, bint check_less_precise=False,
149158 return True
150159
151160 elif isiterable(b):
152- from pandas.util.testing import raise_assert_detail
153- if obj is None :
154- obj = ' Iterable'
155- msg = " Second object is iterable, first isn't"
156- raise_assert_detail(obj, msg, a, b)
161+ from pandas.util.testing import assert_class_equal
162+ # classes can't be the same, to raise error
163+ assert_class_equal(a, b, obj = obj)
157164
158- if isnull(a):
159- assert isnull(b), (
160- " First object is null, second isn't: %r != %r " % (a, b)
161- )
165+ if a == b:
166+ # object comparison
162167 return True
163- elif isnull(b):
164- assert isnull(a), (
165- " First object is not null, second is null: %r != %r " % (a, b)
166- )
168+ if isnull(a) and isnull(b):
169+ # nan / None comparison
167170 return True
168-
169- if is_comparable_as_number(a):
170- assert is_comparable_as_number(b), (
171- " First object is numeric, second is not: %r != %r " % (a, b)
172- )
171+ if is_comparable_as_number(a) and is_comparable_as_number(b):
172+ if array_equivalent(a, b, strict_nan = True ):
173+ # inf comparison
174+ return True
173175
174176 decimal = 5
175177
176178 # deal with differing dtypes
177179 if check_less_precise:
178180 decimal = 3
179181
180- if np.isinf(a):
181- assert np.isinf(b), " First object is inf, second isn't"
182- if np.isposinf(a):
183- assert np.isposinf(b), " First object is positive inf, second is negative inf"
184- else :
185- assert np.isneginf(b), " First object is negative inf, second is positive inf"
182+ fa, fb = a, b
183+
184+ # case for zero
185+ if abs (fa) < 1e-5 :
186+ if not decimal_almost_equal(fa, fb, decimal):
187+ assert False , (
188+ ' (very low values) expected %.5f but got %.5f , with decimal %d ' % (fb, fa, decimal)
189+ )
186190 else :
187- fa, fb = a, b
188-
189- # case for zero
190- if abs (fa) < 1e-5 :
191- if not decimal_almost_equal(fa, fb, decimal):
192- assert False , (
193- ' (very low values) expected %.5f but got %.5f , with decimal %d ' % (fb, fa, decimal)
194- )
195- else :
196- if not decimal_almost_equal(1 , fb / fa, decimal):
197- assert False , ' expected %.5f but got %.5f , with decimal %d ' % (fb, fa, decimal)
198-
199- else :
200- assert a == b, " %r != %r " % (a, b)
191+ if not decimal_almost_equal(1 , fb / fa, decimal):
192+ assert False , ' expected %.5f but got %.5f , with decimal %d ' % (fb, fa, decimal)
193+ return True
201194
202- return True
195+ raise AssertionError ( " {0} != {1} " .format(a, b))
0 commit comments