Skip to content

Commit 803850c

Browse files
fix Dirac basis evaluate_t
1 parent c75f5fc commit 803850c

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

src/aspire/basis/dirac.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,17 @@ def _evaluate_t(self, x):
6767
`self.count` and whose remaining dimensions correspond to
6868
higher dimensions of `v`.
6969
"""
70-
x = np.squeeze(x.asnumpy())
70+
import pdb
71+
72+
pdb.set_trace()
73+
x = m_reshape(x.asnumpy(), new_shape=self.ndim * (self.nres,) + (x.shape[0],))
7174
x, sz_roll = unroll_dim(x, self.ndim + 1)
7275
x = m_reshape(x, new_shape=(self._sz_prod,) + x.shape[self.ndim :])
7376
v = np.zeros(shape=(self.count,) + x.shape[1:], dtype=self.dtype)
7477
v = x[self._mask, ...]
7578
v = roll_dim(v, sz_roll)
7679

77-
return v
80+
return np.squeeze(v)
7881

7982
def expand(self, x):
8083
return self.evaluate_t(x)

tests/test_Diracbasis.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from aspire.basis import DiracBasis
66
from aspire.image import Image
7-
from aspire.utils.matlab_compat import m_flatten
7+
from aspire.utils.matlab_compat import m_flatten, m_reshape
88

99

1010
class DiracBasisTestCase(TestCase):
@@ -192,13 +192,23 @@ def testDiracEvaluate_t(self):
192192
],
193193
]
194194
)
195+
# First test single image
195196
result = self.basis.evaluate_t(x)
196-
197197
# evaluate_t should return a NumPy array
198198
self.assertTrue(isinstance(result, np.ndarray))
199-
199+
# the result should be a flattened array of the values of x
200+
# in particular, for one image, its shape should be (size*size,)
201+
# not (size*size, 1)
200202
self.assertTrue(np.allclose(result, m_flatten(x)))
201203

204+
# Now test a stack of images
205+
stack = np.array([x] * 10)
206+
result_stack = self.basis.evaluate_t(stack)
207+
# the result should be of the shape (size*size, 10)
208+
flat_x = m_flatten(x)
209+
compare_array = m_reshape(np.array([flat_x] * 10), (self.basis.nres**2, 10))
210+
self.assertTrue(np.allclose(result_stack, compare_array))
211+
202212
def testInitWithIntSize(self):
203213
# make sure we can instantiate with just an int as a shortcut
204214
self.assertEqual((8, 8), DiracBasis(8).sz)

0 commit comments

Comments
 (0)