33import numpy as np
44
55from aspire .image import Image
6+ from aspire .volume import Volume
67from aspire .utils import gaussian_2d , utest_tolerance
78from aspire .utils .coor_trans import grid_2d
89from aspire .utils .random import randn
910
11+
1012class SteerableMixin :
1113 def testEvaluateExpand (self ):
1214 coef1 = randn (self .basis .count , seed = self .seed )
1315 coef1 = coef1 .astype (self .dtype )
1416
15- im = self .basis .evaluate (coef1 )
16- if isinstance (im , Image ):
17- im = im .asnumpy ()
18- coef2 = self .basis .expand (im )[0 ]
17+ x = self .basis .evaluate (coef1 )
18+ if isinstance (x , Image ) or isinstance (x , Volume ):
19+ x = x .asnumpy ()
20+
21+ coef2 = self .basis .expand (x )
22+ if coef2 .ndim == 2 :
23+ coef2 = coef2 [0 ]
1924
2025 self .assertTrue (coef1 .shape == coef2 .shape )
2126 self .assertTrue (np .allclose (coef1 , coef2 , atol = utest_tolerance (self .dtype )))
@@ -25,7 +30,7 @@ def testAdjoint(self):
2530 u = u .astype (self .dtype )
2631
2732 Au = self .basis .evaluate (u )
28- if isinstance (Au , Image ):
33+ if isinstance (Au , Image ) or isinstance ( Au , Volume ) :
2934 Au = Au .asnumpy ()
3035
3136 x = randn (* self .basis .sz , seed = self .seed )
@@ -131,3 +136,7 @@ def testModulated(self):
131136 energy_ratio = energy_outside / energy_total
132137
133138 self .assertTrue (energy_ratio < 0.10 )
139+
140+
141+ class Steerable3DMixin (SteerableMixin ):
142+ pass
0 commit comments