1212logger = logging .getLogger (__name__ )
1313
1414
15- def shrink_covar (covar_in , noise_var , gamma , shrinker = None ):
15+ def shrink_covar (covar , noise_var , gamma , shrinker = "frobenius_norm" ):
1616 """
1717 Shrink the covariance matrix
1818 :param covar_in: An input covariance matrix
@@ -22,18 +22,14 @@ def shrink_covar(covar_in, noise_var, gamma, shrinker=None):
2222 :return: The shrinked covariance matrix
2323 """
2424
25- if shrinker is None :
26- shrinker = "frobenius_norm"
2725 ensure (
2826 shrinker in ("frobenius_norm" , "operator_norm" , "soft_threshold" ),
2927 "Unsupported shrink method" ,
3028 )
3129
32- covar = covar_in / noise_var
33-
3430 lambs , eig_vec = eig (make_symmat (covar ))
3531
36- lambda_max = (1 + np .sqrt (gamma )) ** 2
32+ lambda_max = noise_var * (1 + np .sqrt (gamma )) ** 2
3733
3834 lambs [lambs < lambda_max ] = 0
3935
@@ -42,20 +38,36 @@ def shrink_covar(covar_in, noise_var, gamma, shrinker=None):
4238 lambdas = (
4339 1
4440 / 2
45- * (lambdas - gamma + 1 + np .sqrt ((lambdas - gamma + 1 ) ** 2 - 4 * lambdas ))
46- - 1
41+ * (
42+ lambdas
43+ - noise_var * (gamma - 1 )
44+ + np .sqrt (
45+ (lambdas - noise_var * (gamma - 1 )) ** 2
46+ - 4 * noise_var ** 2 * lambdas
47+ )
48+ )
49+ - noise_var
4750 )
51+
4852 lambs [lambs > lambda_max ] = lambdas
4953 elif shrinker == "frobenius_norm" :
5054 lambdas = lambs [lambs > lambda_max ]
5155 lambdas = (
5256 1
5357 / 2
54- * (lambdas - gamma + 1 + np .sqrt ((lambdas - gamma + 1 ) ** 2 - 4 * lambdas ))
55- - 1
58+ * (
59+ lambdas
60+ - noise_var * (gamma - 1 )
61+ + np .sqrt (
62+ (lambdas - noise_var * (gamma - 1 )) ** 2
63+ - 4 * noise_var ** 2 * lambdas
64+ )
65+ )
66+ - noise_var
5667 )
5768 c = np .divide (
58- (1 - np .divide (gamma , lambdas ** 2 )), (1 + np .divide (gamma , lambdas ))
69+ (1 - np .divide (noise_var ** 2 * gamma , lambdas ** 2 )),
70+ (1 + np .divide (noise_var * gamma , lambdas )),
5971 )
6072 lambdas = lambdas * c
6173 lambs [lambs > lambda_max ] = lambdas
@@ -68,7 +80,6 @@ def shrink_covar(covar_in, noise_var, gamma, shrinker=None):
6880 np .fill_diagonal (diag_lambs , lambs )
6981
7082 shrinked_covar = eig_vec @ diag_lambs @ eig_vec .conj ().T
71- shrinked_covar *= noise_var
7283
7384 return shrinked_covar
7485
@@ -206,7 +217,7 @@ def get_covar(
206217 ctf_idx = None ,
207218 mean_coeff = None ,
208219 do_refl = True ,
209- noise_var = 1 ,
220+ noise_var = 0 ,
210221 covar_est_opt = None ,
211222 make_psd = True ,
212223 ):
@@ -242,7 +253,7 @@ def identity(x):
242253 return x
243254
244255 default_est_opt = {
245- "shrinker" : " None" ,
256+ "shrinker" : None ,
246257 "verbose" : 0 ,
247258 "max_iter" : 250 ,
248259 "iter_callback" : [],
@@ -286,7 +297,7 @@ def identity(x):
286297 if not b_coeff .check_psd ():
287298 logger .warning ("Left side b in Cov2D is not positive semidefinite." )
288299
289- if covar_est_opt ["shrinker" ] == " None" :
300+ if covar_est_opt ["shrinker" ] is None :
290301 b = b_coeff - noise_var * b_noise
291302 else :
292303 b = self .shrink_covar_backward (
@@ -373,7 +384,7 @@ def get_cwf_coeffs(
373384 ctf_idx = None ,
374385 mean_coeff = None ,
375386 covar_coeff = None ,
376- noise_var = 1 ,
387+ noise_var = 0 ,
377388 ):
378389 """
379390 Estimate the expansion coefficients using the Covariance Wiener Filtering (CWF) method.
@@ -390,6 +401,7 @@ def get_cwf_coeffs(
390401 These are obtained using a Wiener filter with the specified covariance for the clean images
391402 and white noise of variance `noise_var` for the noise.
392403 """
404+
393405 if mean_coeff is None :
394406 mean_coeff = self .get_mean (coeffs , ctf_fb , ctf_idx )
395407
@@ -398,8 +410,15 @@ def get_cwf_coeffs(
398410 coeffs , ctf_fb , ctf_idx , mean_coeff , noise_var = noise_var
399411 )
400412
401- # should be none or both
402- if (ctf_fb is None ) or (ctf_idx is None ):
413+ # Handle CTF arguments.
414+ if (ctf_fb is None ) ^ (ctf_idx is None ):
415+ raise RuntimeError (
416+ "Both `ctf_fb` and `ctf_idx` should be provided,"
417+ " or both should be `None`."
418+ f' Given { "ctf_fb" if ctf_idx is None else "ctf_idx" } '
419+ )
420+ elif ctf_fb is None :
421+ # Setup defaults for CTF
403422 ctf_idx = np .zeros (coeffs .shape [0 ], dtype = int )
404423 ctf_fb = [BlkDiagMatrix .eye_like (covar_coeff )]
405424
@@ -411,15 +430,19 @@ def get_cwf_coeffs(
411430 coeff_k = coeffs [ctf_idx == k ]
412431 ctf_fb_k = ctf_fb [k ]
413432 ctf_fb_k_t = ctf_fb_k .T
414- sig_covar_coeff = ctf_fb_k @ covar_coeff @ ctf_fb_k_t
415-
416- sig_noise_covar_coeff = sig_covar_coeff + noise_covar_coeff
417433
418434 mean_coeff_k = ctf_fb_k .apply (mean_coeff )
419-
420435 coeff_est_k = coeff_k - mean_coeff_k
421- coeff_est_k = sig_noise_covar_coeff .solve (coeff_est_k .T ).T
422- coeff_est_k = (covar_coeff @ ctf_fb_k_t ).apply (coeff_est_k .T ).T
436+
437+ if noise_var == 0 :
438+ coeff_est_k = ctf_fb_k .solve (coeff_est_k .T ).T
439+ else :
440+ sig_covar_coeff = ctf_fb_k @ covar_coeff @ ctf_fb_k_t
441+ sig_noise_covar_coeff = sig_covar_coeff + noise_covar_coeff
442+
443+ coeff_est_k = sig_noise_covar_coeff .solve (coeff_est_k .T ).T
444+ coeff_est_k = (covar_coeff @ ctf_fb_k_t ).apply (coeff_est_k .T ).T
445+
423446 coeff_est_k = coeff_est_k + mean_coeff
424447 coeffs_est [ctf_idx == k ] = coeff_est_k
425448
@@ -586,7 +609,7 @@ def _mean_correct_covar_rhs(self, b_covar, b_mean, mean_coeff):
586609 return b_covar
587610
588611 def _noise_correct_covar_rhs (self , b_covar , b_noise , noise_var , shrinker ):
589- if shrinker == " None" :
612+ if shrinker is None :
590613 b_noise = - noise_var * b_noise
591614 b_covar += b_noise
592615 else :
@@ -652,7 +675,7 @@ def get_mean(self):
652675 return mean_coeff
653676
654677 def get_covar (
655- self , noise_var = 1 , mean_coeff = None , covar_est_opt = None , make_psd = True
678+ self , noise_var = 0 , mean_coeff = None , covar_est_opt = None , make_psd = True
656679 ):
657680 """
658681 Calculate the block diagonal covariance matrix in the basis
@@ -690,7 +713,7 @@ def identity(x):
690713 return x
691714
692715 default_est_opt = {
693- "shrinker" : " None" ,
716+ "shrinker" : None ,
694717 "verbose" : 0 ,
695718 "max_iter" : 250 ,
696719 "iter_callback" : [],
@@ -740,7 +763,7 @@ def identity(x):
740763 return covar_coeff
741764
742765 def get_cwf_coeffs (
743- self , coeffs , ctf_fb , ctf_idx , mean_coeff , covar_coeff , noise_var = 1
766+ self , coeffs , ctf_fb , ctf_idx , mean_coeff , covar_coeff , noise_var = 0
744767 ):
745768 """
746769 Estimate the expansion coefficients using the Covariance Wiener Filtering (CWF) method.
@@ -757,13 +780,22 @@ def get_cwf_coeffs(
757780 These are obtained using a Wiener filter with the specified covariance for the clean images
758781 and white noise of variance `noise_var` for the noise.
759782 """
783+
760784 if mean_coeff is None :
761785 mean_coeff = self .get_mean ()
762786
763787 if covar_coeff is None :
764788 covar_coeff = self .get_covar (noise_var = noise_var , mean_coeff = mean_coeff )
765789
766- if (ctf_fb is None ) or (ctf_idx is None ):
790+ # Handle CTF arguments.
791+ if (ctf_fb is None ) ^ (ctf_idx is None ):
792+ raise RuntimeError (
793+ "Both `ctf_fb` and `ctf_idx` should be provided,"
794+ " or both should be `None`."
795+ f' Given { "ctf_fb" if ctf_idx is None else "ctf_idx" } '
796+ )
797+ elif ctf_fb is None :
798+ # Setup defaults for CTF
767799 ctf_idx = np .zeros (coeffs .shape [0 ], dtype = int )
768800 ctf_fb = [BlkDiagMatrix .eye_like (covar_coeff )]
769801
@@ -775,14 +807,19 @@ def get_cwf_coeffs(
775807 coeff_k = coeffs [ctf_idx == k ]
776808 ctf_fb_k = ctf_fb [k ]
777809 ctf_fb_k_t = ctf_fb_k .T
778- sig_covar_coeff = ctf_fb_k @ covar_coeff @ ctf_fb_k_t
779- sig_noise_covar_coeff = sig_covar_coeff + noise_covar_coeff
780810
781811 mean_coeff_k = ctf_fb_k .apply (mean_coeff )
782-
783812 coeff_est_k = coeff_k - mean_coeff_k
784- coeff_est_k = sig_noise_covar_coeff .solve (coeff_est_k .T ).T
785- coeff_est_k = (covar_coeff @ ctf_fb_k_t ).apply (coeff_est_k .T ).T
813+
814+ if noise_var == 0 :
815+ coeff_est_k = ctf_fb_k .solve (coeff_est_k .T ).T
816+ else :
817+ sig_covar_coeff = ctf_fb_k @ covar_coeff @ ctf_fb_k_t
818+ sig_noise_covar_coeff = sig_covar_coeff + noise_covar_coeff
819+
820+ coeff_est_k = sig_noise_covar_coeff .solve (coeff_est_k .T ).T
821+ coeff_est_k = (covar_coeff @ ctf_fb_k_t ).apply (coeff_est_k .T ).T
822+
786823 coeff_est_k = coeff_est_k + mean_coeff
787824 coeffs_est [ctf_idx == k ] = coeff_est_k
788825
0 commit comments