From 4234081cc40efb9746969641290370f5e071e85b Mon Sep 17 00:00:00 2001 From: Chris Langfield Date: Thu, 10 Feb 2022 15:24:51 -0500 Subject: [PATCH 01/15] downsample tests --- tests/test_downsample.py | 108 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 tests/test_downsample.py diff --git a/tests/test_downsample.py b/tests/test_downsample.py new file mode 100644 index 0000000000..2d80998aa1 --- /dev/null +++ b/tests/test_downsample.py @@ -0,0 +1,108 @@ +from unittest import TestCase + +import numpy as np + +from aspire.image import Image +from aspire.source import Simulation +from aspire.utils import utest_tolerance +from aspire.utils.coor_trans import grid_2d, grid_3d +from aspire.utils.matrix import anorm +from aspire.utils.misc import gaussian_1d, gaussian_2d, gaussian_3d +from aspire.volume import Volume + + +class DownsampleTestCase(TestCase): + def setUp(self): + self.n = 128 + self.dtype = np.float32 + + def tearDown(self): + pass + + def testDownsample2D_EvenEven(self): + # source resolution: 64 + # target resolution: 32 + imgs_org, imgs_ds = self.createImages(64, 32) + # check resolution is correct + self.assertEqual((self.n, 32, 32), imgs_ds.shape) + # check invidual gridpoints for all images + self.assertTrue(self.checkGridPoints(imgs_org, imgs_ds)) + # check signal energy is conserved + self.assertTrue(self.checkSignalEnergy(imgs_org, imgs_ds)) + + ### Signal energy test fails for this case in current DS implementation + def _testDownsample2D_EvenOdd(self): + # source resolution: 64 + # target resolution: 33 + imgs_org, imgs_ds = self.createImages(64, 33) + # check resolution is correct + self.assertEqual((self.n, 33, 33), imgs_ds.shape) + # check invidual gridpoints for all images + self.assertTrue(self.checkGridPoints(imgs_org, imgs_ds)) + # check signal energy is conserved + self.assertTrue(self.checkSignalEnergy(imgs_org, imgs_ds)) + + def testDownsample2D_OddOdd(self): + # source resolution: 65 + # target resolution: 33 + imgs_org, imgs_ds = self.createImages(65, 33) + # check resolution is correct + self.assertEqual((self.n, 33, 33), imgs_ds.shape) + # check invidual gridpoints for all images + self.assertTrue(self.checkGridPoints(imgs_org, imgs_ds)) + # check signal energy is conserved + self.assertTrue(self.checkSignalEnergy(imgs_org, imgs_ds)) + + def testDownsample2D_OddEven(self): + # source resolution: 65 + # target resolution: 32 + imgs_org, imgs_ds = self.createImages(65, 32) + # check resolution is correct + self.assertEqual((self.n, 32, 32), imgs_ds.shape) + # check invidual gridpoints for all images + self.assertTrue(self.checkGridPoints(imgs_org, imgs_ds)) + # check signal energy is conserved + self.assertTrue(self.checkSignalEnergy(imgs_org, imgs_ds)) + + def checkGridPoints(self, imgs_org, imgs_ds): + # Check individual grid points + L = imgs_org.res + max_resolution = imgs_ds.res + return np.allclose( + imgs_org[:, L // 2, L // 2], + imgs_ds[:, max_resolution // 2, max_resolution // 2], + atol=utest_tolerance(self.dtype), + ) + + def checkSignalEnergy(self, imgs_org, imgs_ds): + # check conservation of energy after downsample + L = imgs_org.res + max_resolution = imgs_ds.res + return np.allclose( + anorm(imgs_org.asnumpy(), axes=(1, 2)) / L, + anorm(imgs_ds.asnumpy(), axes=(1, 2)) / max_resolution, + atol=utest_tolerance(self.dtype), + ) + + def createImages(self, L, max_resolution): + # generate a 3D Gaussian volume + g3d = grid_3d(L, indexing="zyx", dtype=self.dtype) + coords = np.array([g3d["x"].flatten(), g3d["y"].flatten(), g3d["z"].flatten()]) + sigma = 0.2 + vol = np.exp(-0.5 * np.sum(np.abs(coords / sigma) ** 2, axis=0)).astype( + self.dtype + ) + vol = np.reshape(vol, g3d["x"].shape) + # initialize a Simulation object to generate projections of the volume + sim = Simulation( + L, self.n, vols=Volume(vol), offsets=0.0, amplitudes=1.0, dtype=self.dtype + ) + + # get images before downsample + imgs_org = sim.images(start=0, num=self.n) + + # get images after downsample + sim.downsample(max_resolution) + imgs_ds = sim.images(start=0, num=self.n) + + return imgs_org, imgs_ds From ef6b837a9ca198b936a4a35b9feaaa675d52ee4e Mon Sep 17 00:00:00 2001 From: Chris Langfield Date: Thu, 10 Feb 2022 15:29:12 -0500 Subject: [PATCH 02/15] cleanup and deleting test downsample from preprocess_pipeline --- tests/test_downsample.py | 10 +++--- tests/test_preprocess_pipeline.py | 58 ++----------------------------- 2 files changed, 6 insertions(+), 62 deletions(-) diff --git a/tests/test_downsample.py b/tests/test_downsample.py index 2d80998aa1..d891bdd454 100644 --- a/tests/test_downsample.py +++ b/tests/test_downsample.py @@ -2,12 +2,10 @@ import numpy as np -from aspire.image import Image from aspire.source import Simulation from aspire.utils import utest_tolerance -from aspire.utils.coor_trans import grid_2d, grid_3d +from aspire.utils.coor_trans import grid_3d from aspire.utils.matrix import anorm -from aspire.utils.misc import gaussian_1d, gaussian_2d, gaussian_3d from aspire.volume import Volume @@ -30,7 +28,7 @@ def testDownsample2D_EvenEven(self): # check signal energy is conserved self.assertTrue(self.checkSignalEnergy(imgs_org, imgs_ds)) - ### Signal energy test fails for this case in current DS implementation + # Signal energy test fails for this case in current DS implementation def _testDownsample2D_EvenOdd(self): # source resolution: 64 # target resolution: 33 @@ -97,10 +95,10 @@ def createImages(self, L, max_resolution): sim = Simulation( L, self.n, vols=Volume(vol), offsets=0.0, amplitudes=1.0, dtype=self.dtype ) - + # get images before downsample imgs_org = sim.images(start=0, num=self.n) - + # get images after downsample sim.downsample(max_resolution) imgs_ds = sim.images(start=0, num=self.n) diff --git a/tests/test_preprocess_pipeline.py b/tests/test_preprocess_pipeline.py index c1a4c6d8fa..476133b96e 100644 --- a/tests/test_preprocess_pipeline.py +++ b/tests/test_preprocess_pipeline.py @@ -4,13 +4,11 @@ import numpy as np from aspire.noise import AnisotropicNoiseEstimator -from aspire.operators.filters import FunctionFilter, RadialCTFFilter, ScalarFilter +from aspire.operators.filters import FunctionFilter, RadialCTFFilter from aspire.source import ArrayImageSource from aspire.source.simulation import Simulation -from aspire.utils import utest_tolerance -from aspire.utils.coor_trans import grid_2d, grid_3d +from aspire.utils.coor_trans import grid_2d from aspire.utils.matrix import anorm -from aspire.volume import Volume DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") @@ -46,58 +44,6 @@ def testPhaseFlip(self): ) ) - def testDownsample(self): - # generate a 3D map with density decays as Gaussian function - g3d = grid_3d(self.L, indexing="zyx", dtype=self.dtype) - coords = np.array([g3d["x"].flatten(), g3d["y"].flatten(), g3d["z"].flatten()]) - sigma = 0.2 - vol = np.exp(-0.5 * np.sum(np.abs(coords / sigma) ** 2, axis=0)).astype( - self.dtype - ) - vol = np.reshape(vol, g3d["x"].shape) - vols = Volume(vol) - - # set noise to zero and CFT filters to unity for simulation object - noise_var = 0 - noise_filter = ScalarFilter(dim=2, value=noise_var) - sim = Simulation( - L=self.L, - n=self.n, - vols=vols, - offsets=0.0, - amplitudes=1.0, - unique_filters=[ - ScalarFilter(dim=2, value=1) for d in np.linspace(1.5e4, 2.5e4, 7) - ], - noise_filter=noise_filter, - dtype=self.dtype, - ) - # get images before downsample - imgs_org = sim.images(start=0, num=self.n) - # get images after downsample - max_resolution = 32 - sim.downsample(max_resolution) - imgs_ds = sim.images(start=0, num=self.n) - - # Check individual grid points - self.assertTrue( - np.allclose( - imgs_org[:, 32, 32], - imgs_ds[:, 16, 16], - atol=utest_tolerance(self.dtype), - ) - ) - # check resolution - self.assertTrue(np.allclose(max_resolution, imgs_ds.shape[1])) - # check energy conservation after downsample - self.assertTrue( - np.allclose( - anorm(imgs_org.asnumpy(), axes=(1, 2)) / self.L, - anorm(imgs_ds.asnumpy(), axes=(1, 2)) / max_resolution, - atol=utest_tolerance(self.dtype), - ) - ) - def testNormBackground(self): bg_radius = 1.0 grid = grid_2d(self.L, indexing="yx") From b75228674d93e1a4279d7e7f96cf88ad4ea39c8a Mon Sep 17 00:00:00 2001 From: Chris Langfield Date: Thu, 10 Feb 2022 15:46:52 -0500 Subject: [PATCH 03/15] remove hardcoded ds / whiten test --- tests/test_starfile_stack.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/test_starfile_stack.py b/tests/test_starfile_stack.py index 57a58cfa43..301844b83d 100644 --- a/tests/test_starfile_stack.py +++ b/tests/test_starfile_stack.py @@ -61,20 +61,6 @@ def testImageDownsample(self): first_image = self.src.images(0, 1)[0] self.assertEqual(first_image.shape, (16, 16)) - def testImageDownsampleAndWhiten(self): - self.src.downsample(16) - self.src.whiten(noise_filter=ScalarFilter(dim=2, value=0.02450909546680349)) - first_whitened_image = self.src.images(0, 1)[0] - self.assertTrue( - np.allclose( - first_whitened_image, - np.load( - os.path.join(DATA_DIR, "starfile_image_0_whitened.npy") - ).T, # RCOPT - atol=1e-6, - ) - ) - class StarFileSingleImage(StarFileTestCase): def setUp(self): From 412ca28648d74ab2383e455b3ef0fc3e4b2796f2 Mon Sep 17 00:00:00 2001 From: Chris Langfield Date: Thu, 10 Feb 2022 15:48:51 -0500 Subject: [PATCH 04/15] flake8 --- tests/test_starfile_stack.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_starfile_stack.py b/tests/test_starfile_stack.py index 301844b83d..0a3f53b06f 100644 --- a/tests/test_starfile_stack.py +++ b/tests/test_starfile_stack.py @@ -8,7 +8,6 @@ import tests.saved_test_data from aspire.image import Image -from aspire.operators import ScalarFilter from aspire.source.relion import RelionSource DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") From 1c6152388b44a0b247ded0002350da0a9fdc0e72 Mon Sep 17 00:00:00 2001 From: Chris Langfield Date: Fri, 11 Feb 2022 11:49:59 -0500 Subject: [PATCH 05/15] 3d volume downsample test --- tests/test_volume.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/tests/test_volume.py b/tests/test_volume.py index c990c22fe7..8ffa754e0b 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -10,6 +10,7 @@ from aspire.utils import Rotation, powerset from aspire.utils.coor_trans import grid_3d +from aspire.utils.matrix import anorm from aspire.utils.types import utest_tolerance from aspire.volume import Volume, gaussian_blob_vols @@ -314,11 +315,25 @@ def testFlip(self): self.assertTrue(isinstance(result, Volume)) def testDownsample(self): - # Data files re-used from test_preprocess vols = Volume(np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy"))) - - resv = Volume(np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol_down8.npy"))) - result = vols.downsample((8, 8, 8)) - self.assertTrue(np.allclose(result, resv)) - self.assertTrue(isinstance(result, Volume)) + res = vols.resolution + ds_res = result.resolution + + # check signal energy + self.assertTrue( + np.allclose( + anorm(vols.asnumpy(), axes=(1, 2, 3)) / res, + anorm(result.asnumpy(), axes=(1, 2, 3)) / ds_res, + atol=1e-3, + ) + ) + + # check gridpoints + self.assertTrue( + np.allclose( + vols[:, res // 2, res // 2, res // 2], + result[:, ds_res // 2, ds_res // 2, ds_res // 2], + atol=1e-4, + ) + ) From 213324d1d567019eed1d18b3aa046162c6b19914 Mon Sep 17 00:00:00 2001 From: Chris Langfield Date: Fri, 11 Feb 2022 11:54:17 -0500 Subject: [PATCH 06/15] remove redundant test from tests/test_preprocess.py --- tests/test_preprocess.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index fa17264567..7bdebc489e 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -24,14 +24,6 @@ def test01CropPad(self): vols_f = crop_pad(fftshift(fftn(vols[:, :, :, 0])), 8) self.assertTrue(np.allclose(results, vols_f, atol=1e-7)) - def test02Downsample(self): - results = np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol_down8.npy")) - results = results[np.newaxis, ...] - vols = np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")) - vols = vols[np.newaxis, ...] - vols = downsample(vols, (8, 8, 8)) - self.assertTrue(np.allclose(results, vols, atol=1e-7)) - def test03Vol2img(self): results = np.load(os.path.join(DATA_DIR, "clean70SRibosome_down8_imgs32.npy")) vols = Volume(np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol_down8.npy"))) From 695b17ed3ad182d5e49938abd8c16b27b66ee6c4 Mon Sep 17 00:00:00 2001 From: Chris Langfield Date: Fri, 11 Feb 2022 12:36:28 -0500 Subject: [PATCH 07/15] remove unused import --- tests/test_preprocess.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index 7bdebc489e..c8d13cfa97 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -4,7 +4,7 @@ import numpy as np from scipy.fftpack import fftn, fftshift -from aspire.image import crop_pad, downsample, fuzzy_mask +from aspire.image import crop_pad, fuzzy_mask from aspire.volume import Volume DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") From 142555dbf770bc6f504bd69b5ce921eb4154e64e Mon Sep 17 00:00:00 2001 From: Chris Langfield Date: Thu, 24 Feb 2022 14:16:35 -0500 Subject: [PATCH 08/15] remove numbered tests for uniformity --- tests/test_preprocess.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_preprocess.py b/tests/test_preprocess.py index c8d13cfa97..f3f63ea679 100644 --- a/tests/test_preprocess.py +++ b/tests/test_preprocess.py @@ -17,14 +17,14 @@ def setUp(self): def tearDown(self): pass - def test01CropPad(self): + def testCropPad(self): results = np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol_crop8.npy")) vols = np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol.npy")) vols = vols[..., np.newaxis] vols_f = crop_pad(fftshift(fftn(vols[:, :, :, 0])), 8) self.assertTrue(np.allclose(results, vols_f, atol=1e-7)) - def test03Vol2img(self): + def testVol2img(self): results = np.load(os.path.join(DATA_DIR, "clean70SRibosome_down8_imgs32.npy")) vols = Volume(np.load(os.path.join(DATA_DIR, "clean70SRibosome_vol_down8.npy"))) rots = np.load(os.path.join(DATA_DIR, "rand_rot_matrices32.npy")) @@ -32,7 +32,7 @@ def test03Vol2img(self): imgs_clean = vols.project(0, rots).asnumpy() self.assertTrue(np.allclose(results, imgs_clean, atol=1e-7)) - def test04FuzzyMask(self): + def testFuzzyMask(self): results = np.array( [ [ From 5f45a26291aa47fd153a4f71f5c92bd569ac73ed Mon Sep 17 00:00:00 2001 From: Chris Langfield Date: Thu, 24 Feb 2022 14:20:38 -0500 Subject: [PATCH 09/15] typo --- tests/test_downsample.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_downsample.py b/tests/test_downsample.py index d891bdd454..d08561946a 100644 --- a/tests/test_downsample.py +++ b/tests/test_downsample.py @@ -23,7 +23,7 @@ def testDownsample2D_EvenEven(self): imgs_org, imgs_ds = self.createImages(64, 32) # check resolution is correct self.assertEqual((self.n, 32, 32), imgs_ds.shape) - # check invidual gridpoints for all images + # check individual gridpoints for all images self.assertTrue(self.checkGridPoints(imgs_org, imgs_ds)) # check signal energy is conserved self.assertTrue(self.checkSignalEnergy(imgs_org, imgs_ds)) From f85e179d3278f92c9e51a7ccd57f232804da2fbf Mon Sep 17 00:00:00 2001 From: Chris Langfield Date: Thu, 24 Feb 2022 14:27:55 -0500 Subject: [PATCH 10/15] skip one of the ds tests --- tests/test_downsample.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_downsample.py b/tests/test_downsample.py index d08561946a..0728d87253 100644 --- a/tests/test_downsample.py +++ b/tests/test_downsample.py @@ -7,7 +7,7 @@ from aspire.utils.coor_trans import grid_3d from aspire.utils.matrix import anorm from aspire.volume import Volume - +from aspire.utils.misc import gaussian_3d class DownsampleTestCase(TestCase): def setUp(self): @@ -28,8 +28,8 @@ def testDownsample2D_EvenEven(self): # check signal energy is conserved self.assertTrue(self.checkSignalEnergy(imgs_org, imgs_ds)) - # Signal energy test fails for this case in current DS implementation - def _testDownsample2D_EvenOdd(self): + @unittest.skip("Signal energy test fails for this case in current DS implementation") + def testDownsample2D_EvenOdd(self): # source resolution: 64 # target resolution: 33 imgs_org, imgs_ds = self.createImages(64, 33) From f2bf0c13c71f043c4744010ad91ee50b38bcce3b Mon Sep 17 00:00:00 2001 From: Chris Langfield Date: Tue, 1 Mar 2022 10:18:46 -0500 Subject: [PATCH 11/15] use gaussian_3d to generate test volume --- tests/test_downsample.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/test_downsample.py b/tests/test_downsample.py index 0728d87253..efd873ebd7 100644 --- a/tests/test_downsample.py +++ b/tests/test_downsample.py @@ -1,14 +1,14 @@ from unittest import TestCase - +import unittest import numpy as np from aspire.source import Simulation from aspire.utils import utest_tolerance -from aspire.utils.coor_trans import grid_3d from aspire.utils.matrix import anorm from aspire.volume import Volume from aspire.utils.misc import gaussian_3d + class DownsampleTestCase(TestCase): def setUp(self): self.n = 128 @@ -28,7 +28,9 @@ def testDownsample2D_EvenEven(self): # check signal energy is conserved self.assertTrue(self.checkSignalEnergy(imgs_org, imgs_ds)) - @unittest.skip("Signal energy test fails for this case in current DS implementation") + @unittest.skip( + "Signal energy test fails for this case in current DS implementation" + ) def testDownsample2D_EvenOdd(self): # source resolution: 64 # target resolution: 33 @@ -84,13 +86,8 @@ def checkSignalEnergy(self, imgs_org, imgs_ds): def createImages(self, L, max_resolution): # generate a 3D Gaussian volume - g3d = grid_3d(L, indexing="zyx", dtype=self.dtype) - coords = np.array([g3d["x"].flatten(), g3d["y"].flatten(), g3d["z"].flatten()]) sigma = 0.2 - vol = np.exp(-0.5 * np.sum(np.abs(coords / sigma) ** 2, axis=0)).astype( - self.dtype - ) - vol = np.reshape(vol, g3d["x"].shape) + vol = gaussian_3d(L, sigma=((L / 2) * sigma,) * 3, dtype=self.dtype) # initialize a Simulation object to generate projections of the volume sim = Simulation( L, self.n, vols=Volume(vol), offsets=0.0, amplitudes=1.0, dtype=self.dtype From 887f1e6eb8b1c7e372b18259fd6f68fde947128b Mon Sep 17 00:00:00 2001 From: Chris Langfield Date: Tue, 1 Mar 2022 11:41:12 -0500 Subject: [PATCH 12/15] isort --- tests/test_downsample.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_downsample.py b/tests/test_downsample.py index efd873ebd7..0e9ed862fe 100644 --- a/tests/test_downsample.py +++ b/tests/test_downsample.py @@ -1,12 +1,13 @@ -from unittest import TestCase import unittest +from unittest import TestCase + import numpy as np from aspire.source import Simulation from aspire.utils import utest_tolerance from aspire.utils.matrix import anorm -from aspire.volume import Volume from aspire.utils.misc import gaussian_3d +from aspire.volume import Volume class DownsampleTestCase(TestCase): @@ -86,8 +87,8 @@ def checkSignalEnergy(self, imgs_org, imgs_ds): def createImages(self, L, max_resolution): # generate a 3D Gaussian volume - sigma = 0.2 - vol = gaussian_3d(L, sigma=((L / 2) * sigma,) * 3, dtype=self.dtype) + sigma = 0.1 + vol = gaussian_3d(L, sigma=(L * sigma,) * 3, dtype=self.dtype) # initialize a Simulation object to generate projections of the volume sim = Simulation( L, self.n, vols=Volume(vol), offsets=0.0, amplitudes=1.0, dtype=self.dtype From 555cd90080bd7c5e3e7d1f068d6c5c59b3454757 Mon Sep 17 00:00:00 2001 From: Chris Langfield Date: Fri, 18 Mar 2022 12:01:18 -0400 Subject: [PATCH 13/15] checkCenterPoint --- tests/test_downsample.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_downsample.py b/tests/test_downsample.py index 0e9ed862fe..24c634be50 100644 --- a/tests/test_downsample.py +++ b/tests/test_downsample.py @@ -65,8 +65,8 @@ def testDownsample2D_OddEven(self): # check signal energy is conserved self.assertTrue(self.checkSignalEnergy(imgs_org, imgs_ds)) - def checkGridPoints(self, imgs_org, imgs_ds): - # Check individual grid points + def checkCenterPoint(self, imgs_org, imgs_ds): + # Check that center point is the same after ds L = imgs_org.res max_resolution = imgs_ds.res return np.allclose( From 31e09db24d7b814e14cde34f008b9093b45bf905 Mon Sep 17 00:00:00 2001 From: Chris Langfield Date: Fri, 18 Mar 2022 12:02:06 -0400 Subject: [PATCH 14/15] max_resolution -> L_ds --- tests/test_downsample.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_downsample.py b/tests/test_downsample.py index 24c634be50..4a0874bbb9 100644 --- a/tests/test_downsample.py +++ b/tests/test_downsample.py @@ -85,7 +85,7 @@ def checkSignalEnergy(self, imgs_org, imgs_ds): atol=utest_tolerance(self.dtype), ) - def createImages(self, L, max_resolution): + def createImages(self, L, L_ds): # generate a 3D Gaussian volume sigma = 0.1 vol = gaussian_3d(L, sigma=(L * sigma,) * 3, dtype=self.dtype) @@ -98,7 +98,7 @@ def createImages(self, L, max_resolution): imgs_org = sim.images(start=0, num=self.n) # get images after downsample - sim.downsample(max_resolution) + sim.downsample(L_ds) imgs_ds = sim.images(start=0, num=self.n) return imgs_org, imgs_ds From faf608bfec1d084df9cc0347d694bba56b5de3d3 Mon Sep 17 00:00:00 2001 From: Chris Langfield Date: Fri, 18 Mar 2022 12:22:03 -0400 Subject: [PATCH 15/15] parametrize ds tests --- tests/test_downsample.py | 42 ++++++++++++++-------------------------- 1 file changed, 14 insertions(+), 28 deletions(-) diff --git a/tests/test_downsample.py b/tests/test_downsample.py index 4a0874bbb9..a6da6aee4a 100644 --- a/tests/test_downsample.py +++ b/tests/test_downsample.py @@ -18,52 +18,38 @@ def setUp(self): def tearDown(self): pass - def testDownsample2D_EvenEven(self): - # source resolution: 64 - # target resolution: 32 - imgs_org, imgs_ds = self.createImages(64, 32) + def _testDownsample2DCase(self, L, L_ds): + # downsampling from size L to L_ds + imgs_org, imgs_ds = self.createImages(L, L_ds) # check resolution is correct - self.assertEqual((self.n, 32, 32), imgs_ds.shape) - # check individual gridpoints for all images - self.assertTrue(self.checkGridPoints(imgs_org, imgs_ds)) + self.assertEqual((self.n, L_ds, L_ds), imgs_ds.shape) + # check center points for all images + self.assertTrue(self.checkCenterPoint(imgs_org, imgs_ds)) # check signal energy is conserved self.assertTrue(self.checkSignalEnergy(imgs_org, imgs_ds)) + def testDownsample2D_EvenEven(self): + # source resolution: 64 + # target resolution: 32 + self._testDownsample2DCase(64, 32) + @unittest.skip( "Signal energy test fails for this case in current DS implementation" ) def testDownsample2D_EvenOdd(self): # source resolution: 64 # target resolution: 33 - imgs_org, imgs_ds = self.createImages(64, 33) - # check resolution is correct - self.assertEqual((self.n, 33, 33), imgs_ds.shape) - # check invidual gridpoints for all images - self.assertTrue(self.checkGridPoints(imgs_org, imgs_ds)) - # check signal energy is conserved - self.assertTrue(self.checkSignalEnergy(imgs_org, imgs_ds)) + self._testDownsample2DCase(64, 33) def testDownsample2D_OddOdd(self): # source resolution: 65 # target resolution: 33 - imgs_org, imgs_ds = self.createImages(65, 33) - # check resolution is correct - self.assertEqual((self.n, 33, 33), imgs_ds.shape) - # check invidual gridpoints for all images - self.assertTrue(self.checkGridPoints(imgs_org, imgs_ds)) - # check signal energy is conserved - self.assertTrue(self.checkSignalEnergy(imgs_org, imgs_ds)) + self._testDownsample2DCase(65, 33) def testDownsample2D_OddEven(self): # source resolution: 65 # target resolution: 32 - imgs_org, imgs_ds = self.createImages(65, 32) - # check resolution is correct - self.assertEqual((self.n, 32, 32), imgs_ds.shape) - # check invidual gridpoints for all images - self.assertTrue(self.checkGridPoints(imgs_org, imgs_ds)) - # check signal energy is conserved - self.assertTrue(self.checkSignalEnergy(imgs_org, imgs_ds)) + self._testDownsample2DCase(65, 32) def checkCenterPoint(self, imgs_org, imgs_ds): # Check that center point is the same after ds