33import numpy as np
44
55from aspire .basis .ffb_2d import FFBBasis2D
6- from aspire .estimation .covar2d import BatchedRotCov2D
76from aspire .denoising .denoiser_cov2d import DenoiserCov2D
7+ from aspire .estimation .covar2d import BatchedRotCov2D
88from aspire .source .simulation import Simulation
99from aspire .utils .filters import RadialCTFFilter , ScalarFilter
1010
@@ -24,11 +24,14 @@ def setUp(self):
2424 defocus_max = 2.5e4
2525 defocus_ct = 7
2626
27- filters = [RadialCTFFilter (pixel_size , voltage , defocus = d , Cs = 2.0 , alpha = 0.1 )
28- for d in np .linspace (defocus_min , defocus_max , defocus_ct )]
27+ filters = [
28+ RadialCTFFilter (pixel_size , voltage , defocus = d , Cs = 2.0 , alpha = 0.1 )
29+ for d in np .linspace (defocus_min , defocus_max , defocus_ct )
30+ ]
2931
30- src = Simulation (L , n , unique_filters = filters , dtype = self .dtype ,
31- noise_filter = noise_filter )
32+ src = Simulation (
33+ L , n , unique_filters = filters , dtype = self .dtype , noise_filter = noise_filter
34+ )
3235
3336 basis = FFBBasis2D ((L , L ), dtype = self .dtype )
3437
@@ -44,15 +47,20 @@ def setUp(self):
4447 self .denoised_src = self .denoisor .denoise (batch_size = 7 )
4548 self .src = src
4649 self .basis = basis
47- self .covar_est_opt = {'shrinker' : 'frobenius_norm' , 'verbose' : 0 ,
48- 'max_iter' : 250 , 'iter_callback' : [],
49- 'store_iterates' : False , 'rel_tolerance' : 1e-12 ,
50- 'precision' : self .dtype }
50+ self .covar_est_opt = {
51+ "shrinker" : "frobenius_norm" ,
52+ "verbose" : 0 ,
53+ "max_iter" : 250 ,
54+ "iter_callback" : [],
55+ "store_iterates" : False ,
56+ "rel_tolerance" : 1e-12 ,
57+ "precision" : self .dtype ,
58+ }
5159
5260 def blk_diag_allclose (self , blk_diag_a , blk_diag_b , atol = 1e-8 ):
5361 close = True
5462 for blk_a , blk_b in zip (blk_diag_a , blk_diag_b ):
55- close = ( close and np .allclose (blk_a , blk_b , atol = atol ) )
63+ close = close and np .allclose (blk_a , blk_b , atol = atol )
5664 return close
5765
5866 def testMean (self ):
@@ -62,22 +70,31 @@ def testMean(self):
6270 self .assertTrue (np .allclose (mean_denoisor , mean_bcov2d ))
6371
6472 def testCovar (self ):
65- covar_bcov2d = self .bcov2d .get_covar (noise_var = self .noise_var ,
66- covar_est_opt = self .covar_est_opt )
73+ covar_bcov2d = self .bcov2d .get_covar (
74+ noise_var = self .noise_var , covar_est_opt = self .covar_est_opt
75+ )
6776 covar_denoisor = self .denoisor .covar_est
6877
6978 self .assertTrue (self .blk_diag_allclose (covar_denoisor , covar_bcov2d ))
7079
7180 def testCWFCeoffs (self ):
7281 mean_bcov2d = self .bcov2d .get_mean ()
73- covar_bcov2d = self .bcov2d .get_covar (noise_var = self .noise_var ,
74- covar_est_opt = self .covar_est_opt )
82+ covar_bcov2d = self .bcov2d .get_covar (
83+ noise_var = self .noise_var , covar_est_opt = self .covar_est_opt
84+ )
7585 coeffs_bcov2d = self .bcov2d .get_cwf_coeffs (
76- self .coeff , self .ctf_fb , self .ctf_idx ,
77- mean_coeff = mean_bcov2d , covar_coeff = covar_bcov2d ,
78- noise_var = self .noise_var )
86+ self .coeff ,
87+ self .ctf_fb ,
88+ self .ctf_idx ,
89+ mean_coeff = mean_bcov2d ,
90+ covar_coeff = covar_bcov2d ,
91+ noise_var = self .noise_var ,
92+ )
7993 imgs_denoised_bcov2d = self .basis .evaluate (coeffs_bcov2d )
8094 imgs_denoised_denoisor = self .denoised_src .images (0 , self .src .n )
8195
82- self .assertTrue (np .allclose (imgs_denoised_bcov2d .asnumpy (),
83- imgs_denoised_denoisor .asnumpy ()))
96+ self .assertTrue (
97+ np .allclose (
98+ imgs_denoised_bcov2d .asnumpy (), imgs_denoised_denoisor .asnumpy ()
99+ )
100+ )
0 commit comments