Skip to content

Commit 645fe24

Browse files
authored
Merge pull request #325 from ComputationalCryoEM/denoiser_test
Unit tests for Cov2d Denoiser
2 parents 6e9ade4 + dda2697 commit 645fe24

File tree

3 files changed

+50
-7
lines changed

3 files changed

+50
-7
lines changed

src/aspire/denoising/denoised_src.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@ def _images(self, start=0, num=np.inf, indices=None, batch_size=512):
4040
start = indices.min()
4141
end = indices.max()
4242

43-
im = np.empty((self.L, self.L, len(indices)))
43+
nimgs = len(indices)
44+
im = np.empty((nimgs, self.L, self.L))
4445

45-
logger.info(f"Loading {len(indices)} images complete")
46-
for istart in range(start, end, batch_size):
46+
logger.info(f"Loading {nimgs} images complete")
47+
for istart in range(start, end + 1, batch_size):
4748
imgs_denoised = self.denoiser.images(istart, batch_size)
48-
im = imgs_denoised.data
49+
iend = min(istart + batch_size, end + 1)
50+
im[istart:iend] = imgs_denoised.data
4951

5052
return Image(im)

src/aspire/denoising/denoiser_cov2d.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,7 @@ def denoise(self, covar_opt=None, batch_size=512):
130130

131131
# Initialize the rotationally invariant covariance matrix of 2D images
132132
# A fixed batch size is used to go through each image
133-
self.cov2d = BatchedRotCov2D(
134-
self.src, self.basis, batch_size=batch_size, dtype=self.dtype
135-
)
133+
self.cov2d = BatchedRotCov2D(self.src, self.basis, batch_size=batch_size)
136134

137135
default_opt = {
138136
"shrinker": "frobenius_norm",

tests/test_covar2d_denoiser.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from unittest import TestCase
2+
3+
import numpy as np
4+
5+
from aspire.basis.ffb_2d import FFBBasis2D
6+
from aspire.denoising.denoiser_cov2d import DenoiserCov2D
7+
from aspire.operators.filters import RadialCTFFilter, ScalarFilter
8+
from aspire.source.simulation import Simulation
9+
10+
11+
class BatchedRotCov2DTestCase(TestCase):
12+
def testMSE(self):
13+
# need larger numbers of images and higher resolution for good MSE
14+
dtype = np.float32
15+
img_size = 64
16+
num_imgs = 1024
17+
noise_var = 0.1848
18+
noise_filter = ScalarFilter(dim=2, value=noise_var)
19+
filters = [
20+
RadialCTFFilter(5, 200, defocus=d, Cs=2.0, alpha=0.1)
21+
for d in np.linspace(1.5e4, 2.5e4, 7)
22+
]
23+
# set simulation object
24+
sim = Simulation(
25+
L=img_size,
26+
n=num_imgs,
27+
unique_filters=filters,
28+
offsets=0.0,
29+
amplitudes=1.0,
30+
dtype=dtype,
31+
noise_filter=noise_filter,
32+
)
33+
imgs_clean = sim.projections()
34+
35+
# Specify the fast FB basis method for expending the 2D images
36+
ffbbasis = FFBBasis2D((img_size, img_size), dtype=dtype)
37+
denoiser = DenoiserCov2D(sim, ffbbasis, noise_var)
38+
denoised_src = denoiser.denoise(batch_size=64)
39+
imgs_denoised = denoised_src.images(0, num_imgs)
40+
# Calculate the normalized RMSE of the estimated images.
41+
nrmse_ims = (imgs_denoised - imgs_clean).norm() / imgs_clean.norm()
42+
43+
self.assertTrue(nrmse_ims < 0.25)

0 commit comments

Comments
 (0)