|
1 | 1 | import os.path |
2 | 2 | from unittest import TestCase |
3 | | -from unittest.case import SkipTest |
4 | 3 |
|
5 | 4 | import numpy as np |
6 | 5 | from pytest import raises |
7 | 6 | from scipy.special import jv |
8 | 7 |
|
9 | 8 | from aspire.basis import FBBasis2D |
10 | 9 | from aspire.image import Image |
11 | | -from aspire.utils import complex_type, gaussian_2d, real_type, utest_tolerance |
| 10 | +from aspire.utils import complex_type, real_type, utest_tolerance |
12 | 11 | from aspire.utils.coor_trans import grid_2d |
13 | | -from aspire.utils.random import randn |
| 12 | + |
| 13 | +from ._basis_util import Steerable2DMixin |
14 | 14 |
|
15 | 15 | DATA_DIR = os.path.join(os.path.dirname(__file__), "saved_test_data") |
16 | 16 |
|
17 | 17 |
|
18 | | -class FBBasis2DTestCase(TestCase): |
| 18 | +class FBBasis2DTestCase(TestCase, Steerable2DMixin): |
19 | 19 | def setUp(self): |
20 | 20 | self.dtype = np.float32 |
21 | 21 | self.L = 8 |
@@ -329,28 +329,6 @@ def testFBBasis2DExpand(self): |
329 | 329 | ) |
330 | 330 | ) |
331 | 331 |
|
332 | | - def testIndices(self): |
333 | | - ell_max = self.basis.ell_max |
334 | | - k_max = self.basis.k_max |
335 | | - |
336 | | - indices = self.basis.indices() |
337 | | - |
338 | | - i = 0 |
339 | | - |
340 | | - for ell in range(ell_max + 1): |
341 | | - if ell == 0: |
342 | | - sgns = [1] |
343 | | - else: |
344 | | - sgns = [1, -1] |
345 | | - |
346 | | - for sgn in sgns: |
347 | | - for k in range(k_max[ell]): |
348 | | - self.assertTrue(indices["ells"][i] == ell) |
349 | | - self.assertTrue(indices["sgns"][i] == sgn) |
350 | | - self.assertTrue(indices["ks"][i] == k) |
351 | | - |
352 | | - i += 1 |
353 | | - |
354 | 332 | def testElement(self): |
355 | 333 | ell = 1 |
356 | 334 | sgn = -1 |
@@ -387,103 +365,6 @@ def testElement(self): |
387 | 365 | self.assertTrue(np.allclose(im, im_ref, atol=1e-4)) |
388 | 366 | self.assertTrue(np.allclose(coef, coef_ref, atol=1e-4)) |
389 | 367 |
|
390 | | - def testGaussianExpand(self): |
391 | | - # Offset slightly |
392 | | - x0 = 0.50 |
393 | | - y0 = 0.75 |
394 | | - |
395 | | - # Want sigma to be as large as possible without the Gaussian |
396 | | - # spilling too much outside the central disk. |
397 | | - sigma = self.L / 8 |
398 | | - im1 = gaussian_2d(self.L, x0=x0, y0=y0, sigma_x=sigma, sigma_y=sigma) |
399 | | - im1 = im1.astype(self.dtype) |
400 | | - |
401 | | - coef = self.basis.expand(im1) |
402 | | - im2 = self.basis.evaluate(coef) |
403 | | - |
404 | | - if isinstance(im2, Image): |
405 | | - im2 = im2.asnumpy() |
406 | | - |
407 | | - # For small L there's too much clipping at high freqs to get 1e-3 |
408 | | - # accuracy. |
409 | | - if self.L < 32: |
410 | | - atol = 1e-2 |
411 | | - else: |
412 | | - atol = 1e-3 |
413 | | - |
414 | | - self.assertTrue(np.allclose(im1, im2, atol=atol)) |
415 | | - |
416 | | - def testIsotropic(self): |
417 | | - sigma = self.L / 8 |
418 | | - im = gaussian_2d(self.L, sigma_x=sigma, sigma_y=sigma) |
419 | | - im = im.astype(self.dtype) |
420 | | - |
421 | | - coef = self.basis.expand(im) |
422 | | - |
423 | | - ells = self.basis.indices()["ells"] |
424 | | - |
425 | | - energy_outside = np.sum(np.abs(coef[ells != 0]) ** 2) |
426 | | - energy_total = np.sum(np.abs(coef) ** 2) |
427 | | - |
428 | | - energy_ratio = energy_outside / energy_total |
429 | | - |
430 | | - self.assertTrue(energy_ratio < 0.01) |
431 | | - |
432 | | - def testModulated(self): |
433 | | - if self.L < 32: |
434 | | - raise SkipTest |
435 | | - |
436 | | - ell = 1 |
437 | | - |
438 | | - sigma = self.L / 8 |
439 | | - im = gaussian_2d(self.L, sigma_x=sigma, sigma_y=sigma) |
440 | | - im = im.astype(self.dtype) |
441 | | - |
442 | | - g2d = grid_2d(self.L) |
443 | | - |
444 | | - for trig_fun in (np.sin, np.cos): |
445 | | - im1 = im * trig_fun(ell * g2d["phi"]) |
446 | | - |
447 | | - coef = self.basis.expand(im1) |
448 | | - |
449 | | - ells = self.basis.indices()["ells"] |
450 | | - |
451 | | - energy_outside = np.sum(np.abs(coef[ells != ell]) ** 2) |
452 | | - energy_total = np.sum(np.abs(coef) ** 2) |
453 | | - |
454 | | - energy_ratio = energy_outside / energy_total |
455 | | - |
456 | | - self.assertTrue(energy_ratio < 0.10) |
457 | | - |
458 | | - def testEvaluateExpand(self): |
459 | | - coef1 = randn(self.basis.count, seed=self.seed) |
460 | | - coef1 = coef1.astype(self.dtype) |
461 | | - |
462 | | - im = self.basis.evaluate(coef1) |
463 | | - if isinstance(im, Image): |
464 | | - im = im.asnumpy() |
465 | | - coef2 = self.basis.expand(im)[:, 0] |
466 | | - |
467 | | - self.assertTrue(np.allclose(coef1, coef2, atol=utest_tolerance(self.dtype))) |
468 | | - |
469 | | - def testAdjoint(self): |
470 | | - u = randn(self.basis.count, seed=self.seed) |
471 | | - u = u.astype(self.dtype) |
472 | | - |
473 | | - Au = self.basis.evaluate(u) |
474 | | - if isinstance(Au, Image): |
475 | | - Au = Au.asnumpy() |
476 | | - |
477 | | - x = randn(*self.basis.sz, seed=self.seed) |
478 | | - x = x.astype(self.dtype) |
479 | | - |
480 | | - ATx = self.basis.evaluate_t(x) |
481 | | - |
482 | | - Au_dot_x = np.sum(Au * x) |
483 | | - u_dot_ATx = np.sum(u * ATx) |
484 | | - |
485 | | - self.assertTrue(np.isclose(Au_dot_x, u_dot_ATx)) |
486 | | - |
487 | 368 | def testComplexCoversion(self): |
488 | 369 | # Load a reasonable input |
489 | 370 | x = np.load(os.path.join(DATA_DIR, "fbbasis_coefficients_8_8.npy")) |
|
0 commit comments