Skip to content

Commit 0fd9e35

Browse files
authored
Merge pull request #421 from ComputationalCryoEM/fix_covar0
Fixup covar2d shrinker and get_cwf_coeffs for noise_var=0
2 parents 95dff05 + cbb27f6 commit 0fd9e35

File tree

5 files changed

+194
-35
lines changed

5 files changed

+194
-35
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def read(fname):
5252
"jupyter",
5353
"pyflakes",
5454
"pydocstyle",
55+
"parameterized",
5556
"pytest",
5657
"pytest-cov",
5758
"pytest-random-order",

src/aspire/covariance/covar2d.py

Lines changed: 71 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
logger = logging.getLogger(__name__)
1313

1414

15-
def shrink_covar(covar_in, noise_var, gamma, shrinker=None):
15+
def shrink_covar(covar, noise_var, gamma, shrinker="frobenius_norm"):
1616
"""
1717
Shrink the covariance matrix
1818
:param covar_in: An input covariance matrix
@@ -22,18 +22,14 @@ def shrink_covar(covar_in, noise_var, gamma, shrinker=None):
2222
:return: The shrinked covariance matrix
2323
"""
2424

25-
if shrinker is None:
26-
shrinker = "frobenius_norm"
2725
ensure(
2826
shrinker in ("frobenius_norm", "operator_norm", "soft_threshold"),
2927
"Unsupported shrink method",
3028
)
3129

32-
covar = covar_in / noise_var
33-
3430
lambs, eig_vec = eig(make_symmat(covar))
3531

36-
lambda_max = (1 + np.sqrt(gamma)) ** 2
32+
lambda_max = noise_var * (1 + np.sqrt(gamma)) ** 2
3733

3834
lambs[lambs < lambda_max] = 0
3935

@@ -42,20 +38,36 @@ def shrink_covar(covar_in, noise_var, gamma, shrinker=None):
4238
lambdas = (
4339
1
4440
/ 2
45-
* (lambdas - gamma + 1 + np.sqrt((lambdas - gamma + 1) ** 2 - 4 * lambdas))
46-
- 1
41+
* (
42+
lambdas
43+
- noise_var * (gamma - 1)
44+
+ np.sqrt(
45+
(lambdas - noise_var * (gamma - 1)) ** 2
46+
- 4 * noise_var ** 2 * lambdas
47+
)
48+
)
49+
- noise_var
4750
)
51+
4852
lambs[lambs > lambda_max] = lambdas
4953
elif shrinker == "frobenius_norm":
5054
lambdas = lambs[lambs > lambda_max]
5155
lambdas = (
5256
1
5357
/ 2
54-
* (lambdas - gamma + 1 + np.sqrt((lambdas - gamma + 1) ** 2 - 4 * lambdas))
55-
- 1
58+
* (
59+
lambdas
60+
- noise_var * (gamma - 1)
61+
+ np.sqrt(
62+
(lambdas - noise_var * (gamma - 1)) ** 2
63+
- 4 * noise_var ** 2 * lambdas
64+
)
65+
)
66+
- noise_var
5667
)
5768
c = np.divide(
58-
(1 - np.divide(gamma, lambdas ** 2)), (1 + np.divide(gamma, lambdas))
69+
(1 - np.divide(noise_var ** 2 * gamma, lambdas ** 2)),
70+
(1 + np.divide(noise_var * gamma, lambdas)),
5971
)
6072
lambdas = lambdas * c
6173
lambs[lambs > lambda_max] = lambdas
@@ -68,7 +80,6 @@ def shrink_covar(covar_in, noise_var, gamma, shrinker=None):
6880
np.fill_diagonal(diag_lambs, lambs)
6981

7082
shrinked_covar = eig_vec @ diag_lambs @ eig_vec.conj().T
71-
shrinked_covar *= noise_var
7283

7384
return shrinked_covar
7485

@@ -206,7 +217,7 @@ def get_covar(
206217
ctf_idx=None,
207218
mean_coeff=None,
208219
do_refl=True,
209-
noise_var=1,
220+
noise_var=0,
210221
covar_est_opt=None,
211222
make_psd=True,
212223
):
@@ -242,7 +253,7 @@ def identity(x):
242253
return x
243254

244255
default_est_opt = {
245-
"shrinker": "None",
256+
"shrinker": None,
246257
"verbose": 0,
247258
"max_iter": 250,
248259
"iter_callback": [],
@@ -286,7 +297,7 @@ def identity(x):
286297
if not b_coeff.check_psd():
287298
logger.warning("Left side b in Cov2D is not positive semidefinite.")
288299

289-
if covar_est_opt["shrinker"] == "None":
300+
if covar_est_opt["shrinker"] is None:
290301
b = b_coeff - noise_var * b_noise
291302
else:
292303
b = self.shrink_covar_backward(
@@ -373,7 +384,7 @@ def get_cwf_coeffs(
373384
ctf_idx=None,
374385
mean_coeff=None,
375386
covar_coeff=None,
376-
noise_var=1,
387+
noise_var=0,
377388
):
378389
"""
379390
Estimate the expansion coefficients using the Covariance Wiener Filtering (CWF) method.
@@ -390,6 +401,7 @@ def get_cwf_coeffs(
390401
These are obtained using a Wiener filter with the specified covariance for the clean images
391402
and white noise of variance `noise_var` for the noise.
392403
"""
404+
393405
if mean_coeff is None:
394406
mean_coeff = self.get_mean(coeffs, ctf_fb, ctf_idx)
395407

@@ -398,8 +410,15 @@ def get_cwf_coeffs(
398410
coeffs, ctf_fb, ctf_idx, mean_coeff, noise_var=noise_var
399411
)
400412

401-
# should be none or both
402-
if (ctf_fb is None) or (ctf_idx is None):
413+
# Handle CTF arguments.
414+
if (ctf_fb is None) ^ (ctf_idx is None):
415+
raise RuntimeError(
416+
"Both `ctf_fb` and `ctf_idx` should be provided,"
417+
" or both should be `None`."
418+
f' Given {"ctf_fb" if ctf_idx is None else "ctf_idx"}'
419+
)
420+
elif ctf_fb is None:
421+
# Setup defaults for CTF
403422
ctf_idx = np.zeros(coeffs.shape[0], dtype=int)
404423
ctf_fb = [BlkDiagMatrix.eye_like(covar_coeff)]
405424

@@ -411,15 +430,19 @@ def get_cwf_coeffs(
411430
coeff_k = coeffs[ctf_idx == k]
412431
ctf_fb_k = ctf_fb[k]
413432
ctf_fb_k_t = ctf_fb_k.T
414-
sig_covar_coeff = ctf_fb_k @ covar_coeff @ ctf_fb_k_t
415-
416-
sig_noise_covar_coeff = sig_covar_coeff + noise_covar_coeff
417433

418434
mean_coeff_k = ctf_fb_k.apply(mean_coeff)
419-
420435
coeff_est_k = coeff_k - mean_coeff_k
421-
coeff_est_k = sig_noise_covar_coeff.solve(coeff_est_k.T).T
422-
coeff_est_k = (covar_coeff @ ctf_fb_k_t).apply(coeff_est_k.T).T
436+
437+
if noise_var == 0:
438+
coeff_est_k = ctf_fb_k.solve(coeff_est_k.T).T
439+
else:
440+
sig_covar_coeff = ctf_fb_k @ covar_coeff @ ctf_fb_k_t
441+
sig_noise_covar_coeff = sig_covar_coeff + noise_covar_coeff
442+
443+
coeff_est_k = sig_noise_covar_coeff.solve(coeff_est_k.T).T
444+
coeff_est_k = (covar_coeff @ ctf_fb_k_t).apply(coeff_est_k.T).T
445+
423446
coeff_est_k = coeff_est_k + mean_coeff
424447
coeffs_est[ctf_idx == k] = coeff_est_k
425448

@@ -586,7 +609,7 @@ def _mean_correct_covar_rhs(self, b_covar, b_mean, mean_coeff):
586609
return b_covar
587610

588611
def _noise_correct_covar_rhs(self, b_covar, b_noise, noise_var, shrinker):
589-
if shrinker == "None":
612+
if shrinker is None:
590613
b_noise = -noise_var * b_noise
591614
b_covar += b_noise
592615
else:
@@ -652,7 +675,7 @@ def get_mean(self):
652675
return mean_coeff
653676

654677
def get_covar(
655-
self, noise_var=1, mean_coeff=None, covar_est_opt=None, make_psd=True
678+
self, noise_var=0, mean_coeff=None, covar_est_opt=None, make_psd=True
656679
):
657680
"""
658681
Calculate the block diagonal covariance matrix in the basis
@@ -690,7 +713,7 @@ def identity(x):
690713
return x
691714

692715
default_est_opt = {
693-
"shrinker": "None",
716+
"shrinker": None,
694717
"verbose": 0,
695718
"max_iter": 250,
696719
"iter_callback": [],
@@ -740,7 +763,7 @@ def identity(x):
740763
return covar_coeff
741764

742765
def get_cwf_coeffs(
743-
self, coeffs, ctf_fb, ctf_idx, mean_coeff, covar_coeff, noise_var=1
766+
self, coeffs, ctf_fb, ctf_idx, mean_coeff, covar_coeff, noise_var=0
744767
):
745768
"""
746769
Estimate the expansion coefficients using the Covariance Wiener Filtering (CWF) method.
@@ -757,13 +780,22 @@ def get_cwf_coeffs(
757780
These are obtained using a Wiener filter with the specified covariance for the clean images
758781
and white noise of variance `noise_var` for the noise.
759782
"""
783+
760784
if mean_coeff is None:
761785
mean_coeff = self.get_mean()
762786

763787
if covar_coeff is None:
764788
covar_coeff = self.get_covar(noise_var=noise_var, mean_coeff=mean_coeff)
765789

766-
if (ctf_fb is None) or (ctf_idx is None):
790+
# Handle CTF arguments.
791+
if (ctf_fb is None) ^ (ctf_idx is None):
792+
raise RuntimeError(
793+
"Both `ctf_fb` and `ctf_idx` should be provided,"
794+
" or both should be `None`."
795+
f' Given {"ctf_fb" if ctf_idx is None else "ctf_idx"}'
796+
)
797+
elif ctf_fb is None:
798+
# Setup defaults for CTF
767799
ctf_idx = np.zeros(coeffs.shape[0], dtype=int)
768800
ctf_fb = [BlkDiagMatrix.eye_like(covar_coeff)]
769801

@@ -775,14 +807,19 @@ def get_cwf_coeffs(
775807
coeff_k = coeffs[ctf_idx == k]
776808
ctf_fb_k = ctf_fb[k]
777809
ctf_fb_k_t = ctf_fb_k.T
778-
sig_covar_coeff = ctf_fb_k @ covar_coeff @ ctf_fb_k_t
779-
sig_noise_covar_coeff = sig_covar_coeff + noise_covar_coeff
780810

781811
mean_coeff_k = ctf_fb_k.apply(mean_coeff)
782-
783812
coeff_est_k = coeff_k - mean_coeff_k
784-
coeff_est_k = sig_noise_covar_coeff.solve(coeff_est_k.T).T
785-
coeff_est_k = (covar_coeff @ ctf_fb_k_t).apply(coeff_est_k.T).T
813+
814+
if noise_var == 0:
815+
coeff_est_k = ctf_fb_k.solve(coeff_est_k.T).T
816+
else:
817+
sig_covar_coeff = ctf_fb_k @ covar_coeff @ ctf_fb_k_t
818+
sig_noise_covar_coeff = sig_covar_coeff + noise_covar_coeff
819+
820+
coeff_est_k = sig_noise_covar_coeff.solve(coeff_est_k.T).T
821+
coeff_est_k = (covar_coeff @ ctf_fb_k_t).apply(coeff_est_k.T).T
822+
786823
coeff_est_k = coeff_est_k + mean_coeff
787824
coeffs_est[ctf_idx == k] = coeff_est_k
788825

tests/test_batched_covar2d.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ def setUp(self):
4141
def tearDown(self):
4242
pass
4343

44-
def blk_diag_allclose(self, blk_diag_a, blk_diag_b, atol=1e-8):
44+
def blk_diag_allclose(self, blk_diag_a, blk_diag_b, atol=None):
45+
if atol is None:
46+
atol = utest_tolerance(self.dtype)
47+
4548
close = True
4649
for blk_a, blk_b in zip(blk_diag_a, blk_diag_b):
4750
close = close and np.allclose(blk_a, blk_b, atol=atol)
@@ -181,3 +184,53 @@ def testCWFCoeff(self):
181184
atol=utest_tolerance(self.dtype),
182185
)
183186
)
187+
188+
def testCWFCoeffCleanCTF(self):
189+
"""
190+
Test case of clean images (coeff_clean and noise_var=0)
191+
while using a non Identity CTF.
192+
193+
This case may come up when a developer switches between
194+
clean and dirty images.
195+
"""
196+
197+
# Calculate CWF coefficients using Cov2D base class
198+
mean_cov2d = self.cov2d.get_mean(
199+
self.coeff, ctf_fb=self.ctf_fb, ctf_idx=self.ctf_idx
200+
)
201+
covar_cov2d = self.cov2d.get_covar(
202+
self.coeff,
203+
ctf_fb=self.ctf_fb,
204+
ctf_idx=self.ctf_idx,
205+
noise_var=self.noise_var,
206+
make_psd=True,
207+
)
208+
209+
coeff_cov2d = self.cov2d.get_cwf_coeffs(
210+
self.coeff,
211+
self.ctf_fb,
212+
self.ctf_idx,
213+
mean_coeff=mean_cov2d,
214+
covar_coeff=covar_cov2d,
215+
noise_var=0,
216+
)
217+
218+
# Calculate CWF coefficients using Batched Cov2D class
219+
mean_bcov2d = self.bcov2d.get_mean()
220+
covar_bcov2d = self.bcov2d.get_covar(noise_var=self.noise_var, make_psd=True)
221+
222+
coeff_bcov2d = self.bcov2d.get_cwf_coeffs(
223+
self.coeff,
224+
self.ctf_fb,
225+
self.ctf_idx,
226+
mean_bcov2d,
227+
covar_bcov2d,
228+
noise_var=0,
229+
)
230+
self.assertTrue(
231+
self.blk_diag_allclose(
232+
coeff_cov2d,
233+
coeff_bcov2d,
234+
atol=utest_tolerance(self.dtype),
235+
)
236+
)

0 commit comments

Comments
 (0)