Skip to content

Commit 8d2ea65

Browse files
authored
Merge pull request #686 from ComputationalCryoEM/batched_cov2d_noctf
Batched cov2d noctf patch
2 parents 0b21be4 + 47db6b9 commit 8d2ea65

File tree

3 files changed

+168
-106
lines changed

3 files changed

+168
-106
lines changed

src/aspire/covariance/covar2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def _build(self):
502502

503503
self.basis = FFBBasis2D((src.L, src.L), dtype=self.dtype)
504504

505-
if src.unique_filters is None:
505+
if not src.unique_filters:
506506
logger.info("CTF filters are not included in Cov2D denoising")
507507
# set all CTF filters to an identity filter
508508
self.ctf_idx = np.zeros(src.n, dtype=int)

tests/test_batched_covar2d.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,17 @@
1010

1111

1212
class BatchedRotCov2DTestCase(TestCase):
13+
"""
14+
Tests batched cov2d without providing any CTF filters.
15+
"""
16+
17+
filters = None
18+
ctf_idx = None
19+
ctf_fb = None
20+
1321
def setUp(self):
1422
n = 32
1523
L = 8
16-
filters = [
17-
RadialCTFFilter(5, 200, defocus=d, Cs=2.0, alpha=0.1)
18-
for d in np.linspace(1.5e4, 2.5e4, 7)
19-
]
2024
self.dtype = np.float32
2125
self.noise_var = 0.1848
2226

@@ -27,14 +31,15 @@ def setUp(self):
2731
noise_filter = ScalarFilter(dim=2, value=self.noise_var * 0.001)
2832

2933
self.src = Simulation(
30-
L, n, unique_filters=filters, dtype=self.dtype, noise_filter=noise_filter
34+
L,
35+
n,
36+
unique_filters=self.filters,
37+
dtype=self.dtype,
38+
noise_filter=noise_filter,
3139
)
3240
self.basis = FFBBasis2D((L, L), dtype=self.dtype)
3341
self.coeff = self.basis.evaluate_t(self.src.images(0, self.src.n))
3442

35-
self.ctf_idx = self.src.filter_indices
36-
self.ctf_fb = [f.fb_mat(self.basis) for f in self.src.unique_filters]
37-
3843
self.cov2d = RotCov2D(self.basis)
3944
self.bcov2d = BatchedRotCov2D(self.src, self.basis, batch_size=7)
4045

@@ -234,3 +239,24 @@ def testCWFCoeffCleanCTF(self):
234239
atol=utest_tolerance(self.dtype),
235240
)
236241
)
242+
243+
244+
class BatchedRotCov2DTestCaseCTF(BatchedRotCov2DTestCase):
245+
"""
246+
Tests batched cov2d with CTF information.
247+
"""
248+
249+
@property
250+
def filters(self):
251+
return [
252+
RadialCTFFilter(5, 200, defocus=d, Cs=2.0, alpha=0.1)
253+
for d in np.linspace(1.5e4, 2.5e4, 7)
254+
]
255+
256+
@property
257+
def ctf_idx(self):
258+
return self.src.filter_indices
259+
260+
@property
261+
def ctf_fb(self):
262+
return [f.fb_mat(self.basis) for f in self.src.unique_filters]

tests/test_covar2d.py

Lines changed: 133 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -17,41 +17,39 @@
1717

1818

1919
class Cov2DTestCase(TestCase):
20+
"""
21+
Cov2D Test without CTFFilters populated.
22+
"""
23+
24+
unique_filters = None
25+
h_idx = None
26+
h_ctf_fb = None
27+
2028
# These class variables support parameterized arg checking in `testShrinkers`
2129
shrinkers = [(None,), "frobenius_norm", "operator_norm", "soft_threshold"]
2230
bad_shrinker_inputs = ["None", "notashrinker", ""]
2331

2432
def setUp(self):
2533
self.dtype = np.float32
2634

27-
L = 8
35+
self.L = L = 8
2836
n = 32
29-
pixel_size = 5.0 * 65 / L
30-
voltage = 200
31-
defocus_min = 1.5e4
32-
defocus_max = 2.5e4
33-
defocus_ct = 7
3437

3538
self.noise_var = 1.3957e-4
3639
noise_filter = ScalarFilter(dim=2, value=self.noise_var)
3740

38-
unique_filters = [
39-
RadialCTFFilter(pixel_size, voltage, defocus=d, Cs=2.0, alpha=0.1)
40-
for d in np.linspace(defocus_min, defocus_max, defocus_ct)
41-
]
42-
4341
vols = Volume(
4442
np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")).astype(
4543
self.dtype
4644
)
4745
) # RCOPT
4846
vols = vols.downsample((L * np.ones(3, dtype=int))) * 1.0e3
4947
# Since FFBBasis2D doesn't yet implement dtype, we'll set this to double to match its built in types.
50-
sim = Simulation(
48+
self.sim = Simulation(
5149
n=n,
5250
L=L,
5351
vols=vols,
54-
unique_filters=unique_filters,
52+
unique_filters=self.unique_filters,
5553
offsets=0.0,
5654
amplitudes=1.0,
5755
dtype=self.dtype,
@@ -60,12 +58,9 @@ def setUp(self):
6058

6159
self.basis = FFBBasis2D((L, L), dtype=self.dtype)
6260

63-
self.h_idx = sim.filter_indices
64-
self.h_ctf_fb = [filt.fb_mat(self.basis) for filt in unique_filters]
65-
66-
self.imgs_clean = sim.projections()
67-
self.imgs_ctf_clean = sim.clean_images()
68-
self.imgs_ctf_noise = sim.images(start=0, num=n)
61+
self.imgs_clean = self.sim.projections()
62+
self.imgs_ctf_clean = self.sim.clean_images()
63+
self.imgs_ctf_noise = self.sim.images(start=0, num=n)
6964

7065
self.cov2d = RotCov2D(self.basis)
7166
self.coeff_clean = self.basis.evaluate_t(self.imgs_clean)
@@ -76,89 +71,34 @@ def tearDown(self):
7671

7772
def testGetMean(self):
7873
results = np.load(os.path.join(DATA_DIR, "clean70SRibosome_cov2d_mean.npy"))
79-
self.mean_coeff = self.cov2d._get_mean(self.coeff_clean)
80-
self.assertTrue(np.allclose(results, self.mean_coeff))
74+
mean_coeff = self.cov2d._get_mean(self.coeff_clean)
75+
self.assertTrue(np.allclose(results, mean_coeff))
8176

8277
def testGetCovar(self):
8378
results = np.load(
8479
os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covar.npy"),
8580
allow_pickle=True,
8681
)
87-
self.covar_coeff = self.cov2d._get_covar(self.coeff_clean)
82+
covar_coeff = self.cov2d._get_covar(self.coeff_clean)
8883

8984
for im, mat in enumerate(results.tolist()):
90-
self.assertTrue(np.allclose(mat, self.covar_coeff[im]))
85+
self.assertTrue(np.allclose(mat, covar_coeff[im]))
9186

9287
def testGetMeanCTF(self):
93-
results = np.load(os.path.join(DATA_DIR, "clean70SRibosome_cov2d_meanctf.npy"))
94-
self.mean_coeff_ctf = self.cov2d.get_mean(self.coeff, self.h_ctf_fb, self.h_idx)
95-
self.assertTrue(np.allclose(results, self.mean_coeff_ctf))
96-
97-
def testGetCovarCTF(self):
98-
results = np.load(
99-
os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covarctf.npy"),
100-
allow_pickle=True,
101-
)
102-
self.covar_coeff_ctf = self.cov2d.get_covar(
103-
self.coeff, self.h_ctf_fb, self.h_idx, noise_var=self.noise_var
104-
)
105-
for im, mat in enumerate(results.tolist()):
106-
self.assertTrue(np.allclose(mat, self.covar_coeff_ctf[im]))
107-
108-
def testGetCovarCTFShrink(self):
109-
results = np.load(
110-
os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covarctf_shrink.npy"),
111-
allow_pickle=True,
112-
)
113-
covar_opt = {
114-
"shrinker": "frobenius_norm",
115-
"verbose": 0,
116-
"max_iter": 250,
117-
"iter_callback": [],
118-
"store_iterates": False,
119-
"rel_tolerance": 1e-12,
120-
"precision": self.dtype,
121-
}
122-
self.covar_coeff_ctf_shrink = self.cov2d.get_covar(
123-
self.coeff,
124-
self.h_ctf_fb,
125-
self.h_idx,
126-
noise_var=self.noise_var,
127-
covar_est_opt=covar_opt,
128-
)
129-
130-
for im, mat in enumerate(results.tolist()):
131-
self.assertTrue(np.allclose(mat, self.covar_coeff_ctf_shrink[im]))
132-
133-
def testGetCWFCoeffs(self):
134-
results = np.load(
135-
os.path.join(DATA_DIR, "clean70SRibosome_cov2d_cwf_coeff.npy")
136-
)
137-
self.coeff_cwf = self.cov2d.get_cwf_coeffs(
138-
self.coeff, self.h_ctf_fb, self.h_idx, noise_var=self.noise_var
139-
)
140-
self.assertTrue(
141-
np.allclose(results, self.coeff_cwf, atol=utest_tolerance(self.dtype))
142-
)
143-
144-
def testGetCWFCoeffsIdentityCTF(self):
145-
results = np.load(
146-
os.path.join(DATA_DIR, "clean70SRibosome_cov2d_cwf_coeff_noCTF.npy")
147-
)
148-
self.coeff_cwf_noCTF = self.cov2d.get_cwf_coeffs(
149-
self.coeff, noise_var=self.noise_var
150-
)
151-
self.assertTrue(
152-
np.allclose(results, self.coeff_cwf_noCTF, atol=utest_tolerance(self.dtype))
153-
)
88+
"""
89+
Compare `get_mean` (no CTF args) with `_get_mean` (no CTF model).
90+
"""
91+
mean_coeff_ctf = self.cov2d.get_mean(self.coeff, self.h_ctf_fb, self.h_idx)
92+
mean_coeff = self.cov2d._get_mean(self.coeff_clean)
93+
self.assertTrue(np.allclose(mean_coeff_ctf, mean_coeff, atol=0.002))
15494

15595
def testGetCWFCoeffsClean(self):
15696
results = np.load(
15797
os.path.join(DATA_DIR, "clean70SRibosome_cov2d_cwf_coeff_clean.npy")
15898
)
159-
self.coeff_cwf_clean = self.cov2d.get_cwf_coeffs(self.coeff_clean, noise_var=0)
99+
coeff_cwf_clean = self.cov2d.get_cwf_coeffs(self.coeff_clean, noise_var=0)
160100
self.assertTrue(
161-
np.allclose(results, self.coeff_cwf_clean, atol=utest_tolerance(self.dtype))
101+
np.allclose(results, coeff_cwf_clean, atol=utest_tolerance(self.dtype))
162102
)
163103

164104
def testGetCWFCoeffsCleanCTF(self):
@@ -180,17 +120,6 @@ def testGetCWFCoeffsCleanCTF(self):
180120
delta = np.mean(np.square((self.imgs_clean - img_est).asnumpy()))
181121
self.assertTrue(delta < 0.02)
182122

183-
def testGetCWFCoeffsCTFargs(self):
184-
"""
185-
Test we raise when user supplies incorrect CTF arguments,
186-
and that the error message matches.
187-
"""
188-
189-
with raises(RuntimeError, match=r".*Given ctf_fb.*"):
190-
_ = self.cov2d.get_cwf_coeffs(
191-
self.coeff, self.h_ctf_fb, None, noise_var=self.noise_var
192-
)
193-
194123
# Note, parameterized module can be removed at a later date
195124
# and replaced with pytest if ASPIRE-Python moves away from
196125
# the TestCase class style tests.
@@ -220,3 +149,110 @@ def testShrinkers(self, shrinker):
220149
self.assertTrue(
221150
np.allclose(mat, covar_coeff[im], atol=utest_tolerance(self.dtype))
222151
)
152+
153+
154+
class Cov2DTestCaseCTF(Cov2DTestCase):
155+
"""
156+
Cov2D Test with CTFFilters populated.
157+
"""
158+
159+
@property
160+
def unique_filters(self):
161+
return [
162+
RadialCTFFilter(5.0 * 65 / self.L, 200, defocus=d, Cs=2.0, alpha=0.1)
163+
for d in np.linspace(1.5e4, 2.5e4, 7)
164+
]
165+
166+
@property
167+
def h_idx(self):
168+
return self.sim.filter_indices
169+
170+
@property
171+
def h_ctf_fb(self):
172+
return [filt.fb_mat(self.basis) for filt in self.unique_filters]
173+
174+
def testGetCWFCoeffsCTFargs(self):
175+
"""
176+
Test we raise when user supplies incorrect CTF arguments,
177+
and that the error message matches.
178+
"""
179+
180+
with raises(RuntimeError, match=r".*Given ctf_fb.*"):
181+
_ = self.cov2d.get_cwf_coeffs(
182+
self.coeff, self.h_ctf_fb, None, noise_var=self.noise_var
183+
)
184+
185+
def testGetMeanCTF(self):
186+
"""
187+
Compare `get_mean` with saved legacy cov2d results.
188+
"""
189+
results = np.load(os.path.join(DATA_DIR, "clean70SRibosome_cov2d_meanctf.npy"))
190+
mean_coeff_ctf = self.cov2d.get_mean(self.coeff, self.h_ctf_fb, self.h_idx)
191+
self.assertTrue(np.allclose(results, mean_coeff_ctf))
192+
193+
def testGetCWFCoeffs(self):
194+
"""
195+
Tests `get_cwf_coeffs` with poulated CTF.
196+
"""
197+
results = np.load(
198+
os.path.join(DATA_DIR, "clean70SRibosome_cov2d_cwf_coeff.npy")
199+
)
200+
coeff_cwf = self.cov2d.get_cwf_coeffs(
201+
self.coeff, self.h_ctf_fb, self.h_idx, noise_var=self.noise_var
202+
)
203+
self.assertTrue(
204+
np.allclose(results, coeff_cwf, atol=utest_tolerance(self.dtype))
205+
)
206+
207+
# Note, I think this file is incorrectly named...
208+
# It appears to have come from operations on images with ctf applied.
209+
def testGetCWFCoeffsNoCTF(self):
210+
"""
211+
Tests `get_cwf_coeffs` without providing CTF. (Internally uses IdentityCTF).
212+
"""
213+
results = np.load(
214+
os.path.join(DATA_DIR, "clean70SRibosome_cov2d_cwf_coeff_noCTF.npy")
215+
)
216+
coeff_cwf_noCTF = self.cov2d.get_cwf_coeffs(
217+
self.coeff, noise_var=self.noise_var
218+
)
219+
220+
self.assertTrue(
221+
np.allclose(results, coeff_cwf_noCTF, atol=utest_tolerance(self.dtype))
222+
)
223+
224+
def testGetCovarCTF(self):
225+
results = np.load(
226+
os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covarctf.npy"),
227+
allow_pickle=True,
228+
)
229+
covar_coeff_ctf = self.cov2d.get_covar(
230+
self.coeff, self.h_ctf_fb, self.h_idx, noise_var=self.noise_var
231+
)
232+
for im, mat in enumerate(results.tolist()):
233+
self.assertTrue(np.allclose(mat, covar_coeff_ctf[im]))
234+
235+
def testGetCovarCTFShrink(self):
236+
results = np.load(
237+
os.path.join(DATA_DIR, "clean70SRibosome_cov2d_covarctf_shrink.npy"),
238+
allow_pickle=True,
239+
)
240+
covar_opt = {
241+
"shrinker": "frobenius_norm",
242+
"verbose": 0,
243+
"max_iter": 250,
244+
"iter_callback": [],
245+
"store_iterates": False,
246+
"rel_tolerance": 1e-12,
247+
"precision": self.dtype,
248+
}
249+
covar_coeff_ctf_shrink = self.cov2d.get_covar(
250+
self.coeff,
251+
self.h_ctf_fb,
252+
self.h_idx,
253+
noise_var=self.noise_var,
254+
covar_est_opt=covar_opt,
255+
)
256+
257+
for im, mat in enumerate(results.tolist()):
258+
self.assertTrue(np.allclose(mat, covar_coeff_ctf_shrink[im]))

0 commit comments

Comments
 (0)