Skip to content

Commit 6ebf09f

Browse files
Merge remote-tracking branch 'origin/develop' into ctf_estimator_result
2 parents 23c30c7 + 2d67cd3 commit 6ebf09f

31 files changed

+625
-823
lines changed

gallery/tutorials/class_averaging.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@
114114
large_pca_implementation="legacy",
115115
nn_implementation="legacy",
116116
bispectrum_implementation="legacy",
117+
num_procs=1, # Change to "auto" if your machine has many processors
117118
)
118119

119120
classes, reflections, dists = rir.classify()
@@ -167,6 +168,7 @@
167168
large_pca_implementation="legacy",
168169
nn_implementation="sklearn",
169170
bispectrum_implementation="legacy",
171+
num_procs=1, # Change to "auto" if your machine has many processors
170172
)
171173

172174
classes, reflections, dists = noisy_rir.classify()

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,13 @@ def read(fname):
3030
"matplotlib>=3.2.0",
3131
"mrcfile",
3232
"numpy==1.21.5",
33+
"packaging",
3334
"pandas==1.3.5",
35+
"psutil",
3436
"pyfftw",
3537
"PyWavelets",
3638
"pillow",
39+
"ray",
3740
"scipy==1.7.3",
3841
"scikit-learn",
3942
"scikit-image",

src/aspire/basis/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
# isort: on
99

10-
from .dirac import DiracBasis
1110
from .fb_2d import FBBasis2D
1211
from .fb_3d import FBBasis3D
1312
from .ffb_2d import FFBBasis2D

src/aspire/basis/basis.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,24 +74,55 @@ def evaluate(self, v):
7474
Evaluate coefficient vector in basis
7575
7676
:param v: A coefficient vector (or an array of coefficient vectors)
77-
to be evaluated. The first dimension must equal `self.count`.
77+
to be evaluated. The first dimension must correspond to the number of
78+
coefficient vectors, while the second must correspond to `self.count`
7879
:return: The evaluation of the coefficient vector(s) `v` for this basis.
79-
This is an array whose first dimensions equal `self.z` and the
80-
remaining dimensions correspond to dimensions two and higher of `v`.
80+
This is an Image or a Volume object containing one image/volume for each
81+
coefficient vector, and of size `self.sz`.
8182
"""
83+
if v.dtype != self.dtype:
84+
logger.warning(
85+
f"{self.__class__.__name__}::evaluate"
86+
f" Inconsistent dtypes v: {v.dtype} self: {self.dtype}"
87+
)
88+
89+
if self.ndim == 2:
90+
return Image(self._evaluate(v))
91+
elif self.ndim == 3:
92+
return Volume(self._evaluate(v))
93+
94+
def _evaluate(self, v):
8295
raise NotImplementedError("subclasses must implement this")
8396

8497
def evaluate_t(self, v):
8598
"""
8699
Evaluate coefficient in dual basis
87100
88-
:param v: The coefficient array to be evaluated. The first dimensions
89-
must equal `self.sz`.
90-
:return: The evaluation of the coefficient array `v` in the dual
101+
:param v: An Image or Volume object whose size matches `self.sz`.
102+
:return: The evaluation of the Image or Volume object `v` in the dual
91103
basis of `basis`.
92-
This is an array of vectors whose first dimension equals `self.count`
93-
and whose remaining dimensions correspond to higher dimensions of `v`.
104+
This is an array of vectors whose first dimension equals the number of
105+
images/volumes in `v`. and whose second dimension is `self.count`.
94106
"""
107+
if v.dtype != self.dtype:
108+
logger.warning(
109+
f"{self.__class__.__name__}::evaluate_t"
110+
f" Inconsistent dtypes v: {v.dtype} self: {self.dtype}"
111+
)
112+
113+
if not isinstance(v, Image) and not isinstance(v, Volume):
114+
if self.ndim == 2:
115+
_class = Image
116+
elif self.ndim == 3:
117+
_class = Volume
118+
logger.warning(
119+
f"{self.__class__.__name__}::evaluate_t"
120+
f" passed numpy array instead of {_class}."
121+
)
122+
v = _class(v)
123+
return self._evaluate_t(v)
124+
125+
def _evaluate_t(self, v):
95126
raise NotImplementedError("Subclasses should implement this")
96127

97128
def mat_evaluate(self, V):
@@ -104,7 +135,7 @@ def mat_evaluate(self, V):
104135
-`self.sz` corresponding to the evaluation of `V` in
105136
this basis.
106137
"""
107-
return mdim_mat_fun_conj(V, 1, len(self.sz), self.evaluate)
138+
return mdim_mat_fun_conj(V, 1, len(self.sz), self._evaluate)
108139

109140
def mat_evaluate_t(self, X):
110141
"""
@@ -120,7 +151,7 @@ def mat_evaluate_t(self, X):
120151
function calculates V = B' * X * B, where the rows of `B`, rows
121152
of 'X', and columns of `X` are read as vectorized arrays.
122153
"""
123-
return mdim_mat_fun_conj(X, len(self.sz), 1, self.evaluate_t)
154+
return mdim_mat_fun_conj(X, len(self.sz), 1, self._evaluate_t)
124155

125156
def expand(self, x):
126157
"""

src/aspire/basis/dirac.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

src/aspire/basis/fb_2d.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from aspire.basis import FBBasisMixin, SteerableBasis2D
77
from aspire.basis.basis_utils import unique_coords_nd
8-
from aspire.image import Image
98
from aspire.utils import complex_type, real_type, roll_dim, unroll_dim
109
from aspire.utils.matlab_compat import m_flatten, m_reshape
1110

@@ -189,7 +188,7 @@ def basis_norm_2d(self, ell, k):
189188

190189
return rad_norm, ang_norm
191190

192-
def evaluate(self, v):
191+
def _evaluate(self, v):
193192
"""
194193
Evaluate coefficients in standard 2D coordinate basis from those in FB basis
195194
@@ -199,13 +198,6 @@ def evaluate(self, v):
199198
This is an array whose last dimensions equal `self.sz` and the remaining
200199
dimensions correspond to first dimensions of `v`.
201200
"""
202-
203-
if v.dtype != self.dtype:
204-
logger.warning(
205-
f"{self.__class__.__name__}::evaluate"
206-
f" Inconsistent dtypes v: {v.dtype} self: {self.dtype}"
207-
)
208-
209201
# Transpose here once, instead of several times below #RCOPT
210202
v = v.reshape(-1, self.count).T
211203

@@ -242,7 +234,7 @@ def evaluate(self, v):
242234

243235
return x
244236

245-
def evaluate_t(self, v):
237+
def _evaluate_t(self, v):
246238
"""
247239
Evaluate coefficient in FB basis from those in standard 2D coordinate basis
248240
@@ -253,17 +245,7 @@ def evaluate_t(self, v):
253245
`self.count` and whose first dimensions correspond to
254246
first dimensions of `v`.
255247
"""
256-
257-
if v.dtype != self.dtype:
258-
logger.warning(
259-
f"{self.__class__.__name__}::evaluate_t"
260-
f" Inconsistent dtypes v: {v.dtype} self: {self.dtype}"
261-
)
262-
263-
if isinstance(v, Image):
264-
v = v.asnumpy()
265-
266-
v = v.T # RCOPT
248+
v = v.asnumpy().T # RCOPT
267249

268250
x, sz_roll = unroll_dim(v, self.ndim + 1)
269251
x = m_reshape(

src/aspire/basis/fb_3d.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from aspire.basis.basis_utils import real_sph_harmonic, sph_bessel, unique_coords_nd
77
from aspire.utils import roll_dim, unroll_dim
88
from aspire.utils.matlab_compat import m_flatten, m_reshape
9+
from aspire.volume import Volume
910

1011
logger = logging.getLogger(__name__)
1112

@@ -140,7 +141,7 @@ def basis_norm_3d(self, ell, k):
140141
* np.sqrt((self.nres / 2) ** 3)
141142
)
142143

143-
def evaluate(self, v):
144+
def _evaluate(self, v):
144145
"""
145146
Evaluate coefficients in standard 3D coordinate basis from those in FB basis
146147
:param v: A coefficient vector (or an array of coefficient vectors) to
@@ -186,7 +187,7 @@ def evaluate(self, v):
186187

187188
return x.T
188189

189-
def evaluate_t(self, v):
190+
def _evaluate_t(self, v):
190191
"""
191192
Evaluate coefficient in FB basis from those in standard 3D coordinate basis
192193
@@ -197,7 +198,10 @@ def evaluate_t(self, v):
197198
equals `self.count` and whose remaining dimensions correspond
198199
to higher dimensions of `v`.
199200
"""
200-
201+
# v may be a Volume object or a 7D array passed from Basis.mat_evaluate_t
202+
# making this check important
203+
if isinstance(v, Volume):
204+
v = v.asnumpy()
201205
v = v.T
202206
x, sz_roll = unroll_dim(v, self.ndim + 1)
203207
x = m_reshape(

src/aspire/basis/ffb_2d.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from aspire.basis import FBBasis2D
88
from aspire.basis.basis_utils import lgwt
9-
from aspire.image import Image
109
from aspire.nufft import anufft, nufft
1110
from aspire.numeric import fft, xp
1211
from aspire.utils import complex_type
@@ -102,7 +101,7 @@ def get_radial(self):
102101
"""
103102
return self._precomp["radial"]
104103

105-
def evaluate(self, v):
104+
def _evaluate(self, v):
106105
"""
107106
Evaluate coefficients in standard 2D coordinate basis from those in FB basis
108107
@@ -112,13 +111,6 @@ def evaluate(self, v):
112111
coordinate basis. This is Image instance with resolution of `self.sz`
113112
and the first dimension correspond to remaining dimension of `v`.
114113
"""
115-
116-
if v.dtype != self.dtype:
117-
logger.debug(
118-
f"{self.__class__.__name__}::evaluate"
119-
f" Inconsistent dtypes v: {v.dtype} self: {self.dtype}"
120-
)
121-
122114
sz_roll = v.shape[:-1]
123115
v = v.reshape(-1, self.count)
124116

@@ -187,9 +179,9 @@ def evaluate(self, v):
187179
# Return X as Image instance with the last two dimensions as *self.sz
188180
x = x.reshape((*sz_roll, *self.sz))
189181

190-
return Image(x)
182+
return x
191183

192-
def evaluate_t(self, x):
184+
def _evaluate_t(self, x):
193185
"""
194186
Evaluate coefficient in FB basis from those in standard 2D coordinate basis
195187
@@ -199,20 +191,6 @@ def evaluate_t(self, x):
199191
This is an array of vectors whose last dimension equals `self.count`
200192
and whose first dimension correspond to `x.n_images`.
201193
"""
202-
203-
if x.dtype != self.dtype:
204-
logger.warning(
205-
f"{self.__class__.__name__}::evaluate_t"
206-
f" Inconsistent dtypes v: {x.dtype} self: {self.dtype}"
207-
)
208-
209-
if not isinstance(x, Image):
210-
logger.warning(
211-
f"{self.__class__.__name__}::evaluate_t"
212-
" passed numpy array instead of Image."
213-
)
214-
x = Image(x)
215-
216194
# get information on polar grids from precomputed data
217195
n_theta = np.size(self._precomp["freqs"], 2)
218196
n_r = np.size(self._precomp["freqs"], 1)

src/aspire/basis/ffb_3d.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def _precomp(self):
153153
"fourier_pts": fourier_pts,
154154
}
155155

156-
def evaluate(self, v):
156+
def _evaluate(self, v):
157157
"""
158158
Evaluate coefficients in standard 3D coordinate basis from those in 3D FB basis
159159
@@ -278,7 +278,7 @@ def evaluate(self, v):
278278
x = x.reshape((*sz_roll, *self.sz))
279279
return x
280280

281-
def evaluate_t(self, x):
281+
def _evaluate_t(self, x):
282282
"""
283283
Evaluate coefficient in FB basis from those in standard 3D coordinate basis
284284
@@ -289,6 +289,8 @@ def evaluate_t(self, x):
289289
`self.count` and whose remaining dimensions correspond to higher
290290
dimensions of `x`.
291291
"""
292+
x = x.asnumpy()
293+
292294
# roll dimensions
293295
sz_roll = x.shape[:-3]
294296
x = x.reshape((-1, *self.sz))

0 commit comments

Comments
 (0)