Skip to content

Commit 5a6d862

Browse files
committed
Add shape checks
To make sure we don't get the same inadvertent broadcasting bug again.
1 parent c7313b6 commit 5a6d862

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

tests/_basis_util.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def testGaussianExpand(self):
4848

4949
if isinstance(im2, Image):
5050
im2 = im2.asnumpy()
51+
im2 = im2[0]
5152

5253
# For small L there's too much clipping at high freqs to get 1e-3
5354
# accuracy.
@@ -56,6 +57,7 @@ def testGaussianExpand(self):
5657
else:
5758
atol = 1e-3
5859

60+
self.assertTrue(im1.shape == im2.shape)
5961
self.assertTrue(np.allclose(im1, im2, atol=atol))
6062

6163
def testIsotropic(self):
@@ -107,6 +109,7 @@ def testEvaluateExpand(self):
107109
im = im.asnumpy()
108110
coef2 = self.basis.expand(im)[0]
109111

112+
self.assertTrue(coef1.shape == coef2.shape)
110113
self.assertTrue(np.allclose(coef1, coef2, atol=utest_tolerance(self.dtype)))
111114

112115
def testAdjoint(self):
@@ -125,4 +128,5 @@ def testAdjoint(self):
125128
Au_dot_x = np.sum(Au * x)
126129
u_dot_ATx = np.sum(u * ATx)
127130

131+
self.assertTrue(Au_dot_x.shape == u_dot_ATx.shape)
128132
self.assertTrue(np.isclose(Au_dot_x, u_dot_ATx))

0 commit comments

Comments
 (0)