Skip to content

Commit 05b561f

Browse files
Logical rather than hardcoded downsample tests (#567)
* downsample tests * cleanup and deleting test downsample from preprocess_pipeline * remove hardcoded ds / whiten test * flake8 * 3d volume downsample test * remove redundant test from tests/test_preprocess.py * remove unused import * remove numbered tests for uniformity * typo * skip one of the ds tests * use gaussian_3d to generate test volume * isort * checkCenterPoint * max_resolution -> L_ds * parametrize ds tests
1 parent b000a87 commit 05b561f

File tree

5 files changed

+117
-88
lines changed

5 files changed

+117
-88
lines changed

tests/test_downsample.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import unittest
2+
from unittest import TestCase
3+
4+
import numpy as np
5+
6+
from aspire.source import Simulation
7+
from aspire.utils import utest_tolerance
8+
from aspire.utils.matrix import anorm
9+
from aspire.utils.misc import gaussian_3d
10+
from aspire.volume import Volume
11+
12+
13+
class DownsampleTestCase(TestCase):
14+
def setUp(self):
15+
self.n = 128
16+
self.dtype = np.float32
17+
18+
def tearDown(self):
19+
pass
20+
21+
def _testDownsample2DCase(self, L, L_ds):
22+
# downsampling from size L to L_ds
23+
imgs_org, imgs_ds = self.createImages(L, L_ds)
24+
# check resolution is correct
25+
self.assertEqual((self.n, L_ds, L_ds), imgs_ds.shape)
26+
# check center points for all images
27+
self.assertTrue(self.checkCenterPoint(imgs_org, imgs_ds))
28+
# check signal energy is conserved
29+
self.assertTrue(self.checkSignalEnergy(imgs_org, imgs_ds))
30+
31+
def testDownsample2D_EvenEven(self):
32+
# source resolution: 64
33+
# target resolution: 32
34+
self._testDownsample2DCase(64, 32)
35+
36+
@unittest.skip(
37+
"Signal energy test fails for this case in current DS implementation"
38+
)
39+
def testDownsample2D_EvenOdd(self):
40+
# source resolution: 64
41+
# target resolution: 33
42+
self._testDownsample2DCase(64, 33)
43+
44+
def testDownsample2D_OddOdd(self):
45+
# source resolution: 65
46+
# target resolution: 33
47+
self._testDownsample2DCase(65, 33)
48+
49+
def testDownsample2D_OddEven(self):
50+
# source resolution: 65
51+
# target resolution: 32
52+
self._testDownsample2DCase(65, 32)
53+
54+
def checkCenterPoint(self, imgs_org, imgs_ds):
55+
# Check that center point is the same after ds
56+
L = imgs_org.res
57+
max_resolution = imgs_ds.res
58+
return np.allclose(
59+
imgs_org[:, L // 2, L // 2],
60+
imgs_ds[:, max_resolution // 2, max_resolution // 2],
61+
atol=utest_tolerance(self.dtype),
62+
)
63+
64+
def checkSignalEnergy(self, imgs_org, imgs_ds):
65+
# check conservation of energy after downsample
66+
L = imgs_org.res
67+
max_resolution = imgs_ds.res
68+
return np.allclose(
69+
anorm(imgs_org.asnumpy(), axes=(1, 2)) / L,
70+
anorm(imgs_ds.asnumpy(), axes=(1, 2)) / max_resolution,
71+
atol=utest_tolerance(self.dtype),
72+
)
73+
74+
def createImages(self, L, L_ds):
75+
# generate a 3D Gaussian volume
76+
sigma = 0.1
77+
vol = gaussian_3d(L, sigma=(L * sigma,) * 3, dtype=self.dtype)
78+
# initialize a Simulation object to generate projections of the volume
79+
sim = Simulation(
80+
L, self.n, vols=Volume(vol), offsets=0.0, amplitudes=1.0, dtype=self.dtype
81+
)
82+
83+
# get images before downsample
84+
imgs_org = sim.images(start=0, num=self.n)
85+
86+
# get images after downsample
87+
sim.downsample(L_ds)
88+
imgs_ds = sim.images(start=0, num=self.n)
89+
90+
return imgs_org, imgs_ds

tests/test_preprocess.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
from scipy.fftpack import fftn, fftshift
66

7-
from aspire.image import crop_pad, downsample, fuzzy_mask
7+
from aspire.image import crop_pad, fuzzy_mask
88
from aspire.volume import Volume
99

1010
DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data")
@@ -17,30 +17,22 @@ def setUp(self):
1717
def tearDown(self):
1818
pass
1919

20-
def test01CropPad(self):
20+
def testCropPad(self):
2121
results = np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol_crop8.npy"))
2222
vols = np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy"))
2323
vols = vols[..., np.newaxis]
2424
vols_f = crop_pad(fftshift(fftn(vols[:, :, :, 0])), 8)
2525
self.assertTrue(np.allclose(results, vols_f, atol=1e-7))
2626

27-
def test02Downsample(self):
28-
results = np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol_down8.npy"))
29-
results = results[np.newaxis, ...]
30-
vols = np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy"))
31-
vols = vols[np.newaxis, ...]
32-
vols = downsample(vols, (8, 8, 8))
33-
self.assertTrue(np.allclose(results, vols, atol=1e-7))
34-
35-
def test03Vol2img(self):
27+
def testVol2img(self):
3628
results = np.load(os.path.join(DATA_DIR, "clean70SRibosome_down8_imgs32.npy"))
3729
vols = Volume(np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol_down8.npy")))
3830
rots = np.load(os.path.join(DATA_DIR, "rand_rot_matrices32.npy"))
3931
rots = np.moveaxis(rots, 2, 0)
4032
imgs_clean = vols.project(0, rots).asnumpy()
4133
self.assertTrue(np.allclose(results, imgs_clean, atol=1e-7))
4234

43-
def test04FuzzyMask(self):
35+
def testFuzzyMask(self):
4436
results = np.array(
4537
[
4638
[

tests/test_preprocess_pipeline.py

Lines changed: 2 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
import numpy as np
55

66
from aspire.noise import AnisotropicNoiseEstimator
7-
from aspire.operators.filters import FunctionFilter, RadialCTFFilter, ScalarFilter
7+
from aspire.operators.filters import FunctionFilter, RadialCTFFilter
88
from aspire.source import ArrayImageSource
99
from aspire.source.simulation import Simulation
10-
from aspire.utils import grid_2d, grid_3d, utest_tolerance
10+
from aspire.utils import grid_2d
1111
from aspire.utils.matrix import anorm
12-
from aspire.volume import Volume
1312

1413
DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data")
1514

@@ -45,58 +44,6 @@ def testPhaseFlip(self):
4544
)
4645
)
4746

48-
def testDownsample(self):
49-
# generate a 3D map with density decays as Gaussian function
50-
g3d = grid_3d(self.L, indexing="zyx", dtype=self.dtype)
51-
coords = np.array([g3d["x"].flatten(), g3d["y"].flatten(), g3d["z"].flatten()])
52-
sigma = 0.2
53-
vol = np.exp(-0.5 * np.sum(np.abs(coords / sigma) ** 2, axis=0)).astype(
54-
self.dtype
55-
)
56-
vol = np.reshape(vol, g3d["x"].shape)
57-
vols = Volume(vol)
58-
59-
# set noise to zero and CFT filters to unity for simulation object
60-
noise_var = 0
61-
noise_filter = ScalarFilter(dim=2, value=noise_var)
62-
sim = Simulation(
63-
L=self.L,
64-
n=self.n,
65-
vols=vols,
66-
offsets=0.0,
67-
amplitudes=1.0,
68-
unique_filters=[
69-
ScalarFilter(dim=2, value=1) for d in np.linspace(1.5e4, 2.5e4, 7)
70-
],
71-
noise_filter=noise_filter,
72-
dtype=self.dtype,
73-
)
74-
# get images before downsample
75-
imgs_org = sim.images(start=0, num=self.n)
76-
# get images after downsample
77-
max_resolution = 32
78-
sim.downsample(max_resolution)
79-
imgs_ds = sim.images(start=0, num=self.n)
80-
81-
# Check individual grid points
82-
self.assertTrue(
83-
np.allclose(
84-
imgs_org[:, 32, 32],
85-
imgs_ds[:, 16, 16],
86-
atol=utest_tolerance(self.dtype),
87-
)
88-
)
89-
# check resolution
90-
self.assertTrue(np.allclose(max_resolution, imgs_ds.shape[1]))
91-
# check energy conservation after downsample
92-
self.assertTrue(
93-
np.allclose(
94-
anorm(imgs_org.asnumpy(), axes=(1, 2)) / self.L,
95-
anorm(imgs_ds.asnumpy(), axes=(1, 2)) / max_resolution,
96-
atol=utest_tolerance(self.dtype),
97-
)
98-
)
99-
10047
def testNormBackground(self):
10148
bg_radius = 1.0
10249
grid = grid_2d(self.L, indexing="yx")

tests/test_starfile_stack.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import tests.saved_test_data
1010
from aspire.image import Image
11-
from aspire.operators import ScalarFilter
1211
from aspire.source.relion import RelionSource
1312

1413
DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data")
@@ -61,20 +60,6 @@ def testImageDownsample(self):
6160
first_image = self.src.images(0, 1)[0]
6261
self.assertEqual(first_image.shape, (16, 16))
6362

64-
def testImageDownsampleAndWhiten(self):
65-
self.src.downsample(16)
66-
self.src.whiten(noise_filter=ScalarFilter(dim=2, value=0.02450909546680349))
67-
first_whitened_image = self.src.images(0, 1)[0]
68-
self.assertTrue(
69-
np.allclose(
70-
first_whitened_image,
71-
np.load(
72-
os.path.join(DATA_DIR, "starfile_image_0_whitened.npy")
73-
).T, # RCOPT
74-
atol=1e-6,
75-
)
76-
)
77-
7863

7964
class StarFileSingleImage(StarFileTestCase):
8065
def setUp(self):

tests/test_volume.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytest import raises
1010

1111
from aspire.utils import Rotation, grid_3d, powerset
12+
from aspire.utils.matrix import anorm
1213
from aspire.utils.types import utest_tolerance
1314
from aspire.volume import Volume, gaussian_blob_vols
1415

@@ -313,11 +314,25 @@ def testFlip(self):
313314
self.assertTrue(isinstance(result, Volume))
314315

315316
def testDownsample(self):
316-
# Data files re-used from test_preprocess
317317
vols = Volume(np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")))
318-
319-
resv = Volume(np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol_down8.npy")))
320-
321318
result = vols.downsample((8, 8, 8))
322-
self.assertTrue(np.allclose(result, resv))
323-
self.assertTrue(isinstance(result, Volume))
319+
res = vols.resolution
320+
ds_res = result.resolution
321+
322+
# check signal energy
323+
self.assertTrue(
324+
np.allclose(
325+
anorm(vols.asnumpy(), axes=(1, 2, 3)) / res,
326+
anorm(result.asnumpy(), axes=(1, 2, 3)) / ds_res,
327+
atol=1e-3,
328+
)
329+
)
330+
331+
# check gridpoints
332+
self.assertTrue(
333+
np.allclose(
334+
vols[:, res // 2, res // 2, res // 2],
335+
result[:, ds_res // 2, ds_res // 2, ds_res // 2],
336+
atol=1e-4,
337+
)
338+
)

0 commit comments

Comments
 (0)