Skip to content

Commit a7f08f1

Browse files
committed
Cleanup based on request feedback
1 parent e9d7820 commit a7f08f1

File tree

3 files changed

+20
-23
lines changed

3 files changed

+20
-23
lines changed

src/aspire/estimation/covar2d.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def get_mean(self, coeffs, ctf_fb=None, ctf_idx=None):
110110

111111
if (ctf_fb is None) or (ctf_idx is None):
112112
ctf_idx = np.zeros(coeffs.shape[1], dtype=int)
113-
ctf_fb = [BlkDiagMatrix.eye(get_partition(RadialCTFFilter().fb_mat(self.basis)),dtype=coeffs.dtype)]
113+
ctf_fb = [BlkDiagMatrix.eye(get_partition(RadialCTFFilter().fb_mat(self.basis)), dtype=coeffs.dtype)]
114114

115115
b = np.zeros(self.basis.count, dtype=coeffs.dtype)
116116

@@ -121,7 +121,7 @@ def get_mean(self, coeffs, ctf_fb=None, ctf_idx=None):
121121
mean_coeff_k = self._get_mean(coeff_k)
122122
ctf_fb_k = ctf_fb[k]
123123
ctf_fb_k_t = ctf_fb_k.T
124-
b = b + ctf_fb_k_t.apply(mean_coeff_k) * weight
124+
b += weight * ctf_fb_k_t.apply(mean_coeff_k)
125125
A += weight * (ctf_fb_k_t @ ctf_fb_k)
126126

127127
mean_coeff = A.solve(b)
@@ -142,7 +142,7 @@ def get_covar(self, coeffs, ctf_fb=None, ctf_idx=None, mean_coeff=None,
142142
:param covar_est_opt: The optimization parameter list for obtaining the Cov2D matrix.
143143
:return: The basis coefficients of the covariance matrix in
144144
the form of cell array representing a block diagonal matrix. These
145-
block diagonal matrices may be manipulated using the `BlkDiagMatrix` functions.
145+
block diagonal matrices are implemented as BlkDiagMatrix instances.
146146
The covariance is calculated from the images represented by the coeffs array,
147147
along with all possible rotations and reflections. As a result, the computed covariance
148148
matrix is invariant to both reflection and rotation. The effect of the filters in ctf_fb
@@ -187,16 +187,16 @@ def identity(x):
187187
mean_coeff_k = ctf_fb_k.apply(mean_coeff)
188188
covar_coeff_k = self._get_covar(coeff_k, mean_coeff_k)
189189

190-
b_coeff += ctf_fb_k_t @ covar_coeff_k @ (ctf_fb_k * weight)
190+
b_coeff += weight * (ctf_fb_k_t @ covar_coeff_k @ ctf_fb_k)
191191

192-
A_temp = ctf_fb_k_t @ ctf_fb_k
193-
b_noise += A_temp * weight
192+
ctf_fb_k_sq = ctf_fb_k_t @ ctf_fb_k
193+
b_noise += weight * ctf_fb_k_sq
194194

195-
A[k] = A_temp * np.sqrt(weight)
195+
A[k] = np.sqrt(weight) * ctf_fb_k_sq
196196
M += A[k]
197197

198198
if covar_est_opt['shrinker'] == 'None':
199-
b = b_coeff + (-noise_var * b_noise)
199+
b = b_coeff - noise_var * b_noise
200200
else:
201201
b = self.shrink_covar_backward(b_coeff, b_noise, np.size(coeffs, 1),
202202
noise_var, covar_est_opt['shrinker'])
@@ -293,15 +293,14 @@ def get_cwf_coeffs(self, coeffs, ctf_fb=None, ctf_idx=None, mean_coeff=None, cov
293293
coeff_k = coeffs[:, ctf_idx == k]
294294
ctf_fb_k = ctf_fb[k]
295295
ctf_fb_k_t = ctf_fb_k.T
296-
sig_covar_coeff = ctf_fb_k @ (covar_coeff @ ctf_fb_k_t)
296+
sig_covar_coeff = ctf_fb_k @ covar_coeff @ ctf_fb_k_t
297297
sig_noise_covar_coeff = sig_covar_coeff + noise_covar_coeff
298298

299299
mean_coeff_k = ctf_fb_k.apply(mean_coeff[:, np.newaxis])[:, 0]
300300

301301
coeff_est_k = coeff_k - mean_coeff_k[:, np.newaxis]
302302
coeff_est_k = sig_noise_covar_coeff.solve(coeff_est_k)
303-
tmp = covar_coeff @ ctf_fb_k_t
304-
coeff_est_k = tmp.apply(coeff_est_k)
303+
coeff_est_k = (covar_coeff @ ctf_fb_k_t).apply(coeff_est_k)
305304
coeff_est_k = coeff_est_k + mean_coeff[:, np.newaxis]
306305
coeffs_est[:, ctf_idx == k] = coeff_est_k
307306

@@ -423,11 +422,11 @@ def _calc_op(self):
423422
ctf_fb_k = ctf_fb[k]
424423
ctf_fb_k_t = ctf_fb_k.T
425424

426-
A_temp = ctf_fb_k_t @ ctf_fb_k
427-
A_mean_k = A_temp * weight
425+
ctf_fb_k_sq = ctf_fb_k_t @ ctf_fb_k
426+
A_mean_k = weight * ctf_fb_k_sq
428427
A_mean += A_mean_k
429428

430-
A_covar_k = A_temp * np.sqrt(weight)
429+
A_covar_k = np.sqrt(weight) * ctf_fb_k_sq
431430
A_covar[k] = A_covar_k
432431

433432
M_covar += A_covar_k
@@ -444,12 +443,9 @@ def _mean_correct_covar_rhs(self, b_covar, b_mean, mean_coeff):
444443

445444
partition = get_partition(ctf_fb[0])
446445

447-
# GBW, Why are they deep copying, where are they using refs...
448-
449446
# Note: If we don't do this, we'll be modifying the stored `b_covar`
450447
# since the operations below are in-place.
451-
b_covar.data = [blk.copy() for blk in b_covar]
452-
448+
b_covar = b_covar.copy()
453449

454450
for k in np.unique(ctf_idx):
455451
weight = np.count_nonzero(ctf_idx == k) / src.n
@@ -473,7 +469,7 @@ def _mean_correct_covar_rhs(self, b_covar, b_mean, mean_coeff):
473469

474470
def _noise_correct_covar_rhs(self, b_covar, b_noise, noise_var, shrinker):
475471
if shrinker == 'None':
476-
b_noise = b_noise * -noise_var
472+
b_noise = -noise_var * b_noise
477473
b_covar += b_noise
478474
else:
479475
b_covar = self.shrink_covar_backward(b_covar, b_noise, self.src.n,
@@ -566,8 +562,8 @@ def get_covar(self, noise_var=1, mean_coeff=None, covar_est_opt=None):
566562
- 'precision': Precision of conjugate gradient algorithm (see
567563
documentation for `conj_grad`, default `'float64'`)
568564
:return: The block diagonal matrix containing the basis coefficients (in
569-
`self.basis`) for the estimated covariance matrix. These may be
570-
manipulated using the `BlkDiagMatrix` functions.
565+
`self.basis`) for the estimated covariance matrix. These are
566+
implemented using `BlkDiagMatrix`.
571567
"""
572568

573569
def identity(x):

src/aspire/utils/blk_diag_matrix.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class BlkDiagMatrix:
2121
block diagonal matrices as used by ASPIRE.
2222
"""
2323

24+
# Developers' Note:
2425
# All instances of this class should have priority over ndarray ops
2526
# because we implement them here ourselves.
2627
# This is a more np current implementation of __array_priority__
@@ -572,7 +573,7 @@ def from_mat(mat, blk_partition, dtype=np.float64):
572573
rows = blk_partition[:, 0]
573574
cols = blk_partition[:, 1]
574575
cellarray = Cell2D(rows, cols, dtype=mat.dtype)
575-
blk_diag = cellarray.mat2blk_diag(mat, rows, cols)
576+
blk_diag = cellarray.mat_to_blk_diag(mat, rows, cols)
576577
A.data = BlkDiagMatrix.from_blk_diag(blk_diag)
577578
return A
578579

src/aspire/utils/cell.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def mat2cell(self, mat, rows, cols):
4343
offsetr += rows[i]
4444
return self.cell_list
4545

46-
def mat2blk_diag(self, mat, rows, cols):
46+
def mat_to_blk_diag(self, mat, rows, cols):
4747
self.mat2cell(mat, rows, cols)
4848
blk_diag=[]
4949
offset = 0

0 commit comments

Comments
 (0)