|
4 | 4 |
|
5 | 5 | from aspire.basis import DiracBasis |
6 | 6 | from aspire.image import Image |
7 | | -from aspire.utils.matlab_compat import m_flatten |
| 7 | +from aspire.utils.matlab_compat import m_flatten, m_reshape |
8 | 8 |
|
9 | 9 |
|
10 | 10 | class DiracBasisTestCase(TestCase): |
@@ -192,13 +192,23 @@ def testDiracEvaluate_t(self): |
192 | 192 | ], |
193 | 193 | ] |
194 | 194 | ) |
| 195 | + # First test single image |
195 | 196 | result = self.basis.evaluate_t(x) |
196 | | - |
197 | 197 | # evaluate_t should return a NumPy array |
198 | 198 | self.assertTrue(isinstance(result, np.ndarray)) |
199 | | - |
| 199 | + # the result should be a flattened array of the values of x |
| 200 | + # in particular, for one image, its shape should be (size*size,) |
| 201 | + # not (size*size, 1) |
200 | 202 | self.assertTrue(np.allclose(result, m_flatten(x))) |
201 | 203 |
|
| 204 | + # Now test a stack of images |
| 205 | + stack = np.array([x] * 10) |
| 206 | + result_stack = self.basis.evaluate_t(stack) |
| 207 | + # the result should be of the shape (size*size, 10) |
| 208 | + flat_x = m_flatten(x) |
| 209 | + compare_array = m_reshape(np.array([flat_x] * 10), (self.basis.nres**2, 10)) |
| 210 | + self.assertTrue(np.allclose(result_stack, compare_array)) |
| 211 | + |
202 | 212 | def testInitWithIntSize(self): |
203 | 213 | # make sure we can instantiate with just an int as a shortcut |
204 | 214 | self.assertEqual((8, 8), DiracBasis(8).sz) |
0 commit comments