|
| 1 | +import unittest |
| 2 | +from unittest import TestCase |
| 3 | + |
| 4 | +import numpy as np |
| 5 | + |
| 6 | +from aspire.source import Simulation |
| 7 | +from aspire.utils import utest_tolerance |
| 8 | +from aspire.utils.matrix import anorm |
| 9 | +from aspire.utils.misc import gaussian_3d |
| 10 | +from aspire.volume import Volume |
| 11 | + |
| 12 | + |
| 13 | +class DownsampleTestCase(TestCase): |
| 14 | + def setUp(self): |
| 15 | + self.n = 128 |
| 16 | + self.dtype = np.float32 |
| 17 | + |
| 18 | + def tearDown(self): |
| 19 | + pass |
| 20 | + |
| 21 | + def _testDownsample2DCase(self, L, L_ds): |
| 22 | + # downsampling from size L to L_ds |
| 23 | + imgs_org, imgs_ds = self.createImages(L, L_ds) |
| 24 | + # check resolution is correct |
| 25 | + self.assertEqual((self.n, L_ds, L_ds), imgs_ds.shape) |
| 26 | + # check center points for all images |
| 27 | + self.assertTrue(self.checkCenterPoint(imgs_org, imgs_ds)) |
| 28 | + # check signal energy is conserved |
| 29 | + self.assertTrue(self.checkSignalEnergy(imgs_org, imgs_ds)) |
| 30 | + |
| 31 | + def testDownsample2D_EvenEven(self): |
| 32 | + # source resolution: 64 |
| 33 | + # target resolution: 32 |
| 34 | + self._testDownsample2DCase(64, 32) |
| 35 | + |
| 36 | + @unittest.skip( |
| 37 | + "Signal energy test fails for this case in current DS implementation" |
| 38 | + ) |
| 39 | + def testDownsample2D_EvenOdd(self): |
| 40 | + # source resolution: 64 |
| 41 | + # target resolution: 33 |
| 42 | + self._testDownsample2DCase(64, 33) |
| 43 | + |
| 44 | + def testDownsample2D_OddOdd(self): |
| 45 | + # source resolution: 65 |
| 46 | + # target resolution: 33 |
| 47 | + self._testDownsample2DCase(65, 33) |
| 48 | + |
| 49 | + def testDownsample2D_OddEven(self): |
| 50 | + # source resolution: 65 |
| 51 | + # target resolution: 32 |
| 52 | + self._testDownsample2DCase(65, 32) |
| 53 | + |
| 54 | + def checkCenterPoint(self, imgs_org, imgs_ds): |
| 55 | + # Check that center point is the same after ds |
| 56 | + L = imgs_org.res |
| 57 | + max_resolution = imgs_ds.res |
| 58 | + return np.allclose( |
| 59 | + imgs_org[:, L // 2, L // 2], |
| 60 | + imgs_ds[:, max_resolution // 2, max_resolution // 2], |
| 61 | + atol=utest_tolerance(self.dtype), |
| 62 | + ) |
| 63 | + |
| 64 | + def checkSignalEnergy(self, imgs_org, imgs_ds): |
| 65 | + # check conservation of energy after downsample |
| 66 | + L = imgs_org.res |
| 67 | + max_resolution = imgs_ds.res |
| 68 | + return np.allclose( |
| 69 | + anorm(imgs_org.asnumpy(), axes=(1, 2)) / L, |
| 70 | + anorm(imgs_ds.asnumpy(), axes=(1, 2)) / max_resolution, |
| 71 | + atol=utest_tolerance(self.dtype), |
| 72 | + ) |
| 73 | + |
| 74 | + def createImages(self, L, L_ds): |
| 75 | + # generate a 3D Gaussian volume |
| 76 | + sigma = 0.1 |
| 77 | + vol = gaussian_3d(L, sigma=(L * sigma,) * 3, dtype=self.dtype) |
| 78 | + # initialize a Simulation object to generate projections of the volume |
| 79 | + sim = Simulation( |
| 80 | + L, self.n, vols=Volume(vol), offsets=0.0, amplitudes=1.0, dtype=self.dtype |
| 81 | + ) |
| 82 | + |
| 83 | + # get images before downsample |
| 84 | + imgs_org = sim.images(start=0, num=self.n) |
| 85 | + |
| 86 | + # get images after downsample |
| 87 | + sim.downsample(L_ds) |
| 88 | + imgs_ds = sim.images(start=0, num=self.n) |
| 89 | + |
| 90 | + return imgs_org, imgs_ds |
0 commit comments