-
Notifications
You must be signed in to change notification settings - Fork 26
Logical rather than hardcoded downsample tests #567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4234081
ef6b837
b752286
412ca28
1c61523
213324d
695b17e
e616592
142555d
5f45a26
f85e179
f2bf0c1
887f1e6
555cd90
31e09db
faf608b
25b5914
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| import unittest | ||
| from unittest import TestCase | ||
|
|
||
| import numpy as np | ||
|
|
||
| from aspire.source import Simulation | ||
| from aspire.utils import utest_tolerance | ||
| from aspire.utils.matrix import anorm | ||
| from aspire.utils.misc import gaussian_3d | ||
| from aspire.volume import Volume | ||
|
|
||
|
|
||
| class DownsampleTestCase(TestCase): | ||
| def setUp(self): | ||
| self.n = 128 | ||
| self.dtype = np.float32 | ||
|
|
||
| def tearDown(self): | ||
| pass | ||
|
|
||
| def _testDownsample2DCase(self, L, L_ds): | ||
| # downsampling from size L to L_ds | ||
| imgs_org, imgs_ds = self.createImages(L, L_ds) | ||
| # check resolution is correct | ||
| self.assertEqual((self.n, L_ds, L_ds), imgs_ds.shape) | ||
| # check center points for all images | ||
| self.assertTrue(self.checkCenterPoint(imgs_org, imgs_ds)) | ||
| # check signal energy is conserved | ||
| self.assertTrue(self.checkSignalEnergy(imgs_org, imgs_ds)) | ||
|
|
||
| def testDownsample2D_EvenEven(self): | ||
| # source resolution: 64 | ||
| # target resolution: 32 | ||
| self._testDownsample2DCase(64, 32) | ||
|
|
||
| @unittest.skip( | ||
| "Signal energy test fails for this case in current DS implementation" | ||
| ) | ||
| def testDownsample2D_EvenOdd(self): | ||
| # source resolution: 64 | ||
| # target resolution: 33 | ||
| self._testDownsample2DCase(64, 33) | ||
|
|
||
| def testDownsample2D_OddOdd(self): | ||
| # source resolution: 65 | ||
| # target resolution: 33 | ||
| self._testDownsample2DCase(65, 33) | ||
|
|
||
| def testDownsample2D_OddEven(self): | ||
| # source resolution: 65 | ||
| # target resolution: 32 | ||
| self._testDownsample2DCase(65, 32) | ||
|
|
||
| def checkCenterPoint(self, imgs_org, imgs_ds): | ||
| # Check that center point is the same after ds | ||
| L = imgs_org.res | ||
| max_resolution = imgs_ds.res | ||
| return np.allclose( | ||
| imgs_org[:, L // 2, L // 2], | ||
| imgs_ds[:, max_resolution // 2, max_resolution // 2], | ||
| atol=utest_tolerance(self.dtype), | ||
| ) | ||
|
|
||
| def checkSignalEnergy(self, imgs_org, imgs_ds): | ||
| # check conservation of energy after downsample | ||
| L = imgs_org.res | ||
| max_resolution = imgs_ds.res | ||
| return np.allclose( | ||
| anorm(imgs_org.asnumpy(), axes=(1, 2)) / L, | ||
| anorm(imgs_ds.asnumpy(), axes=(1, 2)) / max_resolution, | ||
| atol=utest_tolerance(self.dtype), | ||
| ) | ||
|
|
||
| def createImages(self, L, L_ds): | ||
| # generate a 3D Gaussian volume | ||
| sigma = 0.1 | ||
| vol = gaussian_3d(L, sigma=(L * sigma,) * 3, dtype=self.dtype) | ||
| # initialize a Simulation object to generate projections of the volume | ||
| sim = Simulation( | ||
| L, self.n, vols=Volume(vol), offsets=0.0, amplitudes=1.0, dtype=self.dtype | ||
| ) | ||
|
|
||
| # get images before downsample | ||
| imgs_org = sim.images(start=0, num=self.n) | ||
|
|
||
| # get images after downsample | ||
| sim.downsample(L_ds) | ||
| imgs_ds = sim.images(start=0, num=self.n) | ||
|
|
||
| return imgs_org, imgs_ds |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
| from pytest import raises | ||
|
|
||
| from aspire.utils import Rotation, grid_3d, powerset | ||
| from aspire.utils.matrix import anorm | ||
| from aspire.utils.types import utest_tolerance | ||
| from aspire.volume import Volume, gaussian_blob_vols | ||
|
|
||
|
|
@@ -313,11 +314,25 @@ def testFlip(self): | |
| self.assertTrue(isinstance(result, Volume)) | ||
|
|
||
| def testDownsample(self): | ||
| # Data files re-used from test_preprocess | ||
| vols = Volume(np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy"))) | ||
|
|
||
| resv = Volume(np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol_down8.npy"))) | ||
|
|
||
| result = vols.downsample((8, 8, 8)) | ||
| self.assertTrue(np.allclose(result, resv)) | ||
| self.assertTrue(isinstance(result, Volume)) | ||
| res = vols.resolution | ||
| ds_res = result.resolution | ||
|
|
||
| # check signal energy | ||
| self.assertTrue( | ||
| np.allclose( | ||
| anorm(vols.asnumpy(), axes=(1, 2, 3)) / res, | ||
| anorm(result.asnumpy(), axes=(1, 2, 3)) / ds_res, | ||
| atol=1e-3, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Absolute tolerance level for
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we use smooth Gaussian instead, we should get better energy conservation since we have much less energy in the high frequencies. It may make sense to do this (i.e., generate volumes using |
||
| ) | ||
| ) | ||
|
|
||
| # check gridpoints | ||
| self.assertTrue( | ||
| np.allclose( | ||
| vols[:, res // 2, res // 2, res // 2], | ||
| result[:, ds_res // 2, ds_res // 2, ds_res // 2], | ||
| atol=1e-4, | ||
| ) | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.