Skip to content

Commit c8c5afc

Browse files
committed
Extend general tests to handle FB3D
1 parent dba3570 commit c8c5afc

File tree

2 files changed

+18
-6
lines changed

2 files changed

+18
-6
lines changed

tests/_basis_util.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,24 @@
33
import numpy as np
44

55
from aspire.image import Image
6+
from aspire.volume import Volume
67
from aspire.utils import gaussian_2d, utest_tolerance
78
from aspire.utils.coor_trans import grid_2d
89
from aspire.utils.random import randn
910

11+
1012
class SteerableMixin:
1113
def testEvaluateExpand(self):
1214
coef1 = randn(self.basis.count, seed=self.seed)
1315
coef1 = coef1.astype(self.dtype)
1416

15-
im = self.basis.evaluate(coef1)
16-
if isinstance(im, Image):
17-
im = im.asnumpy()
18-
coef2 = self.basis.expand(im)[0]
17+
x = self.basis.evaluate(coef1)
18+
if isinstance(x, Image) or isinstance(x, Volume):
19+
x = x.asnumpy()
20+
21+
coef2 = self.basis.expand(x)
22+
if coef2.ndim == 2:
23+
coef2 = coef2[0]
1924

2025
self.assertTrue(coef1.shape == coef2.shape)
2126
self.assertTrue(np.allclose(coef1, coef2, atol=utest_tolerance(self.dtype)))
@@ -25,7 +30,7 @@ def testAdjoint(self):
2530
u = u.astype(self.dtype)
2631

2732
Au = self.basis.evaluate(u)
28-
if isinstance(Au, Image):
33+
if isinstance(Au, Image) or isinstance(Au, Volume):
2934
Au = Au.asnumpy()
3035

3136
x = randn(*self.basis.sz, seed=self.seed)
@@ -131,3 +136,7 @@ def testModulated(self):
131136
energy_ratio = energy_outside / energy_total
132137

133138
self.assertTrue(energy_ratio < 0.10)
139+
140+
141+
class Steerable3DMixin(SteerableMixin):
142+
pass

tests/test_FBbasis3D.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@
66
from aspire.basis import FBBasis3D
77
from aspire.utils import utest_tolerance
88

9+
from ._basis_util import Steerable3DMixin
10+
911
DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data")
1012

1113

12-
class FBBasis3DTestCase(TestCase):
14+
class FBBasis3DTestCase(TestCase, Steerable3DMixin):
1315
def setUp(self):
1416
self.dtype = np.float32
1517
self.basis = FBBasis3D((8, 8, 8), dtype=self.dtype)
18+
self.seed = 9161341
1619

1720
def tearDown(self):
1821
pass

0 commit comments

Comments
 (0)